Fix tests
This commit is contained in:
parent
85a4f9dbb0
commit
caf916236f
@ -13,6 +13,9 @@ regex = "1.7.0"
|
||||
rustc-hash = "1.1.0"
|
||||
bstr = "1.0.1"
|
||||
anyhow = "1.0.70"
|
||||
base64 = "0.21.0"
|
||||
reqwest = "0.11.15"
|
||||
tokio = { version = "1.26.0", features = ["full"] }
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
||||
|
51
src/lib.rs
51
src/lib.rs
@ -1,6 +1,8 @@
|
||||
// This check is new and seems buggy (possibly with PyO3 interaction)
|
||||
#![allow(clippy::borrow_deref_ref)]
|
||||
|
||||
pub mod model;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::thread;
|
||||
|
||||
@ -165,7 +167,7 @@ fn hash_current_thread() -> usize {
|
||||
}
|
||||
|
||||
const MAX_NUM_THREADS: usize = 128;
|
||||
struct CoreBPE {
|
||||
pub struct CoreBPE {
|
||||
encoder: HashMap<Vec<u8>, usize>,
|
||||
special_tokens_encoder: HashMap<String, usize>,
|
||||
decoder: HashMap<usize, Vec<u8>>,
|
||||
@ -545,6 +547,9 @@ impl CoreBPE {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use bstr::ByteSlice;
|
||||
use rustc_hash::FxHashMap as HashMap;
|
||||
|
||||
use crate::byte_pair_split;
|
||||
@ -558,4 +563,48 @@ mod tests {
|
||||
let res = byte_pair_split(b"abcd", &ranks);
|
||||
assert_eq!(res, vec![b"ab", b"cd"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_model() {
|
||||
let model = crate::model::model_gpt2().await;
|
||||
assert!(model.is_ok(), "Could not download model (model_gpt2): {:?}", model);
|
||||
|
||||
let model = crate::model::gpt2(model.unwrap());
|
||||
assert!(model.is_ok(), "Could not load model (gpt2): {:?}", model.err().unwrap());
|
||||
|
||||
let model = crate::model::model_cl100k_base().await;
|
||||
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
|
||||
|
||||
let model = crate::model::cl100k_base(model.unwrap());
|
||||
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
|
||||
|
||||
let model = crate::model::model_p50k_base().await;
|
||||
assert!(model.is_ok(), "Could not download model (model_p50k_base): {:?}", model);
|
||||
|
||||
let model = crate::model::p50k_base(model.unwrap());
|
||||
assert!(model.is_ok(), "Could not load model (p50k_base): {:?}", model.err().unwrap());
|
||||
|
||||
let model = crate::model::model_p50k_edit().await;
|
||||
assert!(model.is_ok(), "Could not download model (model_p50k_edit): {:?}", model);
|
||||
|
||||
let model = crate::model::p50k_edit(model.unwrap());
|
||||
assert!(model.is_ok(), "Could not load model (p50k_edit): {:?}", model.err().unwrap());
|
||||
|
||||
let model = crate::model::model_r50k_base().await;
|
||||
assert!(model.is_ok(), "Could not download model (model_r50k_base): {:?}", model);
|
||||
|
||||
let model = crate::model::r50k_base(model.unwrap());
|
||||
assert!(model.is_ok(), "Could not load model (r50k_base): {:?}", model.err().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_encode_decode() {
|
||||
let model = crate::model::cl100k_base(crate::model::model_cl100k_base().await.unwrap()).unwrap();
|
||||
let input = "This is a test";
|
||||
let (encoded, _) = model.encode(input, model.special_tokens_encoder.keys().map(|entry| entry.as_str()).collect::<HashSet<&str>>());
|
||||
let decoded = model.decode_bytes(&encoded);
|
||||
let decoded_string = decoded.to_str();
|
||||
assert!(decoded_string.is_ok(), "Decoding failed: {:?}", decoded);
|
||||
assert_eq!(input.to_string(), decoded_string.unwrap().to_string());
|
||||
}
|
||||
}
|
||||
|
235
src/model.rs
Normal file
235
src/model.rs
Normal file
@ -0,0 +1,235 @@
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use bstr::ByteSlice;
|
||||
|
||||
use rustc_hash::FxHashMap as HashMap;
|
||||
use crate::CoreBPE;
|
||||
|
||||
const ENDOFTEXT: &str = "<|endoftext|>";
|
||||
const FIM_PREFIX: &str = "<|fim_prefix|>";
|
||||
const FIM_MIDDLE: &str = "<|fim_middle|>";
|
||||
const FIM_SUFFIX: &str = "<|fim_suffix|>";
|
||||
const ENDOFPROMPT: &str = "<|endofprompt|>";
|
||||
|
||||
pub fn data_gym_to_mergeable_bpe_ranks(vocab_bpe: &str) -> Option<HashMap<Vec<u8>, usize>> {
|
||||
let mut bpe_ranks = HashMap::<Vec<u8>, usize>::default();
|
||||
let mut data_gym_byte_to_byte = HashMap::<char, u8>::default();
|
||||
|
||||
for chr in '!'..='~' {
|
||||
data_gym_byte_to_byte.insert(chr, chr as u8);
|
||||
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||
}
|
||||
|
||||
for chr in '¡'..='¬' {
|
||||
data_gym_byte_to_byte.insert(chr, chr as u8);
|
||||
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||
}
|
||||
|
||||
for chr in '®'..='ÿ' {
|
||||
data_gym_byte_to_byte.insert(chr, chr as u8);
|
||||
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||
}
|
||||
|
||||
let mut n = 0;
|
||||
|
||||
for chr in '\0'..=' ' {
|
||||
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), chr as u8);
|
||||
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||
n += 1;
|
||||
}
|
||||
|
||||
for chr in char::from_u32(127).unwrap()..=char::from_u32(160).unwrap() {
|
||||
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), chr as u8);
|
||||
bpe_ranks.insert(vec![chr as u8], chr as usize);
|
||||
n += 1;
|
||||
}
|
||||
|
||||
let del = char::from_u32(173).unwrap();
|
||||
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), del as u8);
|
||||
bpe_ranks.insert(vec![del as u8], del as usize);
|
||||
|
||||
let mut error = false;
|
||||
vocab_bpe
|
||||
.split("\n")
|
||||
.skip(1)
|
||||
.take_while(|line| !line.is_empty())
|
||||
.enumerate()
|
||||
.map_while(|(index, line)| {
|
||||
if line.len() == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let space_index = line.find(" ");
|
||||
if space_index.is_none() {
|
||||
error = true;
|
||||
println!("No space in: {}", line);
|
||||
return None;
|
||||
}
|
||||
let space_index = space_index.unwrap();
|
||||
|
||||
let mut inner_error = false;
|
||||
let key = line[..space_index]
|
||||
.chars()
|
||||
.map_while(|c| {
|
||||
if data_gym_byte_to_byte.contains_key(&c) {
|
||||
return Some(data_gym_byte_to_byte[&c]);
|
||||
}
|
||||
println!("Missing key for: {} ({})", c, c as u32);
|
||||
error = true;
|
||||
return None;
|
||||
})
|
||||
.chain(
|
||||
line[space_index + 1..]
|
||||
.chars()
|
||||
.map_while(|c| {
|
||||
if data_gym_byte_to_byte.contains_key(&c) {
|
||||
return Some(data_gym_byte_to_byte[&c]);
|
||||
}
|
||||
inner_error = true;
|
||||
println!("Missing key for: {} ({})", c, c as u32);
|
||||
return None;
|
||||
})
|
||||
)
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
if inner_error || error {
|
||||
return None;
|
||||
}
|
||||
|
||||
bpe_ranks.insert(
|
||||
key,
|
||||
index + 256
|
||||
);
|
||||
|
||||
return Some(());
|
||||
})
|
||||
.for_each(|_| {});
|
||||
|
||||
if error {
|
||||
return None;
|
||||
}
|
||||
|
||||
return Some(bpe_ranks);
|
||||
}
|
||||
|
||||
pub fn load_tiktoken_bpe(tiktoken_bpe: &str) -> Option<HashMap<Vec<u8>, usize>> {
|
||||
let mut error = false;
|
||||
let result = tiktoken_bpe
|
||||
.split("\n")
|
||||
.map_while(|line: &str| {
|
||||
if line.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let space_index = line.find(" ");
|
||||
if space_index.is_none() {
|
||||
error = true;
|
||||
return None;
|
||||
}
|
||||
let space_index = space_index.unwrap();
|
||||
let b64 = BASE64_STANDARD.decode(&line[..space_index]).ok();
|
||||
if b64.is_none() {
|
||||
error = true;
|
||||
return None;
|
||||
}
|
||||
let size = usize::from_str_radix(&line[space_index + 1..], 10).ok();
|
||||
if size.is_none() {
|
||||
error = true;
|
||||
return None;
|
||||
}
|
||||
|
||||
return Some((b64.unwrap(), size.unwrap()));
|
||||
})
|
||||
.collect();
|
||||
|
||||
if error {
|
||||
return None;
|
||||
}
|
||||
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
async fn get_model(url: &str) -> anyhow::Result<String> {
|
||||
Ok(reqwest::get(url).await?.bytes().await?.to_str()?.to_string())
|
||||
}
|
||||
|
||||
pub async fn model_gpt2() -> anyhow::Result<String> {
|
||||
get_model("https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe").await
|
||||
}
|
||||
|
||||
pub fn gpt2(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||
let mut special_tokens = HashMap::<String, usize>::default();
|
||||
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
|
||||
|
||||
return CoreBPE::new(
|
||||
data_gym_to_mergeable_bpe_ranks(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||
special_tokens,
|
||||
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn model_r50k_base() -> anyhow::Result<String> {
|
||||
get_model("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken").await
|
||||
}
|
||||
|
||||
pub fn r50k_base(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||
let mut special_tokens = HashMap::<String, usize>::default();
|
||||
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
|
||||
|
||||
return CoreBPE::new(
|
||||
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||
special_tokens,
|
||||
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn model_p50k_base() -> anyhow::Result<String> {
|
||||
get_model("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken").await
|
||||
}
|
||||
|
||||
pub fn p50k_base(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||
let mut special_tokens = HashMap::<String, usize>::default();
|
||||
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
|
||||
|
||||
return CoreBPE::new(
|
||||
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||
special_tokens,
|
||||
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn model_p50k_edit() -> anyhow::Result<String> {
|
||||
get_model("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken").await
|
||||
}
|
||||
|
||||
pub fn p50k_edit(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||
let mut special_tokens = HashMap::<String, usize>::default();
|
||||
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
|
||||
special_tokens.insert(FIM_PREFIX.to_string(), 50281);
|
||||
special_tokens.insert(FIM_MIDDLE.to_string(), 50282);
|
||||
special_tokens.insert(FIM_SUFFIX.to_string(), 50283);
|
||||
|
||||
return CoreBPE::new(
|
||||
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||
special_tokens,
|
||||
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn model_cl100k_base() -> anyhow::Result<String> {
|
||||
get_model("https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken").await
|
||||
}
|
||||
|
||||
pub fn cl100k_base(model_file: String) -> anyhow::Result<CoreBPE> {
|
||||
let mut special_tokens = HashMap::<String, usize>::default();
|
||||
special_tokens.insert(ENDOFTEXT.to_string(), 50257);
|
||||
special_tokens.insert(FIM_PREFIX.to_string(), 50258);
|
||||
special_tokens.insert(FIM_MIDDLE.to_string(), 50259);
|
||||
special_tokens.insert(FIM_SUFFIX.to_string(), 50260);
|
||||
special_tokens.insert(ENDOFPROMPT.to_string(), 50276);
|
||||
|
||||
return CoreBPE::new(
|
||||
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
|
||||
special_tokens,
|
||||
&"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
|
||||
);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user