Implement context creation from api key

This commit is contained in:
Gabriel Tofvesson 2023-06-25 04:05:53 +02:00
parent 304f1974d6
commit c3b4822b63
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
2 changed files with 106 additions and 3 deletions

View File

@ -1,4 +1,4 @@
use std::{error::Error, ops::Add}; use std::{error::Error};
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context, edits::EditRequestBuilder}; use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context, edits::EditRequestBuilder};
use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}}; use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}};

View File

@ -1,7 +1,7 @@
use std::{ptr::NonNull, marker::PhantomPinned, pin::Pin, cmp::min, num::NonZeroUsize, fmt::Display, error::Error}; use std::{ptr::NonNull, marker::PhantomPinned, pin::Pin, cmp::min, num::NonZeroUsize, fmt::Display, error::Error};
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context as OpenAIContext}; use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context as OpenAIContext};
use tiktoken::CoreBPE; use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}};
const PROMPT_COMPRESS: &str = "Summarize the chat history precisely and concisely"; const PROMPT_COMPRESS: &str = "Summarize the chat history precisely and concisely";
@ -35,6 +35,71 @@ impl ContextOverrunError {
} }
} }
#[derive(Debug)]
struct MissingModelError {
model: String
}
impl Display for MissingModelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!("Missing model information for {}", self.model))?;
Ok(())
}
}
impl Error for MissingModelError {}
#[derive(Debug)]
struct InvalidModelTokenInformation {
model: String
}
impl Display for InvalidModelTokenInformation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!("Max tokens for model \"{}\" misreported as 0", self.model))?;
Ok(())
}
}
impl Error for InvalidModelTokenInformation {}
#[derive(Debug)]
enum ContextCreationError {
ContextOverrunError(ContextOverrunError),
MissingModelError(MissingModelError),
InvalidModelTokenInformation(InvalidModelTokenInformation),
}
impl Display for ContextCreationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContextCreationError::ContextOverrunError(inner) => inner.fmt(f),
ContextCreationError::MissingModelError(inner) => inner.fmt(f),
ContextCreationError::InvalidModelTokenInformation(inner) => inner.fmt(f)
}
}
}
impl Error for ContextCreationError {}
impl From<ContextOverrunError> for ContextCreationError {
fn from(value: ContextOverrunError) -> Self {
Self::ContextOverrunError(value)
}
}
impl From<MissingModelError> for ContextCreationError {
fn from(value: MissingModelError) -> Self {
Self::MissingModelError(value)
}
}
impl From<InvalidModelTokenInformation> for ContextCreationError {
fn from(value: InvalidModelTokenInformation) -> Self {
Self::InvalidModelTokenInformation(value)
}
}
pub struct UserList { pub struct UserList {
pub users: Vec<UserAliases>, pub users: Vec<UserAliases>,
_pin: PhantomPinned _pin: PhantomPinned
@ -110,7 +175,20 @@ impl UserList {
} }
impl Context { impl Context {
fn new(max_tokens: NonZeroUsize, model: String, encoding: CoreBPE, openai_context: OpenAIContext, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, impl Error> { async fn new_from_api(model: String, openai_api_key: String, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, ContextCreationError> {
let encoding = get_model(&model).await.ok_or(MissingModelError { model: model.clone() })?;
Ok(Context::new(
NonZeroUsize::new(get_max_tokens(&model).ok_or(MissingModelError { model: model.clone() })? as usize).ok_or(InvalidModelTokenInformation { model: model.clone() })?,
model,
encoding,
OpenAIContext::new(openai_api_key),
summary_budget,
history_target,
alias_budget
)?)
}
fn new(max_tokens: NonZeroUsize, model: String, encoding: CoreBPE, openai_context: OpenAIContext, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, ContextOverrunError> {
let summary_instruction_budget = count_message_tokens(&get_summary_instruction(), &encoding, &model) as usize; let summary_instruction_budget = count_message_tokens(&get_summary_instruction(), &encoding, &model) as usize;
let summary_budget = summary_budget.get() + count_message_tokens(&get_summary_message(None), &encoding, &model) as usize; let summary_budget = summary_budget.get() + count_message_tokens(&get_summary_message(None), &encoding, &model) as usize;
if history_target.get() + summary_budget + alias_budget.get() + summary_instruction_budget >= max_tokens.get() { if history_target.get() + summary_budget + alias_budget.get() + summary_instruction_budget >= max_tokens.get() {
@ -315,4 +393,29 @@ fn count_message_tokens(message: &ChatMessage, encoding: &CoreBPE, model: &str)
return tpm + encoding.encode_ordinary(&message.content).len() as i64 + encoding.encode_ordinary(role_str(&message.role)).len() as i64 + if let Some(ref name) = message.name { return tpm + encoding.encode_ordinary(&message.content).len() as i64 + encoding.encode_ordinary(role_str(&message.role)).len() as i64 + if let Some(ref name) = message.name {
tpn + encoding.encode_ordinary(name).len() as i64 tpn + encoding.encode_ordinary(name).len() as i64
} else { 0i64 }; } else { 0i64 };
}
async fn get_model(model: &str) -> Option<CoreBPE> {
return match model {
"gpt-4" | "gpt-4-32k" | "gpt-3.5-turbo" | "text-embedding-ada-002" => {
let model = model_cl100k_base().await;
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
let model = cl100k_base(model.unwrap());
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
return Some(model.unwrap());
}
_ => None
}
}
fn get_max_tokens(model: &str) -> Option<i64> {
match model {
"gpt-4" => Some(8192),
"gpt-4-32k" => Some(32768),
"gpt-3.5-turbo" => Some(4096),
"code-davinci-002" => Some(8001),
_ => None
}
} }