Make this a better library (#7)

Some additions to make this more useful as a library
- Move testing only dependencies into dev-dependencies
- The library itself now uses surf instead of reqwest, which is more flexible over runtimes
- Add some crate features to allow configuring which backend to use
- Add logging for requests and responses
This commit is contained in:
Josh Kuhn 2020-12-09 13:52:14 -08:00 committed by GitHub
parent 697acd0a0b
commit 94d633e6f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 61 deletions

View File

@ -10,17 +10,23 @@ repository = "https://github.com/deontologician/openai-api-rust/"
keywords = ["openai", "gpt3"]
categories = ["api-bindings", "asynchronous"]
[features]
default = ["hyper"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
hyper = ["surf/hyper-client"]
curl = ["surf/curl-client"]
h1 = ["surf/h1-client"]
[dependencies]
reqwest = { version = "0.10.9", features = ["json"] }
surf = { version = "^2.1.0", default-features = false }
thiserror = "^1.0.22"
serde = { version = "^1.0.117", features = ["derive"] }
tokio = { version = "^0.2.5", features = ["full"]}
serde_json = "^1.0"
derive_builder = "0.9.0"
derive_builder = "^0.9.0"
log = "^0.4.11"
[dev-dependencies]
mockito = "0.28.0"
maplit = "1.0.2"
tokio = { version = "^0.2.5", features = ["full"]}
serde_json = "^1.0"
env_logger = "0.8.2"

View File

@ -1,8 +1,7 @@
///! OpenAI API client library
///! `OpenAI` API client library
#[macro_use]
extern crate derive_builder;
use reqwest::header::HeaderMap;
use thiserror::Error;
type Result<T> = std::result::Result<T, OpenAIError>;
@ -215,6 +214,9 @@ pub enum OpenAIError {
/// An error the client discovers before talking to the API
#[error("Bad arguments")]
BadArguments(String),
/// Network / protocol related errors
#[error("Error at the protocol level")]
ProtocolError(surf::Error),
}
impl From<api::ErrorMessage> for OpenAIError {
@ -229,49 +231,86 @@ impl From<String> for OpenAIError {
}
}
impl From<surf::Error> for OpenAIError {
fn from(e: surf::Error) -> Self {
OpenAIError::ProtocolError(e)
}
}
/// Client object. Must be constructed to talk to the API.
pub struct OpenAIClient {
client: reqwest::Client,
root: String,
client: surf::Client,
}
/// Authentication middleware
struct BearerToken {
token: String,
}
impl std::fmt::Debug for BearerToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Get the first few characters to help debug, but not accidentally log key
write!(
f,
r#"Bearer {{ token: "{}" }}"#,
self.token.get(0..8).ok_or(std::fmt::Error)?
)
}
}
impl BearerToken {
fn new(token: &str) -> Self {
Self {
token: String::from(token),
}
}
}
#[surf::utils::async_trait]
impl surf::middleware::Middleware for BearerToken {
async fn handle(
&self,
mut req: surf::Request,
client: surf::Client,
next: surf::middleware::Next<'_>,
) -> surf::Result<surf::Response> {
log::debug!("Request: {:?}", req);
req.insert_header("Authorization", format!("Bearer {}", self.token));
let response = next.run(req, client).await?;
log::debug!("Response: {:?}", response);
Ok(response)
}
}
impl OpenAIClient {
/// Creates a new `OpenAIClient` given an api token
#[must_use]
pub fn new(token: &str) -> Self {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
.expect("Client library error. Header value badly formatted"),
let mut client = surf::client();
client.set_base_url(
surf::Url::parse("https://api.openai.com/v1/").expect("Static string should parse"),
);
Self {
client: reqwest::Client::builder()
.default_headers(headers)
.build()
.expect("Client library error. Should have constructed a valid http client."),
root: "https://api.openai.com/v1".into(),
}
client = client.with(BearerToken::new(token));
Self { client }
}
/// Allow setting the api root in the tests
#[cfg(test)]
fn set_api_root(&mut self, url: surf::Url) {
self.client.set_base_url(url);
}
/// Private helper for making gets
async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
let url = &format!("{}/{}", self.root, endpoint);
let response = self
.client
.get(url)
.send()
.await
.expect("Client error. Should have passed a valid url");
if response.status() != 200 {
return Err(OpenAIError::APIError(
response
.json::<api::ErrorWrapper>()
.await
.expect("The API has returned something funky")
.error,
));
let mut response = self.client.get(endpoint).await?;
if let surf::StatusCode::Ok = response.status() {
Ok(response.body_json::<T>().await?)
} else {
{
let err = response.body_json::<api::ErrorWrapper>().await?.error;
Err(OpenAIError::APIError(err))
}
}
Ok(response.json::<T>().await.unwrap())
}
/// Lists the currently available engines.
@ -295,29 +334,26 @@ impl OpenAIClient {
// Private helper to generate post requests. Needs to be a bit more flexible than
// get because it should support SSE eventually
async fn post<B: serde::ser::Serialize>(
&self,
endpoint: &str,
body: B,
) -> Result<reqwest::Response> {
let url = &format!("{}/{}", self.root, endpoint);
let response = self
async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
where
B: serde::ser::Serialize,
R: serde::de::DeserializeOwned,
{
let mut response = self
.client
.post(url)
.json(&body)
.send()
.await
.expect("Client library error, json failed to parse");
if response.status() != 200 {
return Err(OpenAIError::APIError(
.post(endpoint)
.body(surf::Body::from_json(&body)?)
.await?;
match response.status() {
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
_ => Err(OpenAIError::APIError(
response
.json::<api::ErrorWrapper>()
.body_json::<api::ErrorWrapper>()
.await
.expect("The API has returned something funky")
.error,
));
)),
}
Ok(response)
}
/// Get predicted completion of the prompt
///
@ -330,11 +366,7 @@ impl OpenAIClient {
let args = prompt.into();
Ok(self
.post(&format!("engines/{}/completions", args.engine), args)
.await?
//.text()
.json()
.await
.expect("Client error. JSON didn't parse correctly."))
.await?)
}
}
@ -344,8 +376,11 @@ mod unit {
use crate::{api, OpenAIClient, OpenAIError};
fn mocked_client() -> OpenAIClient {
let _ = env_logger::builder().is_test(true).try_init();
let mut client = OpenAIClient::new("bogus");
client.root = mockito::server_url();
client.set_api_root(
surf::Url::parse(&mockito::server_url()).expect("mockito url didn't parse"),
);
client
}
@ -493,9 +528,9 @@ mod integration {
use api::ErrorMessage;
use crate::{api, OpenAIClient, OpenAIError};
/// Used by tests to get a client to the actual api
fn get_client() -> OpenAIClient {
let _ = env_logger::builder().is_test(true).try_init();
let sk = std::env::var("OPENAI_SK").expect(
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
);