[tiktoken] hello world

This commit is contained in:
Shantanu Jain 2022-12-12 11:27:27 -08:00
commit a1a9f16826
17 changed files with 1755 additions and 0 deletions

42
.gitignore vendored Normal file
View File

@ -0,0 +1,42 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# Environments
.env
.venv
# Tools
.mypy_cache
.coverage
htmlcov
# General
.DS_Store
Cargo.lock
target/

21
Cargo.toml Normal file
View File

@ -0,0 +1,21 @@
[package]
name = "tiktoken"
version = "0.1.0"
edition = "2021"
rust-version = "1.57.0"
[lib]
name = "_tiktoken"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.17.3", features = ["extension-module"] }
# tiktoken dependencies
fancy-regex = "0.10.0"
regex = "1.7.0"
rustc-hash = "1.1.0"
bstr = "1.0.1"
[profile.release]
incremental = true

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 OpenAI, Shantanu Jain
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

5
MANIFEST.in Normal file
View File

@ -0,0 +1,5 @@
include *.svg
include *.toml
include Makefile
recursive-include scripts *.py
recursive-include src *.rs

49
Makefile Normal file
View File

@ -0,0 +1,49 @@
PROJECT := tiktoken
.PHONY: default
default: editable_install
.PHONY: install_rust
install_rust:
which cargo >/dev/null || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.62
.PHONY: clean
clean:
cargo clean
pip uninstall -y $(PROJECT)
find . | grep -E '__pycache__|\.pyc' | xargs rm -rf
find . | grep -E '\.so' | xargs rm -rf
rm -rf dist/ build/
rm -rf $(PROJECT).egg-info/
.PHONY: format
format:
@ which black >/dev/null || python3 -m pip install black
@ which isort >/dev/null || python3 -m pip install isort
cargo fmt -- --config group_imports=StdExternalCrate
black --line-length 100 --skip-magic-trailing-comma --quiet .
isort --line-length 100 --profile black --quiet .
.PHONY: format_check
format_check:
@ which black >/dev/null || python3 -m pip install black
@ which isort >/dev/null || python3 -m pip install isort
cargo fmt --check -- --config group_imports=StdExternalCrate
black --check --line-length 100 --skip-magic-trailing-comma --quiet .
isort --check --line-length 100 --profile black --quiet .
.PHONY: lint
lint:
cargo clippy --all -- -D warnings
@ which flake8 >/dev/null || python3 -m pip install flake8==5 flake8-bugbear==22.9.11
flake8 --ignore=E203,E501,W503,E731 --per-file-ignores="$(PROJECT)/__init__.py:F401 setup.py:E402" --exclude=build .
.PHONY: editable_install
editable_install:
@ if [ -f $(PROJECT).egg-info ]; then \
pip install --disable-pip-version-check --progress-bar=off setuptools wheel setuptools-rust ; \
pip install --disable-pip-version-check --no-build-isolation -e . ; \
else \
pip install --disable-pip-version-check --no-deps --no-build-isolation --ignore-installed -e . ; \
fi

28
README.md Normal file
View File

@ -0,0 +1,28 @@
# ⏳ tiktoken
tiktoken is a fast tokeniser.
```python
import tiktoken
enc = tiktoken.get_encoding("gpt2")
print(enc.encode("hello world"))
```
The open source version of `tiktoken` can be installed from PyPI:
```
pip install tiktoken
```
The tokeniser API is documented in `tiktoken/core.py`.
## Performance
`tiktoken` is between 3-6x faster than huggingface's tokeniser:
![image](./perf.svg)
Performance measured on 1GB of text using the GPT-2 tokeniser, using `GPT2TokenizerFast` from
`tokenizers==0.13.2` and `transformers==4.24.0`.

373
perf.svg Normal file
View File

