[tiktoken] hello world
This commit is contained in:
commit
a1a9f16826
42
.gitignore
vendored
Normal file
42
.gitignore
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
.mypy_cache
|
||||||
|
.coverage
|
||||||
|
htmlcov
|
||||||
|
|
||||||
|
# General
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
Cargo.lock
|
||||||
|
target/
|
21
Cargo.toml
Normal file
21
Cargo.toml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
[package]
|
||||||
|
name = "tiktoken"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
rust-version = "1.57.0"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "_tiktoken"
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
pyo3 = { version = "0.17.3", features = ["extension-module"] }
|
||||||
|
|
||||||
|
# tiktoken dependencies
|
||||||
|
fancy-regex = "0.10.0"
|
||||||
|
regex = "1.7.0"
|
||||||
|
rustc-hash = "1.1.0"
|
||||||
|
bstr = "1.0.1"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
incremental = true
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 OpenAI, Shantanu Jain
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
5
MANIFEST.in
Normal file
5
MANIFEST.in
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
include *.svg
|
||||||
|
include *.toml
|
||||||
|
include Makefile
|
||||||
|
recursive-include scripts *.py
|
||||||
|
recursive-include src *.rs
|
49
Makefile
Normal file
49
Makefile
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
PROJECT := tiktoken
|
||||||
|
|
||||||
|
.PHONY: default
|
||||||
|
default: editable_install
|
||||||
|
|
||||||
|
.PHONY: install_rust
|
||||||
|
install_rust:
|
||||||
|
which cargo >/dev/null || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.62
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean:
|
||||||
|
cargo clean
|
||||||
|
pip uninstall -y $(PROJECT)
|
||||||
|
find . | grep -E '__pycache__|\.pyc' | xargs rm -rf
|
||||||
|
find . | grep -E '\.so' | xargs rm -rf
|
||||||
|
rm -rf dist/ build/
|
||||||
|
rm -rf $(PROJECT).egg-info/
|
||||||
|
|
||||||
|
.PHONY: format
|
||||||
|
format:
|
||||||
|
@ which black >/dev/null || python3 -m pip install black
|
||||||
|
@ which isort >/dev/null || python3 -m pip install isort
|
||||||
|
cargo fmt -- --config group_imports=StdExternalCrate
|
||||||
|
black --line-length 100 --skip-magic-trailing-comma --quiet .
|
||||||
|
isort --line-length 100 --profile black --quiet .
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: format_check
|
||||||
|
format_check:
|
||||||
|
@ which black >/dev/null || python3 -m pip install black
|
||||||
|
@ which isort >/dev/null || python3 -m pip install isort
|
||||||
|
cargo fmt --check -- --config group_imports=StdExternalCrate
|
||||||
|
black --check --line-length 100 --skip-magic-trailing-comma --quiet .
|
||||||
|
isort --check --line-length 100 --profile black --quiet .
|
||||||
|
|
||||||
|
.PHONY: lint
|
||||||
|
lint:
|
||||||
|
cargo clippy --all -- -D warnings
|
||||||
|
@ which flake8 >/dev/null || python3 -m pip install flake8==5 flake8-bugbear==22.9.11
|
||||||
|
flake8 --ignore=E203,E501,W503,E731 --per-file-ignores="$(PROJECT)/__init__.py:F401 setup.py:E402" --exclude=build .
|
||||||
|
|
||||||
|
.PHONY: editable_install
|
||||||
|
editable_install:
|
||||||
|
@ if [ -f $(PROJECT).egg-info ]; then \
|
||||||
|
pip install --disable-pip-version-check --progress-bar=off setuptools wheel setuptools-rust ; \
|
||||||
|
pip install --disable-pip-version-check --no-build-isolation -e . ; \
|
||||||
|
else \
|
||||||
|
pip install --disable-pip-version-check --no-deps --no-build-isolation --ignore-installed -e . ; \
|
||||||
|
fi
|
28
README.md
Normal file
28
README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# ⏳ tiktoken
|
||||||
|
|
||||||
|
tiktoken is a fast tokeniser.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import tiktoken
|
||||||
|
enc = tiktoken.get_encoding("gpt2")
|
||||||
|
print(enc.encode("hello world"))
|
||||||
|
```
|
||||||
|
|
||||||
|
The open source version of `tiktoken` can be installed from PyPI:
|
||||||
|
```
|
||||||
|
pip install tiktoken
|
||||||
|
```
|
||||||
|
|
||||||
|
The tokeniser API is documented in `tiktoken/core.py`.
|
||||||
|
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
`tiktoken` is between 3-6x faster than huggingface's tokeniser:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Performance measured on 1GB of text using the GPT-2 tokeniser, using `GPT2TokenizerFast` from
|
||||||
|
`tokenizers==0.13.2` and `transformers==4.24.0`.
|
||||||
|
|
||||||
|
|
373
perf.svg
Normal file
373
perf.svg
Normal file
@ -0,0 +1,373 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||||
|
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" width="569.334pt" height="328.0869pt" viewBox="0 0 569.334 328.0869">
|
||||||
|
<g enable-background="new">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp0">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp0)">
|
||||||
|
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 219.5869 L 569 219.5869 "/>
|
||||||
|
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 150.5869 L 569 150.5869 "/>
|
||||||
|
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 82.58685 L 569 82.58685 "/>
|
||||||
|
<path stroke-width=".5" stroke-linecap="butt" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#b8b8b8" d="M 79 13.58685 L 569 13.58685 "/>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp1">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 39.64996 L -.0000120107 314.4229 L 20.496 314.4229 L 20.49601 39.64996 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp1)">
|
||||||
|
<clipPath id="cp2">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp2)">
|
||||||
|
<text xml:space="preserve" transform="matrix(0 -1 1 0 11.10803 182.06452)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.888 13.560001 17.340003 24.228003 30.900004 37.788003 44.460004 51.576005 58.248006">Throughput</tspan></text>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp3">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 26.168 31.81796 L 67.843997 31.81796 L 67.843997 47.48196 L 26.168 47.48196 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp3)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 27.668 292.71299)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 10.008001 20.460003 28.680005 32.676004">0 MB/s</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp4">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 100.5112 L 67.844 100.5112 L 67.844 116.1752 L 19.496 116.1752 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp4)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 224.01972)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">10 MB/s</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp5">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 169.2044 L 67.844 169.2044 L 67.844 184.86841 L 19.496 184.86841 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp5)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 155.3265)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">20 MB/s</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp6">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 237.8976 L 67.844 237.8976 L 67.844 253.5616 L 19.496 253.5616 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp6)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 86.633319)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">30 MB/s</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp7">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 19.496 306.5909 L 67.844 306.5909 L 67.844 322.2549 L 19.496 322.2549 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp7)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 20.996 17.940125)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 16.680003 27.132004 35.352006 39.348005">40 MB/s</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp8">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp8)">
|
||||||
|
<path stroke-width="1" stroke-linecap="square" stroke-miterlimit="4" stroke-linejoin="miter" fill="none" stroke="#000000" d="M 78.5 288.5869 L 568.5 288.5869 "/>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp9">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 78.83396 0 L 568.834 0 L 568.834 20.496 L 78.83396 20.496 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp9)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 288.266 325.53096)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.888 13.560001 17.340003 23.784003 30.228003 37.344 40.68 47.124 54.012 60.684003 67.356">Thread count</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp10">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 108.998 19.496 L 118.67 19.496 L 118.67 35.16 L 108.998 35.16 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp10)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 110.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">1</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp11">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 178.998 19.496 L 188.67 19.496 L 188.67 35.16 L 178.998 35.16 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp11)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 180.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">2</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp12">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 248.998 19.496 L 258.67 19.496 L 258.67 35.16 L 248.998 35.16 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp12)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 250.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">4</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp13">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 318.998 19.496 L 328.66999 19.496 L 328.66999 35.16 L 318.998 35.16 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp13)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 320.498 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0">8</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp14">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 385.662 19.496 L 402.00599 19.496 L 402.00599 35.16 L 385.662 35.16 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp14)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 387.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">16</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp15">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 455.662 19.496 L 472.00599 19.496 L 472.00599 35.16 L 455.662 35.16 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp15)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 457.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">32</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp16">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 525.662 19.496 L 542.006 19.496 L 542.006 35.16 L 525.662 35.16 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp16)">
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 527.162 305.03495)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001">64</tspan></text>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp17">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 115 40 L 143 40 L 143 52 L 115 52 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp17)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp18">
|
||||||
|
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp18)">
|
||||||
|
<clipPath id="cp19">
|
||||||
|
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp19)">
|
||||||
|
<path transform="matrix(1,0,0,1,115,276.0869)" d="M 0 0 L 28 0 L 28 12 L 0 12 Z " fill="#61d836"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp20">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 185 40 L 213 40 L 213 61 L 185 61 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp20)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp21">
|
||||||
|
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp21)">
|
||||||
|
<clipPath id="cp22">
|
||||||
|
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp22)">
|
||||||
|
<path transform="matrix(1,0,0,1,185,267.0869)" d="M 0 0 L 28 0 L 28 21 L 0 21 Z " fill="#61d836"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp23">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 255 40 L 283 40 L 283 74 L 255 74 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp23)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp24">
|
||||||
|
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp24)">
|
||||||
|
<clipPath id="cp25">
|
||||||
|
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp25)">
|
||||||
|
<path transform="matrix(1,0,0,1,255,254.08692)" d="M 0 0 L 28 0 L 28 34 L 0 34 Z " fill="#61d836"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp26">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 325 40 L 353 40 L 353 80 L 325 80 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp26)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp27">
|
||||||
|
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp27)">
|
||||||
|
<clipPath id="cp28">
|
||||||
|
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp28)">
|
||||||
|
<path transform="matrix(1,0,0,1,325,248.08692)" d="M 0 0 L 28 0 L 28 40 L 0 40 Z " fill="#61d836"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp29">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 395 40 L 423 40 L 423 83 L 395 83 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp29)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp30">
|
||||||
|
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp30)">
|
||||||
|
<clipPath id="cp31">
|
||||||
|
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp31)">
|
||||||
|
<path transform="matrix(1,0,0,1,395,245.08692)" d="M 0 0 L 28 0 L 28 43 L 0 43 Z " fill="#61d836"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp32">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 465 40 L 493 40 L 493 85 L 465 85 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp32)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp33">
|
||||||
|
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp33)">
|
||||||
|
<clipPath id="cp34">
|
||||||
|
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp34)">
|
||||||
|
<path transform="matrix(1,0,0,1,465,243.08692)" d="M 0 0 L 28 0 L 28 45 L 0 45 Z " fill="#61d836"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp35">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 535 40 L 563 40 L 563 88 L 535 88 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp35)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp36">
|
||||||
|
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp36)">
|
||||||
|
<clipPath id="cp37">
|
||||||
|
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp37)">
|
||||||
|
<path transform="matrix(1,0,0,1,535,240.08692)" d="M 0 0 L 28 0 L 28 48 L 0 48 Z " fill="#61d836"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp38">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 84 40 L 112 40 L 112 84 L 84 84 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp38)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp39">
|
||||||
|
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp39)">
|
||||||
|
<clipPath id="cp40">
|
||||||
|
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp40)">
|
||||||
|
<path transform="matrix(1,0,0,1,84,244.08692)" d="M 0 0 L 28 0 L 28 44 L 0 44 Z " fill="#00a2ff"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp41">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 154 40 L 182 40 L 182 123 L 154 123 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp41)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp42">
|
||||||
|
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp42)">
|
||||||
|
<clipPath id="cp43">
|
||||||
|
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp43)">
|
||||||
|
<path transform="matrix(1,0,0,1,154,205.08692)" d="M 0 0 L 28 0 L 28 83 L 0 83 Z " fill="#00a2ff"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp44">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 224 40 L 252 40 L 252 189 L 224 189 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp44)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp45">
|
||||||
|
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp45)">
|
||||||
|
<clipPath id="cp46">
|
||||||
|
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp46)">
|
||||||
|
<path transform="matrix(1,0,0,1,224,139.08692)" d="M 0 0 L 28 0 L 28 149 L 0 149 Z " fill="#00a2ff"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp47">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 294 40 L 322 40 L 322 258 L 294 258 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp47)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp48">
|
||||||
|
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp48)">
|
||||||
|
<clipPath id="cp49">
|
||||||
|
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp49)">
|
||||||
|
<path transform="matrix(1,0,0,1,294,70.086917)" d="M 0 0 L 28 0 L 28 218 L 0 218 Z " fill="#00a2ff"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp50">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 364 40 L 392 40 L 392 303 L 364 303 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp50)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp51">
|
||||||
|
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp51)">
|
||||||
|
<clipPath id="cp52">
|
||||||
|
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp52)">
|
||||||
|
<path transform="matrix(1,0,0,1,364,25.086915)" d="M 0 0 L 28 0 L 28 263 L 0 263 Z " fill="#00a2ff"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp53">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 434 40 L 462 40 L 462 203 L 434 203 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp53)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp54">
|
||||||
|
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp54)">
|
||||||
|
<clipPath id="cp55">
|
||||||
|
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp55)">
|
||||||
|
<path transform="matrix(1,0,0,1,434,125.086917)" d="M 0 0 L 28 0 L 28 163 L 0 163 Z " fill="#00a2ff"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp56">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 504 40 L 532 40 L 532 195 L 504 195 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp56)">
|
||||||
|
<g>
|
||||||
|
<clipPath id="cp57">
|
||||||
|
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp57)">
|
||||||
|
<clipPath id="cp58">
|
||||||
|
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp58)">
|
||||||
|
<path transform="matrix(1,0,0,1,504,133.08692)" d="M 0 0 L 28 0 L 28 155 L 0 155 Z " fill="#00a2ff"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<clipPath id="cp59">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 0 0 L 569.334 0 L 569.334 328.0869 L 0 328.0869 Z "/>
|
||||||
|
</clipPath>
|
||||||
|
<g clip-path="url(#cp59)">
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 459 291 L 471 291 L 471 280 L 459 280 Z " fill="#00a2ff"/>
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 477.8753 46.772127)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 3.7800003 6.4440004 12.672001 16.452002 23.340003 29.568003 36.012">tiktoken</tspan></text>
|
||||||
|
<path transform="matrix(1,0,0,-1,0,328.0869)" d="M 459 278 L 471 278 L 471 266 L 459 266 Z " fill="#61d836"/>
|
||||||
|
<text xml:space="preserve" transform="matrix(1 0 -0 1 477.8753 60.436128)" font-size="12" font-family="HelveticaNeue"><tspan y="0" x="0 6.672001 13.344002 20.232003 27.120003 29.784003 36.456 43.344 46.896005 53.340005 59.784006">huggingface</tspan></text>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 15 KiB |
8
pyproject.toml
Normal file
8
pyproject.toml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
[project]
|
||||||
|
name = "tiktoken"
|
||||||
|
dependencies = ["blobfile>=2", "regex>=2022.1.18"]
|
||||||
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools", "wheel", "setuptools-rust"]
|
||||||
|
|
39
scripts/benchmark.py
Normal file
39
scripts/benchmark.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import base64
|
||||||
|
import functools
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import blobfile
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_batch(documents: list[str]) -> None:
|
||||||
|
num_threads = int(os.environ["RAYON_NUM_THREADS"])
|
||||||
|
num_bytes = sum(map(len, map(str.encode, documents)))
|
||||||
|
print(f"num_threads: {num_threads}, num_bytes: {num_bytes}")
|
||||||
|
|
||||||
|
enc = tiktoken.get_encoding("gpt2")
|
||||||
|
enc.encode("warmup")
|
||||||
|
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
enc.encode_ordinary_batch(documents, num_threads=num_threads)
|
||||||
|
end = time.perf_counter_ns()
|
||||||
|
print(f"tiktoken \t{num_bytes / (end - start) * 1e9} bytes / s")
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
hf_enc = cast(Any, transformers).GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
hf_enc.model_max_length = 1e30 # silence!
|
||||||
|
hf_enc.encode("warmup")
|
||||||
|
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
hf_enc(documents)
|
||||||
|
end = time.perf_counter_ns()
|
||||||
|
print(f"huggingface \t{num_bytes / (end - start) * 1e9} bytes / s")
|
||||||
|
|
||||||
|
|
65
scripts/redact.py
Normal file
65
scripts/redact.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def redact_file(path: Path, dry_run: bool) -> None:
|
||||||
|
if not path.exists() or path.is_dir():
|
||||||
|
return
|
||||||
|
|
||||||
|
text = path.read_text()
|
||||||
|
|
||||||
|
first_line = text.splitlines()[0]
|
||||||
|
if "redact" in first_line:
|
||||||
|
if not dry_run:
|
||||||
|
path.unlink()
|
||||||
|
print(f"Deleted {path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
pattern = "|".join(
|
||||||
|
re.escape(x)
|
||||||
|
for x in [
|
||||||
|
"# ===== redact-beg =====\n",
|
||||||
|
"# ===== redact-end =====\n",
|
||||||
|
"<!--- redact-beg -->\n",
|
||||||
|
"<!--- redact-end -->\n",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if re.search(pattern, text):
|
||||||
|
redacted_text = "".join(re.split(pattern, text)[::2])
|
||||||
|
if not dry_run:
|
||||||
|
path.write_text(redacted_text)
|
||||||
|
print(f"Redacted {path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Skipped {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def redact(dry_run: bool) -> None:
|
||||||
|
tiktoken_root = Path(__file__).parent.parent
|
||||||
|
assert tiktoken_root.name == "tiktoken"
|
||||||
|
assert (tiktoken_root / "pyproject.toml").exists()
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(["git", "ls-files"], cwd=tiktoken_root, text=True)
|
||||||
|
paths = [Path(p) for p in output.splitlines()]
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
paths = list(tiktoken_root.glob("**/*"))
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
redact_file(path, dry_run=dry_run)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--dry-run", type=lambda x: not x or x[0].lower() != "f", default=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
redact(args.dry_run)
|
||||||
|
if args.dry_run:
|
||||||
|
print("Dry run, use --dry-run=false to actually redact files")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
23
setup.py
Normal file
23
setup.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
from setuptools_rust import Binding, RustExtension
|
||||||
|
|
||||||
|
public = True
|
||||||
|
|
||||||
|
if public:
|
||||||
|
version = "0.1"
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="tiktoken",
|
||||||
|
version=version,
|
||||||
|
rust_extensions=[
|
||||||
|
RustExtension(
|
||||||
|
"tiktoken._tiktoken",
|
||||||
|
binding=Binding.PyO3,
|
||||||
|
# Between our use of editable installs and wanting to use Rust for performance sensitive
|
||||||
|
# code, it makes sense to just always use --release
|
||||||
|
debug=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
packages=["tiktoken", "tiktoken_ext"],
|
||||||
|
zip_safe=False,
|
||||||
|
)
|
559
src/lib.rs
Normal file
559
src/lib.rs
Normal file
@ -0,0 +1,559 @@
|
|||||||
|
// This check is new and seems buggy (possibly with PyO3 interaction)
|
||||||
|
#![allow(clippy::borrow_deref_ref)]
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
use fancy_regex::Regex;
|
||||||
|
use pyo3::exceptions;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::{PyBytes, PyList, PyTuple};
|
||||||
|
use pyo3::PyResult;
|
||||||
|
use rustc_hash::FxHashMap as HashMap;
|
||||||
|
|
||||||
|
fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
|
||||||
|
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
|
||||||
|
|
||||||
|
// If you have n parts and m merges, this does O(mn) work
|
||||||
|
// We could do something with a heap and do O(m log n) work
|
||||||
|
|
||||||
|
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
|
||||||
|
// currently do, this is equivalent. An easy way to break this would be to decouple
|
||||||
|
// merge priority from token index or to prevent specific token merges.
|
||||||
|
loop {
|
||||||
|
if parts.len() == 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let mut min_rank: Option<(usize, usize)> = None;
|
||||||
|
for i in 0..parts.len() - 1 {
|
||||||
|
let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
|
||||||
|
*r
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if min_rank.is_none() || rank < min_rank.unwrap().0 {
|
||||||
|
min_rank = Some((rank, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some((_, i)) = min_rank {
|
||||||
|
parts[i] = parts[i].start..parts[i + 1].end;
|
||||||
|
parts.remove(i + 1);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parts
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
|
||||||
|
if piece.len() == 1 {
|
||||||
|
return vec![ranks[piece]];
|
||||||
|
}
|
||||||
|
_byte_pair_merge(piece, ranks)
|
||||||
|
.iter()
|
||||||
|
.map(|p| ranks[&piece[p.start..p.end]])
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
|
||||||
|
if piece.len() == 1 {
|
||||||
|
return vec![piece];
|
||||||
|
}
|
||||||
|
_byte_pair_merge(piece, ranks)
|
||||||
|
.iter()
|
||||||
|
.map(|p| &piece[p.start..p.end])
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Various performance notes:
|
||||||
|
//
|
||||||
|
// Regex
|
||||||
|
// =====
|
||||||
|
// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
|
||||||
|
// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
|
||||||
|
// the usual regex we use.
|
||||||
|
//
|
||||||
|
// However, given that we're using a regex parse-able by `regex`, there isn't much difference
|
||||||
|
// between using the `regex` crate and using the `fancy_regex` crate.
|
||||||
|
//
|
||||||
|
// There is an important interaction between threading, `regex` and `fancy_regex`.
|
||||||
|
// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
|
||||||
|
// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
|
||||||
|
// old `regex`, we don't hit this, because `find_iter` has a different code path.
|
||||||
|
// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
|
||||||
|
// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
|
||||||
|
// each thread.
|
||||||
|
//
|
||||||
|
// Threading
|
||||||
|
// =========
|
||||||
|
// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
|
||||||
|
// So goodbye `rayon`! Let thread count etc be in control of our Python users.
|
||||||
|
//
|
||||||
|
// Caching
|
||||||
|
// =======
|
||||||
|
// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
|
||||||
|
// Originally, we had one too! Without it, we were only vaguely faster than Python.
|
||||||
|
// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
|
||||||
|
// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
|
||||||
|
// multi-threaded performance even when I only had readers (maybed I messed something up?).
|
||||||
|
// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
|
||||||
|
// These are exactly the set or merges that are likely to be hot. And now we don't have to think
|
||||||
|
// about interior mutability, memory use, or cloning.
|
||||||
|
//
|
||||||
|
// Hashing
|
||||||
|
// =======
|
||||||
|
// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
|
||||||
|
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
|
||||||
|
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
|
||||||
|
|
||||||
|
use std::num::NonZeroU64;
|
||||||
|
pub struct FakeThreadId(NonZeroU64);
|
||||||
|
|
||||||
|
fn hash_current_thread() -> usize {
|
||||||
|
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
|
||||||
|
// that works great for our use case of avoiding collisions in our array. Unfortunately,
|
||||||
|
// it's private. However, there are only so many ways you can layout a u64, so just transmute
|
||||||
|
// https://github.com/rust-lang/rust/issues/67939
|
||||||
|
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
|
||||||
|
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
|
||||||
|
let x = unsafe {
|
||||||
|
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
|
||||||
|
};
|
||||||
|
u64::from(x) as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
const MAX_NUM_THREADS: usize = 128;
|
||||||
|
#[pyclass]
|
||||||
|
struct CoreBPE {
|
||||||
|
encoder: HashMap<Vec<u8>, usize>,
|
||||||
|
special_tokens_encoder: HashMap<String, usize>,
|
||||||
|
decoder: HashMap<usize, Vec<u8>>,
|
||||||
|
special_tokens_decoder: HashMap<usize, Vec<u8>>,
|
||||||
|
regex_tls: Vec<Regex>,
|
||||||
|
special_regex_tls: Vec<Regex>,
|
||||||
|
sorted_token_bytes: Vec<Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CoreBPE {
|
||||||
|
fn _get_tl_regex(&self) -> &Regex {
|
||||||
|
// See performance notes above for what this is about
|
||||||
|
// It's also a little janky, please make a better version of it!
|
||||||
|
// However, it's nice that this doesn't leak memory to short-lived threads
|
||||||
|
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _get_tl_special_regex(&self) -> &Regex {
|
||||||
|
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
|
||||||
|
let mut ret = Vec::with_capacity(tokens.len() * 2);
|
||||||
|
for token in tokens {
|
||||||
|
let token_bytes = self
|
||||||
|
.decoder
|
||||||
|
.get(token)
|
||||||
|
.unwrap_or_else(|| &self.special_tokens_decoder[token]);
|
||||||
|
ret.extend(token_bytes);
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
|
||||||
|
// This is the core of the encoding logic; the other functions in here
|
||||||
|
// just make things complicated :-)
|
||||||
|
let regex = self._get_tl_regex();
|
||||||
|
let mut ret = vec![];
|
||||||
|
for mat in regex.find_iter(text) {
|
||||||
|
let piece = mat.unwrap().as_str().as_bytes();
|
||||||
|
if let Some(token) = self.encoder.get(piece) {
|
||||||
|
ret.push(*token);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ret.extend(&byte_pair_encode(piece, &self.encoder));
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
|
||||||
|
let special_regex = self._get_tl_special_regex();
|
||||||
|
let regex = self._get_tl_regex();
|
||||||
|
let mut ret = vec![];
|
||||||
|
|
||||||
|
let mut start = 0;
|
||||||
|
let mut last_piece_token_len = 0;
|
||||||
|
loop {
|
||||||
|
let mut next_special;
|
||||||
|
let mut start_find = start;
|
||||||
|
loop {
|
||||||
|
// Find the next allowed special token, if any
|
||||||
|
next_special = special_regex.find_from_pos(text, start_find).unwrap();
|
||||||
|
match next_special {
|
||||||
|
Some(m) => {
|
||||||
|
if allowed_special.contains(&text[m.start()..m.end()]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
start_find = m.start() + 1;
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let end = next_special.map_or(text.len(), |m| m.start());
|
||||||
|
|
||||||
|
// Okay, here we go, compare this logic to _encode_ordinary_native
|
||||||
|
for mat in regex.find_iter(&text[start..end]) {
|
||||||
|
let piece = mat.unwrap().as_str().as_bytes();
|
||||||
|
if let Some(token) = self.encoder.get(piece) {
|
||||||
|
last_piece_token_len = 1;
|
||||||
|
ret.push(*token);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let tokens = byte_pair_encode(piece, &self.encoder);
|
||||||
|
last_piece_token_len = tokens.len();
|
||||||
|
ret.extend(&tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
match next_special {
|
||||||
|
// And here we push the special token
|
||||||
|
Some(m) => {
|
||||||
|
let piece = m.as_str();
|
||||||
|
let token = self.special_tokens_encoder[piece];
|
||||||
|
ret.push(token);
|
||||||
|
start = m.end();
|
||||||
|
last_piece_token_len = 0;
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// last_piece_token_len is how many tokens came from the last regex split. This is used
|
||||||
|
// for determining unstable tokens, since you can't merge across (stable) regex splits
|
||||||
|
(ret, last_piece_token_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _increase_last_piece_token_len(
|
||||||
|
&self,
|
||||||
|
tokens: Vec<usize>,
|
||||||
|
mut last_piece_token_len: usize,
|
||||||
|
) -> (Vec<usize>, usize) {
|
||||||
|
// Unfortunately, the locations where our regex splits can be unstable.
|
||||||
|
// For the purposes of determining unstable tokens, unstable regex splitting
|
||||||
|
// is only a problem if a split that was present disappears, since this can
|
||||||
|
// lead to merging of tokens otherwise thought to be stable.
|
||||||
|
// cl100k_base makes our life hard by including the \s*[\r\n]+
|
||||||
|
// pattern. This can e.g. cause "\n" + " " to become "\n \n".
|
||||||
|
// Here is a quick and dirty fix:
|
||||||
|
{
|
||||||
|
let token_is_all_space = |token| {
|
||||||
|
self.decoder
|
||||||
|
.get(token)
|
||||||
|
.map(|token_bytes| {
|
||||||
|
token_bytes
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.all(|&b| [b' ', b'\n', b'\t'].contains(&b))
|
||||||
|
})
|
||||||
|
.unwrap_or(false)
|
||||||
|
};
|
||||||
|
if last_piece_token_len > 0
|
||||||
|
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
|
||||||
|
{
|
||||||
|
while (last_piece_token_len < tokens.len())
|
||||||
|
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
|
||||||
|
{
|
||||||
|
last_piece_token_len += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug_assert!(last_piece_token_len <= tokens.len());
|
||||||
|
|
||||||
|
(tokens, last_piece_token_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _encode_unstable_native(
|
||||||
|
&self,
|
||||||
|
text: &str,
|
||||||
|
allowed_special: &HashSet<&str>,
|
||||||
|
) -> (Vec<usize>, HashSet<Vec<usize>>) {
|
||||||
|
let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
|
||||||
|
if last_piece_token_len == 0 {
|
||||||
|
// If last_piece_token_len is zero, the last token was a special token and we have
|
||||||
|
// no unstable bytes
|
||||||
|
return (tokens, HashSet::new());
|
||||||
|
}
|
||||||
|
let (mut tokens, last_piece_token_len) =
|
||||||
|
self._increase_last_piece_token_len(tokens, last_piece_token_len);
|
||||||
|
|
||||||
|
let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
|
||||||
|
tokens.truncate(tokens.len() - last_piece_token_len);
|
||||||
|
|
||||||
|
// TODO: we should try harder to find additional stable tokens
|
||||||
|
// This would reduce the amount of retokenising when determining completions
|
||||||
|
// Refer to the logic in an older version of this file
|
||||||
|
|
||||||
|
let mut completions = HashSet::new();
|
||||||
|
if unstable_bytes.is_empty() {
|
||||||
|
return (tokens, completions);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is the easy bit. Just find all single tokens that start with unstable_bytes
|
||||||
|
// (including tokens that exactly match unstable_bytes)
|
||||||
|
// Separating this from the loop below helps with performance in a common case.
|
||||||
|
let mut point = self
|
||||||
|
.sorted_token_bytes
|
||||||
|
.partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
|
||||||
|
while point < self.sorted_token_bytes.len()
|
||||||
|
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
|
||||||
|
{
|
||||||
|
completions.insert(vec![
|
||||||
|
self.encoder[self.sorted_token_bytes[point].as_slice()],
|
||||||
|
]);
|
||||||
|
point += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now apply even more brute force. At every (other) possible position for the straddling
|
||||||
|
// token, concatenate additional bytes from that token (if any) to unstable_bytes,
|
||||||
|
// and retokenise the whole thing and see what we get.
|
||||||
|
for i in 1..unstable_bytes.len() {
|
||||||
|
let prefix = &unstable_bytes[..i];
|
||||||
|
let suffix = &unstable_bytes[i..];
|
||||||
|
let mut point = self
|
||||||
|
.sorted_token_bytes
|
||||||
|
.partition_point(|x| x.as_slice() < suffix);
|
||||||
|
// TODO: Perf optimisation if suffix starts with " "?
|
||||||
|
while point < self.sorted_token_bytes.len()
|
||||||
|
&& self.sorted_token_bytes[point].starts_with(suffix)
|
||||||
|
{
|
||||||
|
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
|
||||||
|
let encoded = match std::str::from_utf8(&possibility) {
|
||||||
|
// Morally, this is byte_pair_encode(&possibility, &self.encoder)
|
||||||
|
// But we might have introduced a regex split which would prevent merges.
|
||||||
|
// (particularly possible in the presence of unstable regex splits)
|
||||||
|
// So convert to UTF-8 and do regex splitting.
|
||||||
|
// E.g. with cl100k_base " !" gets split to " " + " !",
|
||||||
|
// but byte_pair_encode(" !") != byte_pair_encode(" ")
|
||||||
|
Ok(s) => self._encode_ordinary_native(s),
|
||||||
|
|
||||||
|
// Technically, whether or not this arm is correct depends on whether there
|
||||||
|
// would be a regex split before the UTF-8 truncation point.
|
||||||
|
// Probably niche enough that no one will ever notice (after all, people didn't
|
||||||
|
// notice all the big holes in the previous unstable token implementation)
|
||||||
|
Err(_) => byte_pair_encode(&possibility, &self.encoder),
|
||||||
|
// Something like the following is intriguing but incorrect:
|
||||||
|
// Err(e) => self._encode_ordinary_native(unsafe {
|
||||||
|
// std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
|
||||||
|
// }),
|
||||||
|
};
|
||||||
|
let mut seq = Vec::new();
|
||||||
|
let mut seq_len = 0;
|
||||||
|
for token in encoded {
|
||||||
|
seq.push(token);
|
||||||
|
seq_len += self.decoder[&token].len();
|
||||||
|
if seq_len >= unstable_bytes.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
completions.insert(seq);
|
||||||
|
point += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is also not straightforward. While we generally assume that regex splits are stable,
|
||||||
|
// unfortunately, they are not. That is, if adding bytes were to make a split appear in
|
||||||
|
// unstable_bytes, this could make tokens possible which our logic would otherwise think
|
||||||
|
// would be merged.
|
||||||
|
// For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
|
||||||
|
// develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
|
||||||
|
// Here is a quick and dirty fix:
|
||||||
|
// This isn't right if we ever remove \s+(?!\S)
|
||||||
|
if unstable_bytes.len() > 1 {
|
||||||
|
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
|
||||||
|
if unstable_bytes.len() - last_decoded.1 > 0
|
||||||
|
&& last_decoded.0.map_or(false, |c| c.is_whitespace())
|
||||||
|
{
|
||||||
|
let mut reencoded = byte_pair_encode(
|
||||||
|
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
|
||||||
|
&self.encoder,
|
||||||
|
);
|
||||||
|
reencoded.extend(byte_pair_encode(
|
||||||
|
&unstable_bytes[unstable_bytes.len() - last_decoded.1..],
|
||||||
|
&self.encoder,
|
||||||
|
));
|
||||||
|
completions.insert(reencoded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(tokens, completions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl CoreBPE {
|
||||||
|
#[new]
|
||||||
|
fn new(
|
||||||
|
encoder: HashMap<Vec<u8>, usize>,
|
||||||
|
special_tokens_encoder: HashMap<String, usize>,
|
||||||
|
pattern: &str,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
let regex = Regex::new(pattern)
|
||||||
|
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
|
||||||
|
|
||||||
|
let special_regex = {
|
||||||
|
let _parts = special_tokens_encoder
|
||||||
|
.keys()
|
||||||
|
.map(|s| fancy_regex::escape(s))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Regex::new(&_parts.join("|"))
|
||||||
|
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
|
||||||
|
};
|
||||||
|
|
||||||
|
let decoder: HashMap<usize, Vec<u8>> =
|
||||||
|
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
|
||||||
|
|
||||||
|
assert!(encoder.len() == decoder.len());
|
||||||
|
|
||||||
|
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Clone because I don't know how to tell Rust I'm not going to change the map
|
||||||
|
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
|
||||||
|
sorted_token_bytes.sort();
|
||||||
|
|
||||||
|
Ok(CoreBPE {
|
||||||
|
encoder,
|
||||||
|
special_tokens_encoder,
|
||||||
|
decoder,
|
||||||
|
special_tokens_decoder,
|
||||||
|
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
|
||||||
|
special_regex_tls: (0..MAX_NUM_THREADS)
|
||||||
|
.map(|_| special_regex.clone())
|
||||||
|
.collect(),
|
||||||
|
sorted_token_bytes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ====================
|
||||||
|
// Encoding
|
||||||
|
// ====================
|
||||||
|
|
||||||
|
fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> {
|
||||||
|
py.allow_threads(|| self._encode_ordinary_native(text))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
|
||||||
|
py.allow_threads(|| self._encode_native(text, &allowed_special).0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
|
||||||
|
py.allow_threads(|| {
|
||||||
|
match std::str::from_utf8(bytes) {
|
||||||
|
Ok(text) => self._encode_ordinary_native(text),
|
||||||
|
Err(e) => {
|
||||||
|
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
|
||||||
|
let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new());
|
||||||
|
let (mut tokens, last_piece_token_len) =
|
||||||
|
self._increase_last_piece_token_len(tokens, last_piece_token_len);
|
||||||
|
if !tokens.is_empty() && last_piece_token_len > 0 {
|
||||||
|
// Lop off the tokens from the last piece and run BPE on the remaining bytes
|
||||||
|
// Somewhat niche, but this may not be correct if we'd have had a regex
|
||||||
|
// split between the valid UTF-8 and the invalid bytes, which is why this
|
||||||
|
// method is private
|
||||||
|
let mut unstable_bytes =
|
||||||
|
self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
|
||||||
|
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
|
||||||
|
|
||||||
|
tokens.truncate(tokens.len() - last_piece_token_len);
|
||||||
|
tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
|
||||||
|
}
|
||||||
|
tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_with_unstable(
|
||||||
|
&self,
|
||||||
|
py: Python,
|
||||||
|
text: &str,
|
||||||
|
allowed_special: HashSet<&str>,
|
||||||
|
) -> Py<PyTuple> {
|
||||||
|
let (tokens, completions) =
|
||||||
|
py.allow_threads(|| self._encode_unstable_native(text, &allowed_special));
|
||||||
|
let py_completions =
|
||||||
|
PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..])));
|
||||||
|
(tokens, py_completions).into_py(py)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> {
|
||||||
|
if let Some(token) = self.encoder.get(piece).copied() {
|
||||||
|
return Ok(token);
|
||||||
|
}
|
||||||
|
if let Ok(piece_str) = std::str::from_utf8(piece) {
|
||||||
|
if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
|
||||||
|
return Ok(token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
|
||||||
|
if let Some(token) = self.encoder.get(piece) {
|
||||||
|
return vec![*token];
|
||||||
|
}
|
||||||
|
byte_pair_encode(piece, &self.encoder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ====================
|
||||||
|
// Decoding
|
||||||
|
// ====================
|
||||||
|
|
||||||
|
fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
|
||||||
|
let bytes = py.allow_threads(|| self._decode_native(&tokens));
|
||||||
|
PyBytes::new(py, &bytes).into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
|
||||||
|
if let Some(bytes) = self.decoder.get(&token) {
|
||||||
|
return Ok(PyBytes::new(py, bytes).into());
|
||||||
|
}
|
||||||
|
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
|
||||||
|
return Ok(PyBytes::new(py, bytes).into());
|
||||||
|
}
|
||||||
|
Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ====================
|
||||||
|
// Miscellaneous
|
||||||
|
// ====================
|
||||||
|
|
||||||
|
fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
|
||||||
|
self.sorted_token_bytes
|
||||||
|
.iter()
|
||||||
|
.map(|x| PyBytes::new(py, x).into())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymodule]
|
||||||
|
fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_class::<CoreBPE>()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use rustc_hash::FxHashMap as HashMap;
|
||||||
|
|
||||||
|
use crate::byte_pair_split;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn very_simple_test() {
|
||||||
|
let mut ranks = HashMap::default();
|
||||||
|
ranks.insert(b"ab".to_vec(), 1);
|
||||||
|
ranks.insert(b"cd".to_vec(), 2);
|
||||||
|
|
||||||
|
let res = byte_pair_split(b"abcd", &ranks);
|
||||||
|
assert_eq!(res, vec![b"ab", b"cd"]);
|
||||||
|
}
|
||||||
|
}
|
3
tiktoken/__init__.py
Normal file
3
tiktoken/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .core import Encoding as Encoding
|
||||||
|
from .registry import get_encoding as get_encoding
|
||||||
|
from .registry import list_encoding_names as list_encoding_names
|
310
tiktoken/core.py
Normal file
310
tiktoken/core.py
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
import functools
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union
|
||||||
|
|
||||||
|
import regex
|
||||||
|
|
||||||
|
from tiktoken import _tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
class Encoding:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
pat_str: str,
|
||||||
|
mergeable_ranks: dict[bytes, int],
|
||||||
|
special_tokens: dict[str, int],
|
||||||
|
explicit_n_vocab: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
self._pat_str = pat_str
|
||||||
|
self._mergeable_ranks = mergeable_ranks
|
||||||
|
self._special_tokens = special_tokens
|
||||||
|
|
||||||
|
self.max_token_value = max(
|
||||||
|
max(mergeable_ranks.values()), max(special_tokens.values(), default=0)
|
||||||
|
)
|
||||||
|
if explicit_n_vocab:
|
||||||
|
assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab
|
||||||
|
assert self.max_token_value == explicit_n_vocab - 1
|
||||||
|
|
||||||
|
self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Encoding {self.name!r}>"
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Encoding
|
||||||
|
# ====================
|
||||||
|
|
||||||
|
def encode_ordinary(self, text: str) -> list[int]:
|
||||||
|
"""Encodes a string into tokens, ignoring special tokens.
|
||||||
|
|
||||||
|
This is equivalent to `encode(text, disallowed_special=())` (but slightly faster).
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.encode_ordinary("hello world")
|
||||||
|
[31373, 995]
|
||||||
|
"""
|
||||||
|
return self._core_bpe.encode_ordinary(text)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
|
) -> list[int]:
|
||||||
|
"""Encodes a string into tokens.
|
||||||
|
|
||||||
|
Special tokens are tokens are artificial tokens used to unlock capabilities from a model,
|
||||||
|
such as fill-in-the-middle. So we want to be careful about accidentally encoding special
|
||||||
|
tokens, since they can be used to trick a model into doing something we don't want it to do.
|
||||||
|
|
||||||
|
Hence, by default, encode will raise an error if it encounters text that corresponds
|
||||||
|
to a special token. This can be controlled on a per-token level using the `allowed_special`
|
||||||
|
and `disallowed_special` parameters. In particular:
|
||||||
|
- Setting `disallowed_special` to () will prevent this function from raising errors and
|
||||||
|
cause all text corresponding to special tokens to be encoded as natural text.
|
||||||
|
- Setting `allowed_special` to "all" will allow cause this function to treat all text
|
||||||
|
corresponding to special tokens to be encoded as special tokens.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.encode("hello world")
|
||||||
|
[31373, 995]
|
||||||
|
>>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})
|
||||||
|
[50256]
|
||||||
|
>>> enc.encode("<|endoftext|>", allowed_special="all")
|
||||||
|
[50256]
|
||||||
|
>>> enc.encode("<|endoftext|>")
|
||||||
|
# Raises ValueError
|
||||||
|
>>> enc.encode("<|endoftext|>", disallowed_special=())
|
||||||
|
[27, 91, 437, 1659, 5239, 91, 29]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if allowed_special == "all":
|
||||||
|
allowed_special = self.special_tokens_set
|
||||||
|
if disallowed_special == "all":
|
||||||
|
disallowed_special = self.special_tokens_set - allowed_special
|
||||||
|
if disallowed_special:
|
||||||
|
if not isinstance(disallowed_special, frozenset):
|
||||||
|
disallowed_special = frozenset(disallowed_special)
|
||||||
|
if match := _special_token_regex(disallowed_special).search(text):
|
||||||
|
raise_disallowed_special_token(match.group())
|
||||||
|
|
||||||
|
return self._core_bpe.encode(text, allowed_special)
|
||||||
|
|
||||||
|
def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
|
||||||
|
"""Encodes a list of strings into tokens, in parallel, ignoring special tokens.
|
||||||
|
|
||||||
|
This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster).
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.encode_batch(["hello world", "goodbye world"])
|
||||||
|
[[31373, 995], [11274, 16390, 995]]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
encoder = functools.partial(self.encode_ordinary)
|
||||||
|
with ThreadPoolExecutor(num_threads) as e:
|
||||||
|
return list(e.map(encoder, text))
|
||||||
|
|
||||||
|
def encode_batch(
|
||||||
|
self,
|
||||||
|
text: list[str],
|
||||||
|
*,
|
||||||
|
num_threads: int = 8,
|
||||||
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
|
) -> list[list[int]]:
|
||||||
|
"""Encodes a list of strings into tokens, in parallel.
|
||||||
|
|
||||||
|
See `encode` for more details on `allowed_special` and `disallowed_special`.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.encode_batch(["hello world", "goodbye world"])
|
||||||
|
[[31373, 995], [11274, 16390, 995]]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if allowed_special == "all":
|
||||||
|
allowed_special = self.special_tokens_set
|
||||||
|
if disallowed_special == "all":
|
||||||
|
disallowed_special = self.special_tokens_set - allowed_special
|
||||||
|
if not isinstance(disallowed_special, frozenset):
|
||||||
|
disallowed_special = frozenset(disallowed_special)
|
||||||
|
|
||||||
|
encoder = functools.partial(
|
||||||
|
self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special
|
||||||
|
)
|
||||||
|
with ThreadPoolExecutor(num_threads) as e:
|
||||||
|
return list(e.map(encoder, text))
|
||||||
|
|
||||||
|
def encode_with_unstable(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
|
) -> tuple[list[int], list[list[int]]]:
|
||||||
|
"""Encodes a string into stable tokens and possible completion sequences.
|
||||||
|
|
||||||
|
Note that the stable tokens will only represent a substring of `text`.
|
||||||
|
|
||||||
|
See `encode` for more details on `allowed_special` and `disallowed_special`.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.encode_with_unstable("hello fanta")
|
||||||
|
([31373], [(277, 4910), (5113, 265), ..., (8842,)])
|
||||||
|
|
||||||
|
>>> text = "..."
|
||||||
|
>>> stable_tokens, completions = enc.encode_with_unstable(text)
|
||||||
|
>>> assert text.encode().startswith(enc.decode_bytes(stable_tokens))
|
||||||
|
>>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if allowed_special == "all":
|
||||||
|
allowed_special = self.special_tokens_set
|
||||||
|
if disallowed_special == "all":
|
||||||
|
disallowed_special = self.special_tokens_set - allowed_special
|
||||||
|
if disallowed_special:
|
||||||
|
if not isinstance(disallowed_special, frozenset):
|
||||||
|
disallowed_special = frozenset(disallowed_special)
|
||||||
|
if match := _special_token_regex(disallowed_special).search(text):
|
||||||
|
raise_disallowed_special_token(match.group())
|
||||||
|
|
||||||
|
return self._core_bpe.encode_with_unstable(text, allowed_special)
|
||||||
|
|
||||||
|
def encode_single_token(self, text_or_bytes: Union[str, bytes]) -> int:
|
||||||
|
"""Encodes text corresponding to a single token to its token value.
|
||||||
|
|
||||||
|
NOTE: this will encode all special tokens.
|
||||||
|
|
||||||
|
Raises `KeyError` if the token is not in the vocabulary.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.encode_single_token("hello")
|
||||||
|
31373
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if isinstance(text_or_bytes, str):
|
||||||
|
text_or_bytes = text_or_bytes.encode("utf-8")
|
||||||
|
return self._core_bpe.encode_single_token(text_or_bytes)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Decoding
|
||||||
|
# ====================
|
||||||
|
|
||||||
|
def decode_bytes(self, tokens: list[int]) -> bytes:
|
||||||
|
"""Decodes a list of tokens into bytes.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.decode_bytes([31373, 995])
|
||||||
|
b'hello world'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return self._core_bpe.decode_bytes(tokens)
|
||||||
|
|
||||||
|
def decode(self, tokens: list[int], errors: str = "replace") -> str:
|
||||||
|
"""Decodes a list of tokens into a string.
|
||||||
|
|
||||||
|
WARNING: the default behaviour of this function is lossy, since decoded bytes are not
|
||||||
|
guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter,
|
||||||
|
for instance, setting `errors=strict`.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.decode([31373, 995])
|
||||||
|
'hello world'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
|
||||||
|
|
||||||
|
def decode_single_token_bytes(self, token: int) -> bytes:
|
||||||
|
"""Decodes a token into bytes.
|
||||||
|
|
||||||
|
NOTE: this will decode all special tokens.
|
||||||
|
|
||||||
|
Raises `KeyError` if the token is not in the vocabulary.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.decode_single_token_bytes(31373)
|
||||||
|
b'hello'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return self._core_bpe.decode_single_token_bytes(token)
|
||||||
|
|
||||||
|
def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
|
||||||
|
"""Decodes a list of tokens into a list of bytes.
|
||||||
|
|
||||||
|
Useful for visualising tokenisation.
|
||||||
|
>>> enc.decode_tokens_bytes([31373, 995])
|
||||||
|
[b'hello', b' world']
|
||||||
|
"""
|
||||||
|
return [self.decode_single_token_bytes(token) for token in tokens]
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Miscellaneous
|
||||||
|
# ====================
|
||||||
|
|
||||||
|
def token_byte_values(self) -> list[bytes]:
|
||||||
|
"""Returns the list of all token byte values."""
|
||||||
|
return self._core_bpe.token_byte_values()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eot_token(self) -> int:
|
||||||
|
return self._special_tokens["<|endoftext|>"]
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def special_tokens_set(self) -> set[str]:
|
||||||
|
return set(self._special_tokens.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_vocab(self) -> int:
|
||||||
|
"""For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""
|
||||||
|
return self.max_token_value + 1
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Private
|
||||||
|
# ====================
|
||||||
|
|
||||||
|
def _encode_single_piece(self, text_or_bytes: Union[str, bytes]) -> list[int]:
|
||||||
|
"""Encodes text corresponding to bytes without a regex split.
|
||||||
|
|
||||||
|
NOTE: this will not encode any special tokens.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> enc.encode_single_piece("helloqqqq")
|
||||||
|
[31373, 38227, 38227]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if isinstance(text_or_bytes, str):
|
||||||
|
text_or_bytes = text_or_bytes.encode("utf-8")
|
||||||
|
return self._core_bpe.encode_single_piece(text_or_bytes)
|
||||||
|
|
||||||
|
def _encode_only_native_bpe(self, text: str) -> list[str]:
|
||||||
|
"""Encodes a string into tokens, but do regex splitting in Python."""
|
||||||
|
_unused_pat = regex.compile(self._pat_str)
|
||||||
|
ret = []
|
||||||
|
for piece in regex.findall(_unused_pat, text):
|
||||||
|
ret.extend(self._core_bpe.encode_single_piece(piece))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _encode_bytes(self, text: bytes) -> list[int]:
|
||||||
|
return self._core_bpe._encode_bytes(text)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=128)
|
||||||
|
def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]":
|
||||||
|
inner = "|".join(regex.escape(token) for token in tokens)
|
||||||
|
return regex.compile(f"({inner})")
|
||||||
|
|
||||||
|
|
||||||
|
def raise_disallowed_special_token(token: str) -> NoReturn:
|
||||||
|
raise ValueError(
|
||||||
|
f"Encountered text corresponding to disallowed special token {token!r}.\n"
|
||||||
|
"If you want this text to be encoded as a special token, "
|
||||||
|
f"pass it to `allowed_special`, e.g. `allowed_special={{{token!r}, ...}}`.\n"
|
||||||
|
f"If you want this text to be encoded as normal text, disable the check for this token "
|
||||||
|
f"by passing `disallowed_special=(enc.special_tokens_set - {{{token!r}}})`.\n"
|
||||||
|
"To disable this check for all special tokens, pass `disallowed_special=()`.\n"
|
||||||
|
)
|
97
tiktoken/load.py
Normal file
97
tiktoken/load.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import blobfile
|
||||||
|
|
||||||
|
|
||||||
|
def read_file_cached(blobpath: str) -> bytes:
|
||||||
|
if "TIKTOKEN_CACHE_DIR" in os.environ:
|
||||||
|
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
|
||||||
|
elif "DATA_GYM_CACHE_DIR" in os.environ:
|
||||||
|
cache_dir = os.environ["DATA_GYM_CACHE_DIR"]
|
||||||
|
else:
|
||||||
|
cache_dir = "/tmp/data-gym-cache"
|
||||||
|
|
||||||
|
if cache_dir == "":
|
||||||
|
# disable caching
|
||||||
|
with blobfile.BlobFile(blobpath, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
cache_key = hashlib.sha1(blobpath.encode()).hexdigest()
|
||||||
|
|
||||||
|
cache_path = os.path.join(cache_dir, cache_key)
|
||||||
|
if os.path.exists(cache_path):
|
||||||
|
with open(cache_path, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
with blobfile.BlobFile(blobpath, "rb") as f:
|
||||||
|
contents = f.read()
|
||||||
|
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp"
|
||||||
|
with open(tmp_filename, "wb") as f:
|
||||||
|
f.write(contents)
|
||||||
|
os.rename(tmp_filename, cache_path)
|
||||||
|
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
def data_gym_to_mergeable_bpe_ranks(
|
||||||
|
vocab_bpe_file: str, encoder_json_file: str
|
||||||
|
) -> dict[bytes, int]:
|
||||||
|
# NB: do not add caching to this function
|
||||||
|
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
|
||||||
|
|
||||||
|
data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte}
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in rank_to_intbyte:
|
||||||
|
rank_to_intbyte.append(b)
|
||||||
|
data_gym_byte_to_byte[chr(2**8 + n)] = b
|
||||||
|
n += 1
|
||||||
|
assert len(rank_to_intbyte) == 2**8
|
||||||
|
|
||||||
|
# vocab_bpe contains the merges along with associated ranks
|
||||||
|
vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode()
|
||||||
|
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
|
||||||
|
|
||||||
|
def decode_data_gym(value: str) -> bytes:
|
||||||
|
return bytes(data_gym_byte_to_byte[b] for b in value)
|
||||||
|
|
||||||
|
# add the single byte tokens
|
||||||
|
bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)}
|
||||||
|
# add the merged tokens
|
||||||
|
n = len(bpe_ranks)
|
||||||
|
for first, second in bpe_merges:
|
||||||
|
bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
# check that the encoder file matches the merges file
|
||||||
|
# this sanity check is important since tiktoken assumes that ranks are ordered the same
|
||||||
|
# as merge priority
|
||||||
|
encoder_json = json.loads(read_file_cached(encoder_json_file))
|
||||||
|
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
|
||||||
|
# drop these two special tokens if present, since they're not mergeable bpe tokens
|
||||||
|
encoder_json_loaded.pop(b"<|endoftext|>", None)
|
||||||
|
encoder_json_loaded.pop(b"<|startoftext|>", None)
|
||||||
|
assert bpe_ranks == encoder_json_loaded
|
||||||
|
|
||||||
|
return bpe_ranks
|
||||||
|
|
||||||
|
|
||||||
|
def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None:
|
||||||
|
with blobfile.BlobFile(tiktoken_bpe_file, "wb") as f:
|
||||||
|
for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
|
||||||
|
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
|
||||||
|
|
||||||
|
|
||||||
|
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
|
||||||
|
# NB: do not add caching to this function
|
||||||
|
contents = read_file_cached(tiktoken_bpe_file)
|
||||||
|
return {
|
||||||
|
base64.b64decode(token): int(rank)
|
||||||
|
for token, rank in (line.split() for line in contents.splitlines() if line)
|
||||||
|
}
|
71
tiktoken/registry.py
Normal file
71
tiktoken/registry.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import importlib
|
||||||
|
import pkgutil
|
||||||
|
import threading
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import tiktoken_ext
|
||||||
|
|
||||||
|
from tiktoken.core import Encoding
|
||||||
|
|
||||||
|
_lock = threading.RLock()
|
||||||
|
ENCODINGS: dict[str, Encoding] = {}
|
||||||
|
ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_constructors() -> None:
|
||||||
|
global ENCODING_CONSTRUCTORS
|
||||||
|
with _lock:
|
||||||
|
if ENCODING_CONSTRUCTORS is not None:
|
||||||
|
return
|
||||||
|
ENCODING_CONSTRUCTORS = {}
|
||||||
|
|
||||||
|
# tiktoken_ext is a namespace package
|
||||||
|
# submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes
|
||||||
|
# - we use namespace package pattern so `pkgutil.iter_modules` is fast
|
||||||
|
# - it's a separate top-level package because namespace subpackages of non-namespace
|
||||||
|
# packages don't quite do what you want with editable installs
|
||||||
|
plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".")
|
||||||
|
|
||||||
|
for _, mod_name, _ in plugin_mods:
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
try:
|
||||||
|
constructors = mod.ENCODING_CONSTRUCTORS
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS"
|
||||||
|
) from e
|
||||||
|
for enc_name, constructor in constructors.items():
|
||||||
|
if enc_name in ENCODING_CONSTRUCTORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}"
|
||||||
|
)
|
||||||
|
ENCODING_CONSTRUCTORS[enc_name] = constructor
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoding(encoding_name: str) -> Encoding:
|
||||||
|
if encoding_name in ENCODINGS:
|
||||||
|
return ENCODINGS[encoding_name]
|
||||||
|
|
||||||
|
with _lock:
|
||||||
|
if encoding_name in ENCODINGS:
|
||||||
|
return ENCODINGS[encoding_name]
|
||||||
|
|
||||||
|
if ENCODING_CONSTRUCTORS is None:
|
||||||
|
_find_constructors()
|
||||||
|
assert ENCODING_CONSTRUCTORS is not None
|
||||||
|
|
||||||
|
if encoding_name not in ENCODING_CONSTRUCTORS:
|
||||||
|
raise ValueError(f"Unknown encoding {encoding_name}")
|
||||||
|
|
||||||
|
constructor = ENCODING_CONSTRUCTORS[encoding_name]
|
||||||
|
enc = Encoding(**constructor())
|
||||||
|
ENCODINGS[encoding_name] = enc
|
||||||
|
return enc
|
||||||
|
|
||||||
|
|
||||||
|
def list_encoding_names() -> list[str]:
|
||||||
|
with _lock:
|
||||||
|
if ENCODING_CONSTRUCTORS is None:
|
||||||
|
_find_constructors()
|
||||||
|
assert ENCODING_CONSTRUCTORS is not None
|
||||||
|
return list(ENCODING_CONSTRUCTORS)
|
41
tiktoken_ext/openai_public.py
Normal file
41
tiktoken_ext/openai_public.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe
|
||||||
|
|
||||||
|
ENDOFTEXT = "<|endoftext|>"
|
||||||
|
FIM_PREFIX = "<|fim_prefix|>"
|
||||||
|
FIM_MIDDLE = "<|fim_middle|>"
|
||||||
|
FIM_SUFFIX = "<|fim_suffix|>"
|
||||||
|
ENDOFPROMPT = "<|endofprompt|>"
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2():
|
||||||
|
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
|
||||||
|
vocab_bpe_file="az://openaipublic/gpt-2/encodings/main/vocab.bpe",
|
||||||
|
encoder_json_file="az://openaipublic/gpt-2/encodings/main/encoder.json",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"name": "gpt2",
|
||||||
|
"explicit_n_vocab": 50257,
|
||||||
|
"pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||||
|
"mergeable_ranks": mergeable_ranks,
|
||||||
|
"special_tokens": {"<|endoftext|>": 50256},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def cl100k_base():
|
||||||
|
mergeable_ranks = load_tiktoken_bpe("az://openaipublic/encodings/cl100k_base.tiktoken")
|
||||||
|
special_tokens = {
|
||||||
|
ENDOFTEXT: 100257,
|
||||||
|
FIM_PREFIX: 100258,
|
||||||
|
FIM_MIDDLE: 100259,
|
||||||
|
FIM_SUFFIX: 100260,
|
||||||
|
ENDOFPROMPT: 100276,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"name": "cl100k_base",
|
||||||
|
"pat_str": r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
||||||
|
"mergeable_ranks": mergeable_ranks,
|
||||||
|
"special_tokens": special_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ENCODING_CONSTRUCTORS = {"gpt2": gpt2, "cl100k_base": cl100k_base}
|
Loading…
x
Reference in New Issue
Block a user