Basic functionality implemented

This commit is contained in:
Josh Kuhn 2020-11-25 21:28:03 -08:00
parent dd10b363f0
commit d18d72b646
4 changed files with 1977 additions and 3 deletions

5
.gitignore vendored
View File

@ -1,10 +1,9 @@
# Generated by Cargo
# will have compiled files and executables
/target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
.vscode/
Notes.md

1463
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

19
Cargo.toml Normal file
View File

@ -0,0 +1,19 @@
[package]
name = "openai-api"
version = "0.1.0"
authors = ["Josh Kuhn <deontologician@gmail.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
reqwest = { version = "0.10.9", features = ["json"] }
thiserror = "^1.0.22"
serde = { version = "^1.0.117", features = ["derive"] }
tokio = { version = "^0.2.5", features = ["full"]}
serde_json = "*"
derive_builder = "0.9.0"
[dev-dependencies]
mockito = "0.28.0"
maplit = "1.0.2"

493
src/lib.rs Normal file
View File

@ -0,0 +1,493 @@
#[macro_use]
extern crate derive_builder;
use reqwest::header::HeaderMap;
use thiserror::Error;
type Result<T> = std::result::Result<T, OpenAIError>;
#[allow(clippy::clippy::default_trait_access)]
pub mod api {
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
/// Container type. Used in the api, but not useful for clients of this library
#[derive(Deserialize, Debug)]
pub(super) struct Container<T> {
pub data: Vec<T>,
}
/// Engine description type
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct EngineInfo {
pub id: Engine,
pub owner: String,
pub ready: bool,
}
/// Engine types, known and unknown
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
#[serde(rename_all = "kebab-case")]
pub enum Engine {
Ada,
Babbage,
Curie,
Davinci,
#[serde(rename = "content-filter-alpha-c4")]
ContentFilter,
#[serde(other)]
Other,
}
// Custom Display to lowercase things
impl std::fmt::Display for Engine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Engine::Ada => f.write_str("ada"),
Engine::Babbage => f.write_str("babbage"),
Engine::Curie => f.write_str("curie"),
Engine::Davinci => f.write_str("davinci"),
Engine::ContentFilter => f.write_str("content-filter-alpha-c4"),
_ => panic!("Can't write out Other engine id"),
}
}
}
/// Options that affect the result
#[derive(Serialize, Debug, Builder)]
pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String,
#[builder(default = "16")]
max_tokens: u64,
#[builder(default = "1.0")]
temperature: f64,
#[builder(default = "1.0")]
top_p: f64,
#[builder(default = "1")]
n: u64,
#[builder(setter(strip_option), default)]
logprobs: Option<u64>,
#[builder(default = "false")]
echo: bool,
#[builder(setter(strip_option), default)]
stop: Option<Vec<String>>,
#[builder(default = "0.0")]
presence_penalty: f64,
#[builder(default = "0.0")]
frequency_penalty: f64,
#[builder(default)]
logit_bias: HashMap<String, f64>,
}
/* {
"stream": false, // SSE streams back results
"best_of": Option<u64>, //cant be used with stream
}
*/
// TODO: add validators for the different arguments
impl Default for CompletionArgs {
fn default() -> Self {
CompletionArgsBuilder::default()
.build()
.expect("Client error, invalid defaults")
}
}
impl From<&str> for CompletionArgs {
fn from(prompt_string: &str) -> Self {
Self {
prompt: prompt_string.into(),
..CompletionArgs::default()
}
}
}
/// Represents a non-streamed completion response
#[derive(Deserialize, Debug)]
pub struct Completion {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
}
/// Represents a single choice
#[derive(Deserialize, Debug)]
pub struct Choice {
text: String,
index: u64,
logprobs: Option<LogProbs>,
finish_reason: FinishReason,
}
/// Represents a logprobs subdocument
#[derive(Deserialize, Debug)]
pub struct LogProbs {
tokens: Vec<String>,
token_logprobs: Vec<Option<f64>>,
top_logprobs: Vec<Option<HashMap<String, f64>>>,
text_offset: Vec<u64>,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "kebab-case")]
pub enum FinishReason {
Length,
Stop,
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct ErrorMessage {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
}
#[derive(Deserialize, Debug)]
pub struct ErrorWrapper {
pub error: ErrorMessage,
}
}
#[derive(Error, Debug)]
pub enum OpenAIError {
#[error("Invalid secret key")]
InvalidAPIKey {
#[from]
source: reqwest::header::InvalidHeaderValue,
},
#[error("API Returned an Error document")]
APIError(api::ErrorMessage),
}
pub struct OpenAIClient {
client: reqwest::Client,
root: String,
}
impl OpenAIClient {
/// Creates a new `OpenAIClient`
///
/// # Errors
/// `OpenAIError::InvalidAPIKey` if the api token has invalid characters
pub fn new(token: &str) -> Result<Self> {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))?,
);
Ok(Self {
client: reqwest::Client::builder()
.default_headers(headers)
.build()
.expect("Client library error. Should have constructed a valid http client."),
root: "https://api.openai.com/v1".into(),
})
}
/// Private helper for making gets
async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
let url = &format!("{}/{}", self.root, endpoint);
let response = self
.client
.get(url)
.send()
.await
.expect("Client error. Should have passed a valid url");
if response.status() != 200 {
return Err(OpenAIError::APIError(
response
.json::<api::ErrorWrapper>()
.await
.expect("The API has returned something funky")
.error,
));
}
Ok(response.json::<T>().await.unwrap())
}
/// Lists the currently available engines.
///
/// Provides basic information about each one such as the owner and availability.
///
/// # Errors
/// - `OpenAIError::APIError` if the server returns an error
/// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most
/// likely a bug in this client, please report it)
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
self.get("engines").await.map(|r: api::Container<_>| r.data)
}
/// Retrieves an engine instance
/// Provides basic information about the engine such as the owner and availability.
///
/// # Errors
/// - `OpenAIError::APIError` if the server returns an error
/// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most
/// likely a bug in this client, please report it)
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
self.get(&format!("engines/{}", engine)).await
}
// Private helper to generate post requests. Needs to be a bit more flexible than
// get because it should support SSE eventually
async fn post<B: serde::ser::Serialize>(
&self,
endpoint: &str,
body: B,
) -> Result<reqwest::Response> {
let url = &format!("{}/{}", self.root, endpoint);
let response = self
.client
.post(url)
.json(&body)
.send()
.await
.expect("Client library error, json failed to parse");
if response.status() != 200 {
return Err(OpenAIError::APIError(
response
.json::<api::ErrorWrapper>()
.await
.expect("The API has returned something funky")
.error,
));
}
Ok(response)
}
/// Get predicted completion of the prompt
///
/// # Errors
/// - `OpenAIError::APIError` if the api returns an error
pub async fn complete(
&self,
engine: api::Engine,
prompt: impl Into<api::CompletionArgs>,
) -> Result<api::Completion> {
Ok(self
.post(&format!("engines/{}/completions", engine), prompt.into())
.await?
//.text()
.json()
.await
.expect("Client error. JSON didn't parse correctly."))
}
}
#[cfg(test)]
mod unit {
use crate::{api, OpenAIClient, OpenAIError};
fn mocked_client() -> OpenAIClient {
let mut client = OpenAIClient::new("bogus").unwrap();
client.root = mockito::server_url();
client
}
#[test]
fn can_create_client() {
let _c = mocked_client();
}
#[test]
fn parse_engine_info() -> Result<(), Box<dyn std::error::Error>> {
let example = r#"{
"id": "ada",
"object": "engine",
"owner": "openai",
"ready": true
}"#;
let ei: api::EngineInfo = serde_json::from_str(example)?;
assert_eq!(
ei,
api::EngineInfo {
id: api::Engine::Ada,
owner: "openai".into(),
ready: true,
}
);
Ok(())
}
#[tokio::test]
async fn parse_engines() -> crate::Result<()> {
use api::{Engine, EngineInfo};
let _m = mockito::mock("GET", "/engines")
.with_status(200)
.with_header("content-type", "text/json")
.with_body(
r#"{
"object": "list",
"data": [
{
"id": "ada",
"object": "engine",
"owner": "openai",
"ready": true
},
{
"id": "babbage",
"object": "engine",
"owner": "openai",
"ready": true
},
{
"id": "experimental-engine-v7",
"object": "engine",
"owner": "openai",
"ready": false
},
{
"id": "curie",
"object": "engine",
"owner": "openai",
"ready": true
},
{
"id": "davinci",
"object": "engine",
"owner": "openai",
"ready": true
},
{
"id": "content-filter-alpha-c4",
"object": "engine",
"owner": "openai",
"ready": true
}
]
}"#,
)
.create();
let expected = vec![
EngineInfo {
id: Engine::Ada,
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Babbage,
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Other,
owner: "openai".into(),
ready: false,
},
EngineInfo {
id: Engine::Curie,
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Davinci,
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::ContentFilter,
owner: "openai".into(),
ready: true,
},
];
let response = mocked_client().engines().await?;
assert_eq!(response, expected);
Ok(())
}
#[tokio::test]
async fn engine_error_response() -> crate::Result<()> {
let _m = mockito::mock("GET", "/engines/davinci")
.with_status(404)
.with_header("content-type", "text/json")
.with_body(
r#"{
"error": {
"code": null,
"message": "Some kind of error happened",
"type": "some_error_type"
}
}"#,
)
.create();
let expected = api::ErrorMessage {
message: "Some kind of error happened".into(),
error_type: "some_error_type".into(),
};
let response = mocked_client().engine(api::Engine::Davinci).await;
if let Result::Err(OpenAIError::APIError(msg)) = response {
assert_eq!(expected, msg);
}
Ok(())
}
}
#[cfg(test)]
mod integration {
use api::ErrorMessage;
use crate::{OpenAIClient, OpenAIError, api};
/// Used by tests to get a client to the actual api
fn get_client() -> OpenAIClient {
let sk = std::env::var("OPENAI_SK").expect(
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
);
OpenAIClient::new(&sk).expect("client build failed")
}
#[tokio::test]
async fn can_get_engines() {
let client = get_client();
client.engines().await.unwrap();
}
#[tokio::test]
async fn can_get_engine() {
let client = get_client();
let result = client.engine(api::Engine::Ada).await;
match result {
Err(OpenAIError::APIError(ErrorMessage{message, error_type})) => {
assert_eq!(message, "No engine with that ID: ada");
assert_eq!(error_type, "invalid_request_error");
}
_ => {panic!("Expected an error message, got {:?}", result)}
}
}
#[tokio::test]
async fn complete_string() -> crate::Result<()> {
let client = get_client();
client.complete(api::Engine::Ada, "Hey there").await?;
Ok(())
}
#[tokio::test]
async fn complete_explicit_params() -> crate::Result<()> {
let client = get_client();
let args = api::CompletionArgsBuilder::default()
.prompt("Once upon a time,")
.max_tokens(10)
.temperature(0.5)
.top_p(0.5)
.n(1)
.logprobs(3)
.echo(false)
.stop(vec!["\n".into()])
.presence_penalty(0.5)
.frequency_penalty(0.5)
.logit_bias(maplit::hashmap! {
"1".into() => 1.0,
"23".into() => 0.0,
})
.build()
.expect("Build should have succeeded");
client.complete(api::Engine::Ada, args).await?;
Ok(())
}
}