Update codebase

This commit is contained in:
Shantanu Jain 2023-01-03 13:20:35 -08:00
parent 0f8ec705e2
commit 40d9b1f14e
12 changed files with 57 additions and 13 deletions

View File

@ -16,7 +16,7 @@ jobs:
# cibuildwheel builds linux wheels inside a manylinux container # cibuildwheel builds linux wheels inside a manylinux container
# it also takes care of procuring the correct python version for us # it also takes care of procuring the correct python version for us
os: [ubuntu-latest, windows-latest, macos-latest] os: [ubuntu-latest, windows-latest, macos-latest]
python-version: [39, 310, 311] python-version: [38, 39, 310, 311]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3

12
CHANGELOG.md Normal file
View File

@ -0,0 +1,12 @@
# Changelog
This is the changelog for the open source version of tiktoken.
## [v0.1.2]
- Avoid use of `blobfile` for public files
- Add support for Python 3.8
- Add py.typed
- Improve the public tests
## [v0.1.1]
- Initial release

View File

@ -1,6 +1,8 @@
include *.svg include *.svg
include *.toml include *.toml
include *.md
include Makefile include Makefile
global-include py.typed
recursive-include scripts *.py recursive-include scripts *.py
recursive-include tests *.py recursive-include tests *.py
recursive-include src *.rs recursive-include src *.rs

View File

@ -1,8 +1,8 @@
[project] [project]
name = "tiktoken" name = "tiktoken"
dependencies = ["blobfile>=2", "regex>=2022.1.18"] dependencies = ["blobfile>=2", "regex>=2022.1.18", "requests>=2.26.0"]
dynamic = ["version"] dynamic = ["version"]
requires-python = ">=3.9" requires-python = ">=3.8"
[build-system] [build-system]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@ -9,6 +9,8 @@ def redact_file(path: Path, dry_run: bool) -> None:
return return
text = path.read_text() text = path.read_text()
if not text:
return
first_line = text.splitlines()[0] first_line = text.splitlines()[0]
if "redact" in first_line: if "redact" in first_line:

View File

@ -4,7 +4,7 @@ from setuptools_rust import Binding, RustExtension
public = True public = True
if public: if public:
version = "0.1.1" version = "0.1.2"
setup( setup(
name="tiktoken", name="tiktoken",
@ -18,6 +18,7 @@ setup(
debug=False, debug=False,
) )
], ],
package_data={"tiktoken": ["py.typed"]},
packages=["tiktoken", "tiktoken_ext"], packages=["tiktoken", "tiktoken_ext"],
zip_safe=False, zip_safe=False,
) )

View File

@ -2,10 +2,18 @@ import tiktoken
def test_simple(): def test_simple():
# Note that there are more actual tests, they're just not currently public :-)
enc = tiktoken.get_encoding("gpt2") enc = tiktoken.get_encoding("gpt2")
assert enc.encode("hello world") == [31373, 995] assert enc.encode("hello world") == [31373, 995]
assert enc.decode([31373, 995]) == "hello world" assert enc.decode([31373, 995]) == "hello world"
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256]
enc = tiktoken.get_encoding("cl100k_base") enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("hello world") == [15339, 1917] assert enc.encode("hello world") == [15339, 1917]
assert enc.decode([15339, 1917]) == "hello world" assert enc.decode([15339, 1917]) == "hello world"
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257]
for enc_name in tiktoken.list_encoding_names():
enc = tiktoken.get_encoding(enc_name)
for token in range(10_000):
assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import functools import functools
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64 import base64
import hashlib import hashlib
import json import json
@ -5,6 +7,15 @@ import os
import uuid import uuid
import blobfile import blobfile
import requests
def read_file(blobpath: str) -> bytes:
if not blobpath.startswith("http://") and not blobpath.startswith("https://"):
with blobfile.BlobFile(blobpath, "rb") as f:
return f.read()
# avoiding blobfile for public files helps avoid auth issues, like MFA prompts
return requests.get(blobpath).content
def read_file_cached(blobpath: str) -> bytes: def read_file_cached(blobpath: str) -> bytes:
@ -17,8 +28,7 @@ def read_file_cached(blobpath: str) -> bytes:
if cache_dir == "": if cache_dir == "":
# disable caching # disable caching
with blobfile.BlobFile(blobpath, "rb") as f: return read_file(blobpath)
return f.read()
cache_key = hashlib.sha1(blobpath.encode()).hexdigest() cache_key = hashlib.sha1(blobpath.encode()).hexdigest()
@ -27,8 +37,7 @@ def read_file_cached(blobpath: str) -> bytes:
with open(cache_path, "rb") as f: with open(cache_path, "rb") as f:
return f.read() return f.read()
with blobfile.BlobFile(blobpath, "rb") as f: contents = read_file(blobpath)
contents = f.read()
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp" tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp"

0
tiktoken/py.typed Normal file
View File

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import importlib import importlib
import pkgutil import pkgutil
import threading import threading

View File

@ -9,8 +9,8 @@ ENDOFPROMPT = "<|endofprompt|>"
def gpt2(): def gpt2():
mergeable_ranks = data_gym_to_mergeable_bpe_ranks( mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file="az://openaipublic/gpt-2/encodings/main/vocab.bpe", vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
encoder_json_file="az://openaipublic/gpt-2/encodings/main/encoder.json", encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
) )
return { return {
"name": "gpt2", "name": "gpt2",
@ -22,7 +22,9 @@ def gpt2():
def r50k_base(): def r50k_base():
mergeable_ranks = load_tiktoken_bpe("az://openaipublic/encodings/r50k_base.tiktoken") mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
)
return { return {
"name": "r50k_base", "name": "r50k_base",
"explicit_n_vocab": 50257, "explicit_n_vocab": 50257,
@ -33,7 +35,9 @@ def r50k_base():
def p50k_base(): def p50k_base():
mergeable_ranks = load_tiktoken_bpe("az://openaipublic/encodings/p50k_base.tiktoken") mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
)
return { return {
"name": "p50k_base", "name": "p50k_base",
"explicit_n_vocab": 50281, "explicit_n_vocab": 50281,
@ -44,7 +48,9 @@ def p50k_base():
def cl100k_base(): def cl100k_base():
mergeable_ranks = load_tiktoken_bpe("az://openaipublic/encodings/cl100k_base.tiktoken") mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
)
special_tokens = { special_tokens = {
ENDOFTEXT: 100257, ENDOFTEXT: 100257,
FIM_PREFIX: 100258, FIM_PREFIX: 100258,