Initial commit

This commit is contained in:
Gabriel Tofvesson 2022-12-13 03:48:06 +01:00
commit 76bcd86117
6 changed files with 204 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
/target
sample*
Cargo.lock
**/*.rs.bk
.vscode/
Notes.md

12
Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "ai_chat"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
openai-api = "0.1.5-alpha.0"
[patch.crates-io]
openai-api = { path = "./openai-api-rust" }

1
openai-api-rust Submodule

@ -0,0 +1 @@
Subproject commit f0f969ef4f0e860f5cea383e3a859817b06f9e13

123
src/ai.rs Normal file
View File

@ -0,0 +1,123 @@
use std::time::SystemTime;
use openai_api::Client;
use crate::chat::Chat;
struct ContextManager {
context: Option<String>,
context_refresh_timer: u32,
refresh_interval: u32,
}
impl ContextManager {
pub fn new(refresh_interval: u32) -> ContextManager {
ContextManager { context: Option::None, context_refresh_timer: 0, refresh_interval }
}
fn to_context_string(context: &String) -> String {
format!("Context: {context}")
}
fn should_refresh_context(&self) -> bool {
self.context.is_none() || self.context_refresh_timer >= self.refresh_interval
}
fn get_contextfree_count(&self) -> u32 {
self.context_refresh_timer
}
fn refresh_context(&mut self, chat: &Chat, ai_client: &Client) {
self.context.replace(ai_client.complete_prompt_sync(openai_api::api::CompletionArgs::builder()
.prompt(format!("Explain the conversation concisely:\n{}Explanation:", self.get_chat_context(chat)))
.engine("text-davinci-003")
.max_tokens(512)
.temperature(0.23)
.top_p(0.9)
.frequency_penalty(0.3)
.presence_penalty(0.12)
.build()
.unwrap())
.unwrap()
.choices[0]
.text
.trim()
.to_string());
self.context_refresh_timer = 0;
}
pub fn get_context(&mut self, chat: &Chat, ai_client: &Client) -> &String {
if self.should_refresh_context() {
self.refresh_context(chat, ai_client);
}
return self.context.as_ref().unwrap()
}
pub fn on_message(&mut self) {
self.context_refresh_timer += 1
}
pub fn get_chat_context(&self, chat: &Chat) -> String {
let mut chat_context = String::new();
if let Option::Some(context) = self.context.as_ref() {
chat_context.push_str(&ContextManager::to_context_string(context));
chat_context.push_str("\n");
}
let history = chat.get_history();
let mut idx = 0;
for chat_entry in &history[std::cmp::max(0, history.len() as i32 - self.context_refresh_timer as i32) as usize..] {
chat_context.push_str(&format!("{idx}. {}: {}\n", chat_entry.sender, chat_entry.message));
idx += 1;
}
chat_context
}
}
pub struct Ai {
client: Client,
chat: Chat,
context_manager: ContextManager
}
impl Ai {
pub fn new(chat: Chat, token: String, context_refresh_interval: u32) -> Ai {
let client = Client::new(token.as_str());
Ai { client, chat, context_manager: ContextManager::new(context_refresh_interval) }
}
fn get_prompt(&self) -> String {
format!("{}{}. AI:", self.context_manager.get_chat_context(&self.chat), self.context_manager.get_contextfree_count())
}
pub fn send_message(&mut self, message: String) -> String {
self.chat.add_message("User".to_owned(), message, SystemTime::now());
self.context_manager.on_message();
let prompt = self.get_prompt();
let response = self.client.complete_prompt_sync(openai_api::api::CompletionArgs::builder()
.prompt(prompt)
.engine("text-davinci-003")
.max_tokens(1536)
.temperature(0.35)
.top_p(0.95)
.frequency_penalty(0.5)
.presence_penalty(0.33)
.build()
.unwrap())
.unwrap()
.choices[0]
.text
.trim()
.to_string();
self.context_manager.on_message();
self.chat.add_message("AI".to_owned(), response.to_owned(), SystemTime::now());
response
}
}

42
src/chat.rs Normal file
View File

@ -0,0 +1,42 @@
use std::time::SystemTime;
pub struct ChatMessage {
pub sender: String,
pub message: String,
pub timestamp: SystemTime,
}
pub struct Chat {
participants: Vec<String>,
history: Vec<ChatMessage>,
}
impl Chat {
pub fn new() -> Chat {
Chat {
participants: vec![],
history: vec![],
}
}
pub fn add_message(&mut self, sender: String, message: String, timestamp: SystemTime) {
let mut index = self.participants.iter().find(|&x| x == &sender);
if !self.participants.contains(&sender) {
self.participants.push(sender.to_owned());
}
self.history.push(ChatMessage {
sender: sender,
message,
timestamp,
});
}
pub fn get_participants(&self) -> &Vec<String> {
&self.participants
}
pub fn get_history(&self) -> &Vec<ChatMessage> {
&self.history
}
}

20
src/main.rs Normal file
View File

@ -0,0 +1,20 @@
use std::io::{stdout, stdin, Write};
use ai::Ai;
use chat::Chat;
mod chat;
mod ai;
fn main() {
println!("Welcome to your personal AI chat. Enter a message and the AI will respond...");
let mut ai = Ai::new(Chat::new(), std::env::var("OPENAI_SK").unwrap(), 30);
loop {
let mut input = String::new();
print!("You: ");
stdout().flush().unwrap();
stdin().read_line(&mut input).expect("Failed to read from stdin???"); // Fuck it. Just crash on error
println!("AI: {}", ai.send_message(input));
}
}