From 746727886a5bd052d6e986e17426b706def465b2 Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Sat, 10 Jun 2023 10:59:00 +0200 Subject: [PATCH] Implement basic chat logic --- Cargo.lock | 10 ++++ Cargo.toml | 1 + src/main.rs | 131 +++++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 125 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7c5aba5..6c7a71d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anyhow" version = "1.0.71" @@ -90,6 +99,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" name = "chat" version = "0.1.0" dependencies = [ + "ansi_term", "anyhow", "openai_rs", "tiktoken", diff --git a/Cargo.toml b/Cargo.toml index 1e940d4..cd82668 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ansi_term = "0.12.1" anyhow = "1.0.71" openai_rs = { path = "./OpenAI-Rust" } tiktoken = { path = "./tiktoken" } diff --git a/src/main.rs b/src/main.rs index 6163972..36d6c76 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,27 +1,40 @@ -use openai_rs::{chat::{ChatHistoryBuilder, ChatMessage, Role}, context::Context}; +use std::{num::{NonZeroU32, NonZeroUsize}, cmp::min, io::{stdin, stdout}}; +use std::io::{self, Write}; +use ansi_term::Colour::{White, Red, Green, Blue}; +use openai_rs::{chat::{ChatHistoryBuilder, ChatMessage, Role}, context::Context}; +use tiktoken::{CoreBPE, model::{cl100k_base, model_cl100k_base}}; + +const AI_MODEL: &str = "gpt-4"; + +// Does not pass the Turing test, but makes a convincing candidate #[tokio::main] async fn main() { println!("Getting API key"); let ctx = get_api().unwrap(); + let mut history = vec![ + ChatMessage::new(Role::System, "Your goal is to respond exactly as a human would: you must not allow the user to believe that you are an AI model") + ]; - println!("Generating completion..."); - let completion = ctx - .create_chat_completion_sync( - ChatHistoryBuilder::default() - .messages(vec![ChatMessage::new(Role::User, "Who are you?")]) - .model("gpt-4"), - ) - .await; - assert!( - completion.is_ok(), - "Could not create completion: {}", - completion.unwrap_err() - ); + let encoding = get_model(AI_MODEL).await.expect("Could not get token encoding scheme for model!"); - let result = completion.unwrap(); - assert!(result.choices.len() == 1, "No completion found"); - println!("Got completion: {:?}", result.choices[0].message); + print!("{} {}", Red.paint("You:"), Blue.prefix().to_string()); + stdout().flush().unwrap(); + loop { + history.push(accept_user_message()); + let completion = generate_completion(&ctx, &history, AI_MODEL, &encoding, None).await; + + print!("{} {}\n{} {}", Red.paint("Assistant:"), Green.paint(&completion.content), Red.paint("You:"), Blue.prefix().to_string()); + stdout().flush().unwrap(); + + history.push(completion); + } +} + +fn accept_user_message() -> ChatMessage { + let mut input = String::new(); + stdin().read_line(&mut input).unwrap(); + return ChatMessage { role: Role::User, content: input }; } fn get_api() -> anyhow::Result { @@ -31,3 +44,87 @@ fn get_api() -> anyhow::Result { .to_string(), )) } + +async fn get_model(model: &str) -> Option { + return match model { + "gpt-4" | "gpt-3.5-turbo" | "text-embedding-ada-002" => { + let model = model_cl100k_base().await; + assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model); + + let model = cl100k_base(model.unwrap()); + assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap()); + + return Some(model.unwrap()); + } + _ => None + } +} + +fn get_tokens_per_message(model: &str) -> Option { + match model { + "gpt-4" => Some(3), + "gpt-3.5-turbo" => Some(4), + _ => None + } +} + +fn role_str(role: &Role) -> &str { + match role { + Role::Assistant => "Assistant", + Role::System => "System", + Role::User => "User", + } +} + +fn count_tokens(history: &Vec, encoding: &CoreBPE, model: &str) -> usize { + let mut count = 0; + let tpm = get_tokens_per_message(model).expect("Unknown tokens-per-message value"); + for entry in history { + count += tpm + encoding.encode_ordinary(&entry.content).len() + encoding.encode_ordinary(role_str(&entry.role)).len(); + } + return count; +} + +fn get_max_tokens(model: &str) -> Option { + match model { + "gpt-4" => Some(8192), + "gpt-4-32k" => Some(32768), + "gpt-3.5-turbo" => Some(4096), + "code-davinci-002" => Some(8001), + _ => None + } +} + +async fn generate_completion(ctx: &Context, history: &Vec, model: &str, encoding: &CoreBPE, token_limit: Option) -> ChatMessage { + let message_token_count = count_tokens(history, encoding, model); + let abs_max = get_max_tokens(model).expect("Undefined maximum token count for model!"); + + if message_token_count >= abs_max - get_tokens_per_message(model).unwrap() { + panic!("Message history exceeds token limit! No new message can be generated."); + } + + // Compute maximum number of tokens to generate + let max_tokens = match token_limit { + Some(lim) => min(abs_max - message_token_count, lim.get()), + _ => abs_max - message_token_count + }; + + + let completion = ctx + .create_chat_completion_sync( + ChatHistoryBuilder::default() + .messages(history.clone()) + .model(model), + ) + .await; + assert!( + completion.is_ok(), + "Could not create completion: {}", + completion.unwrap_err() + ); + + let mut result = completion.unwrap(); + assert!(result.choices.len() == 1, "No completion found"); + + return result.choices.pop().unwrap().message; +} \ No newline at end of file