Implement context creation from api key
This commit is contained in:
parent
304f1974d6
commit
c3b4822b63
@ -1,4 +1,4 @@
|
||||
use std::{error::Error, ops::Add};
|
||||
use std::{error::Error};
|
||||
|
||||
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context, edits::EditRequestBuilder};
|
||||
use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}};
|
||||
|
107
src/message.rs
107
src/message.rs
@ -1,7 +1,7 @@
|
||||
use std::{ptr::NonNull, marker::PhantomPinned, pin::Pin, cmp::min, num::NonZeroUsize, fmt::Display, error::Error};
|
||||
|
||||
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context as OpenAIContext};
|
||||
use tiktoken::CoreBPE;
|
||||
use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}};
|
||||
|
||||
const PROMPT_COMPRESS: &str = "Summarize the chat history precisely and concisely";
|
||||
|
||||
@ -35,6 +35,71 @@ impl ContextOverrunError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MissingModelError {
|
||||
model: String
|
||||
}
|
||||
|
||||
impl Display for MissingModelError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&format!("Missing model information for {}", self.model))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for MissingModelError {}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InvalidModelTokenInformation {
|
||||
model: String
|
||||
}
|
||||
|
||||
impl Display for InvalidModelTokenInformation {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&format!("Max tokens for model \"{}\" misreported as 0", self.model))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for InvalidModelTokenInformation {}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ContextCreationError {
|
||||
ContextOverrunError(ContextOverrunError),
|
||||
MissingModelError(MissingModelError),
|
||||
InvalidModelTokenInformation(InvalidModelTokenInformation),
|
||||
}
|
||||
|
||||
impl Display for ContextCreationError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ContextCreationError::ContextOverrunError(inner) => inner.fmt(f),
|
||||
ContextCreationError::MissingModelError(inner) => inner.fmt(f),
|
||||
ContextCreationError::InvalidModelTokenInformation(inner) => inner.fmt(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for ContextCreationError {}
|
||||
|
||||
impl From<ContextOverrunError> for ContextCreationError {
|
||||
fn from(value: ContextOverrunError) -> Self {
|
||||
Self::ContextOverrunError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MissingModelError> for ContextCreationError {
|
||||
fn from(value: MissingModelError) -> Self {
|
||||
Self::MissingModelError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InvalidModelTokenInformation> for ContextCreationError {
|
||||
fn from(value: InvalidModelTokenInformation) -> Self {
|
||||
Self::InvalidModelTokenInformation(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UserList {
|
||||
pub users: Vec<UserAliases>,
|
||||
_pin: PhantomPinned
|
||||
@ -110,7 +175,20 @@ impl UserList {
|
||||
}
|
||||
|
||||
impl Context {
|
||||
fn new(max_tokens: NonZeroUsize, model: String, encoding: CoreBPE, openai_context: OpenAIContext, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, impl Error> {
|
||||
async fn new_from_api(model: String, openai_api_key: String, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, ContextCreationError> {
|
||||
let encoding = get_model(&model).await.ok_or(MissingModelError { model: model.clone() })?;
|
||||
Ok(Context::new(
|
||||
NonZeroUsize::new(get_max_tokens(&model).ok_or(MissingModelError { model: model.clone() })? as usize).ok_or(InvalidModelTokenInformation { model: model.clone() })?,
|
||||
model,
|
||||
encoding,
|
||||
OpenAIContext::new(openai_api_key),
|
||||
summary_budget,
|
||||
history_target,
|
||||
alias_budget
|
||||
)?)
|
||||
}
|
||||
|
||||
fn new(max_tokens: NonZeroUsize, model: String, encoding: CoreBPE, openai_context: OpenAIContext, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, ContextOverrunError> {
|
||||
let summary_instruction_budget = count_message_tokens(&get_summary_instruction(), &encoding, &model) as usize;
|
||||
let summary_budget = summary_budget.get() + count_message_tokens(&get_summary_message(None), &encoding, &model) as usize;
|
||||
if history_target.get() + summary_budget + alias_budget.get() + summary_instruction_budget >= max_tokens.get() {
|
||||
@ -315,4 +393,29 @@ fn count_message_tokens(message: &ChatMessage, encoding: &CoreBPE, model: &str)
|
||||
return tpm + encoding.encode_ordinary(&message.content).len() as i64 + encoding.encode_ordinary(role_str(&message.role)).len() as i64 + if let Some(ref name) = message.name {
|
||||
tpn + encoding.encode_ordinary(name).len() as i64
|
||||
} else { 0i64 };
|
||||
}
|
||||
|
||||
async fn get_model(model: &str) -> Option<CoreBPE> {
|
||||
return match model {
|
||||
"gpt-4" | "gpt-4-32k" | "gpt-3.5-turbo" | "text-embedding-ada-002" => {
|
||||
let model = model_cl100k_base().await;
|
||||
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
|
||||
|
||||
let model = cl100k_base(model.unwrap());
|
||||
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
|
||||
|
||||
return Some(model.unwrap());
|
||||
}
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_max_tokens(model: &str) -> Option<i64> {
|
||||
match model {
|
||||
"gpt-4" => Some(8192),
|
||||
"gpt-4-32k" => Some(32768),
|
||||
"gpt-3.5-turbo" => Some(4096),
|
||||
"code-davinci-002" => Some(8001),
|
||||
_ => None
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user