Fix tests

This commit is contained in:
Gabriel Tofvesson 2023-03-23 19:52:17 +01:00
parent 85a4f9dbb0
commit caf916236f
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
3 changed files with 288 additions and 1 deletions

View File

@ -13,6 +13,9 @@ regex = "1.7.0"
rustc-hash = "1.1.0"
bstr = "1.0.1"
anyhow = "1.0.70"
base64 = "0.21.0"
reqwest = "0.11.15"
tokio = { version = "1.26.0", features = ["full"] }
[profile.release]
incremental = true

View File

@ -1,6 +1,8 @@
// This check is new and seems buggy (possibly with PyO3 interaction)
#![allow(clippy::borrow_deref_ref)]
pub mod model;
use std::collections::HashSet;
use std::thread;
@ -165,7 +167,7 @@ fn hash_current_thread() -> usize {
}
const MAX_NUM_THREADS: usize = 128;
struct CoreBPE {
pub struct CoreBPE {
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, Vec<u8>>,
@ -545,6 +547,9 @@ impl CoreBPE {
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use bstr::ByteSlice;
use rustc_hash::FxHashMap as HashMap;
use crate::byte_pair_split;
@ -558,4 +563,48 @@ mod tests {
let res = byte_pair_split(b"abcd", &ranks);
assert_eq!(res, vec![b"ab", b"cd"]);
}
#[tokio::test]
async fn test_load_model() {
let model = crate::model::model_gpt2().await;
assert!(model.is_ok(), "Could not download model (model_gpt2): {:?}", model);
let model = crate::model::gpt2(model.unwrap());
assert!(model.is_ok(), "Could not load model (gpt2): {:?}", model.err().unwrap());
let model = crate::model::model_cl100k_base().await;
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
let model = crate::model::cl100k_base(model.unwrap());
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
let model = crate::model::model_p50k_base().await;
assert!(model.is_ok(), "Could not download model (model_p50k_base): {:?}", model);
let model = crate::model::p50k_base(model.unwrap());
assert!(model.is_ok(), "Could not load model (p50k_base): {:?}", model.err().unwrap());
let model = crate::model::model_p50k_edit().await;
assert!(model.is_ok(), "Could not download model (model_p50k_edit): {:?}", model);
let model = crate::model::p50k_edit(model.unwrap());
assert!(model.is_ok(), "Could not load model (p50k_edit): {:?}", model.err().unwrap());
let model = crate::model::model_r50k_base().await;
assert!(model.is_ok(), "Could not download model (model_r50k_base): {:?}", model);
let model = crate::model::r50k_base(model.unwrap());
assert!(model.is_ok(), "Could not load model (r50k_base): {:?}", model.err().unwrap());
}
#[tokio::test]
async fn test_model_encode_decode() {
let model = crate::model::cl100k_base(crate::model::model_cl100k_base().await.unwrap()).unwrap();
let input = "This is a test";
let (encoded, _) = model.encode(input, model.special_tokens_encoder.keys().map(|entry| entry.as_str()).collect::<HashSet<&str>>());
let decoded = model.decode_bytes(&encoded);
let decoded_string = decoded.to_str();
assert!(decoded_string.is_ok(), "Decoding failed: {:?}", decoded);
assert_eq!(input.to_string(), decoded_string.unwrap().to_string());
}
}

235
src/model.rs Normal file
View File

@ -0,0 +1,235 @@
use base64::{prelude::BASE64_STANDARD, Engine};
use bstr::ByteSlice;
use rustc_hash::FxHashMap as HashMap;
use crate::CoreBPE;
const ENDOFTEXT: &str = "<|endoftext|>";
const FIM_PREFIX: &str = "<|fim_prefix|>";
const FIM_MIDDLE: &str = "<|fim_middle|>";
const FIM_SUFFIX: &str = "<|fim_suffix|>";
const ENDOFPROMPT: &str = "<|endofprompt|>";
pub fn data_gym_to_mergeable_bpe_ranks(vocab_bpe: &str) -> Option<HashMap<Vec<u8>, usize>> {
let mut bpe_ranks = HashMap::<Vec<u8>, usize>::default();
let mut data_gym_byte_to_byte = HashMap::<char, u8>::default();
for chr in '!'..='~' {
data_gym_byte_to_byte.insert(chr, chr as u8);
bpe_ranks.insert(vec![chr as u8], chr as usize);
}
for chr in '¡'..='¬' {
data_gym_byte_to_byte.insert(chr, chr as u8);
bpe_ranks.insert(vec![chr as u8], chr as usize);
}
for chr in '®'..='ÿ' {
data_gym_byte_to_byte.insert(chr, chr as u8);
bpe_ranks.insert(vec![chr as u8], chr as usize);
}
let mut n = 0;
for chr in '\0'..=' ' {
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), chr as u8);
bpe_ranks.insert(vec![chr as u8], chr as usize);
n += 1;
}
for chr in char::from_u32(127).unwrap()..=char::from_u32(160).unwrap() {
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), chr as u8);
bpe_ranks.insert(vec![chr as u8], chr as usize);
n += 1;
}
let del = char::from_u32(173).unwrap();
data_gym_byte_to_byte.insert(char::from_u32(n as u32 + 256).unwrap(), del as u8);
bpe_ranks.insert(vec![del as u8], del as usize);
let mut error = false;
vocab_bpe
.split("\n")
.skip(1)
.take_while(|line| !line.is_empty())
.enumerate()
.map_while(|(index, line)| {
if line.len() == 0 {
return None;
}
let space_index = line.find(" ");
if space_index.is_none() {
error = true;
println!("No space in: {}", line);
return None;
}
let space_index = space_index.unwrap();
let mut inner_error = false;
let key = line[..space_index]
.chars()
.map_while(|c| {
if data_gym_byte_to_byte.contains_key(&c) {
return Some(data_gym_byte_to_byte[&c]);
}
println!("Missing key for: {} ({})", c, c as u32);
error = true;
return None;
})
.chain(
line[space_index + 1..]
.chars()
.map_while(|c| {
if data_gym_byte_to_byte.contains_key(&c) {
return Some(data_gym_byte_to_byte[&c]);
}
inner_error = true;
println!("Missing key for: {} ({})", c, c as u32);
return None;
})
)
.collect::<Vec<u8>>();
if inner_error || error {
return None;
}
bpe_ranks.insert(
key,
index + 256
);
return Some(());
})
.for_each(|_| {});
if error {
return None;
}
return Some(bpe_ranks);
}
pub fn load_tiktoken_bpe(tiktoken_bpe: &str) -> Option<HashMap<Vec<u8>, usize>> {
let mut error = false;
let result = tiktoken_bpe
.split("\n")
.map_while(|line: &str| {
if line.is_empty() {
return None;
}
let space_index = line.find(" ");
if space_index.is_none() {
error = true;
return None;
}
let space_index = space_index.unwrap();
let b64 = BASE64_STANDARD.decode(&line[..space_index]).ok();
if b64.is_none() {
error = true;
return None;
}
let size = usize::from_str_radix(&line[space_index + 1..], 10).ok();
if size.is_none() {
error = true;
return None;
}
return Some((b64.unwrap(), size.unwrap()));
})
.collect();
if error {
return None;
}
return Some(result);
}
async fn get_model(url: &str) -> anyhow::Result<String> {
Ok(reqwest::get(url).await?.bytes().await?.to_str()?.to_string())
}
pub async fn model_gpt2() -> anyhow::Result<String> {
get_model("https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe").await
}
pub fn gpt2(model_file: String) -> anyhow::Result<CoreBPE> {
let mut special_tokens = HashMap::<String, usize>::default();
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
return CoreBPE::new(
data_gym_to_mergeable_bpe_ranks(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
special_tokens,
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
);
}
pub async fn model_r50k_base() -> anyhow::Result<String> {
get_model("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken").await
}
pub fn r50k_base(model_file: String) -> anyhow::Result<CoreBPE> {
let mut special_tokens = HashMap::<String, usize>::default();
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
return CoreBPE::new(
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
special_tokens,
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
);
}
pub async fn model_p50k_base() -> anyhow::Result<String> {
get_model("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken").await
}
pub fn p50k_base(model_file: String) -> anyhow::Result<CoreBPE> {
let mut special_tokens = HashMap::<String, usize>::default();
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
return CoreBPE::new(
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
special_tokens,
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
);
}
pub async fn model_p50k_edit() -> anyhow::Result<String> {
get_model("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken").await
}
pub fn p50k_edit(model_file: String) -> anyhow::Result<CoreBPE> {
let mut special_tokens = HashMap::<String, usize>::default();
special_tokens.insert(ENDOFTEXT.to_string(), 50256);
special_tokens.insert(FIM_PREFIX.to_string(), 50281);
special_tokens.insert(FIM_MIDDLE.to_string(), 50282);
special_tokens.insert(FIM_SUFFIX.to_string(), 50283);
return CoreBPE::new(
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
special_tokens,
&"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
);
}
pub async fn model_cl100k_base() -> anyhow::Result<String> {
get_model("https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken").await
}
pub fn cl100k_base(model_file: String) -> anyhow::Result<CoreBPE> {
let mut special_tokens = HashMap::<String, usize>::default();
special_tokens.insert(ENDOFTEXT.to_string(), 50257);
special_tokens.insert(FIM_PREFIX.to_string(), 50258);
special_tokens.insert(FIM_MIDDLE.to_string(), 50259);
special_tokens.insert(FIM_SUFFIX.to_string(), 50260);
special_tokens.insert(ENDOFPROMPT.to_string(), 50276);
return CoreBPE::new(
load_tiktoken_bpe(&model_file).ok_or(anyhow::anyhow!("Failed to load model"))?,
special_tokens,
&"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
);
}