Implement completions

This commit is contained in:
Gabriel Tofvesson 2023-03-17 21:06:03 +01:00
parent 926b9044b8
commit de0a109a4a
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
9 changed files with 1569 additions and 3 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
/target
apikey.txt

1184
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -6,3 +6,9 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.69"
derive_builder = "0.12.0"
reqwest = { versino = "0.11.14", features = [ "json" ] }
serde = { version = "1.0.156", features = ["derive"] }
serde_json = "1.0.94"
tokio = { versino = "1.26.0", features = [ "full" ] }

110
src/chat.rs Normal file
View File

@ -0,0 +1,110 @@
use std::collections::HashMap;
use derive_builder::Builder;
use serde::{Serialize, Deserialize};
use crate::completion::{Sequence, Usage};
#[derive(Debug, Clone)]
pub enum Role {
User,
System,
Assistant
}
impl Serialize for Role {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer {
match self {
Self::User => serializer.serialize_str("user"),
Self::System => serializer.serialize_str("system"),
Self::Assistant => serializer.serialize_str("assistant"),
}
}
}
impl<'de> Deserialize<'de> for Role {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de> {
// Deserialize the String
match String::deserialize(deserializer)? {
s if s == "user" => Ok(Self::User),
s if s == "system" => Ok(Self::System),
s if s == "assistant" => Ok(Self::Assistant),
_ => Err(serde::de::Error::custom("Invalid role")),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
}
impl ChatMessage {
pub fn new(role: Role, message: impl Into<String>) -> Self {
Self {
role,
content: message.into()
}
}
}
#[derive(Debug, Serialize, Builder)]
pub struct ChatHistory {
#[builder(setter(into))]
pub messages: Vec<ChatMessage>,
#[builder(setter(into))]
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub stop: Option<Sequence>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub logit_bias: Option<HashMap<u64, i8>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub user: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ChatCompletion {
pub index: i32,
pub message: ChatMessage,
pub finish_reason: String, // TODO: Create enum for this
}
#[derive(Debug, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
/* pub object: "chat.completion", */
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletion>,
pub usage: Usage
}

124
src/completion.rs Normal file
View File

@ -0,0 +1,124 @@
use std::collections::HashMap;
use derive_builder::Builder;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone)]
pub enum Sequence {
String(String),
List(Vec<String>),
}
impl Serialize for Sequence {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Sequence::String(s) => serializer.serialize_str(s),
Sequence::List(l) => serializer.collect_seq(l),
}
}
}
impl From<String> for Sequence {
fn from(s: String) -> Self {
Sequence::String(s)
}
}
impl From<Vec<String>> for Sequence {
fn from(v: Vec<String>) -> Self {
Sequence::List(v)
}
}
impl From<&str> for Sequence {
fn from(s: &str) -> Self {
Sequence::String(s.to_string())
}
}
impl From<&[&str]> for Sequence {
fn from(v: &[&str]) -> Self {
Sequence::List(v.iter().map(|s| s.to_string()).collect())
}
}
#[derive(Debug, Serialize, Builder)]
pub struct CompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub prompt: Option<Sequence>,
#[builder(setter(into))]
pub model: String,
#[builder(setter(into, strip_option), default)]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub n: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub logprobs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub echo: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub stop: Option<Sequence>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub best_of: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub logit_bias: Option<HashMap<u64, i8>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub return_prompt: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub user: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct Choice {
pub index: u64,
pub text: String,
pub logprobs: Option<HashMap<String, f64>>,
pub finish_reason: String,
}
#[derive(Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}

52
src/context.rs Normal file
View File

@ -0,0 +1,52 @@
use reqwest::{Client, RequestBuilder};
use crate::{model::{Model, ModelList}, completion::{CompletionRequest, CompletionResponse}, chat::{ChatCompletionResponse, ChatHistory}};
pub struct Context {
api_key: String,
org_id: Option<String>
}
const API_URL: &str = "https://api.openai.com";
impl Context {
pub fn new(api_key: String) -> Self {
Context {
api_key,
org_id: None,
}
}
pub fn new_with_org(api_key: String, org_id: String) -> Self {
Context {
api_key,
org_id: Some(org_id),
}
}
fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder {
(
if let Some(ref org_id) = self.org_id {
builder.header("OpenAI-Organization", org_id)
} else {
builder
}
).bearer_auth(&self.api_key)
}
pub async fn get_models(&self) -> anyhow::Result<Vec<Model>> {
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<ModelList>().await?.data)
}
pub async fn get_model(&self, model_id: &str) -> anyhow::Result<Model> {
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models/{model_id}", model_id = model_id))).send().await?.json::<Model>().await?)
}
pub async fn create_completion(&self, completion_request: CompletionRequest) -> anyhow::Result<CompletionResponse> {
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/completions")).json(&completion_request)).send().await?.json::<CompletionResponse>().await?)
}
pub async fn create_chat_completion(&self, chat_completion_request: ChatHistory) -> anyhow::Result<ChatCompletionResponse> {
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions")).json(&chat_completion_request)).send().await?.json::<ChatCompletionResponse>().await?)
}
}

60
src/lib.rs Normal file
View File

@ -0,0 +1,60 @@
pub mod context;
pub mod model;
pub mod completion;
pub mod chat;
#[cfg(test)]
mod tests {
use crate::chat::ChatMessage;
use crate::context::Context;
use crate::completion::CompletionRequestBuilder;
fn get_api() -> anyhow::Result<Context> {
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
}
#[tokio::test]
async fn test_get_models() {
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let models = ctx.unwrap().get_models().await;
assert!(models.is_ok(), "Could not get models: {}", models.unwrap_err());
assert!(models.unwrap().len() > 0, "No models found");
}
#[tokio::test]
async fn test_completion() {
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let completion = ctx.unwrap().create_completion(
CompletionRequestBuilder::default()
.model("text-davinci-003")
.prompt("Say 'this is a test'")
.build()
.unwrap()
).await;
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
assert!(completion.unwrap().choices.len() == 1, "No completion found");
}
#[tokio::test]
async fn test_chat_completion() {
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let completion = ctx.unwrap().create_chat_completion(
crate::chat::ChatHistoryBuilder::default()
.messages(vec![ChatMessage::new(crate::chat::Role::User, "Respond to this message with 'this is a test'")])
.model("gpt-3.5-turbo")
.build()
.unwrap()
).await;
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
assert!(completion.unwrap().choices.len() == 1, "No completion found");
}
}

View File

@ -1,3 +0,0 @@
fn main() {
println!("Hello, world!");
}

32
src/model.rs Normal file
View File

@ -0,0 +1,32 @@
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct Permission {
pub id: String,
/* pub object: "model_permission", */
pub created: u64,
pub allow_create_engine: bool,
pub allow_sampling: bool,
pub allow_logprobs: bool,
pub allow_search_indices: bool,
pub allow_view: bool,
pub allow_fine_tuning: bool,
pub organization: String,
/* pub group: null, */
pub is_blocking: bool,
}
#[derive(Debug, Deserialize)]
pub struct Model {
pub id: String,
pub created: u64,
pub owned_by: String,
pub permission: Vec<Permission>,
pub root: String,
pub parent: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ModelList {
pub data: Vec<Model>,
}