@ -0,0 +1,373 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" width="569.334pt" height="328.0869pt" viewBox="0 0 569.334 328.0869">
<g enable-background="new">
<g>
<clipPath id="cp0">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
</clipPath>
<g clip-path="url(#cp0)">
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 219.5869 L 569 219.5869 "/>
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 150.5869 L 569 150.5869 "/>
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 82.58685 L 569 82.58685 "/>
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 13.58685 L 569 13.58685 "/>
</g>
<clipPath id="cp1">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 39.64996 L -.0000120107 314.4229 L 20.496 314.4229 L 20.49601 39.64996 Z "/>
</clipPath>
<g clip-path="url(#cp1)">
<clipPath id="cp2">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
</clipPath>
<g clip-path="url(#cp2)">
<text xml:space="preserve" transform="matrix(0 -1 1 0 11.10803 182.06452)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.888 13.560001 17.340003 24.228003 30.900004 37.788003 44.460004 51.576005 58.248006">Throughput</tspan></text>
</g>
</g>
<clipPath id="cp3">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 26.168 31.81796 L 67.843997 31.81796 L 67.843997 47.48196 L 26.168 47.48196 Z "/>
</clipPath>
<g clip-path="url(#cp3)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 27.668 292.71299)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 10.008001 20.460003 28.680005 32.676004">0 MB/s</tspan></text>
</g>
<clipPath id="cp4">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 100.5112 L 67.844 100.5112 L 67.844 116.1752 L 19.496 116.1752 Z "/>
</clipPath>
<g clip-path="url(#cp4)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 224.01972)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">10 MB/s</tspan></text>
</g>
<clipPath id="cp5">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 169.2044 L 67.844 169.2044 L 67.844 184.86841 L 19.496 184.86841 Z "/>
</clipPath>
<g clip-path="url(#cp5)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 155.3265)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">20 MB/s</tspan></text>
</g>
<clipPath id="cp6">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 237.8976 L 67.844 237.8976 L 67.844 253.5616 L 19.496 253.5616 Z "/>
</clipPath>
<g clip-path="url(#cp6)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 86.633319)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">30 MB/s</tspan></text>
</g>
<clipPath id="cp7">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 306.5909 L 67.844 306.5909 L 67.844 322.2549 L 19.496 322.2549 Z "/>
</clipPath>
<g clip-path="url(#cp7)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 17.940125)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">40 MB/s</tspan></text>
</g>
<clipPath id="cp8">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
</clipPath>
<g clip-path="url(#cp8)">
<path stroke-width="1" stroke-linecap="square" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#000000" d="M 78.5 288.5869 L 568.5 288.5869 "/>
</g>
<clipPath id="cp9">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 78.83396 0 L 568.834 0 L 568.834 20.496 L 78.83396 20.496 Z "/>
</clipPath>
<g clip-path="url(#cp9)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 288.266 325.53096)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.888 13.560001 17.340003 23.784003 30.228003 37.344 40.68 47.124 54.012 60.684003 67.356">Thread count</tspan></text>
</g>
<clipPath id="cp10">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 108.998 19.496 L 118.67 19.496 L 118.67 35.16 L 108.998 35.16 Z "/>
</clipPath>
<g clip-path="url(#cp10)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 110.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">1</tspan></text>
</g>
<clipPath id="cp11">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 178.998 19.496 L 188.67 19.496 L 188.67 35.16 L 178.998 35.16 Z "/>
</clipPath>
<g clip-path="url(#cp11)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 180.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">2</tspan></text>
</g>
<clipPath id="cp12">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 248.998 19.496 L 258.67 19.496 L 258.67 35.16 L 248.998 35.16 Z "/>
</clipPath>
<g clip-path="url(#cp12)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 250.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">4</tspan></text>
</g>
<clipPath id="cp13">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 318.998 19.496 L 328.66999 19.496 L 328.66999 35.16 L 318.998 35.16 Z "/>
</clipPath>
<g clip-path="url(#cp13)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 320.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">8</tspan></text>
</g>
<clipPath id="cp14">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 385.662 19.496 L 402.00599 19.496 L 402.00599 35.16 L 385.662 35.16 Z "/>
</clipPath>
<g clip-path="url(#cp14)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 387.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">16</tspan></text>
</g>
<clipPath id="cp15">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 455.662 19.496 L 472.00599 19.496 L 472.00599 35.16 L 455.662 35.16 Z "/>
</clipPath>
<g clip-path="url(#cp15)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 457.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">32</tspan></text>
</g>
<clipPath id="cp16">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 525.662 19.496 L 542.006 19.496 L 542.006 35.16 L 525.662 35.16 Z "/>
</clipPath>
<g clip-path="url(#cp16)">
<text xml:space="preserve" transform="matrix(1 0 -0 1 527.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">64</tspan></text>
</g>
<clipPath id="cp17">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 115 40 L 143 40 L 143 52 L 115 52 Z "/>
</clipPath>
<g clip-path="url(#cp17)">
<g>
<clipPath id="cp18">
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z "/>
</clipPath>
<g clip-path="url(#cp18)">
<clipPath id="cp19">
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z "/>
</clipPath>
<g clip-path="url(#cp19)">
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z " fill="#61d836"/>
</g>
</g>
</g>
</g>
<clipPath id="cp20">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 185 40 L 213 40 L 213 61 L 185 61 Z "/>
</clipPath>
<g clip-path="url(#cp20)">
<g>
<clipPath id="cp21">
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z "/>
</clipPath>
<g clip-path="url(#cp21)">
<clipPath id="cp22">
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z "/>
</clipPath>
<g clip-path="url(#cp22)">
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z " fill="#61d836"/>
</g>
</g>
</g>
</g>
<clipPath id="cp23">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 255 40 L 283 40 L 283 74 L 255 74 Z "/>
</clipPath>
<g clip-path="url(#cp23)">
<g>
<clipPath id="cp24">
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z "/>
</clipPath>
<g clip-path="url(#cp24)">
<clipPath id="cp25">
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z "/>
</clipPath>
<g clip-path="url(#cp25)">
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z " fill="#61d836"/>
</g>
</g>
</g>
</g>
<clipPath id="cp26">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 325 40 L 353 40 L 353 80 L 325 80 Z "/>
</clipPath>
<g clip-path="url(#cp26)">
<g>
<clipPath id="cp27">
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z "/>
</clipPath>
<g clip-path="url(#cp27)">
<clipPath id="cp28">
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z "/>
</clipPath>
<g clip-path="url(#cp28)">
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z " fill="#61d836"/>
</g>
</g>
</g>
</g>
<clipPath id="cp29">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 395 40 L 423 40 L 423 83 L 395 83 Z "/>
</clipPath>
<g clip-path="url(#cp29)">
<g>
<clipPath id="cp30">
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z "/>
</clipPath>
<g clip-path="url(#cp30)">
<clipPath id="cp31">
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z "/>
</clipPath>
<g clip-path="url(#cp31)">
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z " fill="#61d836"/>
</g>
</g>
</g>
</g>
<clipPath id="cp32">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 465 40 L 493 40 L 493 85 L 465 85 Z "/>
</clipPath>
<g clip-path="url(#cp32)">
<g>
<clipPath id="cp33">
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z "/>
</clipPath>
<g clip-path="url(#cp33)">
<clipPath id="cp34">
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z "/>
</clipPath>
<g clip-path="url(#cp34)">
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z " fill="#61d836"/>
</g>
</g>
</g>
</g>
<clipPath id="cp35">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 535 40 L 563 40 L 563 88 L 535 88 Z "/>
</clipPath>
<g clip-path="url(#cp35)">
<g>
<clipPath id="cp36">
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z "/>
</clipPath>
<g clip-path="url(#cp36)">
<clipPath id="cp37">
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z "/>
</clipPath>
<g clip-path="url(#cp37)">
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z " fill="#61d836"/>
</g>
</g>
</g>
</g>
<clipPath id="cp38">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 84 40 L 112 40 L 112 84 L 84 84 Z "/>
</clipPath>
<g clip-path="url(#cp38)">
<g>
<clipPath id="cp39">
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z "/>
</clipPath>
<g clip-path="url(#cp39)">
<clipPath id="cp40">
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z "/>
</clipPath>
<g clip-path="url(#cp40)">
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z " fill="#00a2ff"/>
</g>
</g>
</g>
</g>
<clipPath id="cp41">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 154 40 L 182 40 L 182 123 L 154 123 Z "/>
</clipPath>
<g clip-path="url(#cp41)">
<g>
<clipPath id="cp42">
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z "/>
</clipPath>
<g clip-path="url(#cp42)">
<clipPath id="cp43">
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z "/>
</clipPath>
<g clip-path="url(#cp43)">
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z " fill="#00a2ff"/>
</g>
</g>
</g>
</g>
<clipPath id="cp44">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 224 40 L 252 40 L 252 189 L 224 189 Z "/>
</clipPath>
<g clip-path="url(#cp44)">
<g>
<clipPath id="cp45">
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z "/>
</clipPath>
<g clip-path="url(#cp45)">
<clipPath id="cp46">
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z "/>
</clipPath>
<g clip-path="url(#cp46)">
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z " fill="#00a2ff"/>
</g>
</g>
</g>
</g>
<clipPath id="cp47">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 294 40 L 322 40 L 322 258 L 294 258 Z "/>
</clipPath>
<g clip-path="url(#cp47)">
<g>
<clipPath id="cp48">
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z "/>
</clipPath>
<g clip-path="url(#cp48)">
<clipPath id="cp49">
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z "/>
</clipPath>
<g clip-path="url(#cp49)">
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z " fill="#00a2ff"/>
</g>
</g>
</g>
</g>
<clipPath id="cp50">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 364 40 L 392 40 L 392 303 L 364 303 Z "/>
</clipPath>
<g clip-path="url(#cp50)">
<g>
<clipPath id="cp51">
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z "/>
</clipPath>
<g clip-path="url(#cp51)">
<clipPath id="cp52">
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z "/>
</clipPath>
<g clip-path="url(#cp52)">
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z " fill="#00a2ff"/>
</g>
</g>
</g>
</g>
<clipPath id="cp53">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 434 40 L 462 40 L 462 203 L 434 203 Z "/>
</clipPath>
<g clip-path="url(#cp53)">
<g>
<clipPath id="cp54">
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z "/>
</clipPath>
<g clip-path="url(#cp54)">
<clipPath id="cp55">
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z "/>
</clipPath>
<g clip-path="url(#cp55)">
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z " fill="#00a2ff"/>
</g>
</g>
</g>
</g>
<clipPath id="cp56">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 504 40 L 532 40 L 532 195 L 504 195 Z "/>
</clipPath>
<g clip-path="url(#cp56)">
<g>
<clipPath id="cp57">
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z "/>
</clipPath>
<g clip-path="url(#cp57)">
<clipPath id="cp58">
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z "/>
</clipPath>
<g clip-path="url(#cp58)">
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z " fill="#00a2ff"/>
</g>
</g>
</g>
</g>
<clipPath id="cp59">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
</clipPath>
<g clip-path="url(#cp59)">
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 459 291 L 471 291 L 471 280 L 459 280 Z " fill="#00a2ff"/>
<text xml:space="preserve" transform="matrix(1 0 -0 1 477.8753 46.772127)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 3.7800003 6.4440004 12.672001 16.452002 23.340003 29.568003 36.012">tiktoken</tspan></text>
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 459 278 L 471 278 L 471 266 L 459 266 Z " fill="#61d836"/>
<text xml:space="preserve" transform="matrix(1 0 -0 1 477.8753 60.436128)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 20.232003 27.120003 29.784003 36.456 43.344 46.896005 53.340005 59.784006">huggingface</tspan></text>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 15 KiB

