#[macro_use] extern crate derive_builder; use reqwest::header::HeaderMap; use thiserror::Error; type Result = std::result::Result; #[allow(clippy::clippy::default_trait_access)] pub mod api { use std::collections::HashMap; use serde::{Deserialize, Serialize}; /// Container type. Used in the api, but not useful for clients of this library #[derive(Deserialize, Debug)] pub(super) struct Container { pub data: Vec, } /// Engine description type #[derive(Deserialize, Debug, Eq, PartialEq)] pub struct EngineInfo { pub id: Engine, pub owner: String, pub ready: bool, } /// Engine types, known and unknown #[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)] #[serde(rename_all = "kebab-case")] pub enum Engine { Ada, Babbage, Curie, Davinci, #[serde(rename = "content-filter-alpha-c4")] ContentFilter, #[serde(other)] Other, } // Custom Display to lowercase things impl std::fmt::Display for Engine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Engine::Ada => f.write_str("ada"), Engine::Babbage => f.write_str("babbage"), Engine::Curie => f.write_str("curie"), Engine::Davinci => f.write_str("davinci"), Engine::ContentFilter => f.write_str("content-filter-alpha-c4"), _ => panic!("Can't write out Other engine id"), } } } /// Options that affect the result #[derive(Serialize, Debug, Builder)] pub struct CompletionArgs { #[builder(setter(into), default = "\"<|endoftext|>\".into()")] prompt: String, #[builder(default = "16")] max_tokens: u64, #[builder(default = "1.0")] temperature: f64, #[builder(default = "1.0")] top_p: f64, #[builder(default = "1")] n: u64, #[builder(setter(strip_option), default)] logprobs: Option, #[builder(default = "false")] echo: bool, #[builder(setter(strip_option), default)] stop: Option>, #[builder(default = "0.0")] presence_penalty: f64, #[builder(default = "0.0")] frequency_penalty: f64, #[builder(default)] logit_bias: HashMap, } /* { "stream": false, // SSE streams back results "best_of": Option, //cant be used with stream } */ // TODO: add validators for the different arguments impl Default for CompletionArgs { fn default() -> Self { CompletionArgsBuilder::default() .build() .expect("Client error, invalid defaults") } } impl From<&str> for CompletionArgs { fn from(prompt_string: &str) -> Self { Self { prompt: prompt_string.into(), ..CompletionArgs::default() } } } /// Represents a non-streamed completion response #[derive(Deserialize, Debug)] pub struct Completion { id: String, object: String, created: u64, model: String, choices: Vec, } /// Represents a single choice #[derive(Deserialize, Debug)] pub struct Choice { text: String, index: u64, logprobs: Option, finish_reason: FinishReason, } /// Represents a logprobs subdocument #[derive(Deserialize, Debug)] pub struct LogProbs { tokens: Vec, token_logprobs: Vec>, top_logprobs: Vec>>, text_offset: Vec, } #[derive(Deserialize, Debug)] #[serde(rename_all = "kebab-case")] pub enum FinishReason { Length, Stop, } #[derive(Deserialize, Debug, Eq, PartialEq)] pub struct ErrorMessage { pub message: String, #[serde(rename = "type")] pub error_type: String, } #[derive(Deserialize, Debug)] pub struct ErrorWrapper { pub error: ErrorMessage, } } #[derive(Error, Debug)] pub enum OpenAIError { #[error("Invalid secret key")] InvalidAPIKey { #[from] source: reqwest::header::InvalidHeaderValue, }, #[error("API Returned an Error document")] APIError(api::ErrorMessage), } pub struct OpenAIClient { client: reqwest::Client, root: String, } impl OpenAIClient { /// Creates a new `OpenAIClient` /// /// # Errors /// `OpenAIError::InvalidAPIKey` if the api token has invalid characters pub fn new(token: &str) -> Result { let mut headers = HeaderMap::new(); headers.insert( reqwest::header::AUTHORIZATION, reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))?, ); Ok(Self { client: reqwest::Client::builder() .default_headers(headers) .build() .expect("Client library error. Should have constructed a valid http client."), root: "https://api.openai.com/v1".into(), }) } /// Private helper for making gets async fn get(&self, endpoint: &str) -> Result { let url = &format!("{}/{}", self.root, endpoint); let response = self .client .get(url) .send() .await .expect("Client error. Should have passed a valid url"); if response.status() != 200 { return Err(OpenAIError::APIError( response .json::() .await .expect("The API has returned something funky") .error, )); } Ok(response.json::().await.unwrap()) } /// Lists the currently available engines. /// /// Provides basic information about each one such as the owner and availability. /// /// # Errors /// - `OpenAIError::APIError` if the server returns an error /// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most /// likely a bug in this client, please report it) pub async fn engines(&self) -> Result> { self.get("engines").await.map(|r: api::Container<_>| r.data) } /// Retrieves an engine instance /// Provides basic information about the engine such as the owner and availability. /// /// # Errors /// - `OpenAIError::APIError` if the server returns an error /// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most /// likely a bug in this client, please report it) pub async fn engine(&self, engine: api::Engine) -> Result { self.get(&format!("engines/{}", engine)).await } // Private helper to generate post requests. Needs to be a bit more flexible than // get because it should support SSE eventually async fn post( &self, endpoint: &str, body: B, ) -> Result { let url = &format!("{}/{}", self.root, endpoint); let response = self .client .post(url) .json(&body) .send() .await .expect("Client library error, json failed to parse"); if response.status() != 200 { return Err(OpenAIError::APIError( response .json::() .await .expect("The API has returned something funky") .error, )); } Ok(response) } /// Get predicted completion of the prompt /// /// # Errors /// - `OpenAIError::APIError` if the api returns an error pub async fn complete( &self, engine: api::Engine, prompt: impl Into, ) -> Result { Ok(self .post(&format!("engines/{}/completions", engine), prompt.into()) .await? //.text() .json() .await .expect("Client error. JSON didn't parse correctly.")) } } #[cfg(test)] mod unit { use crate::{api, OpenAIClient, OpenAIError}; fn mocked_client() -> OpenAIClient { let mut client = OpenAIClient::new("bogus").unwrap(); client.root = mockito::server_url(); client } #[test] fn can_create_client() { let _c = mocked_client(); } #[test] fn parse_engine_info() -> Result<(), Box> { let example = r#"{ "id": "ada", "object": "engine", "owner": "openai", "ready": true }"#; let ei: api::EngineInfo = serde_json::from_str(example)?; assert_eq!( ei, api::EngineInfo { id: api::Engine::Ada, owner: "openai".into(), ready: true, } ); Ok(()) } #[tokio::test] async fn parse_engines() -> crate::Result<()> { use api::{Engine, EngineInfo}; let _m = mockito::mock("GET", "/engines") .with_status(200) .with_header("content-type", "text/json") .with_body( r#"{ "object": "list", "data": [ { "id": "ada", "object": "engine", "owner": "openai", "ready": true }, { "id": "babbage", "object": "engine", "owner": "openai", "ready": true }, { "id": "experimental-engine-v7", "object": "engine", "owner": "openai", "ready": false }, { "id": "curie", "object": "engine", "owner": "openai", "ready": true }, { "id": "davinci", "object": "engine", "owner": "openai", "ready": true }, { "id": "content-filter-alpha-c4", "object": "engine", "owner": "openai", "ready": true } ] }"#, ) .create(); let expected = vec![ EngineInfo { id: Engine::Ada, owner: "openai".into(), ready: true, }, EngineInfo { id: Engine::Babbage, owner: "openai".into(), ready: true, }, EngineInfo { id: Engine::Other, owner: "openai".into(), ready: false, }, EngineInfo { id: Engine::Curie, owner: "openai".into(), ready: true, }, EngineInfo { id: Engine::Davinci, owner: "openai".into(), ready: true, }, EngineInfo { id: Engine::ContentFilter, owner: "openai".into(), ready: true, }, ]; let response = mocked_client().engines().await?; assert_eq!(response, expected); Ok(()) } #[tokio::test] async fn engine_error_response() -> crate::Result<()> { let _m = mockito::mock("GET", "/engines/davinci") .with_status(404) .with_header("content-type", "text/json") .with_body( r#"{ "error": { "code": null, "message": "Some kind of error happened", "type": "some_error_type" } }"#, ) .create(); let expected = api::ErrorMessage { message: "Some kind of error happened".into(), error_type: "some_error_type".into(), }; let response = mocked_client().engine(api::Engine::Davinci).await; if let Result::Err(OpenAIError::APIError(msg)) = response { assert_eq!(expected, msg); } Ok(()) } } #[cfg(test)] mod integration { use api::ErrorMessage; use crate::{OpenAIClient, OpenAIError, api}; /// Used by tests to get a client to the actual api fn get_client() -> OpenAIClient { let sk = std::env::var("OPENAI_SK").expect( "To run integration tests, you must put set the OPENAI_SK env var to your api token", ); OpenAIClient::new(&sk).expect("client build failed") } #[tokio::test] async fn can_get_engines() { let client = get_client(); client.engines().await.unwrap(); } #[tokio::test] async fn can_get_engine() { let client = get_client(); let result = client.engine(api::Engine::Ada).await; match result { Err(OpenAIError::APIError(ErrorMessage{message, error_type})) => { assert_eq!(message, "No engine with that ID: ada"); assert_eq!(error_type, "invalid_request_error"); } _ => {panic!("Expected an error message, got {:?}", result)} } } #[tokio::test] async fn complete_string() -> crate::Result<()> { let client = get_client(); client.complete(api::Engine::Ada, "Hey there").await?; Ok(()) } #[tokio::test] async fn complete_explicit_params() -> crate::Result<()> { let client = get_client(); let args = api::CompletionArgsBuilder::default() .prompt("Once upon a time,") .max_tokens(10) .temperature(0.5) .top_p(0.5) .n(1) .logprobs(3) .echo(false) .stop(vec!["\n".into()]) .presence_penalty(0.5) .frequency_penalty(0.5) .logit_bias(maplit::hashmap! { "1".into() => 1.0, "23".into() => 0.0, }) .build() .expect("Build should have succeeded"); client.complete(api::Engine::Ada, args).await?; Ok(()) } }