Fix tests
This commit is contained in:
parent
85a4f9dbb0
commit
caf916236f
@ -13,6 +13,9 @@ regex = "1.7.0"
|
|||||||
rustc-hash = "1.1.0"
|
rustc-hash = "1.1.0"
|
||||||
bstr = "1.0.1"
|
bstr = "1.0.1"
|
||||||
anyhow = "1.0.70"
|
anyhow = "1.0.70"
|
||||||
|
base64 = "0.21.0"
|
||||||
|
reqwest = "0.11.15"
|
||||||
|
tokio = { version = "1.26.0", features = ["full"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
incremental = true
|
incremental = true
|
||||||
|
51
src/lib.rs
51
src/lib.rs
@ -1,6 +1,8 @@
|
|||||||
// This check is new and seems buggy (possibly with PyO3 interaction)
|
// This check is new and seems buggy (possibly with PyO3 interaction)
|
||||||
#![allow(clippy::borrow_deref_ref)]
|
#![allow(clippy::borrow_deref_ref)]
|
||||||
|
|
||||||
|
pub mod model;
|
||||||
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
|
|
||||||
@ -165,7 +167,7 @@ fn hash_current_thread() -> usize {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const MAX_NUM_THREADS: usize = 128;
|
const MAX_NUM_THREADS: usize = 128;
|
||||||
struct CoreBPE {
|
pub struct CoreBPE {
|
||||||
encoder: HashMap<Vec<u8>, usize>,
|
encoder: HashMap<Vec<u8>, usize>,
|
||||||
special_tokens_encoder: HashMap<String, usize>,
|
special_tokens_encoder: HashMap<String, usize>,
|
||||||
decoder: HashMap<usize, Vec<u8>>,
|
decoder: HashMap<usize, Vec<u8>>,
|
||||||
@ -545,6 +547,9 @@ impl CoreBPE {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use bstr::ByteSlice;
|
||||||
use rustc_hash::FxHashMap as HashMap;
|
use rustc_hash::FxHashMap as HashMap;
|
||||||
|
|
||||||
use crate::byte_pair_split;
|
use crate::byte_pair_split;
|
||||||
@ -558,4 +563,48 @@ mod tests {
|
|||||||
let res = byte_pair_split(b"abcd", &ranks);
|
let res = byte_pair_split(b"abcd", &ranks);
|
||||||
assert_eq!(res, vec![b"ab", b"cd"]);
|
assert_eq!(res, vec![b"ab", b"cd"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_load_model() {
|
||||||
|
let model = crate::model::model_gpt2().await;
|
||||||
|
assert!(model.is_ok(), "Could not download model (model_gpt2): {:?}", model);
|
||||||
|
|
||||||
|
let model = crate::model::gpt2(model.unwrap());
|
||||||
|
assert!(model.is_ok(), "Could not load model (gpt2): {:?}", model.err().unwrap());
|
||||||
|
|
||||||
|
let model = crate::model::model_cl100k_base().await;
|
||||||
|
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
|
||||||
|
|
||||||
|
let model = crate::model::cl100k_base(model.unwrap());
|
||||||
|
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
|
||||||
|
|
||||||
|
let model = crate::model::model_p50k_base().await;
|
||||||
|
assert!(model.is_ok(), "Could not download model (model_p50k_base): {:?}", model);
|
||||||
|
|
||||||
|
let model = crate::model::p50k_base(model.unwrap());
|
||||||
|
assert!(model.is_ok(), "Could not load model (p50k_base): {:?}", model.err().unwrap());
|
||||||
|
|
||||||
|
let model = crate::model::model_p50k_edit().await;
|
||||||
|
assert!(model.is_ok(), "Could not download model (model_p50k_edit): {:?}", model);
|
||||||
|
|
||||||
|
let model = crate::model::p50k_edit(model.unwrap());
|
||||||
|
assert!(model.is_ok(), "Could not load model (p50k_edit): {:?}", model.err().unwrap());
|
||||||
|
|
||||||
|
let model = crate::model::model_r50k_base().await;
|
||||||
|
assert!(model.is_ok(), "Could not download model (model_r50k_base): {:?}", model);
|
||||||
|
|
||||||
|
let model = crate::model::r50k_base(model.unwrap());
|
||||||
|
assert!(model.is_ok(), "Could not load model (r50k_base): {:?}", model.err().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_encode_decode() {
|
||||||
|
let model = crate::model::cl100k_base(crate::model::model_cl100k_base().await.unwrap()).unwrap();
|
||||||
|
let input = "This is a test";
|
||||||
|
let (encoded, _) = model.encode(input, model.special_tokens_encoder.keys().map(|entry| entry.as_str()).collect::<HashSet<&str>>());
|
||||||
|
let decoded = model.decode_bytes(&encoded);
|
||||||
|
let decoded_string = decoded.to_str();
|
||||||
|
assert!(decoded_string.is_ok(), "Decoding failed: {:?}", decoded);
|
||||||
|
assert_eq!(input.to_string(), decoded_string.unwrap().to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
235
src/model.rs
Normal file
235
src/model.rs
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||||
|
use bstr::ByteSlice;
|
||||||
|
|
||||||
|
use rustc_hash::FxHashMap as HashMap;
|
||||||
|
use crate::CoreBPE;
|
||||||
|
|
||||||
|
const ENDOFTEXT: &str = "<|endoftext|>";
|
||||||
|
const FIM_PREFIX: &str = "<|fim_prefix|>";
|
||||||
|
const FIM_MIDDLE: &str = "<|fim_middle|>";
|
||||||
|
const FIM_SUFFIX: &str = "<|fim_suffix|>";
|
||||||
|
const ENDOFPROMPT: &str = "<|endofprompt|>";
|
||||||
|
|
||||||
|
pub fn data_gym_to_mergeable_bpe_ranks(vocab_bpe: &str) -> Option<HashMap<Vec<u8>, usize>> {
|
||||||
|
let mut bpe_ranks = HashMap::<Vec<u8>, usize>::default();
|
||||||
|
let mut data_gym_byte_to_byte = HashMap::<char, u8>::default();
|
||||||
|
|
||||||
|
for chr in '!'..='~' {
|
||||||
|
data_gym_byte_to_byte.insert(chr, chr as u8);
|
||||||
|
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
for chr in '¡'..='¬' {
|
||||||
|
data_gym_byte_to_byte.insert(chr, chr as u8);
|
||||||
|
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
for chr in '®'..='ÿ' {
|
||||||
|
data_gym_byte_to_byte.insert(chr, chr as u8);
|
||||||
|
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut n = 0;
|
||||||
|
|
||||||
|
for chr in '\0'..=' ' {
|
||||||
|
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), chr as u8);
|
||||||
|
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||||
|
n += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for chr in char::from_u32(127).unwrap()..=char::from_u32(160).unwrap() {
|
||||||
|
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), chr as u8);
|
||||||
|
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||||
|
n += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let del = char::from_u32(173).unwrap();
|
||||||
|
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), del as u8);
|
||||||
|
bpe_ranks.insert(vec![del as u8], del as usize);
|
||||||
|
|
||||||
|
let mut error = false;
|
||||||
|
vocab_bpe
|
||||||
|
.split("\n")
|
||||||
|
.skip(1)
|
||||||
|
.take_while(|line| !line.is_empty())
|
||||||
|
.enumerate()
|
||||||
|
.map_while(|(index, line)| {
|
||||||
|
if line.len() == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let space_index = line.find(" ");
|
||||||
|
if space_index.is_none() {
|
||||||
|
error = true;
|
||||||
|
println!("No space in: {}", line);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let space_index = space_index.unwrap();
|
||||||
|
|
||||||
|
let mut inner_error = false;
|
||||||
|
let key = line[..space_index]
|
||||||
|
.chars()
|
||||||
|
.map_while(|c| {
|
||||||
|
if data_gym_byte_to_byte.contains_key(&c) {
|
||||||
|
return Some(data_gym_byte_to_byte[&c]);
|
||||||
|
}
|
||||||
|
println!("Missing key for: {} ({})", c, c as u32);
|
||||||
|
error = true;
|
||||||
|
return None;
|
||||||
|
})
|
||||||
|
.chain(
|
||||||
|
line[space_index + 1..]
|
||||||
|
.chars()
|
||||||
|
.map_while(|c| {
|
||||||
|
if data_gym_byte_to_byte.contains_key(&c) {
|
||||||
|
return Some(data_gym_byte_to_byte[&c]);
|
||||||
|
}
|
||||||
|
inner_error = true;
|
||||||
|
println!("Missing key for: {} ({})", c, c as u32);
|
||||||
|
return None;
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.collect::<Vec<u8>>();
|
||||||
|
|
||||||
|
if inner_error || error {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
bpe_ranks.insert(
|
||||||
|
key,
|
||||||
|
index + 256
|
||||||
|
);
|
||||||
|
|
||||||
|
return Some(());
|
||||||
|
})
|
||||||
|
.for_each(|_| {});
|
||||||
|
|
||||||
|
if error {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Some(bpe_ranks);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_tiktoken_bpe(tiktoken_bpe: &str) -> Option<HashMap<Vec<u8>, usize>> {
|
||||||
|
let mut error = false;
|
||||||
|
let result = tiktoken_bpe
|
||||||
|
.split("\n")
|
||||||
|
.map_while(|line: &str| {
|
||||||
|
if line.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let space_index = line.find(" ");
|
||||||
|
if space_index.is_none() {
|
||||||
|
error = true;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let space_index = space_index.unwrap();
|
||||||
|
let b64 = BASE64_STANDARD.decode(&line[..space_index]).ok();
|
||||||
|
if b64.is_none() {
|
||||||
|
error = true;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let size = usize::from_str_radix(&line[space_index + 1..], 10).ok();
|
||||||
|
if size.is_none() {
|
||||||
|
error = true;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Some((b64.unwrap(), size.unwrap()));
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if error {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Some(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_model(url: &str) -> anyhow::Result<String> {
|
||||||
|
Ok(reqwest::get(url).await?.bytes().await?.to_str()?.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn model_gpt2() -> anyhow::Result<String> {
|
||||||
|
get_model("https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe").await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gpt2(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||||
|
let mut special_tokens = HashMap::<String, usize>::default();
|
||||||
|
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
|
||||||
|
|
||||||
|
return CoreBPE::new(
|
||||||
|
data_gym_to_mergeable_bpe_ranks(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||||
|
special_tokens,
|
||||||
|
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn model_r50k_base() -> anyhow::Result<String> {
|
||||||
|
get_model("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken").await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn r50k_base(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||||
|
let mut special_tokens = HashMap::<String, usize>::default();
|
||||||
|
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
|
||||||
|
|
||||||
|
return CoreBPE::new(
|
||||||
|
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||||
|
special_tokens,
|
||||||
|
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn model_p50k_base() -> anyhow::Result<String> {
|
||||||
|
get_model("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken").await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn p50k_base(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||||
|
let mut special_tokens = HashMap::<String, usize>::default();
|
||||||
|
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
|
||||||
|
|
||||||
|
return CoreBPE::new(
|
||||||
|
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||||
|
special_tokens,
|
||||||
|
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn model_p50k_edit() -> anyhow::Result<String> {
|
||||||
|
get_model("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken").await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn p50k_edit(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||||
|
let mut special_tokens = HashMap::<String, usize>::default();
|
||||||
|
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
|
||||||
|
special_tokens.insert(FIM_PREFIX.to_string(), 50281);
|
||||||
|
special_tokens.insert(FIM_MIDDLE.to_string(), 50282);
|
||||||
|
special_tokens.insert(FIM_SUFFIX.to_string(), 50283);
|
||||||
|
|
||||||
|
return CoreBPE::new(
|
||||||
|
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||||
|
special_tokens,
|
||||||
|
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn model_cl100k_base() -> anyhow::Result<String> {
|
||||||
|
get_model("https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken").await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cl100k_base(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||||
|
let mut special_tokens = HashMap::<String, usize>::default();
|
||||||
|
special_tokens.insert(ENDOFTEXT.to_string(), 50257);
|
||||||
|
special_tokens.insert(FIM_PREFIX.to_string(), 50258);
|
||||||
|
special_tokens.insert(FIM_MIDDLE.to_string(), 50259);
|
||||||
|
special_tokens.insert(FIM_SUFFIX.to_string(), 50260);
|
||||||
|
special_tokens.insert(ENDOFPROMPT.to_string(), 50276);
|
||||||
|
|
||||||
|
return CoreBPE::new(
|
||||||
|
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||||
|
special_tokens,
|
||||||
|
&"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
);
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user