[tiktoken] hello world
This commit is contained in:
commit
a1a9f16826
42
.gitignore
vendored
Normal file
42
.gitignore
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
|
||||
# Tools
|
||||
.mypy_cache
|
||||
.coverage
|
||||
htmlcov
|
||||
|
||||
# General
|
||||
.DS_Store
|
||||
|
||||
Cargo.lock
|
||||
target/
|
21
Cargo.toml
Normal file
21
Cargo.toml
Normal file
@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "tiktoken"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
rust-version = "1.57.0"
|
||||
|
||||
[lib]
|
||||
name = "_tiktoken"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
pyo3 = { version = "0.17.3", features = ["extension-module"] }
|
||||
|
||||
# tiktoken dependencies
|
||||
fancy-regex = "0.10.0"
|
||||
regex = "1.7.0"
|
||||
rustc-hash = "1.1.0"
|
||||
bstr = "1.0.1"
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 OpenAI, Shantanu Jain
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
5
MANIFEST.in
Normal file
5
MANIFEST.in
Normal file
@ -0,0 +1,5 @@
|
||||
include *.svg
|
||||
include *.toml
|
||||
include Makefile
|
||||
recursive-include scripts *.py
|
||||
recursive-include src *.rs
|
49
Makefile
Normal file
49
Makefile
Normal file
@ -0,0 +1,49 @@
|
||||
PROJECT := tiktoken
|
||||
|
||||
.PHONY: default
|
||||
default: editable_install
|
||||
|
||||
.PHONY: install_rust
|
||||
install_rust:
|
||||
which cargo >/dev/null || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.62
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
cargo clean
|
||||
pip uninstall -y $(PROJECT)
|
||||
find . | grep -E '__pycache__|\.pyc' | xargs rm -rf
|
||||
find . | grep -E '\.so' | xargs rm -rf
|
||||
rm -rf dist/ build/
|
||||
rm -rf $(PROJECT).egg-info/
|
||||
|
||||
.PHONY: format
|
||||
format:
|
||||
@ which black >/dev/null || python3 -m pip install black
|
||||
@ which isort >/dev/null || python3 -m pip install isort
|
||||
cargo fmt -- --config group_imports=StdExternalCrate
|
||||
black --line-length 100 --skip-magic-trailing-comma --quiet .
|
||||
isort --line-length 100 --profile black --quiet .
|
||||
|
||||
|
||||
.PHONY: format_check
|
||||
format_check:
|
||||
@ which black >/dev/null || python3 -m pip install black
|
||||
@ which isort >/dev/null || python3 -m pip install isort
|
||||
cargo fmt --check -- --config group_imports=StdExternalCrate
|
||||
black --check --line-length 100 --skip-magic-trailing-comma --quiet .
|
||||
isort --check --line-length 100 --profile black --quiet .
|
||||
|
||||
.PHONY: lint
|
||||
lint:
|
||||
cargo clippy --all -- -D warnings
|
||||
@ which flake8 >/dev/null || python3 -m pip install flake8==5 flake8-bugbear==22.9.11
|
||||
flake8 --ignore=E203,E501,W503,E731 --per-file-ignores="$(PROJECT)/__init__.py:F401 setup.py:E402" --exclude=build .
|
||||
|
||||
.PHONY: editable_install
|
||||
editable_install:
|
||||
@ if [ -f $(PROJECT).egg-info ]; then \
|
||||
pip install --disable-pip-version-check --progress-bar=off setuptools wheel setuptools-rust ; \
|
||||
pip install --disable-pip-version-check --no-build-isolation -e . ; \
|
||||
else \
|
||||
pip install --disable-pip-version-check --no-deps --no-build-isolation --ignore-installed -e . ; \
|
||||
fi
|
28
README.md
Normal file
28
README.md
Normal file
@ -0,0 +1,28 @@
|
||||
# ⏳ tiktoken
|
||||
|
||||
tiktoken is a fast tokeniser.
|
||||
|
||||
```python
|
||||
import tiktoken
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
print(enc.encode("hello world"))
|
||||
```
|
||||
|
||||
The open source version of `tiktoken` can be installed from PyPI:
|
||||
```
|
||||
pip install tiktoken
|
||||
```
|
||||
|
||||
The tokeniser API is documented in `tiktoken/core.py`.
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
`tiktoken` is between 3-6x faster than huggingface's tokeniser:
|
||||
|
||||

