Make this a better library (#7)

Some additions to make this more useful as a library
- Move testing only dependencies into dev-dependencies
- The library itself now uses surf instead of reqwest, which is more flexible over runtimes
- Add some crate features to allow configuring which backend to use
- Add logging for requests and responses
This commit is contained in:
Josh Kuhn 2020-12-09 13:52:14 -08:00 committed by GitHub
parent 697acd0a0b
commit 94d633e6f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 61 deletions

View File

@ -10,17 +10,23 @@ repository = "https://github.com/deontologician/openai-api-rust/"
keywords = ["openai", "gpt3"] keywords = ["openai", "gpt3"]
categories = ["api-bindings", "asynchronous"] categories = ["api-bindings", "asynchronous"]
[features]
default = ["hyper"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html hyper = ["surf/hyper-client"]
curl = ["surf/curl-client"]
h1 = ["surf/h1-client"]
[dependencies] [dependencies]
reqwest = { version = "0.10.9", features = ["json"] } surf = { version = "^2.1.0", default-features = false }
thiserror = "^1.0.22" thiserror = "^1.0.22"
serde = { version = "^1.0.117", features = ["derive"] } serde = { version = "^1.0.117", features = ["derive"] }
tokio = { version = "^0.2.5", features = ["full"]} derive_builder = "^0.9.0"
serde_json = "^1.0" log = "^0.4.11"
derive_builder = "0.9.0"
[dev-dependencies] [dev-dependencies]
mockito = "0.28.0" mockito = "0.28.0"
maplit = "1.0.2" maplit = "1.0.2"
tokio = { version = "^0.2.5", features = ["full"]}
serde_json = "^1.0"
env_logger = "0.8.2"

View File

@ -1,8 +1,7 @@
///! OpenAI API client library ///! `OpenAI` API client library
#[macro_use] #[macro_use]
extern crate derive_builder; extern crate derive_builder;
use reqwest::header::HeaderMap;
use thiserror::Error; use thiserror::Error;
type Result<T> = std::result::Result<T, OpenAIError>; type Result<T> = std::result::Result<T, OpenAIError>;
@ -215,6 +214,9 @@ pub enum OpenAIError {
/// An error the client discovers before talking to the API /// An error the client discovers before talking to the API
#[error("Bad arguments")] #[error("Bad arguments")]
BadArguments(String), BadArguments(String),
/// Network / protocol related errors
#[error("Error at the protocol level")]
ProtocolError(surf::Error),
} }
impl From<api::ErrorMessage> for OpenAIError { impl From<api::ErrorMessage> for OpenAIError {
@ -229,49 +231,86 @@ impl From<String> for OpenAIError {
} }
} }
impl From<surf::Error> for OpenAIError {
fn from(e: surf::Error) -> Self {
OpenAIError::ProtocolError(e)
}
}
/// Client object. Must be constructed to talk to the API. /// Client object. Must be constructed to talk to the API.
pub struct OpenAIClient { pub struct OpenAIClient {
client: reqwest::Client, client: surf::Client,
root: String, }
/// Authentication middleware
struct BearerToken {
token: String,
}
impl std::fmt::Debug for BearerToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Get the first few characters to help debug, but not accidentally log key
write!(
f,
r#"Bearer {{ token: "{}" }}"#,
self.token.get(0..8).ok_or(std::fmt::Error)?
)
}
}
impl BearerToken {
fn new(token: &str) -> Self {
Self {
token: String::from(token),
}
}
}
#[surf::utils::async_trait]
impl surf::middleware::Middleware for BearerToken {
async fn handle(
&self,
mut req: surf::Request,
client: surf::Client,
next: surf::middleware::Next<'_>,
) -> surf::Result<surf::Response> {
log::debug!("Request: {:?}", req);
req.insert_header("Authorization", format!("Bearer {}", self.token));
let response = next.run(req, client).await?;
log::debug!("Response: {:?}", response);
Ok(response)
}
} }
impl OpenAIClient { impl OpenAIClient {
/// Creates a new `OpenAIClient` given an api token /// Creates a new `OpenAIClient` given an api token
#[must_use]
pub fn new(token: &str) -> Self { pub fn new(token: &str) -> Self {
let mut headers = HeaderMap::new(); let mut client = surf::client();
headers.insert( client.set_base_url(
reqwest::header::AUTHORIZATION, surf::Url::parse("https://api.openai.com/v1/").expect("Static string should parse"),
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
.expect("Client library error. Header value badly formatted"),
); );
Self { client = client.with(BearerToken::new(token));
client: reqwest::Client::builder() Self { client }
.default_headers(headers) }
.build()
.expect("Client library error. Should have constructed a valid http client."), /// Allow setting the api root in the tests
root: "https://api.openai.com/v1".into(), #[cfg(test)]
} fn set_api_root(&mut self, url: surf::Url) {
self.client.set_base_url(url);
} }
/// Private helper for making gets /// Private helper for making gets
async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> { async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
let url = &format!("{}/{}", self.root, endpoint); let mut response = self.client.get(endpoint).await?;
let response = self if let surf::StatusCode::Ok = response.status() {
.client Ok(response.body_json::<T>().await?)
.get(url) } else {
.send() {
.await let err = response.body_json::<api::ErrorWrapper>().await?.error;
.expect("Client error. Should have passed a valid url"); Err(OpenAIError::APIError(err))
if response.status() != 200 { }
return Err(OpenAIError::APIError(
response
.json::<api::ErrorWrapper>()
.await
.expect("The API has returned something funky")
.error,
));
} }
Ok(response.json::<T>().await.unwrap())
} }
/// Lists the currently available engines. /// Lists the currently available engines.
@ -295,29 +334,26 @@ impl OpenAIClient {
// Private helper to generate post requests. Needs to be a bit more flexible than // Private helper to generate post requests. Needs to be a bit more flexible than
// get because it should support SSE eventually // get because it should support SSE eventually
async fn post<B: serde::ser::Serialize>( async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
&self, where
endpoint: &str, B: serde::ser::Serialize,
body: B, R: serde::de::DeserializeOwned,
) -> Result<reqwest::Response> { {
let url = &format!("{}/{}", self.root, endpoint); let mut response = self
let response = self
.client .client
.post(url) .post(endpoint)
.json(&body) .body(surf::Body::from_json(&body)?)
.send() .await?;
.await match response.status() {
.expect("Client library error, json failed to parse"); surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
if response.status() != 200 { _ => Err(OpenAIError::APIError(
return Err(OpenAIError::APIError(
response response
.json::<api::ErrorWrapper>() .body_json::<api::ErrorWrapper>()
.await .await
.expect("The API has returned something funky") .expect("The API has returned something funky")
.error, .error,
)); )),
} }
Ok(response)
} }
/// Get predicted completion of the prompt /// Get predicted completion of the prompt
/// ///
@ -330,11 +366,7 @@ impl OpenAIClient {
let args = prompt.into(); let args = prompt.into();
Ok(self Ok(self
.post(&format!("engines/{}/completions", args.engine), args) .post(&format!("engines/{}/completions", args.engine), args)
.await? .await?)
//.text()
.json()
.await
.expect("Client error. JSON didn't parse correctly."))
} }
} }
@ -344,8 +376,11 @@ mod unit {
use crate::{api, OpenAIClient, OpenAIError}; use crate::{api, OpenAIClient, OpenAIError};
fn mocked_client() -> OpenAIClient { fn mocked_client() -> OpenAIClient {
let _ = env_logger::builder().is_test(true).try_init();
let mut client = OpenAIClient::new("bogus"); let mut client = OpenAIClient::new("bogus");
client.root = mockito::server_url(); client.set_api_root(
surf::Url::parse(&mockito::server_url()).expect("mockito url didn't parse"),
);
client client
} }
@ -493,9 +528,9 @@ mod integration {
use api::ErrorMessage; use api::ErrorMessage;
use crate::{api, OpenAIClient, OpenAIError}; use crate::{api, OpenAIClient, OpenAIError};
/// Used by tests to get a client to the actual api /// Used by tests to get a client to the actual api
fn get_client() -> OpenAIClient { fn get_client() -> OpenAIClient {
let _ = env_logger::builder().is_test(true).try_init();
let sk = std::env::var("OPENAI_SK").expect( let sk = std::env::var("OPENAI_SK").expect(
"To run integration tests, you must put set the OPENAI_SK env var to your api token", "To run integration tests, you must put set the OPENAI_SK env var to your api token",
); );