diff --git a/Cargo.lock b/Cargo.lock index e53ce38..860e2e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -87,7 +87,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn", + "syn 1.0.109", ] [[package]] @@ -98,7 +98,7 @@ checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ "darling_core", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -119,7 +119,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -129,7 +129,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" dependencies = [ "derive_builder_core", - "syn", + "syn 1.0.109", ] [[package]] @@ -162,6 +162,17 @@ dependencies = [ "libc", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -201,6 +212,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.27" @@ -208,6 +234,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -216,6 +243,17 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" +[[package]] +name = "futures-executor" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.27" @@ -230,7 +268,7 @@ checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -245,12 +283,19 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -492,6 +537,12 @@ dependencies = [ "unicase", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "mio" version = "0.8.6" @@ -522,6 +573,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "num_cpus" version = "1.15.0" @@ -546,8 +607,10 @@ dependencies = [ "base64", "bytes", "derive_builder", + "futures", "futures-core", "reqwest", + "reqwest-eventsource", "serde", "serde_json", "tokio", @@ -577,7 +640,7 @@ checksum = "b501e44f11665960c7e7fcf062c7d96a14ade4aa98116c004b2e37b5be7d736c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -713,6 +776,22 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-eventsource" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror", +] + [[package]] name = "rustix" version = "0.36.9" @@ -788,7 +867,7 @@ checksum = "d7e29c4601e36bcec74a223228dce795f4cd3616341a4af93520ca1a837c087d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -865,6 +944,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59d3276aee1fa0c33612917969b5172b5be2db051232a6e4826f1a1a9191b045" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "tempfile" version = "3.4.0" @@ -878,6 +968,26 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "thiserror" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.2", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -921,7 +1031,7 @@ checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1070,7 +1180,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen-shared", ] @@ -1104,7 +1214,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index 73c8846..222c0ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,10 @@ anyhow = "1.0.69" base64 = "0.21.0" bytes = "1.4.0" derive_builder = "0.12.0" +futures = "0.3.27" futures-core = "0.3.27" reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] } +reqwest-eventsource = "0.4.0" serde = { version = "1.0.156", features = ["derive"] } serde_json = "1.0.94" tokio = { version = "1.26.0", features = [ "full" ] } diff --git a/src/chat.rs b/src/chat.rs index c8fa7fe..b7f16b6 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,7 +1,9 @@ -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr, pin::Pin, task::Poll}; use derive_builder::Builder; -use reqwest::Client; +use futures::{Stream, StreamExt}; +use reqwest::{Client, RequestBuilder}; +use reqwest_eventsource::{RequestBuilderExt, Event, EventSource}; use serde::{Serialize, Deserialize}; use crate::{completion::{Sequence, Usage}, context::{API_URL, Context}}; @@ -56,6 +58,7 @@ impl ChatMessage { } #[derive(Debug, Serialize, Builder)] +#[builder(pattern = "owned")] pub struct ChatHistory { #[builder(setter(into))] pub messages: Vec, @@ -72,7 +75,7 @@ pub struct ChatHistory { pub n: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] - pub stream: Option, + stream: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub stop: Option, @@ -93,15 +96,66 @@ pub struct ChatHistory { pub user: Option, } +#[derive(Debug)] +pub enum FinishReason { + Stop, + Length, + ContentFilter, +} + +impl<'de> Deserialize<'de> for FinishReason { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + // Deserialize the String + match String::deserialize(deserializer)? { + s if s == "stop" => Ok(Self::Stop), + s if s == "length" => Ok(Self::Length), + s if s == "content_filter" => Ok(Self::ContentFilter), + _ => Err(serde::de::Error::custom("Invalid stop reason")), + } + + } +} + #[derive(Debug, Deserialize)] pub struct ChatCompletion { pub index: i32, pub message: ChatMessage, - pub finish_reason: String, // TODO: Create enum for this + pub finish_reason: Option } #[derive(Debug, Deserialize)] -pub struct ChatCompletionResponse { +pub struct DeltaMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Debug, Deserialize)] +pub struct DeltaChatCompletion { + pub index: i32, + pub delta: DeltaMessage, + pub finish_reason: Option, +} +#[derive(Debug, Deserialize)] +pub struct ChatCompletionDeltaResponse { + pub id: String, + /* pub object: "chat.completion", */ + pub created: u64, + pub model: String, + pub choices: Vec, +} + +impl FromStr for ChatCompletionDeltaResponse { + type Err = serde_json::Error; + + fn from_str(s: &str) -> Result { + serde_json::from_str(s) + } +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionSyncResponse { pub id: String, /* pub object: "chat.completion", */ pub created: u64, @@ -110,16 +164,58 @@ pub struct ChatCompletionResponse { pub usage: Usage } +struct CompletionStream { + stream: EventSource +} + +impl Stream for CompletionStream { + type Item = anyhow::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + loop { + return match self.stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(event))) => { + match event { + Event::Message(message) => { + // Stream has ended + if message.data == "[DONE]" { + return Poll::Ready(None) + } + + match message.data.parse::() { + Ok(value) => Poll::Ready(Some(Ok(value))), + Err(e) => Poll::Ready(Some(Err(e.into()))) + } + }, + _ => continue + } + }, + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(anyhow::Error::new(e)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending + } + } + } +} + impl Context { - pub async fn create_chat_completion(&self, chat_completion_request: ChatHistory) -> anyhow::Result { + fn build_request(&self, stream: bool, chat_completion_request: ChatHistoryBuilder) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions"))) + .json(&chat_completion_request.stream(stream).build()?)) + } + + pub async fn create_chat_completion_sync(&self, chat_completion_request: ChatHistoryBuilder) -> anyhow::Result { Ok( - self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions"))) - .json(&chat_completion_request) + self.build_request(false, chat_completion_request)? .send() .await? .error_for_status()? - .json::() + .json::() .await? ) } + + pub async fn create_chat_completion_streamed(&self, chat_completion_request: ChatHistoryBuilder) -> anyhow::Result>> { + Ok(CompletionStream { stream: self.build_request(true, chat_completion_request)?.eventsource()? }) + } } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index c981a5e..8dd97fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ pub mod util; #[cfg(test)] mod tests { + use futures::StreamExt; use tokio::fs::File; use crate::chat::{ChatHistoryBuilder, ChatMessage, Role}; @@ -65,19 +66,53 @@ mod tests { #[tokio::test] async fn test_chat_completion() { + const PROMPT: &str = "Respond to this message with 'this is a test'"; + let ctx = get_api(); assert!(ctx.is_ok(), "Could not load context"); - let completion = ctx.unwrap().create_chat_completion( + let ctx = ctx.unwrap(); + + println!("Generating completion for prompt: {PROMPT}"); + let completion = ctx.create_chat_completion_sync( ChatHistoryBuilder::default() - .messages(vec![ChatMessage::new(Role::User, "Respond to this message with 'this is a test'")]) + .messages(vec![ChatMessage::new(Role::User, PROMPT)]) .model("gpt-3.5-turbo") - .build() - .unwrap() ).await; - assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err()); - assert!(completion.unwrap().choices.len() == 1, "No completion found"); + assert!(completion.is_ok(), "Could not create completion: {}", completion.unwrap_err()); + + let result = completion.unwrap(); + assert!(result.choices.len() == 1, "No completion found"); + println!("Got completion: {:?}", result.choices[0].message); + + println!("Generating streamed completion for prompt: {PROMPT}"); + let completion = ctx.create_chat_completion_streamed( + ChatHistoryBuilder::default() + .messages(vec![ChatMessage::new(Role::User, PROMPT)]) + .model("gpt-3.5-turbo") + ).await; + + assert!(completion.is_ok(), "Could not create completion: {}", completion.err().unwrap()); + let mut stream = completion.unwrap(); + while let Some(result) = stream.next().await { + assert!(result.is_ok(), "Could not get completion: {}", result.unwrap_err()); + let result = result.unwrap(); + assert!(result.choices.len() == 1, "No completion found"); + + let delta = &result.choices[0]; + if let Some(ref reason) = delta.finish_reason { + println!("Got completion end. Reason: {:?}", reason); + } else { + if let Some(ref role) = delta.delta.role { + println!("Got role: {:?}", role); + } + + if let Some(ref message) = delta.delta.content { + println!("Got completion: {:?}", message); + } + } + } } #[tokio::test]