Update old implementation
This commit is contained in:
parent
f90bb5f317
commit
304f1974d6
211
src/chat_context.rs
Normal file
211
src/chat_context.rs
Normal file
@ -0,0 +1,211 @@
|
||||
use std::{error::Error, ops::Add};
|
||||
|
||||
use openai_rs::{chat::{ChatMessage, Role, ChatHistoryBuilder}, context::Context, edits::EditRequestBuilder};
|
||||
use tiktoken::{CoreBPE, model::{model_cl100k_base, cl100k_base}};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatContextError<'l> {
|
||||
reason: &'l str
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ChatContextError<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&self.reason)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for ChatContextError<'_> {
|
||||
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum MessageType {
|
||||
AssistantMessage,
|
||||
UserMessage {
|
||||
sender: UserAlias,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetaChatMessage {
|
||||
pub chat_message: ChatMessage,
|
||||
pub message_type: MessageType,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct UserAlias {
|
||||
id: u16,
|
||||
names: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct ChatContext {
|
||||
model: String,
|
||||
encoding: CoreBPE,
|
||||
max_tokens: i64,
|
||||
api_context: Context,
|
||||
history: Vec<MetaChatMessage>,
|
||||
context: Option<String>,
|
||||
user_aliases: Vec<UserAlias>,
|
||||
}
|
||||
|
||||
impl ChatContext {
|
||||
pub async fn new(model: String, api_key: String) -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
encoding: get_model(&model).await.ok_or(ChatContextError { reason: "Couldn't get model encoding" })?,
|
||||
max_tokens: get_max_tokens(&model).ok_or(ChatContextError { reason: "Couldn't get max tokens for model" })?,
|
||||
api_context: Context::new(api_key.to_string()),
|
||||
history: Vec::new(),
|
||||
context: None,
|
||||
model,
|
||||
user_aliases: Vec::new()
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn send_message(&mut self, message: MetaChatMessage) -> MetaChatMessage {
|
||||
self.history.push(message);
|
||||
let tpm = get_tokens_per_message(&self.model).unwrap();
|
||||
let message_token_count = count_tokens(&self.history, &self.encoding, &self.model) + tpm;
|
||||
if message_token_count >= self.max_tokens - tpm {
|
||||
panic!("Message history exceeds token limit! No new message can be generated.");
|
||||
}
|
||||
|
||||
// Compute maximum number of tokens to generate
|
||||
let max_tokens = self.max_tokens - message_token_count - tpm - 1;
|
||||
|
||||
|
||||
let completion = self.api_context
|
||||
.create_chat_completion_sync(
|
||||
ChatHistoryBuilder::default()
|
||||
.temperature(0.3) // Model suffers from excessive hallucination. TODO: fine-tune temperature
|
||||
.messages(self.history.iter().map(|message| message.chat_message.clone()).collect::<Vec<ChatMessage>>())
|
||||
.max_tokens(max_tokens as u64)
|
||||
.model(&self.model),
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
completion.is_ok(),
|
||||
"Could not create completion: {}",
|
||||
completion.unwrap_err()
|
||||
);
|
||||
|
||||
let mut result = completion.unwrap();
|
||||
assert!(result.choices.len() == 1, "No completion found");
|
||||
return MetaChatMessage {
|
||||
chat_message: result.choices.pop().unwrap().message,
|
||||
message_type: MessageType::AssistantMessage
|
||||
};
|
||||
}
|
||||
|
||||
async fn update_aliases(&self, instruction: &str, aliases: &mut Vec<UserAlias>, message_context: &[MetaChatMessage], context_count: usize) -> anyhow::Result<()> {
|
||||
if message_context.len() < context_count {
|
||||
return Ok(());
|
||||
}
|
||||
let latest = &message_context[message_context.len() - 1];
|
||||
if let MessageType::UserMessage { ref sender } = latest.message_type {
|
||||
let mut alias_prompt = String::new();
|
||||
|
||||
for alias in aliases {
|
||||
alias_prompt.push_str(&format!("u{}:", alias.id));
|
||||
|
||||
for name in &alias.names {
|
||||
alias_prompt.push_str(&format!(" {name},"));
|
||||
}
|
||||
|
||||
if alias.names.len() > 0 {
|
||||
alias_prompt.pop();
|
||||
}
|
||||
}
|
||||
|
||||
let mut instruction = String::new();
|
||||
instruction.push_str("Update the list of user aliases based on the chat message:");
|
||||
instruction.push_str(&format!("\nu{}: \"{}\"", sender.id, latest.chat_message.content));
|
||||
|
||||
let edit = self.api_context.create_edit(
|
||||
EditRequestBuilder::default()
|
||||
.input(alias_prompt)
|
||||
.instruction(format!(""))
|
||||
.build()?
|
||||
);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
pub fn get_history(&mut self) -> &mut Vec<MetaChatMessage> {
|
||||
&mut self.history
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
async fn get_model(model: &str) -> Option<CoreBPE> {
|
||||
return match model {
|
||||
"gpt-4" | "gpt-4-32k" | "gpt-3.5-turbo" | "text-embedding-ada-002" => {
|
||||
let model = model_cl100k_base().await;
|
||||
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
|
||||
|
||||
let model = cl100k_base(model.unwrap());
|
||||
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
|
||||
|
||||
return Some(model.unwrap());
|
||||
}
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tokens_per_message(model: &str) -> Option<i64> {
|
||||
match model {
|
||||
"gpt-4" | "gpt-4-32k" => Some(3),
|
||||
"gpt-3.5-turbo" => Some(4),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tokens_per_name(model: &str) -> Option<i64> {
|
||||
match model {
|
||||
"gpt-4" | "gpt-4-32k" => Some(1),
|
||||
"gpt-3.5-turbo" => Some(-1),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
fn role_str(role: &Role) -> &str {
|
||||
match role {
|
||||
Role::Assistant => "Assistant",
|
||||
Role::System => "System",
|
||||
Role::User => "User",
|
||||
}
|
||||
}
|
||||
|
||||
fn count_message_tokens(message: &ChatMessage, encoding: &CoreBPE, model: &str) -> i64 {
|
||||
let tpm = get_tokens_per_message(model).expect("Unknown tokens-per-message value");
|
||||
let tpn = get_tokens_per_name(model).expect("Unknown tokens-per-name value");
|
||||
|
||||
return tpm + encoding.encode_ordinary(&message.content).len() as i64 + encoding.encode_ordinary(role_str(&message.role)).len() as i64 + if let Some(ref name) = message.name {
|
||||
tpn + encoding.encode_ordinary(name).len() as i64
|
||||
} else { 0i64 };
|
||||
}
|
||||
|
||||
fn count_tokens(history: &Vec<MetaChatMessage>, encoding: &CoreBPE, model: &str) -> i64 {
|
||||
let mut count = 0i64;
|
||||
let tpm = get_tokens_per_message(model).expect("Unknown tokens-per-message value");
|
||||
let tpn = get_tokens_per_name(model).expect("Unknown tokens-per-name value");
|
||||
for entry in history {
|
||||
count += tpm + encoding.encode_ordinary(&entry.chat_message.content).len() as i64 + encoding.encode_ordinary(role_str(&entry.chat_message.role)).len() as i64;
|
||||
|
||||
if entry.chat_message.name.is_some() {
|
||||
count += tpn + encoding.encode_ordinary(entry.chat_message.name.as_ref().unwrap()).len() as i64;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
fn get_max_tokens(model: &str) -> Option<i64> {
|
||||
match model {
|
||||
"gpt-4" => Some(8192),
|
||||
"gpt-4-32k" => Some(32768),
|
||||
"gpt-3.5-turbo" => Some(4096),
|
||||
"code-davinci-002" => Some(8001),
|
||||
_ => None
|
||||
}
|
||||
}
|
62
src/main.rs
62
src/main.rs
@ -2,39 +2,79 @@ use std::{num::{NonZeroU32, NonZeroUsize}, cmp::min, io::{stdin, stdout}};
|
||||
use std::io::{self, Write};
|
||||
use ansi_term::Colour::{White, Red, Green, Blue};
|
||||
|
||||
use chat_context::{MetaChatMessage, MessageType};
|
||||
use openai_rs::{chat::{ChatHistoryBuilder, ChatMessage, Role}, context::Context};
|
||||
use tiktoken::{CoreBPE, model::{cl100k_base, model_cl100k_base}};
|
||||
|
||||
use crate::chat_context::{ChatContext, UserAlias};
|
||||
|
||||
mod chat_context;
|
||||
mod message;
|
||||
|
||||
const AI_MODEL: &str = "gpt-4";
|
||||
|
||||
// Does not pass the Turing test, but makes a convincing candidate
|
||||
// Easily tricked
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
/*
|
||||
println!("Getting API key");
|
||||
let ctx = get_api().unwrap();
|
||||
let mut history = vec![
|
||||
ChatMessage::new(Role::System, "Your goal is to respond exactly as a human would: you must not allow the user to believe that you are an AI model")
|
||||
ChatMessage::new(Role::System, "Your goal is to respond exactly as a human would: you must not allow the user to believe that you are an AI model. For all matters not related to your humanity, you will answer with precise, concise and accurate information.")
|
||||
];
|
||||
|
||||
let encoding = get_model(AI_MODEL).await.expect("Could not get token encoding scheme for model!");
|
||||
*/
|
||||
println!("Initializing context...");
|
||||
let mut chat_context = ChatContext::new(AI_MODEL.to_string(), get_api_key().expect("Couldn't get API key")).await.unwrap();
|
||||
|
||||
chat_context.get_history().push(MetaChatMessage { chat_message: ChatMessage::new(Role::System, "This is a group-chat with multiple users. Your responses are concise and truthful", Some("context".to_string())), message_type: MessageType::AssistantMessage });
|
||||
chat_context.get_history().push(MetaChatMessage { chat_message: ChatMessage::new(Role::System, "Always use the first listed name when referring to users.\nu0: \"James\", \"Jimmy\", \"Hazel\"\nu1: \"Donna\", \"Delphine\"\nu2: [[unknown]]", Some("aliases".to_string())), message_type: MessageType::AssistantMessage });
|
||||
chat_context.get_history().push(MetaChatMessage { chat_message: ChatMessage::new(Role::System, "You are Jarvis. You only respond when the most recent message is for Jarvis, otherwise you send an empty message", None), message_type: MessageType::AssistantMessage });
|
||||
|
||||
print!("{} {}", Red.paint("You:"), Blue.prefix().to_string());
|
||||
stdout().flush().unwrap();
|
||||
loop {
|
||||
history.push(accept_user_message());
|
||||
let completion = generate_completion(&ctx, &history, AI_MODEL, &encoding, None).await;
|
||||
|
||||
print!("{} {}\n{} {}", Red.paint("Assistant:"), Green.paint(&completion.content), Red.paint("You:"), Blue.prefix().to_string());
|
||||
print!("{} {}", Red.paint("You:"), Blue.prefix().to_string());
|
||||
stdout().flush().unwrap();
|
||||
|
||||
history.push(completion);
|
||||
let user_message = accept_user_message();
|
||||
if user_message.is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let completion = chat_context.send_message(user_message.unwrap()).await;
|
||||
|
||||
if completion.chat_message.content.len() > 0 {
|
||||
println!("{} {}", Red.paint("Assistant:"), Green.paint(&completion.chat_message.content));
|
||||
|
||||
chat_context.get_history().push(completion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn accept_user_message() -> ChatMessage {
|
||||
fn accept_user_message() -> Option<MetaChatMessage> {
|
||||
let mut input = String::new();
|
||||
stdin().read_line(&mut input).unwrap();
|
||||
return ChatMessage { role: Role::User, content: input };
|
||||
print!("{}", White.prefix());
|
||||
stdout().flush().unwrap();
|
||||
|
||||
if input.len() < 3 {
|
||||
println!("{} {}", Red.paint("Error:"), "Invalid user ID");
|
||||
return None;
|
||||
}
|
||||
|
||||
let (name, input) = match &input[0..2] {
|
||||
"u0" | "u1" => (input[0..2].to_string(), input[2..].to_string()),
|
||||
_ => ("u2".to_string(), input)
|
||||
};
|
||||
|
||||
return Some(MetaChatMessage { chat_message: ChatMessage::new(Role::User, input, Some(name)), message_type: MessageType::UserMessage { sender: UserAlias { id: 4, } }});
|
||||
}
|
||||
|
||||
fn get_api_key() -> anyhow::Result<String> {
|
||||
Ok(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?
|
||||
.trim()
|
||||
.to_string())
|
||||
}
|
||||
|
||||
fn get_api() -> anyhow::Result<Context> {
|
||||
@ -113,6 +153,8 @@ async fn generate_completion(ctx: &Context, history: &Vec<ChatMessage>, model: &
|
||||
let completion = ctx
|
||||
.create_chat_completion_sync(
|
||||
ChatHistoryBuilder::default()
|
||||
.temperature(0.55) // Model suffers from excessive hallucination. TODO: fine-tune temperature
|
||||
.frequency_penalty(0.1)
|
||||
.messages(history.clone())
|
||||
.model(model),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user