Implement basic chat logic

This commit is contained in:
Gabriel Tofvesson 2023-06-10 10:59:00 +02:00
parent ded63bc494
commit 746727886a
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
3 changed files with 125 additions and 17 deletions

10
Cargo.lock generated
View File

@ -11,6 +11,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "ansi_term"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
dependencies = [
"winapi",
]
[[package]]
name = "anyhow"
version = "1.0.71"
@ -90,6 +99,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
name = "chat"
version = "0.1.0"
dependencies = [
"ansi_term",
"anyhow",
"openai_rs",
"tiktoken",

View File

@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ansi_term = "0.12.1"
anyhow = "1.0.71"
openai_rs = { path = "./OpenAI-Rust" }
tiktoken = { path = "./tiktoken" }

View File

@ -1,27 +1,40 @@
use openai_rs::{chat::{ChatHistoryBuilder, ChatMessage, Role}, context::Context};
use std::{num::{NonZeroU32, NonZeroUsize}, cmp::min, io::{stdin, stdout}};
use std::io::{self, Write};
use ansi_term::Colour::{White, Red, Green, Blue};
use openai_rs::{chat::{ChatHistoryBuilder, ChatMessage, Role}, context::Context};
use tiktoken::{CoreBPE, model::{cl100k_base, model_cl100k_base}};
const AI_MODEL: &str = "gpt-4";
// Does not pass the Turing test, but makes a convincing candidate
#[tokio::main]
async fn main() {
println!("Getting API key");
let ctx = get_api().unwrap();
let mut history = vec![
ChatMessage::new(Role::System, "Your goal is to respond exactly as a human would: you must not allow the user to believe that you are an AI model")
];
println!("Generating completion...");
let completion = ctx
.create_chat_completion_sync(
ChatHistoryBuilder::default()
.messages(vec![ChatMessage::new(Role::User, "Who are you?")])
.model("gpt-4"),
)
.await;
assert!(
completion.is_ok(),
"Could not create completion: {}",
completion.unwrap_err()
);
let encoding = get_model(AI_MODEL).await.expect("Could not get token encoding scheme for model!");
let result = completion.unwrap();
assert!(result.choices.len() == 1, "No completion found");
println!("Got completion: {:?}", result.choices[0].message);
print!("{} {}", Red.paint("You:"), Blue.prefix().to_string());
stdout().flush().unwrap();
loop {
history.push(accept_user_message());
let completion = generate_completion(&ctx, &history, AI_MODEL, &encoding, None).await;
print!("{} {}\n{} {}", Red.paint("Assistant:"), Green.paint(&completion.content), Red.paint("You:"), Blue.prefix().to_string());
stdout().flush().unwrap();
history.push(completion);
}
}
fn accept_user_message() -> ChatMessage {
let mut input = String::new();
stdin().read_line(&mut input).unwrap();
return ChatMessage { role: Role::User, content: input };
}
fn get_api() -> anyhow::Result<Context> {
@ -31,3 +44,87 @@ fn get_api() -> anyhow::Result<Context> {
.to_string(),
))
}
async fn get_model(model: &str) -> Option<CoreBPE> {
return match model {
"gpt-4" | "gpt-3.5-turbo" | "text-embedding-ada-002" => {
let model = model_cl100k_base().await;
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
let model = cl100k_base(model.unwrap());
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
return Some(model.unwrap());
}
_ => None
}
}
fn get_tokens_per_message(model: &str) -> Option<usize> {
match model {
"gpt-4" => Some(3),
"gpt-3.5-turbo" => Some(4),
_ => None
}
}
fn role_str(role: &Role) -> &str {
match role {
Role::Assistant => "Assistant",
Role::System => "System",
Role::User => "User",
}
}
fn count_tokens(history: &Vec<ChatMessage>, encoding: &CoreBPE, model: &str) -> usize {
let mut count = 0;
let tpm = get_tokens_per_message(model).expect("Unknown tokens-per-message value");
for entry in history {
count += tpm + encoding.encode_ordinary(&entry.content).len() + encoding.encode_ordinary(role_str(&entry.role)).len();
}
return count;
}
fn get_max_tokens(model: &str) -> Option<usize> {
match model {
"gpt-4" => Some(8192),
"gpt-4-32k" => Some(32768),
"gpt-3.5-turbo" => Some(4096),
"code-davinci-002" => Some(8001),
_ => None
}
}
async fn generate_completion(ctx: &Context, history: &Vec<ChatMessage>, model: &str, encoding: &CoreBPE, token_limit: Option<NonZeroUsize>) -> ChatMessage {
let message_token_count = count_tokens(history, encoding, model);
let abs_max = get_max_tokens(model).expect("Undefined maximum token count for model!");
if message_token_count >= abs_max - get_tokens_per_message(model).unwrap() {
panic!("Message history exceeds token limit! No new message can be generated.");
}
// Compute maximum number of tokens to generate
let max_tokens = match token_limit {
Some(lim) => min(abs_max - message_token_count, lim.get()),
_ => abs_max - message_token_count
};
let completion = ctx
.create_chat_completion_sync(
ChatHistoryBuilder::default()
.messages(history.clone())
.model(model),
)
.await;
assert!(
completion.is_ok(),
"Could not create completion: {}",
completion.unwrap_err()
);
let mut result = completion.unwrap();
assert!(result.choices.len() == 1, "No completion found");
return result.choices.pop().unwrap().message;
}