8
pyproject.toml Normal file
View File

@ -0,0 +1,8 @@
[project]
name = "tiktoken"
dependencies = ["blobfile>=2", "regex>=2022.1.18"]
dynamic = ["version"]
[build-system]
requires = ["setuptools", "wheel", "setuptools-rust"]

39
scripts/benchmark.py Normal file
View File

@ -0,0 +1,39 @@
import base64
import functools
import gzip
import json
import os
import random
import time
from typing import Any, cast
import blobfile
import tiktoken
def benchmark_batch(documents: list[str]) -> None:
num_threads = int(os.environ["RAYON_NUM_THREADS"])
num_bytes = sum(map(len, map(str.encode, documents)))
print(f"num_threads: {num_threads}, num_bytes: {num_bytes}")
enc = tiktoken.get_encoding("gpt2")
enc.encode("warmup")
start = time.perf_counter_ns()
enc.encode_ordinary_batch(documents, num_threads=num_threads)
end = time.perf_counter_ns()
print(f"tiktoken \t{num_bytes / (end - start) * 1e9} bytes / s")
import transformers
hf_enc = cast(Any, transformers).GPT2TokenizerFast.from_pretrained("gpt2")
hf_enc.model_max_length = 1e30 # silence!
hf_enc.encode("warmup")
start = time.perf_counter_ns()
hf_enc(documents)
end = time.perf_counter_ns()
print(f"huggingface \t{num_bytes / (end - start) * 1e9} bytes / s")

