Implement completion streaming
This commit is contained in:
parent
3339895c22
commit
ff9f8f9339
130
Cargo.lock
generated
130
Cargo.lock
generated
@ -87,7 +87,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"strsim",
|
"strsim",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -98,7 +98,7 @@ checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"darling_core",
|
"darling_core",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -119,7 +119,7 @@ dependencies = [
|
|||||||
"darling",
|
"darling",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -129,7 +129,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
|
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive_builder_core",
|
"derive_builder_core",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -162,6 +162,17 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "eventsource-stream"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"nom",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastrand"
|
name = "fastrand"
|
||||||
version = "1.9.0"
|
version = "1.9.0"
|
||||||
@ -201,6 +212,21 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures"
|
||||||
|
version = "0.3.27"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
|
"futures-core",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-io",
|
||||||
|
"futures-sink",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.27"
|
version = "0.3.27"
|
||||||
@ -208,6 +234,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac"
|
checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -216,6 +243,17 @@ version = "0.3.27"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
|
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-executor"
|
||||||
|
version = "0.3.27"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-io"
|
name = "futures-io"
|
||||||
version = "0.3.27"
|
version = "0.3.27"
|
||||||
@ -230,7 +268,7 @@ checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -245,12 +283,19 @@ version = "0.3.27"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879"
|
checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-timer"
|
||||||
|
version = "3.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-util"
|
name = "futures-util"
|
||||||
version = "0.3.27"
|
version = "0.3.27"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
|
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-io",
|
"futures-io",
|
||||||
"futures-macro",
|
"futures-macro",
|
||||||
@ -492,6 +537,12 @@ dependencies = [
|
|||||||
"unicase",
|
"unicase",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minimal-lexical"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mio"
|
name = "mio"
|
||||||
version = "0.8.6"
|
version = "0.8.6"
|
||||||
@ -522,6 +573,16 @@ dependencies = [
|
|||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nom"
|
||||||
|
version = "7.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
"minimal-lexical",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num_cpus"
|
name = "num_cpus"
|
||||||
version = "1.15.0"
|
version = "1.15.0"
|
||||||
@ -546,8 +607,10 @@ dependencies = [
|
|||||||
"base64",
|
"base64",
|
||||||
"bytes",
|
"bytes",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
|
"futures",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
"reqwest-eventsource",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
@ -577,7 +640,7 @@ checksum = "b501e44f11665960c7e7fcf062c7d96a14ade4aa98116c004b2e37b5be7d736c"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -713,6 +776,22 @@ dependencies = [
|
|||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "reqwest-eventsource"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51"
|
||||||
|
dependencies = [
|
||||||
|
"eventsource-stream",
|
||||||
|
"futures-core",
|
||||||
|
"futures-timer",
|
||||||
|
"mime",
|
||||||
|
"nom",
|
||||||
|
"pin-project-lite",
|
||||||
|
"reqwest",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustix"
|
name = "rustix"
|
||||||
version = "0.36.9"
|
version = "0.36.9"
|
||||||
@ -788,7 +867,7 @@ checksum = "d7e29c4601e36bcec74a223228dce795f4cd3616341a4af93520ca1a837c087d"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -865,6 +944,17 @@ dependencies = [
|
|||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "2.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "59d3276aee1fa0c33612917969b5172b5be2db051232a6e4826f1a1a9191b045"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tempfile"
|
name = "tempfile"
|
||||||
version = "3.4.0"
|
version = "3.4.0"
|
||||||
@ -878,6 +968,26 @@ dependencies = [
|
|||||||
"windows-sys 0.42.0",
|
"windows-sys 0.42.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "1.0.40"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "1.0.40"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tinyvec"
|
name = "tinyvec"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
@ -921,7 +1031,7 @@ checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1070,7 +1180,7 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1104,7 +1214,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 1.0.109",
|
||||||
"wasm-bindgen-backend",
|
"wasm-bindgen-backend",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
@ -10,8 +10,10 @@ anyhow = "1.0.69"
|
|||||||
base64 = "0.21.0"
|
base64 = "0.21.0"
|
||||||
bytes = "1.4.0"
|
bytes = "1.4.0"
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
|
futures = "0.3.27"
|
||||||
futures-core = "0.3.27"
|
futures-core = "0.3.27"
|
||||||
reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
|
reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
|
||||||
|
reqwest-eventsource = "0.4.0"
|
||||||
serde = { version = "1.0.156", features = ["derive"] }
|
serde = { version = "1.0.156", features = ["derive"] }
|
||||||
serde_json = "1.0.94"
|
serde_json = "1.0.94"
|
||||||
tokio = { version = "1.26.0", features = [ "full" ] }
|
tokio = { version = "1.26.0", features = [ "full" ] }
|
||||||
|
114
src/chat.rs
114
src/chat.rs
@ -1,7 +1,9 @@
|
|||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, str::FromStr, pin::Pin, task::Poll};
|
||||||
|
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use reqwest::Client;
|
use futures::{Stream, StreamExt};
|
||||||
|
use reqwest::{Client, RequestBuilder};
|
||||||
|
use reqwest_eventsource::{RequestBuilderExt, Event, EventSource};
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::{completion::{Sequence, Usage}, context::{API_URL, Context}};
|
use crate::{completion::{Sequence, Usage}, context::{API_URL, Context}};
|
||||||
@ -56,6 +58,7 @@ impl ChatMessage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Builder)]
|
#[derive(Debug, Serialize, Builder)]
|
||||||
|
#[builder(pattern = "owned")]
|
||||||
pub struct ChatHistory {
|
pub struct ChatHistory {
|
||||||
#[builder(setter(into))]
|
#[builder(setter(into))]
|
||||||
pub messages: Vec<ChatMessage>,
|
pub messages: Vec<ChatMessage>,
|
||||||
@ -72,7 +75,7 @@ pub struct ChatHistory {
|
|||||||
pub n: Option<u32>,
|
pub n: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
#[builder(setter(into, strip_option), default)]
|
#[builder(setter(into, strip_option), default)]
|
||||||
pub stream: Option<bool>,
|
stream: Option<bool>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
#[builder(setter(into, strip_option), default)]
|
#[builder(setter(into, strip_option), default)]
|
||||||
pub stop: Option<Sequence>,
|
pub stop: Option<Sequence>,
|
||||||
@ -93,15 +96,66 @@ pub struct ChatHistory {
|
|||||||
pub user: Option<String>,
|
pub user: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum FinishReason {
|
||||||
|
Stop,
|
||||||
|
Length,
|
||||||
|
ContentFilter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for FinishReason {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de> {
|
||||||
|
// Deserialize the String
|
||||||
|
match String::deserialize(deserializer)? {
|
||||||
|
s if s == "stop" => Ok(Self::Stop),
|
||||||
|
s if s == "length" => Ok(Self::Length),
|
||||||
|
s if s == "content_filter" => Ok(Self::ContentFilter),
|
||||||
|
_ => Err(serde::de::Error::custom("Invalid stop reason")),
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct ChatCompletion {
|
pub struct ChatCompletion {
|
||||||
pub index: i32,
|
pub index: i32,
|
||||||
pub message: ChatMessage,
|
pub message: ChatMessage,
|
||||||
pub finish_reason: String, // TODO: Create enum for this
|
pub finish_reason: Option<FinishReason>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct ChatCompletionResponse {
|
pub struct DeltaMessage {
|
||||||
|
pub role: Option<Role>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct DeltaChatCompletion {
|
||||||
|
pub index: i32,
|
||||||
|
pub delta: DeltaMessage,
|
||||||
|
pub finish_reason: Option<FinishReason>,
|
||||||
|
}
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ChatCompletionDeltaResponse {
|
||||||
|
pub id: String,
|
||||||
|
/* pub object: "chat.completion", */
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<DeltaChatCompletion>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for ChatCompletionDeltaResponse {
|
||||||
|
type Err = serde_json::Error;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
serde_json::from_str(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ChatCompletionSyncResponse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
/* pub object: "chat.completion", */
|
/* pub object: "chat.completion", */
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
@ -110,16 +164,58 @@ pub struct ChatCompletionResponse {
|
|||||||
pub usage: Usage
|
pub usage: Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct CompletionStream {
|
||||||
|
stream: EventSource
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for CompletionStream {
|
||||||
|
type Item = anyhow::Result<ChatCompletionDeltaResponse>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
loop {
|
||||||
|
return match self.stream.poll_next_unpin(cx) {
|
||||||
|
Poll::Ready(Some(Ok(event))) => {
|
||||||
|
match event {
|
||||||
|
Event::Message(message) => {
|
||||||
|
// Stream has ended
|
||||||
|
if message.data == "[DONE]" {
|
||||||
|
return Poll::Ready(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
match message.data.parse::<ChatCompletionDeltaResponse>() {
|
||||||
|
Ok(value) => Poll::Ready(Some(Ok(value))),
|
||||||
|
Err(e) => Poll::Ready(Some(Err(e.into())))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => continue
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(anyhow::Error::new(e)))),
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
Poll::Pending => Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Context {
|
impl Context {
|
||||||
pub async fn create_chat_completion(&self, chat_completion_request: ChatHistory) -> anyhow::Result<ChatCompletionResponse> {
|
fn build_request(&self, stream: bool, chat_completion_request: ChatHistoryBuilder) -> anyhow::Result<RequestBuilder> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions")))
|
||||||
|
.json(&chat_completion_request.stream(stream).build()?))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_chat_completion_sync(&self, chat_completion_request: ChatHistoryBuilder) -> anyhow::Result<ChatCompletionSyncResponse> {
|
||||||
Ok(
|
Ok(
|
||||||
self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions")))
|
self.build_request(false, chat_completion_request)?
|
||||||
.json(&chat_completion_request)
|
|
||||||
.send()
|
.send()
|
||||||
.await?
|
.await?
|
||||||
.error_for_status()?
|
.error_for_status()?
|
||||||
.json::<ChatCompletionResponse>()
|
.json::<ChatCompletionSyncResponse>()
|
||||||
.await?
|
.await?
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn create_chat_completion_streamed(&self, chat_completion_request: ChatHistoryBuilder) -> anyhow::Result<impl Stream<Item = anyhow::Result<ChatCompletionDeltaResponse>>> {
|
||||||
|
Ok(CompletionStream { stream: self.build_request(true, chat_completion_request)?.eventsource()? })
|
||||||
|
}
|
||||||
}
|
}
|
47
src/lib.rs
47
src/lib.rs
@ -17,6 +17,7 @@ pub mod util;
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use futures::StreamExt;
|
||||||
use tokio::fs::File;
|
use tokio::fs::File;
|
||||||
|
|
||||||
use crate::chat::{ChatHistoryBuilder, ChatMessage, Role};
|
use crate::chat::{ChatHistoryBuilder, ChatMessage, Role};
|
||||||
@ -65,19 +66,53 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_chat_completion() {
|
async fn test_chat_completion() {
|
||||||
|
const PROMPT: &str = "Respond to this message with 'this is a test'";
|
||||||
|
|
||||||
let ctx = get_api();
|
let ctx = get_api();
|
||||||
assert!(ctx.is_ok(), "Could not load context");
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
|
||||||
let completion = ctx.unwrap().create_chat_completion(
|
let ctx = ctx.unwrap();
|
||||||
|
|
||||||
|
println!("Generating completion for prompt: {PROMPT}");
|
||||||
|
let completion = ctx.create_chat_completion_sync(
|
||||||
ChatHistoryBuilder::default()
|
ChatHistoryBuilder::default()
|
||||||
.messages(vec![ChatMessage::new(Role::User, "Respond to this message with 'this is a test'")])
|
.messages(vec![ChatMessage::new(Role::User, PROMPT)])
|
||||||
.model("gpt-3.5-turbo")
|
.model("gpt-3.5-turbo")
|
||||||
.build()
|
|
||||||
.unwrap()
|
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
assert!(completion.is_ok(), "Could not create completion: {}", completion.unwrap_err());
|
||||||
assert!(completion.unwrap().choices.len() == 1, "No completion found");
|
|
||||||
|
let result = completion.unwrap();
|
||||||
|
assert!(result.choices.len() == 1, "No completion found");
|
||||||
|
println!("Got completion: {:?}", result.choices[0].message);
|
||||||
|
|
||||||
|
println!("Generating streamed completion for prompt: {PROMPT}");
|
||||||
|
let completion = ctx.create_chat_completion_streamed(
|
||||||
|
ChatHistoryBuilder::default()
|
||||||
|
.messages(vec![ChatMessage::new(Role::User, PROMPT)])
|
||||||
|
.model("gpt-3.5-turbo")
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert!(completion.is_ok(), "Could not create completion: {}", completion.err().unwrap());
|
||||||
|
let mut stream = completion.unwrap();
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
assert!(result.is_ok(), "Could not get completion: {}", result.unwrap_err());
|
||||||
|
let result = result.unwrap();
|
||||||
|
assert!(result.choices.len() == 1, "No completion found");
|
||||||
|
|
||||||
|
let delta = &result.choices[0];
|
||||||
|
if let Some(ref reason) = delta.finish_reason {
|
||||||
|
println!("Got completion end. Reason: {:?}", reason);
|
||||||
|
} else {
|
||||||
|
if let Some(ref role) = delta.delta.role {
|
||||||
|
println!("Got role: {:?}", role);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref message) = delta.delta.content {
|
||||||
|
println!("Got completion: {:?}", message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user