commit 76bcd861175ae20275975e99c689323cf756f9d1 Author: Gabriel Tofvesson Date: Tue Dec 13 03:48:06 2022 +0100 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..82f0a25 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +/target +sample* +Cargo.lock +**/*.rs.bk +.vscode/ +Notes.md diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0f11ab9 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "ai_chat" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +openai-api = "0.1.5-alpha.0" + +[patch.crates-io] +openai-api = { path = "./openai-api-rust" } \ No newline at end of file diff --git a/openai-api-rust b/openai-api-rust new file mode 160000 index 0000000..f0f969e --- /dev/null +++ b/openai-api-rust @@ -0,0 +1 @@ +Subproject commit f0f969ef4f0e860f5cea383e3a859817b06f9e13 diff --git a/src/ai.rs b/src/ai.rs new file mode 100644 index 0000000..382a5e9 --- /dev/null +++ b/src/ai.rs @@ -0,0 +1,123 @@ +use std::time::SystemTime; + +use openai_api::Client; + +use crate::chat::Chat; + +struct ContextManager { + context: Option, + context_refresh_timer: u32, + refresh_interval: u32, +} + +impl ContextManager { + pub fn new(refresh_interval: u32) -> ContextManager { + ContextManager { context: Option::None, context_refresh_timer: 0, refresh_interval } + } + + fn to_context_string(context: &String) -> String { + format!("Context: {context}") + } + + fn should_refresh_context(&self) -> bool { + self.context.is_none() || self.context_refresh_timer >= self.refresh_interval + } + + fn get_contextfree_count(&self) -> u32 { + self.context_refresh_timer + } + + fn refresh_context(&mut self, chat: &Chat, ai_client: &Client) { + self.context.replace(ai_client.complete_prompt_sync(openai_api::api::CompletionArgs::builder() + .prompt(format!("Explain the conversation concisely:\n{}Explanation:", self.get_chat_context(chat))) + .engine("text-davinci-003") + .max_tokens(512) + .temperature(0.23) + .top_p(0.9) + .frequency_penalty(0.3) + .presence_penalty(0.12) + .build() + .unwrap()) + .unwrap() + .choices[0] + .text + .trim() + .to_string()); + + self.context_refresh_timer = 0; + } + + pub fn get_context(&mut self, chat: &Chat, ai_client: &Client) -> &String { + if self.should_refresh_context() { + self.refresh_context(chat, ai_client); + } + + return self.context.as_ref().unwrap() + } + + pub fn on_message(&mut self) { + self.context_refresh_timer += 1 + } + + pub fn get_chat_context(&self, chat: &Chat) -> String { + let mut chat_context = String::new(); + if let Option::Some(context) = self.context.as_ref() { + chat_context.push_str(&ContextManager::to_context_string(context)); + chat_context.push_str("\n"); + } + + let history = chat.get_history(); + let mut idx = 0; + for chat_entry in &history[std::cmp::max(0, history.len() as i32 - self.context_refresh_timer as i32) as usize..] { + chat_context.push_str(&format!("{idx}. {}: {}\n", chat_entry.sender, chat_entry.message)); + idx += 1; + } + + chat_context + } +} + +pub struct Ai { + client: Client, + chat: Chat, + context_manager: ContextManager +} + +impl Ai { + pub fn new(chat: Chat, token: String, context_refresh_interval: u32) -> Ai { + let client = Client::new(token.as_str()); + Ai { client, chat, context_manager: ContextManager::new(context_refresh_interval) } + } + + fn get_prompt(&self) -> String { + format!("{}{}. AI:", self.context_manager.get_chat_context(&self.chat), self.context_manager.get_contextfree_count()) + } + + pub fn send_message(&mut self, message: String) -> String { + self.chat.add_message("User".to_owned(), message, SystemTime::now()); + self.context_manager.on_message(); + + let prompt = self.get_prompt(); + + let response = self.client.complete_prompt_sync(openai_api::api::CompletionArgs::builder() + .prompt(prompt) + .engine("text-davinci-003") + .max_tokens(1536) + .temperature(0.35) + .top_p(0.95) + .frequency_penalty(0.5) + .presence_penalty(0.33) + .build() + .unwrap()) + .unwrap() + .choices[0] + .text + .trim() + .to_string(); + + self.context_manager.on_message(); + self.chat.add_message("AI".to_owned(), response.to_owned(), SystemTime::now()); + + response + } +} \ No newline at end of file diff --git a/src/chat.rs b/src/chat.rs new file mode 100644 index 0000000..5a41aae --- /dev/null +++ b/src/chat.rs @@ -0,0 +1,42 @@ +use std::time::SystemTime; + +pub struct ChatMessage { + pub sender: String, + pub message: String, + pub timestamp: SystemTime, +} + +pub struct Chat { + participants: Vec, + history: Vec, +} + +impl Chat { + pub fn new() -> Chat { + Chat { + participants: vec![], + history: vec![], + } + } + + pub fn add_message(&mut self, sender: String, message: String, timestamp: SystemTime) { + let mut index = self.participants.iter().find(|&x| x == &sender); + if !self.participants.contains(&sender) { + self.participants.push(sender.to_owned()); + } + + self.history.push(ChatMessage { + sender: sender, + message, + timestamp, + }); + } + + pub fn get_participants(&self) -> &Vec { + &self.participants + } + + pub fn get_history(&self) -> &Vec { + &self.history + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..63be37e --- /dev/null +++ b/src/main.rs @@ -0,0 +1,20 @@ +use std::io::{stdout, stdin, Write}; + +use ai::Ai; +use chat::Chat; + +mod chat; +mod ai; + +fn main() { + println!("Welcome to your personal AI chat. Enter a message and the AI will respond..."); + let mut ai = Ai::new(Chat::new(), std::env::var("OPENAI_SK").unwrap(), 30); + loop { + let mut input = String::new(); + print!("You: "); + stdout().flush().unwrap(); + stdin().read_line(&mut input).expect("Failed to read from stdin???"); // Fuck it. Just crash on error + + println!("AI: {}", ai.send_message(input)); + } +}