Make this a better library (#7)
Some additions to make this more useful as a library - Move testing only dependencies into dev-dependencies - The library itself now uses surf instead of reqwest, which is more flexible over runtimes - Add some crate features to allow configuring which backend to use - Add logging for requests and responses
This commit is contained in:
parent
697acd0a0b
commit
94d633e6f2
16
Cargo.toml
16
Cargo.toml
@ -10,17 +10,23 @@ repository = "https://github.com/deontologician/openai-api-rust/"
|
||||
keywords = ["openai", "gpt3"]
|
||||
categories = ["api-bindings", "asynchronous"]
|
||||
|
||||
[features]
|
||||
default = ["hyper"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
hyper = ["surf/hyper-client"]
|
||||
curl = ["surf/curl-client"]
|
||||
h1 = ["surf/h1-client"]
|
||||
|
||||
[dependencies]
|
||||
reqwest = { version = "0.10.9", features = ["json"] }
|
||||
surf = { version = "^2.1.0", default-features = false }
|
||||
thiserror = "^1.0.22"
|
||||
serde = { version = "^1.0.117", features = ["derive"] }
|
||||
tokio = { version = "^0.2.5", features = ["full"]}
|
||||
serde_json = "^1.0"
|
||||
derive_builder = "0.9.0"
|
||||
derive_builder = "^0.9.0"
|
||||
log = "^0.4.11"
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "0.28.0"
|
||||
maplit = "1.0.2"
|
||||
tokio = { version = "^0.2.5", features = ["full"]}
|
||||
serde_json = "^1.0"
|
||||
env_logger = "0.8.2"
|
||||
|
147
src/lib.rs
147
src/lib.rs
@ -1,8 +1,7 @@
|
||||
///! OpenAI API client library
|
||||
///! `OpenAI` API client library
|
||||
#[macro_use]
|
||||
extern crate derive_builder;
|
||||
|
||||
use reqwest::header::HeaderMap;
|
||||
use thiserror::Error;
|
||||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
@ -215,6 +214,9 @@ pub enum OpenAIError {
|
||||
/// An error the client discovers before talking to the API
|
||||
#[error("Bad arguments")]
|
||||
BadArguments(String),
|
||||
/// Network / protocol related errors
|
||||
#[error("Error at the protocol level")]
|
||||
ProtocolError(surf::Error),
|
||||
}
|
||||
|
||||
impl From<api::ErrorMessage> for OpenAIError {
|
||||
@ -229,49 +231,86 @@ impl From<String> for OpenAIError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<surf::Error> for OpenAIError {
|
||||
fn from(e: surf::Error) -> Self {
|
||||
OpenAIError::ProtocolError(e)
|
||||
}
|
||||
}
|
||||
|
||||
/// Client object. Must be constructed to talk to the API.
|
||||
pub struct OpenAIClient {
|
||||
client: reqwest::Client,
|
||||
root: String,
|
||||
client: surf::Client,
|
||||
}
|
||||
|
||||
/// Authentication middleware
|
||||
struct BearerToken {
|
||||
token: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BearerToken {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Get the first few characters to help debug, but not accidentally log key
|
||||
write!(
|
||||
f,
|
||||
r#"Bearer {{ token: "{}" }}"#,
|
||||
self.token.get(0..8).ok_or(std::fmt::Error)?
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl BearerToken {
|
||||
fn new(token: &str) -> Self {
|
||||
Self {
|
||||
token: String::from(token),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[surf::utils::async_trait]
|
||||
impl surf::middleware::Middleware for BearerToken {
|
||||
async fn handle(
|
||||
&self,
|
||||
mut req: surf::Request,
|
||||
client: surf::Client,
|
||||
next: surf::middleware::Next<'_>,
|
||||
) -> surf::Result<surf::Response> {
|
||||
log::debug!("Request: {:?}", req);
|
||||
req.insert_header("Authorization", format!("Bearer {}", self.token));
|
||||
let response = next.run(req, client).await?;
|
||||
log::debug!("Response: {:?}", response);
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
/// Creates a new `OpenAIClient` given an api token
|
||||
#[must_use]
|
||||
pub fn new(token: &str) -> Self {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
|
||||
.expect("Client library error. Header value badly formatted"),
|
||||
let mut client = surf::client();
|
||||
client.set_base_url(
|
||||
surf::Url::parse("https://api.openai.com/v1/").expect("Static string should parse"),
|
||||
);
|
||||
Self {
|
||||
client: reqwest::Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.expect("Client library error. Should have constructed a valid http client."),
|
||||
root: "https://api.openai.com/v1".into(),
|
||||
}
|
||||
client = client.with(BearerToken::new(token));
|
||||
Self { client }
|
||||
}
|
||||
|
||||
/// Allow setting the api root in the tests
|
||||
#[cfg(test)]
|
||||
fn set_api_root(&mut self, url: surf::Url) {
|
||||
self.client.set_base_url(url);
|
||||
}
|
||||
|
||||
/// Private helper for making gets
|
||||
async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
|
||||
let url = &format!("{}/{}", self.root, endpoint);
|
||||
let response = self
|
||||
.client
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.expect("Client error. Should have passed a valid url");
|
||||
if response.status() != 200 {
|
||||
return Err(OpenAIError::APIError(
|
||||
response
|
||||
.json::<api::ErrorWrapper>()
|
||||
.await
|
||||
.expect("The API has returned something funky")
|
||||
.error,
|
||||
));
|
||||
let mut response = self.client.get(endpoint).await?;
|
||||
if let surf::StatusCode::Ok = response.status() {
|
||||
Ok(response.body_json::<T>().await?)
|
||||
} else {
|
||||
{
|
||||
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
||||
Err(OpenAIError::APIError(err))
|
||||
}
|
||||
}
|
||||
Ok(response.json::<T>().await.unwrap())
|
||||
}
|
||||
|
||||
/// Lists the currently available engines.
|
||||
@ -295,29 +334,26 @@ impl OpenAIClient {
|
||||
|
||||
// Private helper to generate post requests. Needs to be a bit more flexible than
|
||||
// get because it should support SSE eventually
|
||||
async fn post<B: serde::ser::Serialize>(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
body: B,
|
||||
) -> Result<reqwest::Response> {
|
||||
let url = &format!("{}/{}", self.root, endpoint);
|
||||
let response = self
|
||||
async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
|
||||
where
|
||||
B: serde::ser::Serialize,
|
||||
R: serde::de::DeserializeOwned,
|
||||
{
|
||||
let mut response = self
|
||||
.client
|
||||
.post(url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.expect("Client library error, json failed to parse");
|
||||
if response.status() != 200 {
|
||||
return Err(OpenAIError::APIError(
|
||||
.post(endpoint)
|
||||
.body(surf::Body::from_json(&body)?)
|
||||
.await?;
|
||||
match response.status() {
|
||||
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
||||
_ => Err(OpenAIError::APIError(
|
||||
response
|
||||
.json::<api::ErrorWrapper>()
|
||||
.body_json::<api::ErrorWrapper>()
|
||||
.await
|
||||
.expect("The API has returned something funky")
|
||||
.error,
|
||||
));
|
||||
)),
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
/// Get predicted completion of the prompt
|
||||
///
|
||||
@ -330,11 +366,7 @@ impl OpenAIClient {
|
||||
let args = prompt.into();
|
||||
Ok(self
|
||||
.post(&format!("engines/{}/completions", args.engine), args)
|
||||
.await?
|
||||
//.text()
|
||||
.json()
|
||||
.await
|
||||
.expect("Client error. JSON didn't parse correctly."))
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
|
||||
@ -344,8 +376,11 @@ mod unit {
|
||||
use crate::{api, OpenAIClient, OpenAIError};
|
||||
|
||||
fn mocked_client() -> OpenAIClient {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let mut client = OpenAIClient::new("bogus");
|
||||
client.root = mockito::server_url();
|
||||
client.set_api_root(
|
||||
surf::Url::parse(&mockito::server_url()).expect("mockito url didn't parse"),
|
||||
);
|
||||
client
|
||||
}
|
||||
|
||||
@ -493,9 +528,9 @@ mod integration {
|
||||
use api::ErrorMessage;
|
||||
|
||||
use crate::{api, OpenAIClient, OpenAIError};
|
||||
|
||||
/// Used by tests to get a client to the actual api
|
||||
fn get_client() -> OpenAIClient {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let sk = std::env::var("OPENAI_SK").expect(
|
||||
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
|
||||
);
|
||||
|
Loading…
x
Reference in New Issue
Block a user