Add synchronous client feature (#8)

This commit is contained in:
Josh Kuhn 2020-12-17 22:39:52 -08:00 committed by GitHub
parent 8a2f6400b4
commit 9ded1062d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 422 additions and 153 deletions

View File

@ -20,6 +20,14 @@ jobs:
- uses: actions-rs/cargo@v1 - uses: actions-rs/cargo@v1
with: with:
command: check command: check
- uses: actions-rs/cargo@v1
with:
command: check
args: --features=sync --no-default-features
- uses: actions-rs/cargo@v1
with:
command: check
args: --features=async --no-default-features
test: test:
name: Test Suite name: Test Suite

View File

@ -11,22 +11,25 @@ keywords = ["openai", "gpt3"]
categories = ["api-bindings", "asynchronous"] categories = ["api-bindings", "asynchronous"]
[features] [features]
default = ["hyper"] default = ["sync", "async"]
sync = ["ureq"]
hyper = ["surf/hyper-client"] async = ["surf/hyper-client"]
curl = ["surf/curl-client"]
h1 = ["surf/h1-client"]
[dependencies] [dependencies]
surf = { version = "^2.1.0", default-features = false }
thiserror = "^1.0.22" thiserror = "^1.0.22"
serde = { version = "^1.0.117", features = ["derive"] } serde = { version = "^1.0.117", features = ["derive"] }
derive_builder = "^0.9.0" derive_builder = "^0.9.0"
log = "^0.4.11" log = "^0.4.11"
# Used by sync feature
ureq = { version = "^1.5.4", optional=true, features = ["json", "tls"] }
serde_json = { version="^1.0"}
# Used by async feature
surf = { version = "^2.1.0", optional=true, default-features=false}
[dev-dependencies] [dev-dependencies]
mockito = "0.28.0" mockito = "0.28.0"
maplit = "1.0.2" maplit = "1.0.2"
tokio = { version = "^0.2.5", features = ["full"]} tokio = { version = "^0.2.5", features = ["full"]}
serde_json = "^1.0"
env_logger = "0.8.2" env_logger = "0.8.2"
serde_json = "^1.0"

View File

@ -1,6 +1,6 @@
use openai_api::{ use openai_api::{
api::{CompletionArgs, Engine}, api::{CompletionArgs, Engine},
OpenAIClient, Client,
}; };
const START_PROMPT: &str = " const START_PROMPT: &str = "
@ -11,7 +11,7 @@ AI: I am an AI. How can I help you today?";
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap(); let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token); let client = Client::new(&api_token);
let mut context = String::from(START_PROMPT); let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder(); let mut args = CompletionArgs::builder();
args.engine(Engine::Davinci) args.engine(Engine::Davinci)

View File

@ -1,14 +1,14 @@
///! Example that prints out a story from a prompt. Used in the readme. ///! Example that prints out a story from a prompt. Used in the readme.
use openai_api::OpenAIClient; use openai_api::Client;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap(); let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token); let client = Client::new(&api_token);
let prompt = String::from("Once upon a time,"); let prompt = String::from("Once upon a time,");
println!( println!(
"{}{}", "{}{}",
prompt, prompt,
client.complete(prompt.as_str()).await.unwrap() client.complete_prompt(prompt.as_str()).await.unwrap()
); );
} }

View File