65
scripts/redact.py Normal file
View File

@ -0,0 +1,65 @@
import argparse
import re
import subprocess
from pathlib import Path
def redact_file(path: Path, dry_run: bool) -> None:
if not path.exists() or path.is_dir():
return
text = path.read_text()
first_line = text.splitlines()[0]
if "redact" in first_line:
if not dry_run:
path.unlink()
print(f"Deleted {path}")
return
pattern = "|".join(
re.escape(x)
for x in [
"# ===== redact-beg =====\n",
"# ===== redact-end =====\n",
"<!--- redact-beg -->\n",
"<!--- redact-end -->\n",
]
)
if re.search(pattern, text):
redacted_text = "".join(re.split(pattern, text)[::2])
if not dry_run:
path.write_text(redacted_text)
print(f"Redacted {path}")
return
print(f"Skipped {path}")
def redact(dry_run: bool) -> None:
tiktoken_root = Path(__file__).parent.parent
assert tiktoken_root.name == "tiktoken"
assert (tiktoken_root / "pyproject.toml").exists()
try:
output = subprocess.check_output(["git", "ls-files"], cwd=tiktoken_root, text=True)
paths = [Path(p) for p in output.splitlines()]
except subprocess.CalledProcessError:
paths = list(tiktoken_root.glob("**/*"))
for path in paths:
redact_file(path, dry_run=dry_run)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--dry-run", type=lambda x: not x or x[0].lower() != "f", default=True)
args = parser.parse_args()
redact(args.dry_run)
if args.dry_run:
print("Dry run, use --dry-run=false to actually redact files")
if __name__ == "__main__":
main()

23
setup.py Normal file
View File

@ -0,0 +1,23 @@
from setuptools import setup
from setuptools_rust import Binding, RustExtension
public = True
if public:
version = "0.1"
setup(
name="tiktoken",
version=version,
rust_extensions=[
RustExtension(
"tiktoken._tiktoken",
binding=Binding.PyO3,
# Between our use of editable installs and wanting to use Rust for performance sensitive
# code, it makes sense to just always use --release
debug=False,
)
],
packages=["tiktoken", "tiktoken_ext"],
zip_safe=False,
)

559
src/lib.rs Normal file
View File

