Implement basic chat logic
This commit is contained in:
parent
ded63bc494
commit
746727886a
10
Cargo.lock
generated
10
Cargo.lock
generated
@ -11,6 +11,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ansi_term"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.71"
|
||||
@ -90,6 +99,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
name = "chat"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ansi_term",
|
||||
"anyhow",
|
||||
"openai_rs",
|
||||
"tiktoken",
|
||||
|
@ -6,6 +6,7 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
ansi_term = "0.12.1"
|
||||
anyhow = "1.0.71"
|
||||
openai_rs = { path = "./OpenAI-Rust" }
|
||||
tiktoken = { path = "./tiktoken" }
|
||||
|
131
src/main.rs
131
src/main.rs
@ -1,27 +1,40 @@
|
||||
use openai_rs::{chat::{ChatHistoryBuilder, ChatMessage, Role}, context::Context};
|
||||
use std::{num::{NonZeroU32, NonZeroUsize}, cmp::min, io::{stdin, stdout}};
|
||||
use std::io::{self, Write};
|
||||
use ansi_term::Colour::{White, Red, Green, Blue};
|
||||
|
||||
use openai_rs::{chat::{ChatHistoryBuilder, ChatMessage, Role}, context::Context};
|
||||
use tiktoken::{CoreBPE, model::{cl100k_base, model_cl100k_base}};
|
||||
|
||||
const AI_MODEL: &str = "gpt-4";
|
||||
|
||||
// Does not pass the Turing test, but makes a convincing candidate
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
println!("Getting API key");
|
||||
let ctx = get_api().unwrap();
|
||||
let mut history = vec![
|
||||
ChatMessage::new(Role::System, "Your goal is to respond exactly as a human would: you must not allow the user to believe that you are an AI model")
|
||||
];
|
||||
|
||||
println!("Generating completion...");
|
||||
let completion = ctx
|
||||
.create_chat_completion_sync(
|
||||
ChatHistoryBuilder::default()
|
||||
.messages(vec![ChatMessage::new(Role::User, "Who are you?")])
|
||||
.model("gpt-4"),
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
completion.is_ok(),
|
||||
"Could not create completion: {}",
|
||||
completion.unwrap_err()
|
||||
);
|
||||
let encoding = get_model(AI_MODEL).await.expect("Could not get token encoding scheme for model!");
|
||||
|
||||
let result = completion.unwrap();
|
||||
assert!(result.choices.len() == 1, "No completion found");
|
||||
println!("Got completion: {:?}", result.choices[0].message);
|
||||
print!("{} {}", Red.paint("You:"), Blue.prefix().to_string());
|
||||
stdout().flush().unwrap();
|
||||
loop {
|
||||
history.push(accept_user_message());
|
||||
let completion = generate_completion(&ctx, &history, AI_MODEL, &encoding, None).await;
|
||||
|
||||
print!("{} {}\n{} {}", Red.paint("Assistant:"), Green.paint(&completion.content), Red.paint("You:"), Blue.prefix().to_string());
|
||||
stdout().flush().unwrap();
|
||||
|
||||
history.push(completion);
|
||||
}
|
||||
}
|
||||
|
||||
fn accept_user_message() -> ChatMessage {
|
||||
let mut input = String::new();
|
||||
stdin().read_line(&mut input).unwrap();
|
||||
return ChatMessage { role: Role::User, content: input };
|
||||
}
|
||||
|
||||
fn get_api() -> anyhow::Result<Context> {
|
||||
@ -31,3 +44,87 @@ fn get_api() -> anyhow::Result<Context> {
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_model(model: &str) -> Option<CoreBPE> {
|
||||
return match model {
|
||||
"gpt-4" | "gpt-3.5-turbo" | "text-embedding-ada-002" => {
|
||||
let model = model_cl100k_base().await;
|
||||
assert!(model.is_ok(), "Could not download model (model_cl100k_base): {:?}", model);
|
||||
|
||||
let model = cl100k_base(model.unwrap());
|
||||
assert!(model.is_ok(), "Could not load model (cl100k_base): {:?}", model.err().unwrap());
|
||||
|
||||
return Some(model.unwrap());
|
||||
}
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tokens_per_message(model: &str) -> Option<usize> {
|
||||
match model {
|
||||
"gpt-4" => Some(3),
|
||||
"gpt-3.5-turbo" => Some(4),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
fn role_str(role: &Role) -> &str {
|
||||
match role {
|
||||
Role::Assistant => "Assistant",
|
||||
Role::System => "System",
|
||||
Role::User => "User",
|
||||
}
|
||||
}
|
||||
|
||||
fn count_tokens(history: &Vec<ChatMessage>, encoding: &CoreBPE, model: &str) -> usize {
|
||||
let mut count = 0;
|
||||
let tpm = get_tokens_per_message(model).expect("Unknown tokens-per-message value");
|
||||
for entry in history {
|
||||
count += tpm + encoding.encode_ordinary(&entry.content).len() + encoding.encode_ordinary(role_str(&entry.role)).len();
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
fn get_max_tokens(model: &str) -> Option<usize> {
|
||||
match model {
|
||||
"gpt-4" => Some(8192),
|
||||
"gpt-4-32k" => Some(32768),
|
||||
"gpt-3.5-turbo" => Some(4096),
|
||||
"code-davinci-002" => Some(8001),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
async fn generate_completion(ctx: &Context, history: &Vec<ChatMessage>, model: &str, encoding: &CoreBPE, token_limit: Option<NonZeroUsize>) -> ChatMessage {
|
||||
let message_token_count = count_tokens(history, encoding, model);
|
||||
let abs_max = get_max_tokens(model).expect("Undefined maximum token count for model!");
|
||||
|
||||
if message_token_count >= abs_max - get_tokens_per_message(model).unwrap() {
|
||||
panic!("Message history exceeds token limit! No new message can be generated.");
|
||||
}
|
||||
|
||||
// Compute maximum number of tokens to generate
|
||||
let max_tokens = match token_limit {
|
||||
Some(lim) => min(abs_max - message_token_count, lim.get()),
|
||||
_ => abs_max - message_token_count
|
||||
};
|
||||
|
||||
|
||||
let completion = ctx
|
||||
.create_chat_completion_sync(
|
||||
ChatHistoryBuilder::default()
|
||||
.messages(history.clone())
|
||||
.model(model),
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
completion.is_ok(),
|
||||
"Could not create completion: {}",
|
||||
completion.unwrap_err()
|
||||
);
|
||||
|
||||
let mut result = completion.unwrap();
|
||||
assert!(result.choices.len() == 1, "No completion found");
|
||||
|
||||
return result.choices.pop().unwrap().message;
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user