@ -4,14 +4,14 @@ extern crate derive_builder;
use thiserror::Error; use thiserror::Error;
type Result<T> = std::result::Result<T, OpenAIError>; type Result<T> = std::result::Result<T, Error>;
#[allow(clippy::default_trait_access)] #[allow(clippy::default_trait_access)]
pub mod api { pub mod api {
//! Data types corresponding to requests and responses from the API //! Data types corresponding to requests and responses from the API
use std::collections::HashMap; use std::{collections::HashMap, fmt::Display};
use super::OpenAIClient; use super::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Container type. Used in the api, but not useful for clients of this library /// Container type. Used in the api, but not useful for clients of this library
@ -115,9 +115,15 @@ pub mod api {
/// Request a completion from the api /// Request a completion from the api
/// ///
/// # Errors /// # Errors
/// `OpenAIError::APIError` if the api returns an error /// `Error::APIError` if the api returns an error
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> { #[cfg(feature = "async")]
client.complete(self.clone()).await pub async fn complete_prompt(self, client: &Client) -> super::Result<Completion> {
client.complete_prompt(self).await
}
#[cfg(feature = "sync")]
pub fn complete_prompt_sync(self, client: &Client) -> super::Result<Completion> {
client.complete_prompt_sync(self)
} }
} }
@ -125,10 +131,16 @@ pub mod api {
/// Request a completion from the api /// Request a completion from the api
/// ///
/// # Errors /// # Errors
/// `OpenAIError::BadArguments` if the arguments to complete are not valid /// `Error::BadArguments` if the arguments to complete are not valid
/// `OpenAIError::APIError` if the api returns an error /// `Error::APIError` if the api returns an error
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> { #[cfg(feature = "async")]
client.complete(self.build()?).await pub async fn complete_prompt(&self, client: &Client) -> super::Result<Completion> {
client.complete_prompt(self.build()?).await
}
#[cfg(feature = "sync")]
pub fn complete_prompt_sync(&self, client: &Client) -> super::Result<Completion> {
client.complete_prompt_sync(self.build()?)
} }
} }
@ -199,55 +211,70 @@ pub mod api {
pub error_type: String, pub error_type: String,
} }
impl Display for ErrorMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.message.fmt(f)
}
}
/// API-level wrapper used in deserialization /// API-level wrapper used in deserialization
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub(super) struct ErrorWrapper { pub(crate) struct ErrorWrapper {
pub error: ErrorMessage, pub error: ErrorMessage,
} }
} }
/// This library's main `Error` type. /// This library's main `Error` type.
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum OpenAIError { pub enum Error {
/// An error returned by the API itself /// An error returned by the API itself
#[error("API Returned an Error document")] #[error("API returned an Error: {}", .0.message)]
APIError(api::ErrorMessage), APIError(api::ErrorMessage),
/// An error the client discovers before talking to the API /// An error the client discovers before talking to the API
#[error("Bad arguments")] #[error("Bad arguments: {0}")]
BadArguments(String), BadArguments(String),
/// Network / protocol related errors /// Network / protocol related errors
#[error("Error at the protocol level")] #[cfg(feature = "async")]
ProtocolError(surf::Error), #[error("Error at the protocol level: {0}")]
AsyncProtocolError(surf::Error),
#[cfg(feature = "sync")]
#[error("Error at the protocol level, sync client")]
SyncProtocolError(ureq::Error),
} }
impl From<api::ErrorMessage> for OpenAIError { impl From<api::ErrorMessage> for Error {
fn from(e: api::ErrorMessage) -> Self { fn from(e: api::ErrorMessage) -> Self {
OpenAIError::APIError(e) Error::APIError(e)
} }
} }
impl From<String> for OpenAIError { impl From<String> for Error {
fn from(e: String) -> Self { fn from(e: String) -> Self {
OpenAIError::BadArguments(e) Error::BadArguments(e)
} }
} }
impl From<surf::Error> for OpenAIError { #[cfg(feature = "async")]
impl From<surf::Error> for Error {
fn from(e: surf::Error) -> Self { fn from(e: surf::Error) -> Self {
OpenAIError::ProtocolError(e) Error::AsyncProtocolError(e)
} }
} }
/// Client object. Must be constructed to talk to the API. #[cfg(feature = "sync")]
pub struct OpenAIClient { impl From<ureq::Error> for Error {
client: surf::Client, fn from(e: ureq::Error) -> Self {
Error::SyncProtocolError(e)
}
} }
/// Authentication middleware /// Authentication middleware
#[cfg(feature = "async")]
struct BearerToken { struct BearerToken {
token: String, token: String,
} }
#[cfg(feature = "async")]
impl std::fmt::Debug for BearerToken { impl std::fmt::Debug for BearerToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Get the first few characters to help debug, but not accidentally log key // Get the first few characters to help debug, but not accidentally log key
@ -259,6 +286,7 @@ impl std::fmt::Debug for BearerToken {
} }
} }
#[cfg(feature = "async")]
impl BearerToken { impl BearerToken {
fn new(token: &str) -> Self { fn new(token: &str) -> Self {
Self { Self {
@ -267,6 +295,7 @@ impl BearerToken {
} }
} }
#[cfg(feature = "async")]
#[surf::utils::async_trait] #[surf::utils::async_trait]
impl surf::middleware::Middleware for BearerToken { impl surf::middleware::Middleware for BearerToken {
async fn handle( async fn handle(
@ -277,40 +306,99 @@ impl surf::middleware::Middleware for BearerToken {
) -> surf::Result<surf::Response> { ) -> surf::Result<surf::Response> {
log::debug!("Request: {:?}", req); log::debug!("Request: {:?}", req);
req.insert_header("Authorization", format!("Bearer {}", self.token)); req.insert_header("Authorization", format!("Bearer {}", self.token));
let response = next.run(req, client).await?; let response: surf::Response = next.run(req, client).await?;
log::debug!("Response: {:?}", response); log::debug!("Response: {:?}", response);
Ok(response) Ok(response)
} }
} }
impl OpenAIClient { #[cfg(feature = "async")]
/// Creates a new `OpenAIClient` given an api token fn async_client(token: &str, base_url: &str) -> surf::Client {
let mut async_client = surf::client();
async_client.set_base_url(surf::Url::parse(base_url).expect("Static string should parse"));
async_client.with(BearerToken::new(token))
}
#[cfg(feature = "sync")]
fn sync_client(token: &str) -> ureq::Agent {
ureq::agent().auth_kind("Bearer", token).build()
}
/// Client object. Must be constructed to talk to the API.
pub struct Client {
#[cfg(feature = "async")]
async_client: surf::Client,
#[cfg(feature = "sync")]
sync_client: ureq::Agent,
#[cfg(feature = "sync")]
base_url: String,
}
impl Client {
// Creates a new `Client` given an api token
#[must_use] #[must_use]
pub fn new(token: &str) -> Self { pub fn new(token: &str) -> Self {
let mut client = surf::client(); let base_url = String::from("https://api.openai.com/v1/");
client.set_base_url( Self {
surf::Url::parse("https://api.openai.com/v1/").expect("Static string should parse"), #[cfg(feature = "async")]
); async_client: async_client(token, &base_url),
client = client.with(BearerToken::new(token)); #[cfg(feature = "sync")]
Self { client } sync_client: sync_client(token),
#[cfg(feature = "sync")]
base_url,
}
} }
/// Allow setting the api root in the tests // Allow setting the api root in the tests
#[cfg(test)] #[cfg(test)]
fn set_api_root(&mut self, url: surf::Url) { fn set_api_root(mut self, base_url: &str) -> Self {
self.client.set_base_url(url); #[cfg(feature = "async")]
{
self.async_client.set_base_url(
surf::Url::parse(base_url).expect("static URL expected to parse correctly"),
);
}
#[cfg(feature = "sync")]
{
self.base_url = String::from(base_url);
}
self
} }
/// Private helper for making gets /// Private helper for making gets
async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> { #[cfg(feature = "async")]
let mut response = self.client.get(endpoint).await?; async fn get<T>(&self, endpoint: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let mut response = self.async_client.get(endpoint).await?;
if let surf::StatusCode::Ok = response.status() { if let surf::StatusCode::Ok = response.status() {
Ok(response.body_json::<T>().await?) Ok(response.body_json::<T>().await?)
} else { } else {
{ let err = response.body_json::<api::ErrorWrapper>().await?.error;
let err = response.body_json::<api::ErrorWrapper>().await?.error; Err(Error::APIError(err))
Err(OpenAIError::APIError(err)) }
} }
#[cfg(feature = "sync")]
fn get_sync<T>(&self, endpoint: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let response = dbg!(self
.sync_client
.get(&format!("{}{}", self.base_url, endpoint)))
.call();
if let 200 = response.status() {
Ok(response
.into_json_deserialize()
.expect("Bug: client couldn't deserialize api response"))
} else {
let err = response
.into_json_deserialize::<api::ErrorWrapper>()
.expect("Bug: client couldn't deserialize api error response")
.error;
Err(Error::APIError(err))
} }
} }
@ -319,35 +407,55 @@ impl OpenAIClient {
/// Provides basic information about each one such as the owner and availability. /// Provides basic information about each one such as the owner and availability.
/// ///
/// # Errors /// # Errors
/// - `OpenAIError::APIError` if the server returns an error /// - `Error::APIError` if the server returns an error
#[cfg(feature = "async")]
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> { pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
self.get("engines").await.map(|r: api::Container<_>| r.data) self.get("engines").await.map(|r: api::Container<_>| r.data)
} }
/// Lists the currently available engines.
///
/// Provides basic information about each one such as the owner and availability.
///
/// # Errors
/// - `Error::APIError` if the server returns an error
#[cfg(feature = "sync")]
pub fn engines_sync(&self) -> Result<Vec<api::EngineInfo>> {
self.get_sync("engines").map(|r: api::Container<_>| r.data)
}
/// Retrieves an engine instance /// Retrieves an engine instance
///
/// Provides basic information about the engine such as the owner and availability. /// Provides basic information about the engine such as the owner and availability.
/// ///
/// # Errors /// # Errors
/// - `OpenAIError::APIError` if the server returns an error /// - `Error::APIError` if the server returns an error
#[cfg(feature = "async")]
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> { pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
self.get(&format!("engines/{}", engine)).await self.get(&format!("engines/{}", engine)).await
} }
#[cfg(feature = "sync")]
pub fn engine_sync(&self, engine: api::Engine) -> Result<api::EngineInfo> {
self.get_sync(&format!("engines/{}", engine))
}
// Private helper to generate post requests. Needs to be a bit more flexible than // Private helper to generate post requests. Needs to be a bit more flexible than
// get because it should support SSE eventually // get because it should support SSE eventually
#[cfg(feature = "async")]
async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R> async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
where where
B: serde::ser::Serialize, B: serde::ser::Serialize,
R: serde::de::DeserializeOwned, R: serde::de::DeserializeOwned,
{ {
let mut response = self let mut response = self
.client .async_client
.post(endpoint) .post(endpoint)
.body(surf::Body::from_json(&body)?) .body(surf::Body::from_json(&body)?)
.await?; .await?;
match response.status() { match response.status() {
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?), surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
_ => Err(OpenAIError::APIError( _ => Err(Error::APIError(
response response
.body_json::<api::ErrorWrapper>() .body_json::<api::ErrorWrapper>()
.await .await
@ -356,11 +464,38 @@ impl OpenAIClient {
)), )),
} }
} }
#[cfg(feature = "sync")]
fn post_sync<B, R>(&self, endpoint: &str, body: B) -> Result<R>
where
B: serde::ser::Serialize,
R: serde::de::DeserializeOwned,
{
let response = self
.sync_client
.post(&format!("{}{}", self.base_url, endpoint))
.send_json(
serde_json::to_value(body).expect("Bug: client couldn't serialize its own type"),
);
match response.status() {
200 => Ok(response
.into_json_deserialize()
.expect("Bug: client couldn't deserialize api response")),
_ => Err(Error::APIError(
response
.into_json_deserialize::<api::ErrorWrapper>()
.expect("Bug: client couldn't deserialize api error response")
.error,
)),
}
}
/// Get predicted completion of the prompt /// Get predicted completion of the prompt
/// ///
/// # Errors /// # Errors
/// - `OpenAIError::APIError` if the api returns an error /// - `Error::APIError` if the api returns an error
pub async fn complete( #[cfg(feature = "async")]
pub async fn complete_prompt(
&self, &self,
prompt: impl Into<api::CompletionArgs>, prompt: impl Into<api::CompletionArgs>,
) -> Result<api::Completion> { ) -> Result<api::Completion> {
@ -369,20 +504,60 @@ impl OpenAIClient {
.post(&format!("engines/{}/completions", args.engine), args) .post(&format!("engines/{}/completions", args.engine), args)
.await?) .await?)
} }
/// Get predicted completion of the prompt synchronously
///
/// # Error
/// - `Error::APIError` if the api returns an error
#[cfg(feature = "sync")]
pub fn complete_prompt_sync(
&self,
prompt: impl Into<api::CompletionArgs>,
) -> Result<api::Completion> {
let args = prompt.into();
self.post_sync(&format!("engines/{}/completions", args.engine), args)
}
}
// TODO: add a macro to de-boilerplate the sync and async tests
#[allow(unused_macros)]
macro_rules! async_test {
($test_name: ident, $test_body: block) => {
#[cfg(feature = "async")]
#[tokio::test]
async fn $test_name() -> crate::Result<()> {
$test_body;
Ok(())
}
};
}
#[allow(unused_macros)]
macro_rules! sync_test {
($test_name: ident, $test_body: expr) => {
#[cfg(feature = "sync")]
#[test]
fn $test_name() -> crate::Result<()> {
$test_body;
Ok(())
}
};
} }
#[cfg(test)] #[cfg(test)]
mod unit { mod unit {
use crate::{api, OpenAIClient, OpenAIError}; use mockito::Mock;
fn mocked_client() -> OpenAIClient { use crate::{
api::{self, Completion, CompletionArgs, Engine, EngineInfo},
Client, Error,
};
fn mocked_client() -> Client {
let _ = env_logger::builder().is_test(true).try_init(); let _ = env_logger::builder().is_test(true).try_init();
let mut client = OpenAIClient::new("bogus"); Client::new("bogus").set_api_root(&format!("{}/", mockito::server_url()))
client.set_api_root(
surf::Url::parse(&mockito::server_url()).expect("mockito url didn't parse"),
);
client
} }
#[test] #[test]
@ -410,10 +585,8 @@ mod unit {
Ok(()) Ok(())
} }
#[tokio::test] fn mock_engines() -> (Mock, Vec<EngineInfo>) {
async fn parse_engines() -> crate::Result<()> { let mock = mockito::mock("GET", "/engines")
use api::{Engine, EngineInfo};
let _m = mockito::mock("GET", "/engines")
.with_status(200) .with_status(200)
.with_header("content-type", "application/json") .with_header("content-type", "application/json")
.with_body( .with_body(
@ -460,6 +633,7 @@ mod unit {
}"#, }"#,
) )
.create(); .create();
let expected = vec![ let expected = vec![
EngineInfo { EngineInfo {
id: Engine::Ada, id: Engine::Ada,
@ -492,40 +666,59 @@ mod unit {
ready: true, ready: true,
}, },
]; ];
let response = mocked_client().engines().await?; (mock, expected)
assert_eq!(response, expected);
Ok(())
} }
#[tokio::test] async_test!(parse_engines_async, {
async fn engine_error_response() -> crate::Result<()> { let (_m, expected) = mock_engines();
let _m = mockito::mock("GET", "/engines/davinci") let response = mocked_client().engines().await?;
assert_eq!(response, expected);
});
sync_test!(parse_engines_sync, {
let (_m, expected) = mock_engines();
let response = mocked_client().engines_sync()?;
assert_eq!(response, expected);
});
fn mock_engine() -> (Mock, api::ErrorMessage) {
let mock = mockito::mock("GET", "/engines/davinci")
.with_status(404) .with_status(404)
.with_header("content-type", "application/json") .with_header("content-type", "application/json")
.with_body( .with_body(
r#"{ r#"{
"error": { "error": {
"code": null, "code": null,
"message": "Some kind of error happened", "message": "Some kind of error happened",
"type": "some_error_type" "type": "some_error_type"
} }
}"#, }"#,
) )
.create(); .create();
let expected = api::ErrorMessage { let expected = api::ErrorMessage {
message: "Some kind of error happened".into(), message: "Some kind of error happened".into(),
error_type: "some_error_type".into(), error_type: "some_error_type".into(),
}; };
let response = mocked_client().engine(api::Engine::Davinci).await; (mock, expected)
if let Result::Err(OpenAIError::APIError(msg)) = response {
assert_eq!(expected, msg);
}
Ok(())
} }
#[tokio::test] async_test!(engine_error_response_async, {
async fn completion_args() -> crate::Result<()> { let (_m, expected) = mock_engine();
let _m = mockito::mock("POST", "/engines/davinci/completions") let response = mocked_client().engine(api::Engine::Davinci).await;
if let Result::Err(Error::APIError(msg)) = response {
assert_eq!(expected, msg);
}
});
sync_test!(engine_error_response_sync, {
let (_m, expected) = mock_engine();
let response = mocked_client().engine_sync(api::Engine::Davinci);
if let Result::Err(Error::APIError(msg)) = response {
assert_eq!(expected, msg);
}
});
fn mock_completion() -> crate::Result<(Mock, CompletionArgs, Completion)> {
let mock = mockito::mock("POST", "/engines/davinci/completions")
.with_status(200) .with_status(200)
.with_header("content-type", "application/json") .with_header("content-type", "application/json")
.with_body( .with_body(
@ -544,7 +737,17 @@ mod unit {
] ]
}"#, }"#,
) )
.expect(1)
.create(); .create();
let args = api::CompletionArgs::builder()
.engine(api::Engine::Davinci)
.prompt("Once upon a time")
.max_tokens(5)
.temperature(1.0)
.top_p(1.0)
.n(1)
.stop(vec!["\n".into()])
.build()?;
let expected = api::Completion { let expected = api::Completion {
id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(), id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
created: 1589478378, created: 1589478378,
@ -556,54 +759,75 @@ mod unit {
finish_reason: api::FinishReason::MaxTokensReached, finish_reason: api::FinishReason::MaxTokensReached,
}], }],
}; };
let args = api::CompletionArgs::builder() Ok((mock, args, expected))
.engine(api::Engine::Davinci)
.prompt("Once upon a time")
.max_tokens(5)
.temperature(1.0)
.top_p(1.0)
.n(1)
.stop(vec!["\n".into()])
.build()?;
let response = mocked_client().complete(args).await?;
assert_eq!(response.model, expected.model);
assert_eq!(response.id, expected.id);
assert_eq!(response.created, expected.created);
let (resp_choice, expected_choice) = (&response.choices[0], &expected.choices[0]);
assert_eq!(resp_choice.text, expected_choice.text);
assert_eq!(resp_choice.index, expected_choice.index);
assert!(resp_choice.logprobs.is_none());
assert_eq!(resp_choice.finish_reason, expected_choice.finish_reason);
Ok(())
} }
// Defines boilerplate here. The Completion can't derive Eq since it contains
// floats in various places.
fn assert_completion_equal(a: Completion, b: Completion) {
assert_eq!(a.model, b.model);
assert_eq!(a.id, b.id);
assert_eq!(a.created, b.created);
let (a_choice, b_choice) = (&a.choices[0], &b.choices[0]);
assert_eq!(a_choice.text, b_choice.text);
assert_eq!(a_choice.index, b_choice.index);
assert!(a_choice.logprobs.is_none());
assert_eq!(a_choice.finish_reason, b_choice.finish_reason);
}
async_test!(completion_args_async, {
let (m, args, expected) = mock_completion()?;
let response = mocked_client().complete_prompt(args).await?;
assert_completion_equal(response, expected);
m.assert();
});
sync_test!(completion_args_sync, {
let (m, args, expected) = mock_completion()?;
let response = mocked_client().complete_prompt_sync(args)?;
assert_completion_equal(response, expected);
m.assert();
});
} }
#[cfg(test)] #[cfg(test)]
mod integration { mod integration {
use crate::{
use api::ErrorMessage; api::{self, Completion},
Client, Error,
use crate::{api, OpenAIClient, OpenAIError}; };
/// Used by tests to get a client to the actual api /// Used by tests to get a client to the actual api
fn get_client() -> OpenAIClient { fn get_client() -> Client {
let _ = env_logger::builder().is_test(true).try_init(); let _ = env_logger::builder().is_test(true).try_init();
let sk = std::env::var("OPENAI_SK").expect( let sk = std::env::var("OPENAI_SK").expect(
"To run integration tests, you must put set the OPENAI_SK env var to your api token", "To run integration tests, you must put set the OPENAI_SK env var to your api token",
); );
OpenAIClient::new(&sk) Client::new(&sk)
} }
#[tokio::test] async_test!(can_get_engines_async, {
async fn can_get_engines() {
let client = get_client(); let client = get_client();
client.engines().await.unwrap(); client.engines().await?
} });
#[tokio::test]
async fn can_get_engine() { sync_test!(can_get_engines_sync, {
let client = get_client(); let client = get_client();
let result = client.engine(api::Engine::Ada).await; let engines = client
.engines_sync()?
.into_iter()
.map(|ei| ei.id)
.collect::<Vec<_>>();
assert!(engines.contains(&api::Engine::Ada));
assert!(engines.contains(&api::Engine::Babbage));
assert!(engines.contains(&api::Engine::Curie));
assert!(engines.contains(&api::Engine::Davinci));
});
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
where
T: std::fmt::Debug,
{
match result { match result {
Err(OpenAIError::APIError(ErrorMessage { Err(Error::APIError(api::ErrorMessage {
message, message,
error_type, error_type,
})) => { })) => {
@ -615,18 +839,28 @@ mod integration {
} }
} }
} }
async_test!(can_get_engine_async, {
#[tokio::test]
async fn complete_string() -> crate::Result<()> {
let client = get_client(); let client = get_client();
client.complete("Hey there").await?; assert_expected_engine_failure(client.engine(api::Engine::Ada).await);
Ok(()) });
}
#[tokio::test] sync_test!(can_get_engine_sync, {
async fn complete_explicit_params() -> crate::Result<()> {
let client = get_client(); let client = get_client();
let args = api::CompletionArgsBuilder::default() assert_expected_engine_failure(client.engine_sync(api::Engine::Ada));
});
async_test!(complete_string_async, {
let client = get_client();
client.complete_prompt("Hey there").await?;
});
sync_test!(complete_string_sync, {
let client = get_client();
client.complete_prompt_sync("Hey there")?;
});
fn create_args() -> api::CompletionArgs {
api::CompletionArgsBuilder::default()
.prompt("Once upon a time,") .prompt("Once upon a time,")
.max_tokens(10) .max_tokens(10)
.temperature(0.5) .temperature(0.5)
@ -642,28 +876,52 @@ mod integration {
"23".into() => 0.0, "23".into() => 0.0,
}) })
.build() .build()
.expect("Build should have succeeded"); .expect("Bug: build should succeed")
client.complete(args).await?;
Ok(())
} }
async_test!(complete_explicit_params_async, {
#[tokio::test]
async fn complete_stop_condition() -> crate::Result<()> {
let client = get_client(); let client = get_client();
let args = create_args();
client.complete_prompt(args).await?;
});
sync_test!(complete_explicit_params_sync, {
let client = get_client();
let args = create_args();
client.complete_prompt_sync(args)?
});
fn stop_condition_args() -> api::CompletionArgs {
let mut args = api::CompletionArgs::builder(); let mut args = api::CompletionArgs::builder();
let completion = args args.prompt(
.prompt( r#"
r#"
Q: Please type `#` now Q: Please type `#` now
A:"#, A:"#,
) )
// turn temp & top_p way down to prevent test flakiness // turn temp & top_p way down to prevent test flakiness
.temperature(0.0) .temperature(0.0)
.top_p(0.0) .top_p(0.0)
.max_tokens(100) .max_tokens(100)
.stop(vec!["#".into(), "\n".into()]) .stop(vec!["#".into(), "\n".into()])
.complete(&client).await?; .build()
assert_eq!(completion.choices[0].finish_reason, api::FinishReason::StopSequenceReached); .expect("Bug: build should succeed")
Ok(())
} }
fn assert_completion_finish_reason(completion: Completion) {
assert_eq!(
completion.choices[0].finish_reason,
api::FinishReason::StopSequenceReached
);
}
async_test!(complete_stop_condition_async, {
let client = get_client();
let args = stop_condition_args();
assert_completion_finish_reason(client.complete_prompt(args).await?);
});
sync_test!(complete_stop_condition_sync, {
let client = get_client();
let args = stop_condition_args();
assert_completion_finish_reason(client.complete_prompt_sync(args)?);
});
} }