@ -0,0 +1,559 @@
// This check is new and seems buggy (possibly with PyO3 interaction)
#![allow(clippy::borrow_deref_ref)]
use std::collections::HashSet;
use std::thread;
use fancy_regex::Regex;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyList, PyTuple};
use pyo3::PyResult;
use rustc_hash::FxHashMap as HashMap;
fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
// If you have n parts and m merges, this does O(mn) work
// We could do something with a heap and do O(m log n) work
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
// currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
loop {
if parts.len() == 1 {
break;
}
let mut min_rank: Option<(usize, usize)> = None;
for i in 0..parts.len() - 1 {
let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
*r
} else {
continue;
};
if min_rank.is_none() || rank < min_rank.unwrap().0 {
min_rank = Some((rank, i));
}
}
if let Some((_, i)) = min_rank {
parts[i] = parts[i].start..parts[i + 1].end;
parts.remove(i + 1);
} else {
break;
}
}
parts
}
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(piece, ranks)
.iter()
.map(|p| ranks[&piece[p.start..p.end]])
.collect()
}
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
if piece.len() == 1 {
return vec![piece];
}
_byte_pair_merge(piece, ranks)
.iter()
.map(|p| &piece[p.start..p.end])
.collect()
}
// Various performance notes:
//
// Regex
// =====
// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
// the usual regex we use.
//
// However, given that we're using a regex parse-able by `regex`, there isn't much difference
// between using the `regex` crate and using the `fancy_regex` crate.
//
// There is an important interaction between threading, `regex` and `fancy_regex`.
// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
// old `regex`, we don't hit this, because `find_iter` has a different code path.
// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
// each thread.
//
// Threading
// =========
// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
// So goodbye `rayon`! Let thread count etc be in control of our Python users.
//
// Caching
// =======
// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
// Originally, we had one too! Without it, we were only vaguely faster than Python.
// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
// multi-threaded performance even when I only had readers (maybed I messed something up?).
// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
// These are exactly the set or merges that are likely to be hot. And now we don't have to think
// about interior mutability, memory use, or cloning.
//
// Hashing
// =======
// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
use std::num::NonZeroU64;
pub struct FakeThreadId(NonZeroU64);
fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x = unsafe {
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
};
u64::from(x) as usize
}
const MAX_NUM_THREADS: usize = 128;
#[pyclass]
struct CoreBPE {
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, Vec<u8>>,
special_tokens_decoder: HashMap<usize, Vec<u8>>,
regex_tls: Vec<Regex>,
special_regex_tls: Vec<Regex>,
sorted_token_bytes: Vec<Vec<u8>>,
}
impl CoreBPE {
fn _get_tl_regex(&self) -> &Regex {
// See performance notes above for what this is about
// It's also a little janky, please make a better version of it!
// However, it's nice that this doesn't leak memory to short-lived threads
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}
fn _get_tl_special_regex(&self) -> &Regex {
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}
fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for token in tokens {
let token_bytes = self
.decoder
.get(token)
.unwrap_or_else(|| &self.special_tokens_decoder[token]);
ret.extend(token_bytes);
}
ret
}
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
let regex = self._get_tl_regex();
let mut ret = vec![];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
ret.push(*token);
continue;
}
ret.extend(&byte_pair_encode(piece, &self.encoder));
}
ret
}
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
let special_regex = self._get_tl_special_regex();
let regex = self._get_tl_regex();
let mut ret = vec![];
let mut start = 0;
let mut last_piece_token_len = 0;
loop {
let mut next_special;
let mut start_find = start;
loop {
// Find the next allowed special token, if any
next_special = special_regex.find_from_pos(text, start_find).unwrap();
match next_special {
Some(m) => {
if allowed_special.contains(&text[m.start()..m.end()]) {
break;
}
start_find = m.start() + 1;
}
None => break,
}
}
let end = next_special.map_or(text.len(), |m| m.start());
// Okay, here we go, compare this logic to _encode_ordinary_native
for mat in regex.find_iter(&text[start..end]) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
last_piece_token_len = 1;
ret.push(*token);
continue;
}
let tokens = byte_pair_encode(piece, &self.encoder);
last_piece_token_len = tokens.len();
ret.extend(&tokens);
}
match next_special {
// And here we push the special token
Some(m) => {
let piece = m.as_str();
let token = self.special_tokens_encoder[piece];
ret.push(token);
start = m.end();
last_piece_token_len = 0;
}
None => break,
}
}
// last_piece_token_len is how many tokens came from the last regex split. This is used
// for determining unstable tokens, since you can't merge across (stable) regex splits
(ret, last_piece_token_len)
}
fn _increase_last_piece_token_len(
&self,
tokens: Vec<usize>,
mut last_piece_token_len: usize,
) -> (Vec<usize>, usize) {
// Unfortunately, the locations where our regex splits can be unstable.
// For the purposes of determining unstable tokens, unstable regex splitting
// is only a problem if a split that was present disappears, since this can
// lead to merging of tokens otherwise thought to be stable.
// cl100k_base makes our life hard by including the \s*[\r\n]+
// pattern. This can e.g. cause "\n" + " " to become "\n \n".
// Here is a quick and dirty fix:
{
let token_is_all_space = |token| {
self.decoder
.get(token)
.map(|token_bytes| {
token_bytes
.iter()
.rev()
.all(|&b| [b' ', b'\n', b'\t'].contains(&b))
})
.unwrap_or(false)
};
if last_piece_token_len > 0
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
{
while (last_piece_token_len < tokens.len())
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
{
last_piece_token_len += 1;
}
}
}
debug_assert!(last_piece_token_len <= tokens.len());
(tokens, last_piece_token_len)
}
fn _encode_unstable_native(
&self,
text: &str,
allowed_special: &HashSet<&str>,
) -> (Vec<usize>, HashSet<Vec<usize>>) {
let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
if last_piece_token_len == 0 {
// If last_piece_token_len is zero, the last token was a special token and we have
// no unstable bytes
return (tokens, HashSet::new());
}
let (mut tokens, last_piece_token_len) =
self._increase_last_piece_token_len(tokens, last_piece_token_len);
let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
tokens.truncate(tokens.len() - last_piece_token_len);
// TODO: we should try harder to find additional stable tokens
// This would reduce the amount of retokenising when determining completions
// Refer to the logic in an older version of this file
let mut completions = HashSet::new();
if unstable_bytes.is_empty() {
return (tokens, completions);
}
// This is the easy bit. Just find all single tokens that start with unstable_bytes
// (including tokens that exactly match unstable_bytes)
// Separating this from the loop below helps with performance in a common case.
let mut point = self
.sorted_token_bytes
.partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
{
completions.insert(vec![
self.encoder[self.sorted_token_bytes[point].as_slice()],
]);
point += 1;
}
// Now apply even more brute force. At every (other) possible position for the straddling
// token, concatenate additional bytes from that token (if any) to unstable_bytes,
// and retokenise the whole thing and see what we get.
for i in 1..unstable_bytes.len() {
let prefix = &unstable_bytes[..i];
let suffix = &unstable_bytes[i..];
let mut point = self
.sorted_token_bytes
.partition_point(|x| x.as_slice() < suffix);
// TODO: Perf optimisation if suffix starts with " "?
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(suffix)
{
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
let encoded = match std::str::from_utf8(&possibility) {
// Morally, this is byte_pair_encode(&possibility, &self.encoder)
// But we might have introduced a regex split which would prevent merges.
// (particularly possible in the presence of unstable regex splits)
// So convert to UTF-8 and do regex splitting.
// E.g. with cl100k_base " !" gets split to " " + " !",
// but byte_pair_encode(" !") != byte_pair_encode(" ")
Ok(s) => self._encode_ordinary_native(s),
// Technically, whether or not this arm is correct depends on whether there
// would be a regex split before the UTF-8 truncation point.
// Probably niche enough that no one will ever notice (after all, people didn't
// notice all the big holes in the previous unstable token implementation)
Err(_) => byte_pair_encode(&possibility, &self.encoder),
// Something like the following is intriguing but incorrect:
// Err(e) => self._encode_ordinary_native(unsafe {
// std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
// }),
};
let mut seq = Vec::new();
let mut seq_len = 0;
for token in encoded {
seq.push(token);
seq_len += self.decoder[&token].len();
if seq_len >= unstable_bytes.len() {
break;
}
}
completions.insert(seq);
point += 1;
}
}
// This is also not straightforward. While we generally assume that regex splits are stable,
// unfortunately, they are not. That is, if adding bytes were to make a split appear in
// unstable_bytes, this could make tokens possible which our logic would otherwise think
// would be merged.
// For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
// develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
// Here is a quick and dirty fix:
// This isn't right if we ever remove \s+(?!\S)
if unstable_bytes.len() > 1 {
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
if unstable_bytes.len() - last_decoded.1 > 0
&& last_decoded.0.map_or(false, |c| c.is_whitespace())
{
let mut reencoded = byte_pair_encode(
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
&self.encoder,
);
reencoded.extend(byte_pair_encode(
&unstable_bytes[unstable_bytes.len() - last_decoded.1..],
&self.encoder,
));
completions.insert(reencoded);
}
}
(tokens, completions)
}
}
#[pymethods]
impl CoreBPE {
#[new]
fn new(
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> PyResult<Self> {
let regex = Regex::new(pattern)
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
let special_regex = {
let _parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|"))
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
};
let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
assert!(encoder.len() == decoder.len());
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();
// Clone because I don't know how to tell Rust I'm not going to change the map
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();
Ok(CoreBPE {
encoder,
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
special_regex_tls: (0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect(),
sorted_token_bytes,
})
}
// ====================
// Encoding
// ====================
fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> {
py.allow_threads(|| self._encode_ordinary_native(text))
}
fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
py.allow_threads(|| self._encode_native(text, &allowed_special).0)
}
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
py.allow_threads(|| {
match std::str::from_utf8(bytes) {
Ok(text) => self._encode_ordinary_native(text),
Err(e) => {
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new());
let (mut tokens, last_piece_token_len) =
self._increase_last_piece_token_len(tokens, last_piece_token_len);
if !tokens.is_empty() && last_piece_token_len > 0 {
// Lop off the tokens from the last piece and run BPE on the remaining bytes
// Somewhat niche, but this may not be correct if we'd have had a regex
// split between the valid UTF-8 and the invalid bytes, which is why this
// method is private
let mut unstable_bytes =
self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
tokens.truncate(tokens.len() - last_piece_token_len);
tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
}
tokens
}
}
})
}
fn encode_with_unstable(
&self,
py: Python,
text: &str,
allowed_special: HashSet<&str>,
) -> Py<PyTuple> {
let (tokens, completions) =
py.allow_threads(|| self._encode_unstable_native(text, &allowed_special));
let py_completions =
PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..])));
(tokens, py_completions).into_py(py)
}
fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> {
if let Some(token) = self.encoder.get(piece).copied() {
return Ok(token);
}
if let Ok(piece_str) = std::str::from_utf8(piece) {
if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
return Ok(token);
}
}
Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
}
fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
if let Some(token) = self.encoder.get(piece) {
return vec![*token];
}
byte_pair_encode(piece, &self.encoder)
}
// ====================
// Decoding
// ====================
fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
let bytes = py.allow_threads(|| self._decode_native(&tokens));
PyBytes::new(py, &bytes).into()
}
fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
if let Some(bytes) = self.decoder.get(&token) {
return Ok(PyBytes::new(py, bytes).into());
}
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
return Ok(PyBytes::new(py, bytes).into());
}
Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
}
// ====================
// Miscellaneous
// ====================
fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
self.sorted_token_bytes
.iter()
.map(|x| PyBytes::new(py, x).into())
.collect()
}
}
#[pymodule]
fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<CoreBPE>()?;
Ok(())
}
#[cfg(test)]
mod tests {
use rustc_hash::FxHashMap as HashMap;
use crate::byte_pair_split;
#[test]
fn very_simple_test() {
let mut ranks = HashMap::default();
ranks.insert(b"ab".to_vec(), 1);
ranks.insert(b"cd".to_vec(), 2);
let res = byte_pair_split(b"abcd", &ranks);
assert_eq!(res, vec![b"ab", b"cd"]);
}
}

