Implement completions
This commit is contained in:
parent
926b9044b8
commit
de0a109a4a
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
/target
|
||||
apikey.txt
|
||||
|
1184
Cargo.lock
generated
Normal file
1184
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
@ -6,3 +6,9 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.69"
|
||||
derive_builder = "0.12.0"
|
||||
reqwest = { versino = "0.11.14", features = [ "json" ] }
|
||||
serde = { version = "1.0.156", features = ["derive"] }
|
||||
serde_json = "1.0.94"
|
||||
tokio = { versino = "1.26.0", features = [ "full" ] }
|
||||
|
110
src/chat.rs
Normal file
110
src/chat.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use derive_builder::Builder;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::completion::{Sequence, Usage};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Role {
|
||||
User,
|
||||
System,
|
||||
Assistant
|
||||
}
|
||||
|
||||
impl Serialize for Role {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer {
|
||||
match self {
|
||||
Self::User => serializer.serialize_str("user"),
|
||||
Self::System => serializer.serialize_str("system"),
|
||||
Self::Assistant => serializer.serialize_str("assistant"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Role {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de> {
|
||||
// Deserialize the String
|
||||
match String::deserialize(deserializer)? {
|
||||
s if s == "user" => Ok(Self::User),
|
||||
s if s == "system" => Ok(Self::System),
|
||||
s if s == "assistant" => Ok(Self::Assistant),
|
||||
_ => Err(serde::de::Error::custom("Invalid role")),
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: Role, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role,
|
||||
content: message.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Builder)]
|
||||
pub struct ChatHistory {
|
||||
#[builder(setter(into))]
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[builder(setter(into))]
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub n: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub stop: Option<Sequence>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub max_tokens: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub presence_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub logit_bias: Option<HashMap<u64, i8>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatCompletion {
|
||||
pub index: i32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: String, // TODO: Create enum for this
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
/* pub object: "chat.completion", */
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatCompletion>,
|
||||
pub usage: Usage
|
||||
}
|
124
src/completion.rs
Normal file
124
src/completion.rs
Normal file
@ -0,0 +1,124 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use derive_builder::Builder;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Sequence {
|
||||
String(String),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
impl Serialize for Sequence {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
Sequence::String(s) => serializer.serialize_str(s),
|
||||
Sequence::List(l) => serializer.collect_seq(l),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Sequence {
|
||||
fn from(s: String) -> Self {
|
||||
Sequence::String(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<String>> for Sequence {
|
||||
fn from(v: Vec<String>) -> Self {
|
||||
Sequence::List(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Sequence {
|
||||
fn from(s: &str) -> Self {
|
||||
Sequence::String(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[&str]> for Sequence {
|
||||
fn from(v: &[&str]) -> Self {
|
||||
Sequence::List(v.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Builder)]
|
||||
pub struct CompletionRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub prompt: Option<Sequence>,
|
||||
#[builder(setter(into))]
|
||||
pub model: String,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub max_tokens: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub suffix: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub n: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub logprobs: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub echo: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub stop: Option<Sequence>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub presence_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub best_of: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub logit_bias: Option<HashMap<u64, i8>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub return_prompt: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: u64,
|
||||
pub text: String,
|
||||
pub logprobs: Option<HashMap<String, f64>>,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Usage,
|
||||
}
|
52
src/context.rs
Normal file
52
src/context.rs
Normal file
@ -0,0 +1,52 @@
|
||||
use reqwest::{Client, RequestBuilder};
|
||||
|
||||
use crate::{model::{Model, ModelList}, completion::{CompletionRequest, CompletionResponse}, chat::{ChatCompletionResponse, ChatHistory}};
|
||||
|
||||
pub struct Context {
|
||||
api_key: String,
|
||||
org_id: Option<String>
|
||||
}
|
||||
|
||||
const API_URL: &str = "https://api.openai.com";
|
||||
|
||||
impl Context {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
Context {
|
||||
api_key,
|
||||
org_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_org(api_key: String, org_id: String) -> Self {
|
||||
Context {
|
||||
api_key,
|
||||
org_id: Some(org_id),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder {
|
||||
(
|
||||
if let Some(ref org_id) = self.org_id {
|
||||
builder.header("OpenAI-Organization", org_id)
|
||||
} else {
|
||||
builder
|
||||
}
|
||||
).bearer_auth(&self.api_key)
|
||||
}
|
||||
|
||||
pub async fn get_models(&self) -> anyhow::Result<Vec<Model>> {
|
||||
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<ModelList>().await?.data)
|
||||
}
|
||||
|
||||
pub async fn get_model(&self, model_id: &str) -> anyhow::Result<Model> {
|
||||
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models/{model_id}", model_id = model_id))).send().await?.json::<Model>().await?)
|
||||
}
|
||||
|
||||
pub async fn create_completion(&self, completion_request: CompletionRequest) -> anyhow::Result<CompletionResponse> {
|
||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/completions")).json(&completion_request)).send().await?.json::<CompletionResponse>().await?)
|
||||
}
|
||||
|
||||
pub async fn create_chat_completion(&self, chat_completion_request: ChatHistory) -> anyhow::Result<ChatCompletionResponse> {
|
||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions")).json(&chat_completion_request)).send().await?.json::<ChatCompletionResponse>().await?)
|
||||
}
|
||||
}
|
60
src/lib.rs
Normal file
60
src/lib.rs
Normal file
@ -0,0 +1,60 @@
|
||||
pub mod context;
|
||||
pub mod model;
|
||||
pub mod completion;
|
||||
pub mod chat;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::chat::ChatMessage;
|
||||
use crate::context::Context;
|
||||
use crate::completion::CompletionRequestBuilder;
|
||||
|
||||
fn get_api() -> anyhow::Result<Context> {
|
||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_models() {
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
|
||||
let models = ctx.unwrap().get_models().await;
|
||||
assert!(models.is_ok(), "Could not get models: {}", models.unwrap_err());
|
||||
assert!(models.unwrap().len() > 0, "No models found");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_completion() {
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
|
||||
let completion = ctx.unwrap().create_completion(
|
||||
CompletionRequestBuilder::default()
|
||||
.model("text-davinci-003")
|
||||
.prompt("Say 'this is a test'")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
||||
assert!(completion.unwrap().choices.len() == 1, "No completion found");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_chat_completion() {
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
|
||||
let completion = ctx.unwrap().create_chat_completion(
|
||||
crate::chat::ChatHistoryBuilder::default()
|
||||
.messages(vec![ChatMessage::new(crate::chat::Role::User, "Respond to this message with 'this is a test'")])
|
||||
.model("gpt-3.5-turbo")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
||||
assert!(completion.unwrap().choices.len() == 1, "No completion found");
|
||||
}
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
}
|
32
src/model.rs
Normal file
32
src/model.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Permission {
|
||||
pub id: String,
|
||||
/* pub object: "model_permission", */
|
||||
pub created: u64,
|
||||
pub allow_create_engine: bool,
|
||||
pub allow_sampling: bool,
|
||||
pub allow_logprobs: bool,
|
||||
pub allow_search_indices: bool,
|
||||
pub allow_view: bool,
|
||||
pub allow_fine_tuning: bool,
|
||||
pub organization: String,
|
||||
/* pub group: null, */
|
||||
pub is_blocking: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Model {
|
||||
pub id: String,
|
||||
pub created: u64,
|
||||
pub owned_by: String,
|
||||
pub permission: Vec<Permission>,
|
||||
pub root: String,
|
||||
pub parent: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ModelList {
|
||||
pub data: Vec<Model>,
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user