diff --git a/Cargo.toml b/Cargo.toml index 094debe..0c20ac4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,17 +10,23 @@ repository = "https://github.com/deontologician/openai-api-rust/" keywords = ["openai", "gpt3"] categories = ["api-bindings", "asynchronous"] +[features] +default = ["hyper"] -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +hyper = ["surf/hyper-client"] +curl = ["surf/curl-client"] +h1 = ["surf/h1-client"] [dependencies] -reqwest = { version = "0.10.9", features = ["json"] } +surf = { version = "^2.1.0", default-features = false } thiserror = "^1.0.22" serde = { version = "^1.0.117", features = ["derive"] } -tokio = { version = "^0.2.5", features = ["full"]} -serde_json = "^1.0" -derive_builder = "0.9.0" +derive_builder = "^0.9.0" +log = "^0.4.11" [dev-dependencies] mockito = "0.28.0" maplit = "1.0.2" +tokio = { version = "^0.2.5", features = ["full"]} +serde_json = "^1.0" +env_logger = "0.8.2" diff --git a/src/lib.rs b/src/lib.rs index cf0c843..690c48f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,7 @@ -///! OpenAI API client library +///! `OpenAI` API client library #[macro_use] extern crate derive_builder; -use reqwest::header::HeaderMap; use thiserror::Error; type Result = std::result::Result; @@ -215,6 +214,9 @@ pub enum OpenAIError { /// An error the client discovers before talking to the API #[error("Bad arguments")] BadArguments(String), + /// Network / protocol related errors + #[error("Error at the protocol level")] + ProtocolError(surf::Error), } impl From for OpenAIError { @@ -229,49 +231,86 @@ impl From for OpenAIError { } } +impl From for OpenAIError { + fn from(e: surf::Error) -> Self { + OpenAIError::ProtocolError(e) + } +} + /// Client object. Must be constructed to talk to the API. pub struct OpenAIClient { - client: reqwest::Client, - root: String, + client: surf::Client, +} + +/// Authentication middleware +struct BearerToken { + token: String, +} + +impl std::fmt::Debug for BearerToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Get the first few characters to help debug, but not accidentally log key + write!( + f, + r#"Bearer {{ token: "{}" }}"#, + self.token.get(0..8).ok_or(std::fmt::Error)? + ) + } +} + +impl BearerToken { + fn new(token: &str) -> Self { + Self { + token: String::from(token), + } + } +} + +#[surf::utils::async_trait] +impl surf::middleware::Middleware for BearerToken { + async fn handle( + &self, + mut req: surf::Request, + client: surf::Client, + next: surf::middleware::Next<'_>, + ) -> surf::Result { + log::debug!("Request: {:?}", req); + req.insert_header("Authorization", format!("Bearer {}", self.token)); + let response = next.run(req, client).await?; + log::debug!("Response: {:?}", response); + Ok(response) + } } impl OpenAIClient { /// Creates a new `OpenAIClient` given an api token + #[must_use] pub fn new(token: &str) -> Self { - let mut headers = HeaderMap::new(); - headers.insert( - reqwest::header::AUTHORIZATION, - reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token)) - .expect("Client library error. Header value badly formatted"), + let mut client = surf::client(); + client.set_base_url( + surf::Url::parse("https://api.openai.com/v1/").expect("Static string should parse"), ); - Self { - client: reqwest::Client::builder() - .default_headers(headers) - .build() - .expect("Client library error. Should have constructed a valid http client."), - root: "https://api.openai.com/v1".into(), - } + client = client.with(BearerToken::new(token)); + Self { client } + } + + /// Allow setting the api root in the tests + #[cfg(test)] + fn set_api_root(&mut self, url: surf::Url) { + self.client.set_base_url(url); } /// Private helper for making gets async fn get(&self, endpoint: &str) -> Result { - let url = &format!("{}/{}", self.root, endpoint); - let response = self - .client - .get(url) - .send() - .await - .expect("Client error. Should have passed a valid url"); - if response.status() != 200 { - return Err(OpenAIError::APIError( - response - .json::() - .await - .expect("The API has returned something funky") - .error, - )); + let mut response = self.client.get(endpoint).await?; + if let surf::StatusCode::Ok = response.status() { + Ok(response.body_json::().await?) + } else { + { + let err = response.body_json::().await?.error; + Err(OpenAIError::APIError(err)) + } } - Ok(response.json::().await.unwrap()) } /// Lists the currently available engines. @@ -295,29 +334,26 @@ impl OpenAIClient { // Private helper to generate post requests. Needs to be a bit more flexible than // get because it should support SSE eventually - async fn post( - &self, - endpoint: &str, - body: B, - ) -> Result { - let url = &format!("{}/{}", self.root, endpoint); - let response = self + async fn post(&self, endpoint: &str, body: B) -> Result + where + B: serde::ser::Serialize, + R: serde::de::DeserializeOwned, + { + let mut response = self .client - .post(url) - .json(&body) - .send() - .await - .expect("Client library error, json failed to parse"); - if response.status() != 200 { - return Err(OpenAIError::APIError( + .post(endpoint) + .body(surf::Body::from_json(&body)?) + .await?; + match response.status() { + surf::StatusCode::Ok => Ok(response.body_json::().await?), + _ => Err(OpenAIError::APIError( response - .json::() + .body_json::() .await .expect("The API has returned something funky") .error, - )); + )), } - Ok(response) } /// Get predicted completion of the prompt /// @@ -330,11 +366,7 @@ impl OpenAIClient { let args = prompt.into(); Ok(self .post(&format!("engines/{}/completions", args.engine), args) - .await? - //.text() - .json() - .await - .expect("Client error. JSON didn't parse correctly.")) + .await?) } } @@ -344,8 +376,11 @@ mod unit { use crate::{api, OpenAIClient, OpenAIError}; fn mocked_client() -> OpenAIClient { + let _ = env_logger::builder().is_test(true).try_init(); let mut client = OpenAIClient::new("bogus"); - client.root = mockito::server_url(); + client.set_api_root( + surf::Url::parse(&mockito::server_url()).expect("mockito url didn't parse"), + ); client } @@ -493,9 +528,9 @@ mod integration { use api::ErrorMessage; use crate::{api, OpenAIClient, OpenAIError}; - /// Used by tests to get a client to the actual api fn get_client() -> OpenAIClient { + let _ = env_logger::builder().is_test(true).try_init(); let sk = std::env::var("OPENAI_SK").expect( "To run integration tests, you must put set the OPENAI_SK env var to your api token", );