3
tiktoken/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .core import Encoding as Encoding
from .registry import get_encoding as get_encoding
from .registry import list_encoding_names as list_encoding_names

310
tiktoken/core.py Normal file
View File

@ -0,0 +1,310 @@
import functools
from concurrent.futures import ThreadPoolExecutor
from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union
import regex
from tiktoken import _tiktoken
class Encoding:
def __init__(
self,
name: str,
*,
pat_str: str,
mergeable_ranks: dict[bytes, int],
special_tokens: dict[str, int],
explicit_n_vocab: Optional[int] = None,
):
self.name = name
self._pat_str = pat_str
self._mergeable_ranks = mergeable_ranks
self._special_tokens = special_tokens
self.max_token_value = max(
max(mergeable_ranks.values()), max(special_tokens.values(), default=0)
)
if explicit_n_vocab:
assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab
assert self.max_token_value == explicit_n_vocab - 1
self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str)
def __repr__(self) -> str:
return f"<Encoding {self.name!r}>"
# ====================
# Encoding
# ====================
def encode_ordinary(self, text: str) -> list[int]:
"""Encodes a string into tokens, ignoring special tokens.
This is equivalent to `encode(text, disallowed_special=())` (but slightly faster).
```
>>> enc.encode_ordinary("hello world")
[31373, 995]
"""
return self._core_bpe.encode_ordinary(text)
def encode(
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
"""Encodes a string into tokens.
Special tokens are tokens are artificial tokens used to unlock capabilities from a model,
such as fill-in-the-middle. So we want to be careful about accidentally encoding special
tokens, since they can be used to trick a model into doing something we don't want it to do.
Hence, by default, encode will raise an error if it encounters text that corresponds
to a special token. This can be controlled on a per-token level using the `allowed_special`
and `disallowed_special` parameters. In particular:
- Setting `disallowed_special` to () will prevent this function from raising errors and
cause all text corresponding to special tokens to be encoded as natural text.
- Setting `allowed_special` to "all" will allow cause this function to treat all text
corresponding to special tokens to be encoded as special tokens.
```
>>> enc.encode("hello world")
[31373, 995]
>>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})
[50256]
>>> enc.encode("<|endoftext|>", allowed_special="all")
[50256]
>>> enc.encode("<|endoftext|>")
# Raises ValueError
>>> enc.encode("<|endoftext|>", disallowed_special=())
[27, 91, 437, 1659, 5239, 91, 29]
```
"""
if allowed_special == "all":
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - allowed_special
if disallowed_special:
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)
if match := _special_token_regex(disallowed_special).search(text):
raise_disallowed_special_token(match.group())
return self._core_bpe.encode(text, allowed_special)
def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
"""Encodes a list of strings into tokens, in parallel, ignoring special tokens.
This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster).
```
>>> enc.encode_batch(["hello world", "goodbye world"])
[[31373, 995], [11274, 16390, 995]]
```
"""
encoder = functools.partial(self.encode_ordinary)
with ThreadPoolExecutor(num_threads) as e:
return list(e.map(encoder, text))
def encode_batch(
self,
text: list[str],
*,
num_threads: int = 8,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[list[int]]:
"""Encodes a list of strings into tokens, in parallel.
See `encode` for more details on `allowed_special` and `disallowed_special`.
```
>>> enc.encode_batch(["hello world", "goodbye world"])
[[31373, 995], [11274, 16390, 995]]
```
"""
if allowed_special == "all":
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - allowed_special
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)
encoder = functools.partial(
self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special
)
with ThreadPoolExecutor(num_threads) as e:
return list(e.map(encoder, text))
def encode_with_unstable(
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> tuple[list[int], list[list[int]]]:
"""Encodes a string into stable tokens and possible completion sequences.
Note that the stable tokens will only represent a substring of `text`.
See `encode` for more details on `allowed_special` and `disallowed_special`.
```
>>> enc.encode_with_unstable("hello fanta")
([31373], [(277, 4910), (5113, 265), ..., (8842,)])
>>> text = "..."
>>> stable_tokens, completions = enc.encode_with_unstable(text)
>>> assert text.encode().startswith(enc.decode_bytes(stable_tokens))
>>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions)
```
"""
if allowed_special == "all":
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - allowed_special
if disallowed_special:
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)
if match := _special_token_regex(disallowed_special).search(text):
raise_disallowed_special_token(match.group())
return self._core_bpe.encode_with_unstable(text, allowed_special)
def encode_single_token(self, text_or_bytes: Union[str, bytes]) -> int:
"""Encodes text corresponding to a single token to its token value.
NOTE: this will encode all special tokens.
Raises `KeyError` if the token is not in the vocabulary.
```
>>> enc.encode_single_token("hello")
31373
```
"""
if isinstance(text_or_bytes, str):
text_or_bytes = text_or_bytes.encode("utf-8")
return self._core_bpe.encode_single_token(text_or_bytes)
# ====================
# Decoding
# ====================
def decode_bytes(self, tokens: list[int]) -> bytes:
"""Decodes a list of tokens into bytes.
```
>>> enc.decode_bytes([31373, 995])
b'hello world'
```
"""
return self._core_bpe.decode_bytes(tokens)
def decode(self, tokens: list[int], errors: str = "replace") -> str:
"""Decodes a list of tokens into a string.
WARNING: the default behaviour of this function is lossy, since decoded bytes are not
guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter,
for instance, setting `errors=strict`.
```
>>> enc.decode([31373, 995])
'hello world'
```
"""
return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
def decode_single_token_bytes(self, token: int) -> bytes:
"""Decodes a token into bytes.
NOTE: this will decode all special tokens.
Raises `KeyError` if the token is not in the vocabulary.
```
>>> enc.decode_single_token_bytes(31373)
b'hello'
```
"""
return self._core_bpe.decode_single_token_bytes(token)
def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
"""Decodes a list of tokens into a list of bytes.
Useful for visualising tokenisation.
>>> enc.decode_tokens_bytes([31373, 995])
[b'hello', b' world']
"""
return [self.decode_single_token_bytes(token) for token in tokens]
# ====================
# Miscellaneous
# ====================
def token_byte_values(self) -> list[bytes]:
"""Returns the list of all token byte values."""
return self._core_bpe.token_byte_values()
@property
def eot_token(self) -> int:
return self._special_tokens["<|endoftext|>"]
@functools.cached_property
def special_tokens_set(self) -> set[str]:
return set(self._special_tokens.keys())
@property
def n_vocab(self) -> int:
"""For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""
return self.max_token_value + 1
# ====================
# Private
# ====================
def _encode_single_piece(self, text_or_bytes: Union[str, bytes]) -> list[int]:
"""Encodes text corresponding to bytes without a regex split.
NOTE: this will not encode any special tokens.
```
>>> enc.encode_single_piece("helloqqqq")
[31373, 38227, 38227]
```
"""
if isinstance(text_or_bytes, str):
text_or_bytes = text_or_bytes.encode("utf-8")
return self._core_bpe.encode_single_piece(text_or_bytes)
def _encode_only_native_bpe(self, text: str) -> list[str]:
"""Encodes a string into tokens, but do regex splitting in Python."""
_unused_pat = regex.compile(self._pat_str)
ret = []
for piece in regex.findall(_unused_pat, text):
ret.extend(self._core_bpe.encode_single_piece(piece))
return ret
def _encode_bytes(self, text: bytes) -> list[int]:
return self._core_bpe._encode_bytes(text)
@functools.lru_cache(maxsize=128)
def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]":
inner = "|".join(regex.escape(token) for token in tokens)
return regex.compile(f"({inner})")
def raise_disallowed_special_token(token: str) -> NoReturn:
raise ValueError(
f"Encountered text corresponding to disallowed special token {token!r}.\n"
"If you want this text to be encoded as a special token, "
f"pass it to `allowed_special`, e.g. `allowed_special={{{token!r}, ...}}`.\n"
f"If you want this text to be encoded as normal text, disable the check for this token "
f"by passing `disallowed_special=(enc.special_tokens_set - {{{token!r}}})`.\n"
"To disable this check for all special tokens, pass `disallowed_special=()`.\n"
)

97
tiktoken/load.py Normal file
View File

@ -0,0 +1,97 @@
import base64
import hashlib
import json
import os
import uuid
import blobfile
def read_file_cached(blobpath: str) -> bytes:
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
elif "DATA_GYM_CACHE_DIR" in os.environ:
cache_dir = os.environ["DATA_GYM_CACHE_DIR"]
else:
cache_dir = "/tmp/data-gym-cache"
if cache_dir == "":
# disable caching
with blobfile.BlobFile(blobpath, "rb") as f:
return f.read()
cache_key = hashlib.sha1(blobpath.encode()).hexdigest()
cache_path = os.path.join(cache_dir, cache_key)
if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
return f.read()
with blobfile.BlobFile(blobpath, "rb") as f:
contents = f.read()
os.makedirs(cache_dir, exist_ok=True)
tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp"
with open(tmp_filename, "wb") as f:
f.write(contents)
os.rename(tmp_filename, cache_path)
return contents
def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file: str, encoder_json_file: str
) -> dict[bytes, int]:
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte}
n = 0
for b in range(2**8):
if b not in rank_to_intbyte:
rank_to_intbyte.append(b)
data_gym_byte_to_byte[chr(2**8 + n)] = b
n += 1
assert len(rank_to_intbyte) == 2**8
# vocab_bpe contains the merges along with associated ranks
vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode()
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
def decode_data_gym(value: str) -> bytes:
return bytes(data_gym_byte_to_byte[b] for b in value)
# add the single byte tokens
bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)}
# add the merged tokens
n = len(bpe_ranks)
for first, second in bpe_merges:
bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n
n += 1
# check that the encoder file matches the merges file
# this sanity check is important since tiktoken assumes that ranks are ordered the same
# as merge priority
encoder_json = json.loads(read_file_cached(encoder_json_file))
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
# drop these two special tokens if present, since they're not mergeable bpe tokens
encoder_json_loaded.pop(b"<|endoftext|>", None)
encoder_json_loaded.pop(b"<|startoftext|>", None)
assert bpe_ranks == encoder_json_loaded
return bpe_ranks
def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None:
with blobfile.BlobFile(tiktoken_bpe_file, "wb") as f:
for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file)
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}

71
tiktoken/registry.py Normal file
View File

@ -0,0 +1,71 @@
import importlib
import pkgutil
import threading
from typing import Any, Callable, Optional
import tiktoken_ext
from tiktoken.core import Encoding
_lock = threading.RLock()
ENCODINGS: dict[str, Encoding] = {}
ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None
def _find_constructors() -> None:
global ENCODING_CONSTRUCTORS
with _lock:
if ENCODING_CONSTRUCTORS is not None:
return
ENCODING_CONSTRUCTORS = {}
# tiktoken_ext is a namespace package
# submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes
# - we use namespace package pattern so `pkgutil.iter_modules` is fast
# - it's a separate top-level package because namespace subpackages of non-namespace
# packages don't quite do what you want with editable installs
plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".")
for _, mod_name, _ in plugin_mods:
mod = importlib.import_module(mod_name)
try:
constructors = mod.ENCODING_CONSTRUCTORS
except AttributeError as e:
raise ValueError(
f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS"
) from e
for enc_name, constructor in constructors.items():
if enc_name in ENCODING_CONSTRUCTORS:
raise ValueError(
f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}"
)
ENCODING_CONSTRUCTORS[enc_name] = constructor
def get_encoding(encoding_name: str) -> Encoding:
if encoding_name in ENCODINGS:
return ENCODINGS[encoding_name]
with _lock:
if encoding_name in ENCODINGS:
return ENCODINGS[encoding_name]
if ENCODING_CONSTRUCTORS is None:
_find_constructors()
assert ENCODING_CONSTRUCTORS is not None
if encoding_name not in ENCODING_CONSTRUCTORS:
raise ValueError(f"Unknown encoding {encoding_name}")
constructor = ENCODING_CONSTRUCTORS[encoding_name]
enc = Encoding(**constructor())
ENCODINGS[encoding_name] = enc
return enc
def list_encoding_names() -> list[str]:
with _lock:
if ENCODING_CONSTRUCTORS is None:
_find_constructors()
assert ENCODING_CONSTRUCTORS is not None
return list(ENCODING_CONSTRUCTORS)

View File

@ -0,0 +1,41 @@
from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe
ENDOFTEXT = "<|endoftext|>"
FIM_PREFIX = "<|fim_prefix|>"
FIM_MIDDLE = "<|fim_middle|>"
FIM_SUFFIX = "<|fim_suffix|>"
ENDOFPROMPT = "<|endofprompt|>"
def gpt2():
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file="az://openaipublic/gpt-2/encodings/main/vocab.bpe",
encoder_json_file="az://openaipublic/gpt-2/encodings/main/encoder.json",
)
return {
"name": "gpt2",
"explicit_n_vocab": 50257,
"pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": {"<|endoftext|>": 50256},
}
def cl100k_base():
mergeable_ranks = load_tiktoken_bpe("az://openaipublic/encodings/cl100k_base.tiktoken")
special_tokens = {
ENDOFTEXT: 100257,
FIM_PREFIX: 100258,
FIM_MIDDLE: 100259,
FIM_SUFFIX: 100260,
ENDOFPROMPT: 100276,
}
return {
"name": "cl100k_base",
"pat_str": r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
ENCODING_CONSTRUCTORS = {"gpt2": gpt2, "cl100k_base": cl100k_base}