Implement context creation from api key
This commit is contained in:
parent
304f1974d6
commit
c3b4822b63
@ -1,4 +1,4 @@
|
|||||||
use std::{error::Error, ops::Add};
|
use std::{error::Error};
|
||||||
|
|
||||||
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context, edits::EditRequestBuilder};
|
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context, edits::EditRequestBuilder};
|
||||||
use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}};
|
use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}};
|
||||||
|
107
src/message.rs
107
src/message.rs
@ -1,7 +1,7 @@
|
|||||||
use std::{ptr::NonNull, marker::PhantomPinned, pin::Pin, cmp::min, num::NonZeroUsize, fmt::Display, error::Error};
|
use std::{ptr::NonNull, marker::PhantomPinned, pin::Pin, cmp::min, num::NonZeroUsize, fmt::Display, error::Error};
|
||||||
|
|
||||||
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context as OpenAIContext};
|
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context as OpenAIContext};
|
||||||
use tiktoken::CoreBPE;
|
use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}};
|
||||||
|
|
||||||
const PROMPT_COMPRESS: &str = "Summarize the chat history precisely and concisely";
|
const PROMPT_COMPRESS: &str = "Summarize the chat history precisely and concisely";
|
||||||
|
|
||||||
@ -35,6 +35,71 @@ impl ContextOverrunError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct MissingModelError {
|
||||||
|
model: String
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for MissingModelError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str(&format!("Missing model information for {}", self.model))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for MissingModelError {}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct InvalidModelTokenInformation {
|
||||||
|
model: String
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for InvalidModelTokenInformation {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str(&format!("Max tokens for model \"{}\" misreported as 0", self.model))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for InvalidModelTokenInformation {}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum ContextCreationError {
|
||||||
|
ContextOverrunError(ContextOverrunError),
|
||||||
|
MissingModelError(MissingModelError),
|
||||||
|
InvalidModelTokenInformation(InvalidModelTokenInformation),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for ContextCreationError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
ContextCreationError::ContextOverrunError(inner) => inner.fmt(f),
|
||||||
|
ContextCreationError::MissingModelError(inner) => inner.fmt(f),
|
||||||
|
ContextCreationError::InvalidModelTokenInformation(inner) => inner.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for ContextCreationError {}
|
||||||
|
|
||||||
|
impl From<ContextOverrunError> for ContextCreationError {
|
||||||
|
fn from(value: ContextOverrunError) -> Self {
|
||||||
|
Self::ContextOverrunError(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<MissingModelError> for ContextCreationError {
|
||||||
|
fn from(value: MissingModelError) -> Self {
|
||||||
|
Self::MissingModelError(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InvalidModelTokenInformation> for ContextCreationError {
|
||||||
|
fn from(value: InvalidModelTokenInformation) -> Self {
|
||||||
|
Self::InvalidModelTokenInformation(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct UserList {
|
pub struct UserList {
|
||||||
pub users: Vec<UserAliases>,
|
pub users: Vec<UserAliases>,
|
||||||
_pin: PhantomPinned
|
_pin: PhantomPinned
|
||||||
@ -110,7 +175,20 @@ impl UserList {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Context {
|
impl Context {
|
||||||
fn new(max_tokens: NonZeroUsize, model: String, encoding: CoreBPE, openai_context: OpenAIContext, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, impl Error> {
|
async fn new_from_api(model: String, openai_api_key: String, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, ContextCreationError> {
|
||||||
|
let encoding = get_model(&model).await.ok_or(MissingModelError { model: model.clone() })?;
|
||||||
|
Ok(Context::new(
|
||||||
|
NonZeroUsize::new(get_max_tokens(&model).ok_or(MissingModelError { model: model.clone() })? as usize).ok_or(InvalidModelTokenInformation { model: model.clone() })?,
|
||||||
|
model,
|
||||||
|
encoding,
|
||||||
|
OpenAIContext::new(openai_api_key),
|
||||||
|
summary_budget,
|
||||||
|
history_target,
|
||||||
|
alias_budget
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(max_tokens: NonZeroUsize, model: String, encoding: CoreBPE, openai_context: OpenAIContext, summary_budget: NonZeroUsize, history_target: NonZeroUsize, alias_budget: NonZeroUsize) -> Result<Self, ContextOverrunError> {
|
||||||
let summary_instruction_budget = count_message_tokens(&get_summary_instruction(), &encoding, &model) as usize;
|
let summary_instruction_budget = count_message_tokens(&get_summary_instruction(), &encoding, &model) as usize;
|
||||||
let summary_budget = summary_budget.get() + count_message_tokens(&get_summary_message(None), &encoding, &model) as usize;
|
let summary_budget = summary_budget.get() + count_message_tokens(&get_summary_message(None), &encoding, &model) as usize;
|
||||||
if history_target.get() + summary_budget + alias_budget.get() + summary_instruction_budget >= max_tokens.get() {
|
if history_target.get() + summary_budget + alias_budget.get() + summary_instruction_budget >= max_tokens.get() {
|
||||||
@ -315,4 +393,29 @@ fn count_message_tokens(message: &ChatMessage, encoding: &CoreBPE, model: &str)
|
|||||||
return tpm + encoding.encode_ordinary(&message.content).len() as i64 + encoding.encode_ordinary(role_str(&message.role)).len() as i64 + if let Some(ref name) = message.name {
|
return tpm + encoding.encode_ordinary(&message.content).len() as i64 + encoding.encode_ordinary(role_str(&message.role)).len() as i64 + if let Some(ref name) = message.name {
|
||||||
tpn + encoding.encode_ordinary(name).len() as i64
|
tpn + encoding.encode_ordinary(name).len() as i64
|
||||||
} else { 0i64 };
|
} else { 0i64 };
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_model(model: &str) -> Option<CoreBPE> {
|
||||||
|
return match model {
|
||||||
|
"gpt-4" | "gpt-4-32k" | "gpt-3.5-turbo" | "text-embedding-ada-002" => {
|
||||||
|
let model = model_cl100k_base().await;
|
||||||
|
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
|
||||||
|
|
||||||
|
let model = cl100k_base(model.unwrap());
|
||||||
|
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
|
||||||
|
|
||||||
|
return Some(model.unwrap());
|
||||||
|
}
|
||||||
|
_ => None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_max_tokens(model: &str) -> Option<i64> {
|
||||||
|
match model {
|
||||||
|
"gpt-4" => Some(8192),
|
||||||
|
"gpt-4-32k" => Some(32768),
|
||||||
|
"gpt-3.5-turbo" => Some(4096),
|
||||||
|
"code-davinci-002" => Some(8001),
|
||||||
|
_ => None
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user