Implement completions
This commit is contained in:
parent
926b9044b8
commit
de0a109a4a
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
/target
|
/target
|
||||||
|
apikey.txt
|
||||||
|
1184
Cargo.lock
generated
Normal file
1184
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
@ -6,3 +6,9 @@ edition = "2021"
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1.0.69"
|
||||||
|
derive_builder = "0.12.0"
|
||||||
|
reqwest = { versino = "0.11.14", features = [ "json" ] }
|
||||||
|
serde = { version = "1.0.156", features = ["derive"] }
|
||||||
|
serde_json = "1.0.94"
|
||||||
|
tokio = { versino = "1.26.0", features = [ "full" ] }
|
||||||
|
110
src/chat.rs
Normal file
110
src/chat.rs
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use derive_builder::Builder;
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
|
use crate::completion::{Sequence, Usage};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
System,
|
||||||
|
Assistant
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for Role {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer {
|
||||||
|
match self {
|
||||||
|
Self::User => serializer.serialize_str("user"),
|
||||||
|
Self::System => serializer.serialize_str("system"),
|
||||||
|
Self::Assistant => serializer.serialize_str("assistant"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for Role {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de> {
|
||||||
|
// Deserialize the String
|
||||||
|
match String::deserialize(deserializer)? {
|
||||||
|
s if s == "user" => Ok(Self::User),
|
||||||
|
s if s == "system" => Ok(Self::System),
|
||||||
|
s if s == "assistant" => Ok(Self::Assistant),
|
||||||
|
_ => Err(serde::de::Error::custom("Invalid role")),
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatMessage {
|
||||||
|
pub fn new(role: Role, message: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role,
|
||||||
|
content: message.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Builder)]
|
||||||
|
pub struct ChatHistory {
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub model: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub stop: Option<Sequence>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub max_tokens: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub presence_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub frequency_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub logit_bias: Option<HashMap<u64, i8>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ChatCompletion {
|
||||||
|
pub index: i32,
|
||||||
|
pub message: ChatMessage,
|
||||||
|
pub finish_reason: String, // TODO: Create enum for this
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
/* pub object: "chat.completion", */
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatCompletion>,
|
||||||
|
pub usage: Usage
|
||||||
|
}
|
124
src/completion.rs
Normal file
124
src/completion.rs
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use derive_builder::Builder;
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum Sequence {
|
||||||
|
String(String),
|
||||||
|
List(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for Sequence {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Sequence::String(s) => serializer.serialize_str(s),
|
||||||
|
Sequence::List(l) => serializer.collect_seq(l),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for Sequence {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Sequence::String(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Vec<String>> for Sequence {
|
||||||
|
fn from(v: Vec<String>) -> Self {
|
||||||
|
Sequence::List(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for Sequence {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
Sequence::String(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&[&str]> for Sequence {
|
||||||
|
fn from(v: &[&str]) -> Self {
|
||||||
|
Sequence::List(v.iter().map(|s| s.to_string()).collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Builder)]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub prompt: Option<Sequence>,
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub model: String,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub max_tokens: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub n: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub logprobs: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub echo: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub stop: Option<Sequence>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub presence_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub frequency_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub best_of: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub logit_bias: Option<HashMap<u64, i8>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub return_prompt: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Choice {
|
||||||
|
pub index: u64,
|
||||||
|
pub text: String,
|
||||||
|
pub logprobs: Option<HashMap<String, f64>>,
|
||||||
|
pub finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u64,
|
||||||
|
pub completion_tokens: u64,
|
||||||
|
pub total_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct CompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<Choice>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
52
src/context.rs
Normal file
52
src/context.rs
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
use reqwest::{Client, RequestBuilder};
|
||||||
|
|
||||||
|
use crate::{model::{Model, ModelList}, completion::{CompletionRequest, CompletionResponse}, chat::{ChatCompletionResponse, ChatHistory}};
|
||||||
|
|
||||||
|
pub struct Context {
|
||||||
|
api_key: String,
|
||||||
|
org_id: Option<String>
|
||||||
|
}
|
||||||
|
|
||||||
|
const API_URL: &str = "https://api.openai.com";
|
||||||
|
|
||||||
|
impl Context {
|
||||||
|
pub fn new(api_key: String) -> Self {
|
||||||
|
Context {
|
||||||
|
api_key,
|
||||||
|
org_id: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_org(api_key: String, org_id: String) -> Self {
|
||||||
|
Context {
|
||||||
|
api_key,
|
||||||
|
org_id: Some(org_id),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder {
|
||||||
|
(
|
||||||
|
if let Some(ref org_id) = self.org_id {
|
||||||
|
builder.header("OpenAI-Organization", org_id)
|
||||||
|
} else {
|
||||||
|
builder
|
||||||
|
}
|
||||||
|
).bearer_auth(&self.api_key)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_models(&self) -> anyhow::Result<Vec<Model>> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<ModelList>().await?.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_model(&self, model_id: &str) -> anyhow::Result<Model> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models/{model_id}", model_id = model_id))).send().await?.json::<Model>().await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_completion(&self, completion_request: CompletionRequest) -> anyhow::Result<CompletionResponse> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/completions")).json(&completion_request)).send().await?.json::<CompletionResponse>().await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_chat_completion(&self, chat_completion_request: ChatHistory) -> anyhow::Result<ChatCompletionResponse> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions")).json(&chat_completion_request)).send().await?.json::<ChatCompletionResponse>().await?)
|
||||||
|
}
|
||||||
|
}
|
60
src/lib.rs
Normal file
60
src/lib.rs
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
pub mod context;
|
||||||
|
pub mod model;
|
||||||
|
pub mod completion;
|
||||||
|
pub mod chat;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::chat::ChatMessage;
|
||||||
|
use crate::context::Context;
|
||||||
|
use crate::completion::CompletionRequestBuilder;
|
||||||
|
|
||||||
|
fn get_api() -> anyhow::Result<Context> {
|
||||||
|
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_models() {
|
||||||
|
let ctx = get_api();
|
||||||
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
|
||||||
|
let models = ctx.unwrap().get_models().await;
|
||||||
|
assert!(models.is_ok(), "Could not get models: {}", models.unwrap_err());
|
||||||
|
assert!(models.unwrap().len() > 0, "No models found");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_completion() {
|
||||||
|
let ctx = get_api();
|
||||||
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
|
||||||
|
let completion = ctx.unwrap().create_completion(
|
||||||
|
CompletionRequestBuilder::default()
|
||||||
|
.model("text-davinci-003")
|
||||||
|
.prompt("Say 'this is a test'")
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
||||||
|
assert!(completion.unwrap().choices.len() == 1, "No completion found");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completion() {
|
||||||
|
let ctx = get_api();
|
||||||
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
|
||||||
|
let completion = ctx.unwrap().create_chat_completion(
|
||||||
|
crate::chat::ChatHistoryBuilder::default()
|
||||||
|
.messages(vec![ChatMessage::new(crate::chat::Role::User, "Respond to this message with 'this is a test'")])
|
||||||
|
.model("gpt-3.5-turbo")
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
||||||
|
assert!(completion.unwrap().choices.len() == 1, "No completion found");
|
||||||
|
}
|
||||||
|
}
|
@ -1,3 +0,0 @@
|
|||||||
fn main() {
|
|
||||||
println!("Hello, world!");
|
|
||||||
}
|
|
32
src/model.rs
Normal file
32
src/model.rs
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Permission {
|
||||||
|
pub id: String,
|
||||||
|
/* pub object: "model_permission", */
|
||||||
|
pub created: u64,
|
||||||
|
pub allow_create_engine: bool,
|
||||||
|
pub allow_sampling: bool,
|
||||||
|
pub allow_logprobs: bool,
|
||||||
|
pub allow_search_indices: bool,
|
||||||
|
pub allow_view: bool,
|
||||||
|
pub allow_fine_tuning: bool,
|
||||||
|
pub organization: String,
|
||||||
|
/* pub group: null, */
|
||||||
|
pub is_blocking: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Model {
|
||||||
|
pub id: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub owned_by: String,
|
||||||
|
pub permission: Vec<Permission>,
|
||||||
|
pub root: String,
|
||||||
|
pub parent: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub(crate) struct ModelList {
|
||||||
|
pub data: Vec<Model>,
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user