|
||||
|
||||
Performance measured on 1GB of text using the GPT-2 tokeniser, using `GPT2TokenizerFast` from
|
||||
`tokenizers==0.13.2` and `transformers==4.24.0`.
|
||||
|
||||
|
373
perf.svg
Normal file
373
perf.svg
Normal file
@ -0,0 +1,373 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" width="569.334pt" height="328.0869pt" viewBox="0 0 569.334 328.0869">
|
||||
<g enable-background="new">
|
||||
<g>
|
||||
<clipPath id="cp0">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp0)">
|
||||
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 219.5869 L 569 219.5869 "/>
|
||||
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 150.5869 L 569 150.5869 "/>
|
||||
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 82.58685 L 569 82.58685 "/>
|
||||
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 13.58685 L 569 13.58685 "/>
|
||||
</g>
|
||||
<clipPath id="cp1">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 39.64996 L -.0000120107 314.4229 L 20.496 314.4229 L 20.49601 39.64996 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp1)">
|
||||
<clipPath id="cp2">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp2)">
|
||||
<text xml:space="preserve" transform="matrix(0 -1 1 0 11.10803 182.06452)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.888 13.560001 17.340003 24.228003 30.900004 37.788003 44.460004 51.576005 58.248006">Throughput</tspan></text>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp3">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 26.168 31.81796 L 67.843997 31.81796 L 67.843997 47.48196 L 26.168 47.48196 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp3)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 27.668 292.71299)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 10.008001 20.460003 28.680005 32.676004">0 MB/s</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp4">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 100.5112 L 67.844 100.5112 L 67.844 116.1752 L 19.496 116.1752 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp4)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 224.01972)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">10 MB/s</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp5">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 169.2044 L 67.844 169.2044 L 67.844 184.86841 L 19.496 184.86841 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp5)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 155.3265)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">20 MB/s</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp6">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 237.8976 L 67.844 237.8976 L 67.844 253.5616 L 19.496 253.5616 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp6)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 86.633319)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">30 MB/s</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp7">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 306.5909 L 67.844 306.5909 L 67.844 322.2549 L 19.496 322.2549 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp7)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 17.940125)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">40 MB/s</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp8">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp8)">
|
||||
<path stroke-width="1" stroke-linecap="square" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#000000" d="M 78.5 288.5869 L 568.5 288.5869 "/>
|
||||
</g>
|
||||
<clipPath id="cp9">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 78.83396 0 L 568.834 0 L 568.834 20.496 L 78.83396 20.496 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp9)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 288.266 325.53096)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.888 13.560001 17.340003 23.784003 30.228003 37.344 40.68 47.124 54.012 60.684003 67.356">Thread count</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp10">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 108.998 19.496 L 118.67 19.496 L 118.67 35.16 L 108.998 35.16 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp10)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 110.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">1</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp11">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 178.998 19.496 L 188.67 19.496 L 188.67 35.16 L 178.998 35.16 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp11)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 180.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">2</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp12">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 248.998 19.496 L 258.67 19.496 L 258.67 35.16 L 248.998 35.16 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp12)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 250.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">4</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp13">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 318.998 19.496 L 328.66999 19.496 L 328.66999 35.16 L 318.998 35.16 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp13)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 320.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">8</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp14">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 385.662 19.496 L 402.00599 19.496 L 402.00599 35.16 L 385.662 35.16 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp14)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 387.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">16</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp15">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 455.662 19.496 L 472.00599 19.496 L 472.00599 35.16 L 455.662 35.16 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp15)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 457.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">32</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp16">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 525.662 19.496 L 542.006 19.496 L 542.006 35.16 L 525.662 35.16 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp16)">
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 527.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">64</tspan></text>
|
||||
</g>
|
||||
<clipPath id="cp17">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 115 40 L 143 40 L 143 52 L 115 52 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp17)">
|
||||
<g>
|
||||
<clipPath id="cp18">
|
||||
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp18)">
|
||||
<clipPath id="cp19">
|
||||
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp19)">
|
||||
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z " fill="#61d836"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp20">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 185 40 L 213 40 L 213 61 L 185 61 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp20)">
|
||||
<g>
|
||||
<clipPath id="cp21">
|
||||
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp21)">
|
||||
<clipPath id="cp22">
|
||||
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp22)">
|
||||
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z " fill="#61d836"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp23">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 255 40 L 283 40 L 283 74 L 255 74 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp23)">
|
||||
<g>
|
||||
<clipPath id="cp24">
|
||||
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp24)">
|
||||
<clipPath id="cp25">
|
||||
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp25)">
|
||||
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z " fill="#61d836"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp26">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 325 40 L 353 40 L 353 80 L 325 80 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp26)">
|
||||
<g>
|
||||
<clipPath id="cp27">
|
||||
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp27)">
|
||||
<clipPath id="cp28">
|
||||
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp28)">
|
||||
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z " fill="#61d836"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp29">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 395 40 L 423 40 L 423 83 L 395 83 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp29)">
|
||||
<g>
|
||||
<clipPath id="cp30">
|
||||
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp30)">
|
||||
<clipPath id="cp31">
|
||||
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp31)">
|
||||
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z " fill="#61d836"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp32">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 465 40 L 493 40 L 493 85 L 465 85 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp32)">
|
||||
<g>
|
||||
<clipPath id="cp33">
|
||||
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp33)">
|
||||
<clipPath id="cp34">
|
||||
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp34)">
|
||||
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z " fill="#61d836"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp35">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 535 40 L 563 40 L 563 88 L 535 88 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp35)">
|
||||
<g>
|
||||
<clipPath id="cp36">
|
||||
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp36)">
|
||||
<clipPath id="cp37">
|
||||
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp37)">
|
||||
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z " fill="#61d836"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp38">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 84 40 L 112 40 L 112 84 L 84 84 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp38)">
|
||||
<g>
|
||||
<clipPath id="cp39">
|
||||
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp39)">
|
||||
<clipPath id="cp40">
|
||||
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp40)">
|
||||
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z " fill="#00a2ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp41">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 154 40 L 182 40 L 182 123 L 154 123 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp41)">
|
||||
<g>
|
||||
<clipPath id="cp42">
|
||||
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp42)">
|
||||
<clipPath id="cp43">
|
||||
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp43)">
|
||||
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z " fill="#00a2ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp44">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 224 40 L 252 40 L 252 189 L 224 189 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp44)">
|
||||
<g>
|
||||
<clipPath id="cp45">
|
||||
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp45)">
|
||||
<clipPath id="cp46">
|
||||
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp46)">
|
||||
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z " fill="#00a2ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp47">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 294 40 L 322 40 L 322 258 L 294 258 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp47)">
|
||||
<g>
|
||||
<clipPath id="cp48">
|
||||
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp48)">
|
||||
<clipPath id="cp49">
|
||||
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp49)">
|
||||
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z " fill="#00a2ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp50">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 364 40 L 392 40 L 392 303 L 364 303 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp50)">
|
||||
<g>
|
||||
<clipPath id="cp51">
|
||||
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp51)">
|
||||
<clipPath id="cp52">
|
||||
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp52)">
|
||||
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z " fill="#00a2ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp53">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 434 40 L 462 40 L 462 203 L 434 203 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp53)">
|
||||
<g>
|
||||
<clipPath id="cp54">
|
||||
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp54)">
|
||||
<clipPath id="cp55">
|
||||
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp55)">
|
||||
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z " fill="#00a2ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp56">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 504 40 L 532 40 L 532 195 L 504 195 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp56)">
|
||||
<g>
|
||||
<clipPath id="cp57">
|
||||
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp57)">
|
||||
<clipPath id="cp58">
|
||||
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp58)">
|
||||
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z " fill="#00a2ff"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
<clipPath id="cp59">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
|
||||
</clipPath>
|
||||
<g clip-path="url(#cp59)">
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 459 291 L 471 291 L 471 280 L 459 280 Z " fill="#00a2ff"/>
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 477.8753 46.772127)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 3.7800003 6.4440004 12.672001 16.452002 23.340003 29.568003 36.012">tiktoken</tspan></text>
|
||||
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 459 278 L 471 278 L 471 266 L 459 266 Z " fill="#61d836"/>
|
||||
<text xml:space="preserve" transform="matrix(1 0 -0 1 477.8753 60.436128)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 20.232003 27.120003 29.784003 36.456 43.344 46.896005 53.340005 59.784006">huggingface</tspan></text>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 15 KiB |
8
pyproject.toml
Normal file
8
pyproject.toml
Normal file
@ -0,0 +1,8 @@
|
||||
[project]
|
||||
name = "tiktoken"
|
||||
dependencies = ["blobfile>=2", "regex>=2022.1.18"]
|
||||
dynamic = ["version"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel", "setuptools-rust"]
|
||||
|
39
scripts/benchmark.py
Normal file
39
scripts/benchmark.py
Normal file
@ -0,0 +1,39 @@
|
||||
import base64
|
||||
import functools
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import blobfile
|
||||
|
||||
import tiktoken
|
||||
|
||||
|
||||
def benchmark_batch(documents: list[str]) -> None:
|
||||
num_threads = int(os.environ["RAYON_NUM_THREADS"])
|
||||
num_bytes = sum(map(len, map(str.encode, documents)))
|
||||
print(f"num_threads: {num_threads}, num_bytes: {num_bytes}")
|
||||
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
enc.encode("warmup")
|
||||
|
||||
start = time.perf_counter_ns()
|
||||
enc.encode_ordinary_batch(documents, num_threads=num_threads)
|
||||
end = time.perf_counter_ns()
|
||||
print(f"tiktoken \t{num_bytes / (end - start) * 1e9} bytes / s")
|
||||
|
||||
import transformers
|
||||
|
||||
hf_enc = cast(Any, transformers).GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
hf_enc.model_max_length = 1e30 # silence!
|
||||
hf_enc.encode("warmup")
|
||||
|
||||
start = time.perf_counter_ns()
|
||||
hf_enc(documents)
|
||||
end = time.perf_counter_ns()
|
||||
print(f"huggingface \t{num_bytes / (end - start) * 1e9} bytes / s")
|
||||
|
||||
|
65
scripts/redact.py
Normal file
65
scripts/redact.py
Normal file
@ -0,0 +1,65 @@
|
||||
import argparse
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def redact_file(path: Path, dry_run: bool) -> None:
|
||||
if not path.exists() or path.is_dir():
|
||||
return
|
||||
|
||||
text = path.read_text()
|
||||
|
||||
first_line = text.splitlines()[0]
|
||||
if "redact" in first_line:
|
||||
if not dry_run:
|
||||
path.unlink()
|
||||
print(f"Deleted {path}")
|
||||
return
|
||||
|
||||
pattern = "|".join(
|
||||
re.escape(x)
|
||||
for x in [
|
||||
"# ===== redact-beg =====\n",
|
||||
"# ===== redact-end =====\n",
|
||||
"<!--- redact-beg -->\n",
|
||||
"<!--- redact-end -->\n",
|
||||
]
|
||||
)
|
||||
|
||||
if re.search(pattern, text):
|
||||
redacted_text = "".join(re.split(pattern, text)[::2])
|
||||
if not dry_run:
|
||||
path.write_text(redacted_text)
|
||||
print(f"Redacted {path}")
|
||||
return
|
||||
|
||||
print(f"Skipped {path}")
|
||||
|
||||
|
||||
def redact(dry_run: bool) -> None:
|
||||
tiktoken_root = Path(__file__).parent.parent
|
||||
assert tiktoken_root.name == "tiktoken"
|
||||
assert (tiktoken_root / "pyproject.toml").exists()
|
||||
|
||||
try:
|
||||
output = subprocess.check_output(["git", "ls-files"], cwd=tiktoken_root, text=True)
|
||||
paths = [Path(p) for p in output.splitlines()]
|
||||
except subprocess.CalledProcessError:
|
||||
paths = list(tiktoken_root.glob("**/*"))
|
||||
|
||||
for path in paths:
|
||||
redact_file(path, dry_run=dry_run)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dry-run", type=lambda x: not x or x[0].lower() != "f", default=True)
|
||||
args = parser.parse_args()
|
||||
redact(args.dry_run)
|
||||
if args.dry_run:
|
||||
print("Dry run, use --dry-run=false to actually redact files")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
23
setup.py
Normal file
23
setup.py
Normal file
@ -0,0 +1,23 @@
|
||||
from setuptools import setup
|
||||
from setuptools_rust import Binding, RustExtension
|
||||
|
||||
public = True
|
||||
|
||||
if public:
|
||||
version = "0.1"
|
||||
|
||||
setup(
|
||||
name="tiktoken",
|
||||
version=version,
|
||||
rust_extensions=[
|
||||
RustExtension(
|
||||
"tiktoken._tiktoken",
|
||||
binding=Binding.PyO3,
|
||||
# Between our use of editable installs and wanting to use Rust for performance sensitive
|
||||
# code, it makes sense to just always use --release
|
||||
debug=False,
|
||||
)
|
||||
],
|
||||
packages=["tiktoken", "tiktoken_ext"],
|
||||
zip_safe=False,
|
||||
)
|
559
src/lib.rs
Normal file
559
src/lib.rs
Normal file
@ -0,0 +1,559 @@
|
||||
// This check is new and seems buggy (possibly with PyO3 interaction)
|
||||
#![allow(clippy::borrow_deref_ref)]
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::thread;
|
||||
|
||||
use fancy_regex::Regex;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyBytes, PyList, PyTuple};
|
||||
use pyo3::PyResult;
|
||||
use rustc_hash::FxHashMap as HashMap;
|
||||
|
||||
fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
|
||||
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
|
||||
|
||||
// If you have n parts and m merges, this does O(mn) work
|
||||
// We could do something with a heap and do O(m log n) work
|
||||
|
||||
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
|
||||
// currently do, this is equivalent. An easy way to break this would be to decouple
|
||||
// merge priority from token index or to prevent specific token merges.
|
||||
loop {
|
||||
if parts.len() == 1 {
|
||||
break;
|
||||
}
|
||||
let mut min_rank: Option<(usize, usize)> = None;
|
||||
for i in 0..parts.len() - 1 {
|
||||
let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
|
||||
*r
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
if min_rank.is_none() || rank < min_rank.unwrap().0 {
|
||||
min_rank = Some((rank, i));
|
||||
}
|
||||
}
|
||||
if let Some((_, i)) = min_rank {
|
||||
parts[i] = parts[i].start..parts[i + 1].end;
|
||||
parts.remove(i + 1);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
parts
|
||||
}
|
||||
|
||||
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
|
||||
if piece.len() == 1 {
|
||||
return vec![ranks[piece]];
|
||||
}
|
||||
_byte_pair_merge(piece, ranks)
|
||||
.iter()
|
||||
.map(|p| ranks[&piece[p.start..p.end]])
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
|
||||
if piece.len() == 1 {
|
||||
return vec![piece];
|
||||
}
|
||||
_byte_pair_merge(piece, ranks)
|
||||
.iter()
|
||||
.map(|p| &piece[p.start..p.end])
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Various performance notes:
|
||||
//
|
||||
// Regex
|
||||
// =====
|
||||
// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
|
||||
// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
|
||||
// the usual regex we use.
|
||||
//
|
||||
// However, given that we're using a regex parse-able by `regex`, there isn't much difference
|
||||
// between using the `regex` crate and using the `fancy_regex` crate.
|
||||
//
|
||||
// There is an important interaction between threading, `regex` and `fancy_regex`.
|
||||
// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
|
||||
// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
|
||||
// old `regex`, we don't hit this, because `find_iter` has a different code path.
|
||||
// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
|
||||
// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
|
||||
// each thread.
|
||||
//
|
||||
// Threading
|
||||
// =========
|
||||
// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
|
||||
// So goodbye `rayon`! Let thread count etc be in control of our Python users.
|
||||
//
|
||||
// Caching
|
||||
// =======
|
||||
// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
|
||||
// Originally, we had one too! Without it, we were only vaguely faster than Python.
|
||||
// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
|
||||
// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
|
||||
// multi-threaded performance even when I only had readers (maybed I messed something up?).
|
||||
// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
|
||||
// These are exactly the set or merges that are likely to be hot. And now we don't have to think
|
||||
// about interior mutability, memory use, or cloning.
|
||||
//
|
||||
// Hashing
|
||||
// =======
|
||||
// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
|
||||
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
|
||||
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
|
||||
|
||||
use std::num::NonZeroU64;
|
||||
pub struct FakeThreadId(NonZeroU64);
|
||||
|
||||
fn hash_current_thread() -> usize {
|
||||
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
|
||||
// that works great for our use case of avoiding collisions in our array. Unfortunately,
|
||||
// it's private. However, there are only so many ways you can layout a u64, so just transmute
|
||||
// https://github.com/rust-lang/rust/issues/67939
|
||||
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
|
||||
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
|
||||
let x = unsafe {
|
||||
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
|
||||
};
|
||||
u64::from(x) as usize
|
||||
}
|
||||
|
||||
const MAX_NUM_THREADS: usize = 128;
|
||||
#[pyclass]
|
||||
struct CoreBPE {
|
||||
encoder: HashMap<Vec<u8>, usize>,
|
||||
special_tokens_encoder: HashMap<String, usize>,
|
||||
decoder: HashMap<usize, Vec<u8>>,
|
||||
special_tokens_decoder: HashMap<usize, Vec<u8>>,
|
||||
regex_tls: Vec<Regex>,
|
||||
special_regex_tls: Vec<Regex>,
|
||||
sorted_token_bytes: Vec<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl CoreBPE {
|
||||
fn _get_tl_regex(&self) -> &Regex {
|
||||
// See performance notes above for what this is about
|
||||
// It's also a little janky, please make a better version of it!
|
||||
// However, it's nice that this doesn't leak memory to short-lived threads
|
||||
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
|
||||
}
|
||||
|
||||
fn _get_tl_special_regex(&self) -> &Regex {
|
||||
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
|
||||
}
|
||||
|
||||
fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
|
||||
let mut ret = Vec::with_capacity(tokens.len() * 2);
|
||||
for token in tokens {
|
||||
let token_bytes = self
|
||||
.decoder
|
||||
.get(token)
|
||||
.unwrap_or_else(|| &self.special_tokens_decoder[token]);
|
||||
ret.extend(token_bytes);
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
|
||||
// This is the core of the encoding logic; the other functions in here
|
||||
// just make things complicated :-)
|
||||
let regex = self._get_tl_regex();
|
||||
let mut ret = vec![];
|
||||
for mat in regex.find_iter(text) {
|
||||
let piece = mat.unwrap().as_str().as_bytes();
|
||||
if let Some(token) = self.encoder.get(piece) {
|
||||
ret.push(*token);
|
||||
continue;
|
||||
}
|
||||
ret.extend(&byte_pair_encode(piece, &self.encoder));
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
|
||||
let special_regex = self._get_tl_special_regex();
|
||||
let regex = self._get_tl_regex();
|
||||
let mut ret = vec![];
|
||||
|
||||
let mut start = 0;
|
||||
let mut last_piece_token_len = 0;
|
||||
loop {
|
||||
let mut next_special;
|
||||
let mut start_find = start;
|
||||
loop {
|
||||
// Find the next allowed special token, if any
|
||||
next_special = special_regex.find_from_pos(text, start_find).unwrap();
|
||||
match next_special {
|
||||
Some(m) => {
|
||||
if allowed_special.contains(&text[m.start()..m.end()]) {
|
||||
break;
|
||||
}
|
||||
start_find = m.start() + 1;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
let end = next_special.map_or(text.len(), |m| m.start());
|
||||
|
||||
// Okay, here we go, compare this logic to _encode_ordinary_native
|
||||
for mat in regex.find_iter(&text[start..end]) {
|
||||
let piece = mat.unwrap().as_str().as_bytes();
|
||||
if let Some(token) = self.encoder.get(piece) {
|
||||
last_piece_token_len = 1;
|
||||
ret.push(*token);
|
||||
continue;
|
||||
}
|
||||
let tokens = byte_pair_encode(piece, &self.encoder);
|
||||
last_piece_token_len = tokens.len();
|
||||
ret.extend(&tokens);
|
||||
}
|
||||
|
||||
match next_special {
|
||||
// And here we push the special token
|
||||
Some(m) => {
|
||||
let piece = m.as_str();
|
||||
let token = self.special_tokens_encoder[piece];
|
||||
ret.push(token);
|
||||
start = m.end();
|
||||
last_piece_token_len = 0;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
// last_piece_token_len is how many tokens came from the last regex split. This is used
|
||||
// for determining unstable tokens, since you can't merge across (stable) regex splits
|
||||
(ret, last_piece_token_len)
|
||||
}
|
||||
|
||||
fn _increase_last_piece_token_len(
|
||||
&self,
|
||||
tokens: Vec<usize>,
|
||||
mut last_piece_token_len: usize,
|
||||
) -> (Vec<usize>, usize) {
|
||||
// Unfortunately, the locations where our regex splits can be unstable.
|
||||
// For the purposes of determining unstable tokens, unstable regex splitting
|
||||
// is only a problem if a split that was present disappears, since this can
|
||||
// lead to merging of tokens otherwise thought to be stable.
|
||||
// cl100k_base makes our life hard by including the \s*[\r\n]+
|
||||
// pattern. This can e.g. cause "\n" + " " to become "\n \n".
|
||||
// Here is a quick and dirty fix:
|
||||
{
|
||||
let token_is_all_space = |token| {
|
||||
self.decoder
|
||||
.get(token)
|
||||
.map(|token_bytes| {
|
||||
token_bytes
|
||||
.iter()
|
||||
.rev()
|
||||
.all(|&b| [b' ', b'\n', b'\t'].contains(&b))
|
||||
})
|
||||
.unwrap_or(false)
|
||||
};
|
||||
if last_piece_token_len > 0
|
||||
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
|
||||
{
|
||||
while (last_piece_token_len < tokens.len())
|
||||
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
|
||||
{
|
||||
last_piece_token_len += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_assert!(last_piece_token_len <= tokens.len());
|
||||
|
||||
(tokens, last_piece_token_len)
|
||||
}
|
||||
|
||||
fn _encode_unstable_native(
|
||||
&self,
|
||||
text: &str,
|
||||
allowed_special: &HashSet<&str>,
|
||||
) -> (Vec<usize>, HashSet<Vec<usize>>) {
|
||||
let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
|
||||
if last_piece_token_len == 0 {
|
||||
// If last_piece_token_len is zero, the last token was a special token and we have
|
||||
// no unstable bytes
|
||||
return (tokens, HashSet::new());
|
||||
}
|
||||
let (mut tokens, last_piece_token_len) =
|
||||
self._increase_last_piece_token_len(tokens, last_piece_token_len);
|
||||
|
||||
let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
|
||||
tokens.truncate(tokens.len() - last_piece_token_len);
|
||||
|
||||
// TODO: we should try harder to find additional stable tokens
|
||||
// This would reduce the amount of retokenising when determining completions
|
||||
// Refer to the logic in an older version of this file
|
||||
|
||||
let mut completions = HashSet::new();
|
||||
if unstable_bytes.is_empty() {
|
||||
return (tokens, completions);
|
||||
}
|
||||
|
||||
// This is the easy bit. Just find all single tokens that start with unstable_bytes
|
||||
// (including tokens that exactly match unstable_bytes)
|
||||
// Separating this from the loop below helps with performance in a common case.
|
||||
let mut point = self
|
||||
.sorted_token_bytes
|
||||
.partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
|
||||
while point < self.sorted_token_bytes.len()
|
||||
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
|
||||
{
|
||||
completions.insert(vec![
|
||||
self.encoder[self.sorted_token_bytes[point].as_slice()],
|
||||
]);
|
||||
point += 1;
|
||||
}
|
||||
|
||||
// Now apply even more brute force. At every (other) possible position for the straddling
|
||||
// token, concatenate additional bytes from that token (if any) to unstable_bytes,
|
||||
// and retokenise the whole thing and see what we get.
|
||||
for i in 1..unstable_bytes.len() {
|
||||
let prefix = &unstable_bytes[..i];
|
||||
let suffix = &unstable_bytes[i..];
|
||||
let mut point = self
|
||||
.sorted_token_bytes
|
||||
.partition_point(|x| x.as_slice() < suffix);
|
||||
// TODO: Perf optimisation if suffix starts with " "?
|
||||
while point < self.sorted_token_bytes.len()
|
||||
&& self.sorted_token_bytes[point].starts_with(suffix)
|
||||
{
|
||||
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
|
||||
let encoded = match std::str::from_utf8(&possibility) {
|
||||
// Morally, this is byte_pair_encode(&possibility, &self.encoder)
|
||||
// But we might have introduced a regex split which would prevent merges.
|
||||
// (particularly possible in the presence of unstable regex splits)
|
||||
// So convert to UTF-8 and do regex splitting.
|
||||
// E.g. with cl100k_base " !" gets split to " " + " !",
|
||||
// but byte_pair_encode(" !") != byte_pair_encode(" ")
|
||||
Ok(s) => self._encode_ordinary_native(s),
|
||||
|
||||
// Technically, whether or not this arm is correct depends on whether there
|
||||
// would be a regex split before the UTF-8 truncation point.
|
||||
// Probably niche enough that no one will ever notice (after all, people didn't
|
||||
// notice all the big holes in the previous unstable token implementation)
|
||||
Err(_) => byte_pair_encode(&possibility, &self.encoder),
|
||||
// Something like the following is intriguing but incorrect:
|
||||
// Err(e) => self._encode_ordinary_native(unsafe {
|
||||
// std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
|
||||
// }),
|
||||
};
|
||||
let mut seq = Vec::new();
|
||||
let mut seq_len = 0;
|
||||
for token in encoded {
|
||||
seq.push(token);
|
||||
seq_len += self.decoder[&token].len();
|
||||
if seq_len >= unstable_bytes.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
completions.insert(seq);
|
||||
point += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// This is also not straightforward. While we generally assume that regex splits are stable,
|
||||
// unfortunately, they are not. That is, if adding bytes were to make a split appear in
|
||||
// unstable_bytes, this could make tokens possible which our logic would otherwise think
|
||||
// would be merged.
|
||||
// For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
|
||||
// develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
|
||||
// Here is a quick and dirty fix:
|
||||
// This isn't right if we ever remove \s+(?!\S)
|
||||
if unstable_bytes.len() > 1 {
|
||||
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
|
||||
if unstable_bytes.len() - last_decoded.1 > 0
|
||||
&& last_decoded.0.map_or(false, |c| c.is_whitespace())
|
||||
{
|
||||
let mut reencoded = byte_pair_encode(
|
||||
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
|
||||
&self.encoder,
|
||||
);
|
||||
reencoded.extend(byte_pair_encode(
|
||||
&unstable_bytes[unstable_bytes.len() - last_decoded.1..],
|
||||
&self.encoder,
|
||||
));
|
||||
completions.insert(reencoded);
|
||||
}
|
||||
}
|
||||
|
||||
(tokens, completions)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl CoreBPE {
|
||||
#[new]
|
||||
fn new(
|
||||
encoder: HashMap<Vec<u8>, usize>,
|
||||
special_tokens_encoder: HashMap<String, usize>,
|
||||
pattern: &str,
|
||||
) -> PyResult<Self> {
|
||||
let regex = Regex::new(pattern)
|
||||
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
|
||||
|
||||
let special_regex = {
|
||||
let _parts = special_tokens_encoder
|
||||
.keys()
|
||||
.map(|s| fancy_regex::escape(s))
|
||||
.collect::<Vec<_>>();
|
||||
Regex::new(&_parts.join("|"))
|
||||
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
|
||||
};
|
||||
|
||||
let decoder: HashMap<usize, Vec<u8>> =
|
||||
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
|
||||
|
||||
assert!(encoder.len() == decoder.len());
|
||||
|
||||
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
|
||||
.iter()
|
||||
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
|
||||
.collect();
|
||||
|
||||
// Clone because I don't know how to tell Rust I'm not going to change the map
|
||||
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
|
||||
sorted_token_bytes.sort();
|
||||
|
||||
Ok(CoreBPE {
|
||||
encoder,
|
||||
special_tokens_encoder,
|
||||
decoder,
|
||||
special_tokens_decoder,
|
||||
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
|
||||
special_regex_tls: (0..MAX_NUM_THREADS)
|
||||
.map(|_| special_regex.clone())
|
||||
.collect(),
|
||||
sorted_token_bytes,
|
||||
})
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Encoding
|
||||
// ====================
|
||||
|
||||
fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> {
|
||||
py.allow_threads(|| self._encode_ordinary_native(text))
|
||||
}
|
||||
|
||||
fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
|
||||
py.allow_threads(|| self._encode_native(text, &allowed_special).0)
|
||||
}
|
||||
|
||||
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
|
||||
py.allow_threads(|| {
|
||||
match std::str::from_utf8(bytes) {
|
||||
Ok(text) => self._encode_ordinary_native(text),
|
||||
Err(e) => {
|
||||
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
|
||||
let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new());
|
||||
let (mut tokens, last_piece_token_len) =
|
||||
self._increase_last_piece_token_len(tokens, last_piece_token_len);
|
||||
if !tokens.is_empty() && last_piece_token_len > 0 {
|
||||
// Lop off the tokens from the last piece and run BPE on the remaining bytes
|
||||
// Somewhat niche, but this may not be correct if we'd have had a regex
|
||||
// split between the valid UTF-8 and the invalid bytes, which is why this
|
||||
// method is private
|
||||
let mut unstable_bytes =
|
||||
self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
|
||||
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
|
||||
|
||||
tokens.truncate(tokens.len() - last_piece_token_len);
|
||||
tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
|
||||
}
|
||||
tokens
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn encode_with_unstable(
|
||||
&self,
|
||||
py: Python,
|
||||
text: &str,
|
||||
allowed_special: HashSet<&str>,
|
||||
) -> Py<PyTuple> {
|
||||
let (tokens, completions) =
|
||||
py.allow_threads(|| self._encode_unstable_native(text, &allowed_special));
|
||||
let py_completions =
|
||||
PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..])));
|
||||
(tokens, py_completions).into_py(py)
|
||||
}
|
||||
|
||||
fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> {
|
||||
if let Some(token) = self.encoder.get(piece).copied() {
|
||||
return Ok(token);
|
||||
}
|
||||
if let Ok(piece_str) = std::str::from_utf8(piece) {
|
||||
if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
|
||||
return Ok(token);
|
||||
}
|
||||
}
|
||||
Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
|
||||
}
|
||||
|
||||
fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
|
||||
if let Some(token) = self.encoder.get(piece) {
|
||||
return vec![*token];
|
||||
}
|
||||
byte_pair_encode(piece, &self.encoder)
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Decoding
|
||||
// ====================
|
||||
|
||||
fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
|
||||
let bytes = py.allow_threads(|| self._decode_native(&tokens));
|
||||
PyBytes::new(py, &bytes).into()
|
||||
}
|
||||
|
||||
fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
|
||||
if let Some(bytes) = self.decoder.get(&token) {
|
||||
return Ok(PyBytes::new(py, bytes).into());
|
||||
}
|
||||
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
|
||||
return Ok(PyBytes::new(py, bytes).into());
|
||||
}
|
||||
Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Miscellaneous
|
||||
// ====================
|
||||
|
||||
fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
|
||||
self.sorted_token_bytes
|
||||
.iter()
|
||||
.map(|x| PyBytes::new(py, x).into())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<CoreBPE>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rustc_hash::FxHashMap as HashMap;
|
||||
|
||||
use crate::byte_pair_split;
|
||||
|
||||
#[test]
|
||||
fn very_simple_test() {
|
||||
let mut ranks = HashMap::default();
|
||||
ranks.insert(b"ab".to_vec(), 1);
|
||||
ranks.insert(b"cd".to_vec(), 2);
|
||||
|
||||
let res = byte_pair_split(b"abcd", &ranks);
|
||||
assert_eq!(res, vec![b"ab", b"cd"]);
|
||||
}
|
||||
}
|
3
tiktoken/__init__.py
Normal file
3
tiktoken/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .core import Encoding as Encoding
|
||||
from .registry import get_encoding as get_encoding
|
||||
from .registry import list_encoding_names as list_encoding_names
|
310
tiktoken/core.py
Normal file
310
tiktoken/core.py
Normal file
@ -0,0 +1,310 @@
|
||||
import functools
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union
|
||||
|
||||
import regex
|
||||
|
||||
from tiktoken import _tiktoken
|
||||
|
||||
|
||||
class Encoding:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
pat_str: str,
|
||||
mergeable_ranks: dict[bytes, int],
|
||||
special_tokens: dict[str, int],
|
||||
explicit_n_vocab: Optional[int] = None,
|
||||
):
|
||||
self.name = name
|
||||
|
||||
self._pat_str = pat_str
|
||||
self._mergeable_ranks = mergeable_ranks
|
||||
self._special_tokens = special_tokens
|
||||
|
||||
self.max_token_value = max(
|
||||
max(mergeable_ranks.values()), max(special_tokens.values(), default=0)
|
||||
)
|
||||
if explicit_n_vocab:
|
||||
assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab
|
||||
assert self.max_token_value == explicit_n_vocab - 1
|
||||
|
||||
self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Encoding {self.name!r}>"
|
||||
|
||||
# ====================
|
||||
# Encoding
|
||||
# ====================
|
||||
|
||||
def encode_ordinary(self, text: str) -> list[int]:
|
||||
"""Encodes a string into tokens, ignoring special tokens.
|
||||
|
||||
This is equivalent to `encode(text, disallowed_special=())` (but slightly faster).
|
||||
|
||||
```
|
||||
>>> enc.encode_ordinary("hello world")
|
||||
[31373, 995]
|
||||
"""
|
||||
return self._core_bpe.encode_ordinary(text)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> list[int]:
|
||||
"""Encodes a string into tokens.
|
||||
|
||||
Special tokens are tokens are artificial tokens used to unlock capabilities from a model,
|
||||
such as fill-in-the-middle. So we want to be careful about accidentally encoding special
|
||||
tokens, since they can be used to trick a model into doing something we don't want it to do.
|
||||
|
||||
Hence, by default, encode will raise an error if it encounters text that corresponds
|
||||
to a special token. This can be controlled on a per-token level using the `allowed_special`
|
||||
and `disallowed_special` parameters. In particular:
|
||||
- Setting `disallowed_special` to () will prevent this function from raising errors and
|
||||
cause all text corresponding to special tokens to be encoded as natural text.
|
||||
- Setting `allowed_special` to "all" will allow cause this function to treat all text
|
||||
corresponding to special tokens to be encoded as special tokens.
|
||||
|
||||
```
|
||||
>>> enc.encode("hello world")
|
||||
[31373, 995]
|
||||
>>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})
|
||||
[50256]
|
||||
>>> enc.encode("<|endoftext|>", allowed_special="all")
|
||||
[50256]
|
||||
>>> enc.encode("<|endoftext|>")
|
||||
# Raises ValueError
|
||||
>>> enc.encode("<|endoftext|>", disallowed_special=())
|
||||
[27, 91, 437, 1659, 5239, 91, 29]
|
||||
```
|
||||
"""
|
||||
if allowed_special == "all":
|
||||
allowed_special = self.special_tokens_set
|
||||
if disallowed_special == "all":
|
||||
disallowed_special = self.special_tokens_set - allowed_special
|
||||
if disallowed_special:
|
||||
if not isinstance(disallowed_special, frozenset):
|
||||
disallowed_special = frozenset(disallowed_special)
|
||||
if match := _special_token_regex(disallowed_special).search(text):
|
||||
raise_disallowed_special_token(match.group())
|
||||
|
||||
return self._core_bpe.encode(text, allowed_special)
|
||||
|
||||
def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
|
||||
"""Encodes a list of strings into tokens, in parallel, ignoring special tokens.
|
||||
|
||||
This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster).
|
||||
|
||||
```
|
||||
>>> enc.encode_batch(["hello world", "goodbye world"])
|
||||
[[31373, 995], [11274, 16390, 995]]
|
||||
```
|
||||
"""
|
||||
encoder = functools.partial(self.encode_ordinary)
|
||||
with ThreadPoolExecutor(num_threads) as e:
|
||||
return list(e.map(encoder, text))
|
||||
|
||||
def encode_batch(
|
||||
self,
|
||||
text: list[str],
|
||||
*,
|
||||
num_threads: int = 8,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> list[list[int]]:
|
||||
"""Encodes a list of strings into tokens, in parallel.
|
||||
|
||||
See `encode` for more details on `allowed_special` and `disallowed_special`.
|
||||
|
||||
```
|
||||
>>> enc.encode_batch(["hello world", "goodbye world"])
|
||||
[[31373, 995], [11274, 16390, 995]]
|
||||
```
|
||||
"""
|
||||
if allowed_special == "all":
|
||||
allowed_special = self.special_tokens_set
|
||||
if disallowed_special == "all":
|
||||
disallowed_special = self.special_tokens_set - allowed_special
|
||||
if not isinstance(disallowed_special, frozenset):
|
||||
disallowed_special = frozenset(disallowed_special)
|
||||
|
||||
encoder = functools.partial(
|
||||
self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special
|
||||
)
|
||||
with ThreadPoolExecutor(num_threads) as e:
|
||||
return list(e.map(encoder, text))
|
||||
|
||||
def encode_with_unstable(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> tuple[list[int], list[list[int]]]:
|
||||
"""Encodes a string into stable tokens and possible completion sequences.
|
||||
|
||||
Note that the stable tokens will only represent a substring of `text`.
|
||||
|
||||
See `encode` for more details on `allowed_special` and `disallowed_special`.
|
||||
|
||||
```
|
||||
>>> enc.encode_with_unstable("hello fanta")
|
||||
([31373], [(277, 4910), (5113, 265), ..., (8842,)])
|
||||
|
||||
>>> text = "..."
|
||||
>>> stable_tokens, completions = enc.encode_with_unstable(text)
|
||||
>>> assert text.encode().startswith(enc.decode_bytes(stable_tokens))
|
||||
>>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions)
|
||||
```
|
||||
"""
|
||||
if allowed_special == "all":
|
||||
allowed_special = self.special_tokens_set
|
||||
if disallowed_special == "all":
|
||||
disallowed_special = self.special_tokens_set - allowed_special
|
||||
if disallowed_special:
|
||||
if not isinstance(disallowed_special, frozenset):
|
||||
disallowed_special = frozenset(disallowed_special)
|
||||
if match := _special_token_regex(disallowed_special).search(text):
|
||||
raise_disallowed_special_token(match.group())
|
||||
|
||||
return self._core_bpe.encode_with_unstable(text, allowed_special)
|
||||
|
||||
def encode_single_token(self, text_or_bytes: Union[str, bytes]) -> int:
|
||||
"""Encodes text corresponding to a single token to its token value.
|
||||
|
||||
NOTE: this will encode all special tokens.
|
||||
|
||||
Raises `KeyError` if the token is not in the vocabulary.
|
||||
|
||||
```
|
||||
>>> enc.encode_single_token("hello")
|
||||
31373
|
||||
```
|
||||
"""
|
||||
if isinstance(text_or_bytes, str):
|
||||
text_or_bytes = text_or_bytes.encode("utf-8")
|
||||
return self._core_bpe.encode_single_token(text_or_bytes)
|
||||
|
||||
# ====================
|
||||
# Decoding
|
||||
# ====================
|
||||
|
||||
def decode_bytes(self, tokens: list[int]) -> bytes:
|
||||
"""Decodes a list of tokens into bytes.
|
||||
|
||||
```
|
||||
>>> enc.decode_bytes([31373, 995])
|
||||
b'hello world'
|
||||
```
|
||||
"""
|
||||
return self._core_bpe.decode_bytes(tokens)
|
||||
|
||||
def decode(self, tokens: list[int], errors: str = "replace") -> str:
|
||||
"""Decodes a list of tokens into a string.
|
||||
|
||||
WARNING: the default behaviour of this function is lossy, since decoded bytes are not
|
||||
guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter,
|
||||
for instance, setting `errors=strict`.
|
||||
|
||||
```
|
||||
>>> enc.decode([31373, 995])
|
||||
'hello world'
|
||||
```
|
||||
"""
|
||||
return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
|
||||
|
||||
def decode_single_token_bytes(self, token: int) -> bytes:
|
||||
"""Decodes a token into bytes.
|
||||
|
||||
NOTE: this will decode all special tokens.
|
||||
|
||||
Raises `KeyError` if the token is not in the vocabulary.
|
||||
|
||||
```
|
||||
>>> enc.decode_single_token_bytes(31373)
|
||||
b'hello'
|
||||
```
|
||||
"""
|
||||
return self._core_bpe.decode_single_token_bytes(token)
|
||||
|
||||
def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
|
||||
"""Decodes a list of tokens into a list of bytes.
|
||||
|
||||
Useful for visualising tokenisation.
|
||||
>>> enc.decode_tokens_bytes([31373, 995])
|
||||
[b'hello', b' world']
|
||||
"""
|
||||
return [self.decode_single_token_bytes(token) for token in tokens]
|
||||
|
||||
# ====================
|
||||
# Miscellaneous
|
||||
# ====================
|
||||
|
||||
def token_byte_values(self) -> list[bytes]:
|
||||
"""Returns the list of all token byte values."""
|
||||
return self._core_bpe.token_byte_values()
|
||||
|
||||
@property
|
||||
def eot_token(self) -> int:
|
||||
return self._special_tokens["<|endoftext|>"]
|
||||
|
||||
@functools.cached_property
|
||||
def special_tokens_set(self) -> set[str]:
|
||||
return set(self._special_tokens.keys())
|
||||
|
||||
@property
|
||||
def n_vocab(self) -> int:
|
||||
"""For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""
|
||||
return self.max_token_value + 1
|
||||
|
||||
# ====================
|
||||
# Private
|
||||
# ====================
|
||||
|
||||
def _encode_single_piece(self, text_or_bytes: Union[str, bytes]) -> list[int]:
|
||||
"""Encodes text corresponding to bytes without a regex split.
|
||||
|
||||
NOTE: this will not encode any special tokens.
|
||||
|
||||
```
|
||||
>>> enc.encode_single_piece("helloqqqq")
|
||||
[31373, 38227, 38227]
|
||||
```
|
||||
"""
|
||||
if isinstance(text_or_bytes, str):
|
||||
text_or_bytes = text_or_bytes.encode("utf-8")
|
||||
return self._core_bpe.encode_single_piece(text_or_bytes)
|
||||
|
||||
def _encode_only_native_bpe(self, text: str) -> list[str]:
|
||||
"""Encodes a string into tokens, but do regex splitting in Python."""
|
||||
_unused_pat = regex.compile(self._pat_str)
|
||||
ret = []
|
||||
for piece in regex.findall(_unused_pat, text):
|
||||
ret.extend(self._core_bpe.encode_single_piece(piece))
|
||||
return ret
|
||||
|
||||
def _encode_bytes(self, text: bytes) -> list[int]:
|
||||
return self._core_bpe._encode_bytes(text)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]":
|
||||
inner = "|".join(regex.escape(token) for token in tokens)
|
||||
return regex.compile(f"({inner})")
|
||||
|
||||
|
||||
def raise_disallowed_special_token(token: str) -> NoReturn:
|
||||
raise ValueError(
|
||||
f"Encountered text corresponding to disallowed special token {token!r}.\n"
|
||||
"If you want this text to be encoded as a special token, "
|
||||
f"pass it to `allowed_special`, e.g. `allowed_special={{{token!r}, ...}}`.\n"
|
||||
f"If you want this text to be encoded as normal text, disable the check for this token "
|
||||
f"by passing `disallowed_special=(enc.special_tokens_set - {{{token!r}}})`.\n"
|
||||
"To disable this check for all special tokens, pass `disallowed_special=()`.\n"
|
||||
)
|
97
tiktoken/load.py
Normal file
97
tiktoken/load.py
Normal file
@ -0,0 +1,97 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import blobfile
|
||||
|
||||
|
||||
def read_file_cached(blobpath: str) -> bytes:
|
||||
if "TIKTOKEN_CACHE_DIR" in os.environ:
|
||||
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
|
||||
elif "DATA_GYM_CACHE_DIR" in os.environ:
|
||||
cache_dir = os.environ["DATA_GYM_CACHE_DIR"]
|
||||
else:
|
||||
cache_dir = "/tmp/data-gym-cache"
|
||||
|
||||
if cache_dir == "":
|
||||
# disable caching
|
||||
with blobfile.BlobFile(blobpath, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
cache_key = hashlib.sha1(blobpath.encode()).hexdigest()
|
||||
|
||||
cache_path = os.path.join(cache_dir, cache_key)
|
||||
if os.path.exists(cache_path):
|
||||
with open(cache_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
with blobfile.BlobFile(blobpath, "rb") as f:
|
||||
contents = f.read()
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp"
|
||||
with open(tmp_filename, "wb") as f:
|
||||
f.write(contents)
|
||||
os.rename(tmp_filename, cache_path)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def data_gym_to_mergeable_bpe_ranks(
|
||||
vocab_bpe_file: str, encoder_json_file: str
|
||||
) -> dict[bytes, int]:
|
||||
# NB: do not add caching to this function
|
||||
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
|
||||
|
||||
data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte}
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in rank_to_intbyte:
|
||||
rank_to_intbyte.append(b)
|
||||
data_gym_byte_to_byte[chr(2**8 + n)] = b
|
||||
n += 1
|
||||
assert len(rank_to_intbyte) == 2**8
|
||||
|
||||
# vocab_bpe contains the merges along with associated ranks
|
||||
vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
|
||||
|
||||
def decode_data_gym(value: str) -> bytes:
|
||||
return bytes(data_gym_byte_to_byte[b] for b in value)
|
||||
|
||||
# add the single byte tokens
|
||||
bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)}
|
||||
# add the merged tokens
|
||||
n = len(bpe_ranks)
|
||||
for first, second in bpe_merges:
|
||||
bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n
|
||||
n += 1
|
||||
|
||||
# check that the encoder file matches the merges file
|
||||
# this sanity check is important since tiktoken assumes that ranks are ordered the same
|
||||
# as merge priority
|
||||
encoder_json = json.loads(read_file_cached(encoder_json_file))
|
||||
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
|
||||
# drop these two special tokens if present, since they're not mergeable bpe tokens
|
||||
encoder_json_loaded.pop(b"<|endoftext|>", None)
|
||||
encoder_json_loaded.pop(b"<|startoftext|>", None)
|
||||
assert bpe_ranks == encoder_json_loaded
|
||||
|
||||
return bpe_ranks
|
||||
|
||||
|
||||
def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None:
|
||||
with blobfile.BlobFile(tiktoken_bpe_file, "wb") as f:
|
||||
for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
|
||||
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
|
||||
|
||||
|
||||
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
|
||||
# NB: do not add caching to this function
|
||||
contents = read_file_cached(tiktoken_bpe_file)
|
||||
return {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in contents.splitlines() if line)
|
||||
}
|
71
tiktoken/registry.py
Normal file
71
tiktoken/registry.py
Normal file
@ -0,0 +1,71 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
import threading
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import tiktoken_ext
|
||||
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
_lock = threading.RLock()
|
||||
ENCODINGS: dict[str, Encoding] = {}
|
||||
ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None
|
||||
|
||||
|
||||
def _find_constructors() -> None:
|
||||
global ENCODING_CONSTRUCTORS
|
||||
with _lock:
|
||||
if ENCODING_CONSTRUCTORS is not None:
|
||||
return
|
||||
ENCODING_CONSTRUCTORS = {}
|
||||
|
||||
# tiktoken_ext is a namespace package
|
||||
# submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes
|
||||
# - we use namespace package pattern so `pkgutil.iter_modules` is fast
|
||||
# - it's a separate top-level package because namespace subpackages of non-namespace
|
||||
# packages don't quite do what you want with editable installs
|
||||
plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".")
|
||||
|
||||
for _, mod_name, _ in plugin_mods:
|
||||
mod = importlib.import_module(mod_name)
|
||||
try:
|
||||
constructors = mod.ENCODING_CONSTRUCTORS
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS"
|
||||
) from e
|
||||
for enc_name, constructor in constructors.items():
|
||||
if enc_name in ENCODING_CONSTRUCTORS:
|
||||
raise ValueError(
|
||||
f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}"
|
||||
)
|
||||
ENCODING_CONSTRUCTORS[enc_name] = constructor
|
||||
|
||||
|
||||
def get_encoding(encoding_name: str) -> Encoding:
|
||||
if encoding_name in ENCODINGS:
|
||||
return ENCODINGS[encoding_name]
|
||||
|
||||
with _lock:
|
||||
if encoding_name in ENCODINGS:
|
||||
return ENCODINGS[encoding_name]
|
||||
|
||||
if ENCODING_CONSTRUCTORS is None:
|
||||
_find_constructors()
|
||||
assert ENCODING_CONSTRUCTORS is not None
|
||||
|
||||
if encoding_name not in ENCODING_CONSTRUCTORS:
|
||||
raise ValueError(f"Unknown encoding {encoding_name}")
|
||||
|
||||
constructor = ENCODING_CONSTRUCTORS[encoding_name]
|
||||
enc = Encoding(**constructor())
|
||||
ENCODINGS[encoding_name] = enc
|
||||
return enc
|
||||
|
||||
|
||||
def list_encoding_names() -> list[str]:
|
||||
with _lock:
|
||||
if ENCODING_CONSTRUCTORS is None:
|
||||
_find_constructors()
|
||||
assert ENCODING_CONSTRUCTORS is not None
|
||||
return list(ENCODING_CONSTRUCTORS)
|
41
tiktoken_ext/openai_public.py
Normal file
41
tiktoken_ext/openai_public.py
Normal file
@ -0,0 +1,41 @@
|
||||
from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe
|
||||
|
||||
ENDOFTEXT = "<|endoftext|>"
|
||||
FIM_PREFIX = "<|fim_prefix|>"
|
||||
FIM_MIDDLE = "<|fim_middle|>"
|
||||
FIM_SUFFIX = "<|fim_suffix|>"
|
||||
ENDOFPROMPT = "<|endofprompt|>"
|
||||
|
||||
|
||||
def gpt2():
|
||||
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
|
||||
vocab_bpe_file="az://openaipublic/gpt-2/encodings/main/vocab.bpe",
|
||||
encoder_json_file="az://openaipublic/gpt-2/encodings/main/encoder.json",
|
||||
)
|
||||
return {
|
||||
"name": "gpt2",
|
||||
"explicit_n_vocab": 50257,
|
||||
"pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
"mergeable_ranks": mergeable_ranks,
|
||||
"special_tokens": {"<|endoftext|>": 50256},
|
||||
}
|
||||
|
||||
|
||||
def cl100k_base():
|
||||
mergeable_ranks = load_tiktoken_bpe("az://openaipublic/encodings/cl100k_base.tiktoken")
|
||||
special_tokens = {
|
||||
ENDOFTEXT: 100257,
|
||||
FIM_PREFIX: 100258,
|
||||
FIM_MIDDLE: 100259,
|
||||
FIM_SUFFIX: 100260,
|
||||
ENDOFPROMPT: 100276,
|
||||
}
|
||||
return {
|
||||
"name": "cl100k_base",
|
||||
"pat_str": r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
||||
"mergeable_ranks": mergeable_ranks,
|
||||
"special_tokens": special_tokens,
|
||||
}
|
||||
|
||||
|
||||
ENCODING_CONSTRUCTORS = {"gpt2": gpt2, "cl100k_base": cl100k_base}
|
Loading…
x
Reference in New Issue
Block a user