commit a1a9f16826f3f2d8ba80b6c5fd270c1c340d6d67 Author: Shantanu Jain Date: Mon Dec 12 11:27:27 2022 -0800 [tiktoken] hello world diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9e090c8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,42 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Environments +.env +.venv + +# Tools +.mypy_cache +.coverage +htmlcov + +# General +.DS_Store + +Cargo.lock +target/ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..24b42fd --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tiktoken" +version = "0.1.0" +edition = "2021" +rust-version = "1.57.0" + +[lib] +name = "_tiktoken" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.17.3", features = ["extension-module"] } + +# tiktoken dependencies +fancy-regex = "0.10.0" +regex = "1.7.0" +rustc-hash = "1.1.0" +bstr = "1.0.1" + +[profile.release] +incremental = true diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..83ed103 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 OpenAI, Shantanu Jain + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..cb017cd --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +include *.svg +include *.toml +include Makefile +recursive-include scripts *.py +recursive-include src *.rs diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..92aec0f --- /dev/null +++ b/Makefile @@ -0,0 +1,49 @@ +PROJECT := tiktoken + +.PHONY: default +default: editable_install + +.PHONY: install_rust +install_rust: + which cargo >/dev/null || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.62 + +.PHONY: clean +clean: + cargo clean + pip uninstall -y $(PROJECT) + find . | grep -E '__pycache__|\.pyc' | xargs rm -rf + find . | grep -E '\.so' | xargs rm -rf + rm -rf dist/ build/ + rm -rf $(PROJECT).egg-info/ + +.PHONY: format +format: + @ which black >/dev/null || python3 -m pip install black + @ which isort >/dev/null || python3 -m pip install isort + cargo fmt -- --config group_imports=StdExternalCrate + black --line-length 100 --skip-magic-trailing-comma --quiet . + isort --line-length 100 --profile black --quiet . + + +.PHONY: format_check +format_check: + @ which black >/dev/null || python3 -m pip install black + @ which isort >/dev/null || python3 -m pip install isort + cargo fmt --check -- --config group_imports=StdExternalCrate + black --check --line-length 100 --skip-magic-trailing-comma --quiet . + isort --check --line-length 100 --profile black --quiet . + +.PHONY: lint +lint: + cargo clippy --all -- -D warnings + @ which flake8 >/dev/null || python3 -m pip install flake8==5 flake8-bugbear==22.9.11 + flake8 --ignore=E203,E501,W503,E731 --per-file-ignores="$(PROJECT)/__init__.py:F401 setup.py:E402" --exclude=build . + +.PHONY: editable_install +editable_install: + @ if [ -f $(PROJECT).egg-info ]; then \ + pip install --disable-pip-version-check --progress-bar=off setuptools wheel setuptools-rust ; \ + pip install --disable-pip-version-check --no-build-isolation -e . ; \ + else \ + pip install --disable-pip-version-check --no-deps --no-build-isolation --ignore-installed -e . ; \ + fi diff --git a/README.md b/README.md new file mode 100644 index 0000000..f0ea386 --- /dev/null +++ b/README.md @@ -0,0 +1,28 @@ +# ⏳ tiktoken + +tiktoken is a fast tokeniser. + +```python +import tiktoken +enc = tiktoken.get_encoding("gpt2") +print(enc.encode("hello world")) +``` + +The open source version of `tiktoken` can be installed from PyPI: +``` +pip install tiktoken +``` + +The tokeniser API is documented in `tiktoken/core.py`. + + +## Performance + +`tiktoken` is between 3-6x faster than huggingface's tokeniser: + +![image](./perf.svg) + +Performance measured on 1GB of text using the GPT-2 tokeniser, using `GPT2TokenizerFast` from +`tokenizers==0.13.2` and `transformers==4.24.0`. + + diff --git a/perf.svg b/perf.svg new file mode 100644 index 0000000..7157ef9 --- /dev/null +++ b/perf.svg @@ -0,0 +1,373 @@ + + + + + + + + + + + + + + + + + + + + + + +Throughput + + + + + + +0 MB/s + + + + + +10 MB/s + + + + + +20 MB/s + + + + + +30 MB/s + + + + + +40 MB/s + + + + + + + + + + + +Thread count + + + + + +1 + + + + + +2 + + + + + +4 + + + + + +8 + + + + + +16 + + + + + +32 + + + + + +64 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +tiktoken + +huggingface + + + + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bb9aeeb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[project] +name = "tiktoken" +dependencies = ["blobfile>=2", "regex>=2022.1.18"] +dynamic = ["version"] + +[build-system] +requires = ["setuptools", "wheel", "setuptools-rust"] + diff --git a/scripts/benchmark.py b/scripts/benchmark.py new file mode 100644 index 0000000..4d679fa --- /dev/null +++ b/scripts/benchmark.py @@ -0,0 +1,39 @@ +import base64 +import functools +import gzip +import json +import os +import random +import time +from typing import Any, cast + +import blobfile + +import tiktoken + + +def benchmark_batch(documents: list[str]) -> None: + num_threads = int(os.environ["RAYON_NUM_THREADS"]) + num_bytes = sum(map(len, map(str.encode, documents))) + print(f"num_threads: {num_threads}, num_bytes: {num_bytes}") + + enc = tiktoken.get_encoding("gpt2") + enc.encode("warmup") + + start = time.perf_counter_ns() + enc.encode_ordinary_batch(documents, num_threads=num_threads) + end = time.perf_counter_ns() + print(f"tiktoken \t{num_bytes / (end - start) * 1e9} bytes / s") + + import transformers + + hf_enc = cast(Any, transformers).GPT2TokenizerFast.from_pretrained("gpt2") + hf_enc.model_max_length = 1e30 # silence! + hf_enc.encode("warmup") + + start = time.perf_counter_ns() + hf_enc(documents) + end = time.perf_counter_ns() + print(f"huggingface \t{num_bytes / (end - start) * 1e9} bytes / s") + + diff --git a/scripts/redact.py b/scripts/redact.py new file mode 100644 index 0000000..bcf8ef1 --- /dev/null +++ b/scripts/redact.py @@ -0,0 +1,65 @@ +import argparse +import re +import subprocess +from pathlib import Path + + +def redact_file(path: Path, dry_run: bool) -> None: + if not path.exists() or path.is_dir(): + return + + text = path.read_text() + + first_line = text.splitlines()[0] + if "redact" in first_line: + if not dry_run: + path.unlink() + print(f"Deleted {path}") + return + + pattern = "|".join( + re.escape(x) + for x in [ + "# ===== redact-beg =====\n", + "# ===== redact-end =====\n", + "\n", + "\n", + ] + ) + + if re.search(pattern, text): + redacted_text = "".join(re.split(pattern, text)[::2]) + if not dry_run: + path.write_text(redacted_text) + print(f"Redacted {path}") + return + + print(f"Skipped {path}") + + +def redact(dry_run: bool) -> None: + tiktoken_root = Path(__file__).parent.parent + assert tiktoken_root.name == "tiktoken" + assert (tiktoken_root / "pyproject.toml").exists() + + try: + output = subprocess.check_output(["git", "ls-files"], cwd=tiktoken_root, text=True) + paths = [Path(p) for p in output.splitlines()] + except subprocess.CalledProcessError: + paths = list(tiktoken_root.glob("**/*")) + + for path in paths: + redact_file(path, dry_run=dry_run) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--dry-run", type=lambda x: not x or x[0].lower() != "f", default=True) + args = parser.parse_args() + redact(args.dry_run) + if args.dry_run: + print("Dry run, use --dry-run=false to actually redact files") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..df18eda --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup +from setuptools_rust import Binding, RustExtension + +public = True + +if public: + version = "0.1" + +setup( + name="tiktoken", + version=version, + rust_extensions=[ + RustExtension( + "tiktoken._tiktoken", + binding=Binding.PyO3, + # Between our use of editable installs and wanting to use Rust for performance sensitive + # code, it makes sense to just always use --release + debug=False, + ) + ], + packages=["tiktoken", "tiktoken_ext"], + zip_safe=False, +) diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..8235dbb --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,559 @@ +// This check is new and seems buggy (possibly with PyO3 interaction) +#![allow(clippy::borrow_deref_ref)] + +use std::collections::HashSet; +use std::thread; + +use fancy_regex::Regex; +use pyo3::exceptions; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyList, PyTuple}; +use pyo3::PyResult; +use rustc_hash::FxHashMap as HashMap; + +fn _byte_pair_merge(piece: &[u8], ranks: &HashMap, usize>) -> Vec> { + let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect(); + + // If you have n parts and m merges, this does O(mn) work + // We could do something with a heap and do O(m log n) work + + // Note that we hash bytes, not token pairs. As long as we train BPE the way we + // currently do, this is equivalent. An easy way to break this would be to decouple + // merge priority from token index or to prevent specific token merges. + loop { + if parts.len() == 1 { + break; + } + let mut min_rank: Option<(usize, usize)> = None; + for i in 0..parts.len() - 1 { + let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) { + *r + } else { + continue; + }; + if min_rank.is_none() || rank < min_rank.unwrap().0 { + min_rank = Some((rank, i)); + } + } + if let Some((_, i)) = min_rank { + parts[i] = parts[i].start..parts[i + 1].end; + parts.remove(i + 1); + } else { + break; + } + } + parts +} + +pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { + if piece.len() == 1 { + return vec![ranks[piece]]; + } + _byte_pair_merge(piece, ranks) + .iter() + .map(|p| ranks[&piece[p.start..p.end]]) + .collect() +} + +pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { + if piece.len() == 1 { + return vec![piece]; + } + _byte_pair_merge(piece, ranks) + .iter() + .map(|p| &piece[p.start..p.end]) + .collect() +} + +// Various performance notes: +// +// Regex +// ===== +// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy +// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than +// the usual regex we use. +// +// However, given that we're using a regex parse-able by `regex`, there isn't much difference +// between using the `regex` crate and using the `fancy_regex` crate. +// +// There is an important interaction between threading, `regex` and `fancy_regex`. +// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on +// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain +// old `regex`, we don't hit this, because `find_iter` has a different code path. +// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md +// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for +// each thread. +// +// Threading +// ========= +// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. +// So goodbye `rayon`! Let thread count etc be in control of our Python users. +// +// Caching +// ======= +// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. +// Originally, we had one too! Without it, we were only vaguely faster than Python. +// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance +// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect +// multi-threaded performance even when I only had readers (maybed I messed something up?). +// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! +// These are exactly the set or merges that are likely to be hot. And now we don't have to think +// about interior mutability, memory use, or cloning. +// +// Hashing +// ======= +// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? +// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made +// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. + +use std::num::NonZeroU64; +pub struct FakeThreadId(NonZeroU64); + +fn hash_current_thread() -> usize { + // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter + // that works great for our use case of avoiding collisions in our array. Unfortunately, + // it's private. However, there are only so many ways you can layout a u64, so just transmute + // https://github.com/rust-lang/rust/issues/67939 + const _: [u8; 8] = [0; std::mem::size_of::()]; + const _: [u8; 8] = [0; std::mem::size_of::()]; + let x = unsafe { + std::mem::transmute::(thread::current().id()).0 + }; + u64::from(x) as usize +} + +const MAX_NUM_THREADS: usize = 128; +#[pyclass] +struct CoreBPE { + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + decoder: HashMap>, + special_tokens_decoder: HashMap>, + regex_tls: Vec, + special_regex_tls: Vec, + sorted_token_bytes: Vec>, +} + +impl CoreBPE { + fn _get_tl_regex(&self) -> &Regex { + // See performance notes above for what this is about + // It's also a little janky, please make a better version of it! + // However, it's nice that this doesn't leak memory to short-lived threads + &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] + } + + fn _get_tl_special_regex(&self) -> &Regex { + &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] + } + + fn _decode_native(&self, tokens: &[usize]) -> Vec { + let mut ret = Vec::with_capacity(tokens.len() * 2); + for token in tokens { + let token_bytes = self + .decoder + .get(token) + .unwrap_or_else(|| &self.special_tokens_decoder[token]); + ret.extend(token_bytes); + } + ret + } + + fn _encode_ordinary_native(&self, text: &str) -> Vec { + // This is the core of the encoding logic; the other functions in here + // just make things complicated :-) + let regex = self._get_tl_regex(); + let mut ret = vec![]; + for mat in regex.find_iter(text) { + let piece = mat.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + ret.push(*token); + continue; + } + ret.extend(&byte_pair_encode(piece, &self.encoder)); + } + ret + } + + fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { + let special_regex = self._get_tl_special_regex(); + let regex = self._get_tl_regex(); + let mut ret = vec![]; + + let mut start = 0; + let mut last_piece_token_len = 0; + loop { + let mut next_special; + let mut start_find = start; + loop { + // Find the next allowed special token, if any + next_special = special_regex.find_from_pos(text, start_find).unwrap(); + match next_special { + Some(m) => { + if allowed_special.contains(&text[m.start()..m.end()]) { + break; + } + start_find = m.start() + 1; + } + None => break, + } + } + let end = next_special.map_or(text.len(), |m| m.start()); + + // Okay, here we go, compare this logic to _encode_ordinary_native + for mat in regex.find_iter(&text[start..end]) { + let piece = mat.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + last_piece_token_len = 1; + ret.push(*token); + continue; + } + let tokens = byte_pair_encode(piece, &self.encoder); + last_piece_token_len = tokens.len(); + ret.extend(&tokens); + } + + match next_special { + // And here we push the special token + Some(m) => { + let piece = m.as_str(); + let token = self.special_tokens_encoder[piece]; + ret.push(token); + start = m.end(); + last_piece_token_len = 0; + } + None => break, + } + } + + // last_piece_token_len is how many tokens came from the last regex split. This is used + // for determining unstable tokens, since you can't merge across (stable) regex splits + (ret, last_piece_token_len) + } + + fn _increase_last_piece_token_len( + &self, + tokens: Vec, + mut last_piece_token_len: usize, + ) -> (Vec, usize) { + // Unfortunately, the locations where our regex splits can be unstable. + // For the purposes of determining unstable tokens, unstable regex splitting + // is only a problem if a split that was present disappears, since this can + // lead to merging of tokens otherwise thought to be stable. + // cl100k_base makes our life hard by including the \s*[\r\n]+ + // pattern. This can e.g. cause "\n" + " " to become "\n \n". + // Here is a quick and dirty fix: + { + let token_is_all_space = |token| { + self.decoder + .get(token) + .map(|token_bytes| { + token_bytes + .iter() + .rev() + .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) + }) + .unwrap_or(false) + }; + if last_piece_token_len > 0 + && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) + { + while (last_piece_token_len < tokens.len()) + && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) + { + last_piece_token_len += 1; + } + } + } + debug_assert!(last_piece_token_len <= tokens.len()); + + (tokens, last_piece_token_len) + } + + fn _encode_unstable_native( + &self, + text: &str, + allowed_special: &HashSet<&str>, + ) -> (Vec, HashSet>) { + let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); + if last_piece_token_len == 0 { + // If last_piece_token_len is zero, the last token was a special token and we have + // no unstable bytes + return (tokens, HashSet::new()); + } + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + + let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + tokens.truncate(tokens.len() - last_piece_token_len); + + // TODO: we should try harder to find additional stable tokens + // This would reduce the amount of retokenising when determining completions + // Refer to the logic in an older version of this file + + let mut completions = HashSet::new(); + if unstable_bytes.is_empty() { + return (tokens, completions); + } + + // This is the easy bit. Just find all single tokens that start with unstable_bytes + // (including tokens that exactly match unstable_bytes) + // Separating this from the loop below helps with performance in a common case. + let mut point = self + .sorted_token_bytes + .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); + while point < self.sorted_token_bytes.len() + && self.sorted_token_bytes[point].starts_with(&unstable_bytes) + { + completions.insert(vec![ + self.encoder[self.sorted_token_bytes[point].as_slice()], + ]); + point += 1; + } + + // Now apply even more brute force. At every (other) possible position for the straddling + // token, concatenate additional bytes from that token (if any) to unstable_bytes, + // and retokenise the whole thing and see what we get. + for i in 1..unstable_bytes.len() { + let prefix = &unstable_bytes[..i]; + let suffix = &unstable_bytes[i..]; + let mut point = self + .sorted_token_bytes + .partition_point(|x| x.as_slice() < suffix); + // TODO: Perf optimisation if suffix starts with " "? + while point < self.sorted_token_bytes.len() + && self.sorted_token_bytes[point].starts_with(suffix) + { + let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); + let encoded = match std::str::from_utf8(&possibility) { + // Morally, this is byte_pair_encode(&possibility, &self.encoder) + // But we might have introduced a regex split which would prevent merges. + // (particularly possible in the presence of unstable regex splits) + // So convert to UTF-8 and do regex splitting. + // E.g. with cl100k_base " !" gets split to " " + " !", + // but byte_pair_encode(" !") != byte_pair_encode(" ") + Ok(s) => self._encode_ordinary_native(s), + + // Technically, whether or not this arm is correct depends on whether there + // would be a regex split before the UTF-8 truncation point. + // Probably niche enough that no one will ever notice (after all, people didn't + // notice all the big holes in the previous unstable token implementation) + Err(_) => byte_pair_encode(&possibility, &self.encoder), + // Something like the following is intriguing but incorrect: + // Err(e) => self._encode_ordinary_native(unsafe { + // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) + // }), + }; + let mut seq = Vec::new(); + let mut seq_len = 0; + for token in encoded { + seq.push(token); + seq_len += self.decoder[&token].len(); + if seq_len >= unstable_bytes.len() { + break; + } + } + completions.insert(seq); + point += 1; + } + } + + // This is also not straightforward. While we generally assume that regex splits are stable, + // unfortunately, they are not. That is, if adding bytes were to make a split appear in + // unstable_bytes, this could make tokens possible which our logic would otherwise think + // would be merged. + // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could + // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. + // Here is a quick and dirty fix: + // This isn't right if we ever remove \s+(?!\S) + if unstable_bytes.len() > 1 { + let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); + if unstable_bytes.len() - last_decoded.1 > 0 + && last_decoded.0.map_or(false, |c| c.is_whitespace()) + { + let mut reencoded = byte_pair_encode( + &unstable_bytes[..unstable_bytes.len() - last_decoded.1], + &self.encoder, + ); + reencoded.extend(byte_pair_encode( + &unstable_bytes[unstable_bytes.len() - last_decoded.1..], + &self.encoder, + )); + completions.insert(reencoded); + } + } + + (tokens, completions) + } +} + +#[pymethods] +impl CoreBPE { + #[new] + fn new( + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + pattern: &str, + ) -> PyResult { + let regex = Regex::new(pattern) + .map_err(|e| PyErr::new::(e.to_string()))?; + + let special_regex = { + let _parts = special_tokens_encoder + .keys() + .map(|s| fancy_regex::escape(s)) + .collect::>(); + Regex::new(&_parts.join("|")) + .map_err(|e| PyErr::new::(e.to_string()))? + }; + + let decoder: HashMap> = + encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); + + assert!(encoder.len() == decoder.len()); + + let special_tokens_decoder: HashMap> = special_tokens_encoder + .iter() + .map(|(k, v)| (*v, k.as_bytes().to_vec())) + .collect(); + + // Clone because I don't know how to tell Rust I'm not going to change the map + let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); + sorted_token_bytes.sort(); + + Ok(CoreBPE { + encoder, + special_tokens_encoder, + decoder, + special_tokens_decoder, + regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), + special_regex_tls: (0..MAX_NUM_THREADS) + .map(|_| special_regex.clone()) + .collect(), + sorted_token_bytes, + }) + } + + // ==================== + // Encoding + // ==================== + + fn encode_ordinary(&self, py: Python, text: &str) -> Vec { + py.allow_threads(|| self._encode_ordinary_native(text)) + } + + fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { + py.allow_threads(|| self._encode_native(text, &allowed_special).0) + } + + fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { + py.allow_threads(|| { + match std::str::from_utf8(bytes) { + Ok(text) => self._encode_ordinary_native(text), + Err(e) => { + let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; + let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + if !tokens.is_empty() && last_piece_token_len > 0 { + // Lop off the tokens from the last piece and run BPE on the remaining bytes + // Somewhat niche, but this may not be correct if we'd have had a regex + // split between the valid UTF-8 and the invalid bytes, which is why this + // method is private + let mut unstable_bytes = + self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); + + tokens.truncate(tokens.len() - last_piece_token_len); + tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); + } + tokens + } + } + }) + } + + fn encode_with_unstable( + &self, + py: Python, + text: &str, + allowed_special: HashSet<&str>, + ) -> Py { + let (tokens, completions) = + py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); + let py_completions = + PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); + (tokens, py_completions).into_py(py) + } + + fn encode_single_token(&self, piece: &[u8]) -> PyResult { + if let Some(token) = self.encoder.get(piece).copied() { + return Ok(token); + } + if let Ok(piece_str) = std::str::from_utf8(piece) { + if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { + return Ok(token); + } + } + Err(PyErr::new::(piece.to_owned())) + } + + fn encode_single_piece(&self, piece: &[u8]) -> Vec { + if let Some(token) = self.encoder.get(piece) { + return vec![*token]; + } + byte_pair_encode(piece, &self.encoder) + } + + // ==================== + // Decoding + // ==================== + + fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { + let bytes = py.allow_threads(|| self._decode_native(&tokens)); + PyBytes::new(py, &bytes).into() + } + + fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { + if let Some(bytes) = self.decoder.get(&token) { + return Ok(PyBytes::new(py, bytes).into()); + } + if let Some(bytes) = self.special_tokens_decoder.get(&token) { + return Ok(PyBytes::new(py, bytes).into()); + } + Err(PyErr::new::(token.to_string())) + } + + // ==================== + // Miscellaneous + // ==================== + + fn token_byte_values(&self, py: Python) -> Vec> { + self.sorted_token_bytes + .iter() + .map(|x| PyBytes::new(py, x).into()) + .collect() + } +} + +#[pymodule] +fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use rustc_hash::FxHashMap as HashMap; + + use crate::byte_pair_split; + + #[test] + fn very_simple_test() { + let mut ranks = HashMap::default(); + ranks.insert(b"ab".to_vec(), 1); + ranks.insert(b"cd".to_vec(), 2); + + let res = byte_pair_split(b"abcd", &ranks); + assert_eq!(res, vec![b"ab", b"cd"]); + } +} diff --git a/tiktoken/__init__.py b/tiktoken/__init__.py new file mode 100644 index 0000000..f4b5065 --- /dev/null +++ b/tiktoken/__init__.py @@ -0,0 +1,3 @@ +from .core import Encoding as Encoding +from .registry import get_encoding as get_encoding +from .registry import list_encoding_names as list_encoding_names diff --git a/tiktoken/core.py b/tiktoken/core.py new file mode 100644 index 0000000..e200c29 --- /dev/null +++ b/tiktoken/core.py @@ -0,0 +1,310 @@ +import functools +from concurrent.futures import ThreadPoolExecutor +from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union + +import regex + +from tiktoken import _tiktoken + + +class Encoding: + def __init__( + self, + name: str, + *, + pat_str: str, + mergeable_ranks: dict[bytes, int], + special_tokens: dict[str, int], + explicit_n_vocab: Optional[int] = None, + ): + self.name = name + + self._pat_str = pat_str + self._mergeable_ranks = mergeable_ranks + self._special_tokens = special_tokens + + self.max_token_value = max( + max(mergeable_ranks.values()), max(special_tokens.values(), default=0) + ) + if explicit_n_vocab: + assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab + assert self.max_token_value == explicit_n_vocab - 1 + + self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str) + + def __repr__(self) -> str: + return f"" + + # ==================== + # Encoding + # ==================== + + def encode_ordinary(self, text: str) -> list[int]: + """Encodes a string into tokens, ignoring special tokens. + + This is equivalent to `encode(text, disallowed_special=())` (but slightly faster). + + ``` + >>> enc.encode_ordinary("hello world") + [31373, 995] + """ + return self._core_bpe.encode_ordinary(text) + + def encode( + self, + text: str, + *, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> list[int]: + """Encodes a string into tokens. + + Special tokens are tokens are artificial tokens used to unlock capabilities from a model, + such as fill-in-the-middle. So we want to be careful about accidentally encoding special + tokens, since they can be used to trick a model into doing something we don't want it to do. + + Hence, by default, encode will raise an error if it encounters text that corresponds + to a special token. This can be controlled on a per-token level using the `allowed_special` + and `disallowed_special` parameters. In particular: + - Setting `disallowed_special` to () will prevent this function from raising errors and + cause all text corresponding to special tokens to be encoded as natural text. + - Setting `allowed_special` to "all" will allow cause this function to treat all text + corresponding to special tokens to be encoded as special tokens. + + ``` + >>> enc.encode("hello world") + [31373, 995] + >>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}) + [50256] + >>> enc.encode("<|endoftext|>", allowed_special="all") + [50256] + >>> enc.encode("<|endoftext|>") + # Raises ValueError + >>> enc.encode("<|endoftext|>", disallowed_special=()) + [27, 91, 437, 1659, 5239, 91, 29] + ``` + """ + if allowed_special == "all": + allowed_special = self.special_tokens_set + if disallowed_special == "all": + disallowed_special = self.special_tokens_set - allowed_special + if disallowed_special: + if not isinstance(disallowed_special, frozenset): + disallowed_special = frozenset(disallowed_special) + if match := _special_token_regex(disallowed_special).search(text): + raise_disallowed_special_token(match.group()) + + return self._core_bpe.encode(text, allowed_special) + + def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]: + """Encodes a list of strings into tokens, in parallel, ignoring special tokens. + + This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster). + + ``` + >>> enc.encode_batch(["hello world", "goodbye world"]) + [[31373, 995], [11274, 16390, 995]] + ``` + """ + encoder = functools.partial(self.encode_ordinary) + with ThreadPoolExecutor(num_threads) as e: + return list(e.map(encoder, text)) + + def encode_batch( + self, + text: list[str], + *, + num_threads: int = 8, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> list[list[int]]: + """Encodes a list of strings into tokens, in parallel. + + See `encode` for more details on `allowed_special` and `disallowed_special`. + + ``` + >>> enc.encode_batch(["hello world", "goodbye world"]) + [[31373, 995], [11274, 16390, 995]] + ``` + """ + if allowed_special == "all": + allowed_special = self.special_tokens_set + if disallowed_special == "all": + disallowed_special = self.special_tokens_set - allowed_special + if not isinstance(disallowed_special, frozenset): + disallowed_special = frozenset(disallowed_special) + + encoder = functools.partial( + self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special + ) + with ThreadPoolExecutor(num_threads) as e: + return list(e.map(encoder, text)) + + def encode_with_unstable( + self, + text: str, + *, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> tuple[list[int], list[list[int]]]: + """Encodes a string into stable tokens and possible completion sequences. + + Note that the stable tokens will only represent a substring of `text`. + + See `encode` for more details on `allowed_special` and `disallowed_special`. + + ``` + >>> enc.encode_with_unstable("hello fanta") + ([31373], [(277, 4910), (5113, 265), ..., (8842,)]) + + >>> text = "..." + >>> stable_tokens, completions = enc.encode_with_unstable(text) + >>> assert text.encode().startswith(enc.decode_bytes(stable_tokens)) + >>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions) + ``` + """ + if allowed_special == "all": + allowed_special = self.special_tokens_set + if disallowed_special == "all": + disallowed_special = self.special_tokens_set - allowed_special + if disallowed_special: + if not isinstance(disallowed_special, frozenset): + disallowed_special = frozenset(disallowed_special) + if match := _special_token_regex(disallowed_special).search(text): + raise_disallowed_special_token(match.group()) + + return self._core_bpe.encode_with_unstable(text, allowed_special) + + def encode_single_token(self, text_or_bytes: Union[str, bytes]) -> int: + """Encodes text corresponding to a single token to its token value. + + NOTE: this will encode all special tokens. + + Raises `KeyError` if the token is not in the vocabulary. + + ``` + >>> enc.encode_single_token("hello") + 31373 + ``` + """ + if isinstance(text_or_bytes, str): + text_or_bytes = text_or_bytes.encode("utf-8") + return self._core_bpe.encode_single_token(text_or_bytes) + + # ==================== + # Decoding + # ==================== + + def decode_bytes(self, tokens: list[int]) -> bytes: + """Decodes a list of tokens into bytes. + + ``` + >>> enc.decode_bytes([31373, 995]) + b'hello world' + ``` + """ + return self._core_bpe.decode_bytes(tokens) + + def decode(self, tokens: list[int], errors: str = "replace") -> str: + """Decodes a list of tokens into a string. + + WARNING: the default behaviour of this function is lossy, since decoded bytes are not + guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter, + for instance, setting `errors=strict`. + + ``` + >>> enc.decode([31373, 995]) + 'hello world' + ``` + """ + return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors) + + def decode_single_token_bytes(self, token: int) -> bytes: + """Decodes a token into bytes. + + NOTE: this will decode all special tokens. + + Raises `KeyError` if the token is not in the vocabulary. + + ``` + >>> enc.decode_single_token_bytes(31373) + b'hello' + ``` + """ + return self._core_bpe.decode_single_token_bytes(token) + + def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: + """Decodes a list of tokens into a list of bytes. + + Useful for visualising tokenisation. + >>> enc.decode_tokens_bytes([31373, 995]) + [b'hello', b' world'] + """ + return [self.decode_single_token_bytes(token) for token in tokens] + + # ==================== + # Miscellaneous + # ==================== + + def token_byte_values(self) -> list[bytes]: + """Returns the list of all token byte values.""" + return self._core_bpe.token_byte_values() + + @property + def eot_token(self) -> int: + return self._special_tokens["<|endoftext|>"] + + @functools.cached_property + def special_tokens_set(self) -> set[str]: + return set(self._special_tokens.keys()) + + @property + def n_vocab(self) -> int: + """For backwards compatibility. Prefer to use `enc.max_token_value + 1`.""" + return self.max_token_value + 1 + + # ==================== + # Private + # ==================== + + def _encode_single_piece(self, text_or_bytes: Union[str, bytes]) -> list[int]: + """Encodes text corresponding to bytes without a regex split. + + NOTE: this will not encode any special tokens. + + ``` + >>> enc.encode_single_piece("helloqqqq") + [31373, 38227, 38227] + ``` + """ + if isinstance(text_or_bytes, str): + text_or_bytes = text_or_bytes.encode("utf-8") + return self._core_bpe.encode_single_piece(text_or_bytes) + + def _encode_only_native_bpe(self, text: str) -> list[str]: + """Encodes a string into tokens, but do regex splitting in Python.""" + _unused_pat = regex.compile(self._pat_str) + ret = [] + for piece in regex.findall(_unused_pat, text): + ret.extend(self._core_bpe.encode_single_piece(piece)) + return ret + + def _encode_bytes(self, text: bytes) -> list[int]: + return self._core_bpe._encode_bytes(text) + + +@functools.lru_cache(maxsize=128) +def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]": + inner = "|".join(regex.escape(token) for token in tokens) + return regex.compile(f"({inner})") + + +def raise_disallowed_special_token(token: str) -> NoReturn: + raise ValueError( + f"Encountered text corresponding to disallowed special token {token!r}.\n" + "If you want this text to be encoded as a special token, " + f"pass it to `allowed_special`, e.g. `allowed_special={{{token!r}, ...}}`.\n" + f"If you want this text to be encoded as normal text, disable the check for this token " + f"by passing `disallowed_special=(enc.special_tokens_set - {{{token!r}}})`.\n" + "To disable this check for all special tokens, pass `disallowed_special=()`.\n" + ) diff --git a/tiktoken/load.py b/tiktoken/load.py new file mode 100644 index 0000000..06e51cc --- /dev/null +++ b/tiktoken/load.py @@ -0,0 +1,97 @@ +import base64 +import hashlib +import json +import os +import uuid + +import blobfile + + +def read_file_cached(blobpath: str) -> bytes: + if "TIKTOKEN_CACHE_DIR" in os.environ: + cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] + elif "DATA_GYM_CACHE_DIR" in os.environ: + cache_dir = os.environ["DATA_GYM_CACHE_DIR"] + else: + cache_dir = "/tmp/data-gym-cache" + + if cache_dir == "": + # disable caching + with blobfile.BlobFile(blobpath, "rb") as f: + return f.read() + + cache_key = hashlib.sha1(blobpath.encode()).hexdigest() + + cache_path = os.path.join(cache_dir, cache_key) + if os.path.exists(cache_path): + with open(cache_path, "rb") as f: + return f.read() + + with blobfile.BlobFile(blobpath, "rb") as f: + contents = f.read() + + os.makedirs(cache_dir, exist_ok=True) + tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp" + with open(tmp_filename, "wb") as f: + f.write(contents) + os.rename(tmp_filename, cache_path) + + return contents + + +def data_gym_to_mergeable_bpe_ranks( + vocab_bpe_file: str, encoder_json_file: str +) -> dict[bytes, int]: + # NB: do not add caching to this function + rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] + + data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte} + n = 0 + for b in range(2**8): + if b not in rank_to_intbyte: + rank_to_intbyte.append(b) + data_gym_byte_to_byte[chr(2**8 + n)] = b + n += 1 + assert len(rank_to_intbyte) == 2**8 + + # vocab_bpe contains the merges along with associated ranks + vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode() + bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]] + + def decode_data_gym(value: str) -> bytes: + return bytes(data_gym_byte_to_byte[b] for b in value) + + # add the single byte tokens + bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)} + # add the merged tokens + n = len(bpe_ranks) + for first, second in bpe_merges: + bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n + n += 1 + + # check that the encoder file matches the merges file + # this sanity check is important since tiktoken assumes that ranks are ordered the same + # as merge priority + encoder_json = json.loads(read_file_cached(encoder_json_file)) + encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()} + # drop these two special tokens if present, since they're not mergeable bpe tokens + encoder_json_loaded.pop(b"<|endoftext|>", None) + encoder_json_loaded.pop(b"<|startoftext|>", None) + assert bpe_ranks == encoder_json_loaded + + return bpe_ranks + + +def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None: + with blobfile.BlobFile(tiktoken_bpe_file, "wb") as f: + for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]): + f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") + + +def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: + # NB: do not add caching to this function + contents = read_file_cached(tiktoken_bpe_file) + return { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) + } diff --git a/tiktoken/registry.py b/tiktoken/registry.py new file mode 100644 index 0000000..24bb173 --- /dev/null +++ b/tiktoken/registry.py @@ -0,0 +1,71 @@ +import importlib +import pkgutil +import threading +from typing import Any, Callable, Optional + +import tiktoken_ext + +from tiktoken.core import Encoding + +_lock = threading.RLock() +ENCODINGS: dict[str, Encoding] = {} +ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None + + +def _find_constructors() -> None: + global ENCODING_CONSTRUCTORS + with _lock: + if ENCODING_CONSTRUCTORS is not None: + return + ENCODING_CONSTRUCTORS = {} + + # tiktoken_ext is a namespace package + # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes + # - we use namespace package pattern so `pkgutil.iter_modules` is fast + # - it's a separate top-level package because namespace subpackages of non-namespace + # packages don't quite do what you want with editable installs + plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".") + + for _, mod_name, _ in plugin_mods: + mod = importlib.import_module(mod_name) + try: + constructors = mod.ENCODING_CONSTRUCTORS + except AttributeError as e: + raise ValueError( + f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS" + ) from e + for enc_name, constructor in constructors.items(): + if enc_name in ENCODING_CONSTRUCTORS: + raise ValueError( + f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}" + ) + ENCODING_CONSTRUCTORS[enc_name] = constructor + + +def get_encoding(encoding_name: str) -> Encoding: + if encoding_name in ENCODINGS: + return ENCODINGS[encoding_name] + + with _lock: + if encoding_name in ENCODINGS: + return ENCODINGS[encoding_name] + + if ENCODING_CONSTRUCTORS is None: + _find_constructors() + assert ENCODING_CONSTRUCTORS is not None + + if encoding_name not in ENCODING_CONSTRUCTORS: + raise ValueError(f"Unknown encoding {encoding_name}") + + constructor = ENCODING_CONSTRUCTORS[encoding_name] + enc = Encoding(**constructor()) + ENCODINGS[encoding_name] = enc + return enc + + +def list_encoding_names() -> list[str]: + with _lock: + if ENCODING_CONSTRUCTORS is None: + _find_constructors() + assert ENCODING_CONSTRUCTORS is not None + return list(ENCODING_CONSTRUCTORS) diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py new file mode 100644 index 0000000..cc6ad3c --- /dev/null +++ b/tiktoken_ext/openai_public.py @@ -0,0 +1,41 @@ +from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe + +ENDOFTEXT = "<|endoftext|>" +FIM_PREFIX = "<|fim_prefix|>" +FIM_MIDDLE = "<|fim_middle|>" +FIM_SUFFIX = "<|fim_suffix|>" +ENDOFPROMPT = "<|endofprompt|>" + + +def gpt2(): + mergeable_ranks = data_gym_to_mergeable_bpe_ranks( + vocab_bpe_file="az://openaipublic/gpt-2/encodings/main/vocab.bpe", + encoder_json_file="az://openaipublic/gpt-2/encodings/main/encoder.json", + ) + return { + "name": "gpt2", + "explicit_n_vocab": 50257, + "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "mergeable_ranks": mergeable_ranks, + "special_tokens": {"<|endoftext|>": 50256}, + } + + +def cl100k_base(): + mergeable_ranks = load_tiktoken_bpe("az://openaipublic/encodings/cl100k_base.tiktoken") + special_tokens = { + ENDOFTEXT: 100257, + FIM_PREFIX: 100258, + FIM_MIDDLE: 100259, + FIM_SUFFIX: 100260, + ENDOFPROMPT: 100276, + } + return { + "name": "cl100k_base", + "pat_str": r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", + "mergeable_ranks": mergeable_ranks, + "special_tokens": special_tokens, + } + + +ENCODING_CONSTRUCTORS = {"gpt2": gpt2, "cl100k_base": cl100k_base}