diff --git a/src/chat_context.rs b/src/chat_context.rs index 1f6e63a..64561e0 100644 --- a/src/chat_context.rs +++ b/src/chat_context.rs @@ -1,4 +1,4 @@ -use std::{error::Error, ops::Add}; +use std::{error::Error}; use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context, edits::EditRequestBuilder}; use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}}; diff --git a/src/message.rs b/src/message.rs index 9865ebc..fd85f94 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,7 +1,7 @@ use std::{ptr::NonNull, marker::PhantomPinned, pin::Pin, cmp::min, num::NonZeroUsize, fmt::Display, error::Error}; use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context as OpenAIContext}; -use tiktoken::CoreBPE; +use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}}; const PROMPT_COMPRESS: &str = "Summarize the chat history precisely and concisely"; @@ -35,6 +35,71 @@ impl ContextOverrunError { } } +#[derive(Debug)] +struct MissingModelError { + model: String +} + +impl Display for MissingModelError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&format!("Missing model information for {}", self.model))?; + Ok(()) + } +} + +impl Error for MissingModelError {} + +#[derive(Debug)] +struct InvalidModelTokenInformation { + model: String +} + +impl Display for InvalidModelTokenInformation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&format!("Max tokens for model \"{}\" misreported as 0", self.model))?; + Ok(()) + } +} + +impl Error for InvalidModelTokenInformation {} + +#[derive(Debug)] +enum ContextCreationError { + ContextOverrunError(ContextOverrunError), + MissingModelError(MissingModelError), + InvalidModelTokenInformation(InvalidModelTokenInformation), +} + +impl Display for ContextCreationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContextCreationError::ContextOverrunError(inner) => inner.fmt(f), + ContextCreationError::MissingModelError(inner) => inner.fmt(f), + ContextCreationError::InvalidModelTokenInformation(inner) => inner.fmt(f) + } + } +} + +impl Error for ContextCreationError {} + +impl From for ContextCreationError { + fn from(value: ContextOverrunError) -> Self { + Self::ContextOverrunError(value) + } +} + +impl From for ContextCreationError { + fn from(value: MissingModelError) -> Self { + Self::MissingModelError(value) + } +} + +impl From for ContextCreationError { + fn from(value: InvalidModelTokenInformation) -> Self { + Self::InvalidModelTokenInformation(value) + } +} + pub struct UserList { pub users: Vec, _pin: PhantomPinned @@ -110,7 +175,20 @@ impl UserList { } impl Context { - fn new(max_tokens: NonZeroUsize, model: String, encoding: CoreBPE, openai_context: OpenAIContext, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result { + async fn new_from_api(model: String, openai_api_key: String, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result { + let encoding = get_model(&model).await.ok_or(MissingModelError { model: model.clone() })?; + Ok(Context::new( + NonZeroUsize::new(get_max_tokens(&model).ok_or(MissingModelError { model: model.clone() })? as usize).ok_or(InvalidModelTokenInformation { model: model.clone() })?, + model, + encoding, + OpenAIContext::new(openai_api_key), + summary_budget, + history_target, + alias_budget + )?) + } + + fn new(max_tokens: NonZeroUsize, model: String, encoding: CoreBPE, openai_context: OpenAIContext, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result { let summary_instruction_budget = count_message_tokens(&get_summary_instruction(), &encoding, &model) as usize; let summary_budget = summary_budget.get() + count_message_tokens(&get_summary_message(None), &encoding, &model) as usize; if history_target.get() + summary_budget + alias_budget.get() + summary_instruction_budget >= max_tokens.get() { @@ -315,4 +393,29 @@ fn count_message_tokens(message: &ChatMessage, encoding: &CoreBPE, model: &str) return tpm + encoding.encode_ordinary(&message.content).len() as i64 + encoding.encode_ordinary(role_str(&message.role)).len() as i64 + if let Some(ref name) = message.name { tpn + encoding.encode_ordinary(name).len() as i64 } else { 0i64 }; +} + +async fn get_model(model: &str) -> Option { + return match model { + "gpt-4" | "gpt-4-32k" | "gpt-3.5-turbo" | "text-embedding-ada-002" => { + let model = model_cl100k_base().await; + assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model); + + let model = cl100k_base(model.unwrap()); + assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap()); + + return Some(model.unwrap()); + } + _ => None + } +} + +fn get_max_tokens(model: &str) -> Option { + match model { + "gpt-4" => Some(8192), + "gpt-4-32k" => Some(32768), + "gpt-3.5-turbo" => Some(4096), + "code-davinci-002" => Some(8001), + _ => None + } } \ No newline at end of file