Initial commit
This commit is contained in:
commit
76bcd86117
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
/target
|
||||||
|
sample*
|
||||||
|
Cargo.lock
|
||||||
|
**/*.rs.bk
|
||||||
|
.vscode/
|
||||||
|
Notes.md
|
12
Cargo.toml
Normal file
12
Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "ai_chat"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
openai-api = "0.1.5-alpha.0"
|
||||||
|
|
||||||
|
[patch.crates-io]
|
||||||
|
openai-api = { path = "./openai-api-rust" }
|
1
openai-api-rust
Submodule
1
openai-api-rust
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit f0f969ef4f0e860f5cea383e3a859817b06f9e13
|
123
src/ai.rs
Normal file
123
src/ai.rs
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
use openai_api::Client;
|
||||||
|
|
||||||
|
use crate::chat::Chat;
|
||||||
|
|
||||||
|
struct ContextManager {
|
||||||
|
context: Option<String>,
|
||||||
|
context_refresh_timer: u32,
|
||||||
|
refresh_interval: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContextManager {
|
||||||
|
pub fn new(refresh_interval: u32) -> ContextManager {
|
||||||
|
ContextManager { context: Option::None, context_refresh_timer: 0, refresh_interval }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_context_string(context: &String) -> String {
|
||||||
|
format!("Context: {context}")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_refresh_context(&self) -> bool {
|
||||||
|
self.context.is_none() || self.context_refresh_timer >= self.refresh_interval
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_contextfree_count(&self) -> u32 {
|
||||||
|
self.context_refresh_timer
|
||||||
|
}
|
||||||
|
|
||||||
|
fn refresh_context(&mut self, chat: &Chat, ai_client: &Client) {
|
||||||
|
self.context.replace(ai_client.complete_prompt_sync(openai_api::api::CompletionArgs::builder()
|
||||||
|
.prompt(format!("Explain the conversation concisely:\n{}Explanation:", self.get_chat_context(chat)))
|
||||||
|
.engine("text-davinci-003")
|
||||||
|
.max_tokens(512)
|
||||||
|
.temperature(0.23)
|
||||||
|
.top_p(0.9)
|
||||||
|
.frequency_penalty(0.3)
|
||||||
|
.presence_penalty(0.12)
|
||||||
|
.build()
|
||||||
|
.unwrap())
|
||||||
|
.unwrap()
|
||||||
|
.choices[0]
|
||||||
|
.text
|
||||||
|
.trim()
|
||||||
|
.to_string());
|
||||||
|
|
||||||
|
self.context_refresh_timer = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_context(&mut self, chat: &Chat, ai_client: &Client) -> &String {
|
||||||
|
if self.should_refresh_context() {
|
||||||
|
self.refresh_context(chat, ai_client);
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.context.as_ref().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_message(&mut self) {
|
||||||
|
self.context_refresh_timer += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_chat_context(&self, chat: &Chat) -> String {
|
||||||
|
let mut chat_context = String::new();
|
||||||
|
if let Option::Some(context) = self.context.as_ref() {
|
||||||
|
chat_context.push_str(&ContextManager::to_context_string(context));
|
||||||
|
chat_context.push_str("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
let history = chat.get_history();
|
||||||
|
let mut idx = 0;
|
||||||
|
for chat_entry in &history[std::cmp::max(0, history.len() as i32 - self.context_refresh_timer as i32) as usize..] {
|
||||||
|
chat_context.push_str(&format!("{idx}. {}: {}\n", chat_entry.sender, chat_entry.message));
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
chat_context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Ai {
|
||||||
|
client: Client,
|
||||||
|
chat: Chat,
|
||||||
|
context_manager: ContextManager
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ai {
|
||||||
|
pub fn new(chat: Chat, token: String, context_refresh_interval: u32) -> Ai {
|
||||||
|
let client = Client::new(token.as_str());
|
||||||
|
Ai { client, chat, context_manager: ContextManager::new(context_refresh_interval) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_prompt(&self) -> String {
|
||||||
|
format!("{}{}. AI:", self.context_manager.get_chat_context(&self.chat), self.context_manager.get_contextfree_count())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_message(&mut self, message: String) -> String {
|
||||||
|
self.chat.add_message("User".to_owned(), message, SystemTime::now());
|
||||||
|
self.context_manager.on_message();
|
||||||
|
|
||||||
|
let prompt = self.get_prompt();
|
||||||
|
|
||||||
|
let response = self.client.complete_prompt_sync(openai_api::api::CompletionArgs::builder()
|
||||||
|
.prompt(prompt)
|
||||||
|
.engine("text-davinci-003")
|
||||||
|
.max_tokens(1536)
|
||||||
|
.temperature(0.35)
|
||||||
|
.top_p(0.95)
|
||||||
|
.frequency_penalty(0.5)
|
||||||
|
.presence_penalty(0.33)
|
||||||
|
.build()
|
||||||
|
.unwrap())
|
||||||
|
.unwrap()
|
||||||
|
.choices[0]
|
||||||
|
.text
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
self.context_manager.on_message();
|
||||||
|
self.chat.add_message("AI".to_owned(), response.to_owned(), SystemTime::now());
|
||||||
|
|
||||||
|
response
|
||||||
|
}
|
||||||
|
}
|
42
src/chat.rs
Normal file
42
src/chat.rs
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub sender: String,
|
||||||
|
pub message: String,
|
||||||
|
pub timestamp: SystemTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Chat {
|
||||||
|
participants: Vec<String>,
|
||||||
|
history: Vec<ChatMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Chat {
|
||||||
|
pub fn new() -> Chat {
|
||||||
|
Chat {
|
||||||
|
participants: vec![],
|
||||||
|
history: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_message(&mut self, sender: String, message: String, timestamp: SystemTime) {
|
||||||
|
let mut index = self.participants.iter().find(|&x| x == &sender);
|
||||||
|
if !self.participants.contains(&sender) {
|
||||||
|
self.participants.push(sender.to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.history.push(ChatMessage {
|
||||||
|
sender: sender,
|
||||||
|
message,
|
||||||
|
timestamp,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_participants(&self) -> &Vec<String> {
|
||||||
|
&self.participants
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_history(&self) -> &Vec<ChatMessage> {
|
||||||
|
&self.history
|
||||||
|
}
|
||||||
|
}
|
20
src/main.rs
Normal file
20
src/main.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
use std::io::{stdout, stdin, Write};
|
||||||
|
|
||||||
|
use ai::Ai;
|
||||||
|
use chat::Chat;
|
||||||
|
|
||||||
|
mod chat;
|
||||||
|
mod ai;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
println!("Welcome to your personal AI chat. Enter a message and the AI will respond...");
|
||||||
|
let mut ai = Ai::new(Chat::new(), std::env::var("OPENAI_SK").unwrap(), 30);
|
||||||
|
loop {
|
||||||
|
let mut input = String::new();
|
||||||
|
print!("You: ");
|
||||||
|
stdout().flush().unwrap();
|
||||||
|
stdin().read_line(&mut input).expect("Failed to read from stdin???"); // Fuck it. Just crash on error
|
||||||
|
|
||||||
|
println!("AI: {}", ai.send_message(input));
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user