Add synchronous client feature (#8)
This commit is contained in:
parent
8a2f6400b4
commit
9ded1062d1
8
.github/workflows/rust.yml
vendored
8
.github/workflows/rust.yml
vendored
@ -20,6 +20,14 @@ jobs:
|
|||||||
- uses: actions-rs/cargo@v1
|
- uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: check
|
command: check
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: check
|
||||||
|
args: --features=sync --no-default-features
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: check
|
||||||
|
args: --features=async --no-default-features
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test Suite
|
name: Test Suite
|
||||||
|
17
Cargo.toml
17
Cargo.toml
@ -11,22 +11,25 @@ keywords = ["openai", "gpt3"]
|
|||||||
categories = ["api-bindings", "asynchronous"]
|
categories = ["api-bindings", "asynchronous"]
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["hyper"]
|
default = ["sync", "async"]
|
||||||
|
sync = ["ureq"]
|
||||||
hyper = ["surf/hyper-client"]
|
async = ["surf/hyper-client"]
|
||||||
curl = ["surf/curl-client"]
|
|
||||||
h1 = ["surf/h1-client"]
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
surf = { version = "^2.1.0", default-features = false }
|
|
||||||
thiserror = "^1.0.22"
|
thiserror = "^1.0.22"
|
||||||
serde = { version = "^1.0.117", features = ["derive"] }
|
serde = { version = "^1.0.117", features = ["derive"] }
|
||||||
derive_builder = "^0.9.0"
|
derive_builder = "^0.9.0"
|
||||||
log = "^0.4.11"
|
log = "^0.4.11"
|
||||||
|
|
||||||
|
# Used by sync feature
|
||||||
|
ureq = { version = "^1.5.4", optional=true, features = ["json", "tls"] }
|
||||||
|
serde_json = { version="^1.0"}
|
||||||
|
# Used by async feature
|
||||||
|
surf = { version = "^2.1.0", optional=true, default-features=false}
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
mockito = "0.28.0"
|
mockito = "0.28.0"
|
||||||
maplit = "1.0.2"
|
maplit = "1.0.2"
|
||||||
tokio = { version = "^0.2.5", features = ["full"]}
|
tokio = { version = "^0.2.5", features = ["full"]}
|
||||||
serde_json = "^1.0"
|
|
||||||
env_logger = "0.8.2"
|
env_logger = "0.8.2"
|
||||||
|
serde_json = "^1.0"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use openai_api::{
|
use openai_api::{
|
||||||
api::{CompletionArgs, Engine},
|
api::{CompletionArgs, Engine},
|
||||||
OpenAIClient,
|
Client,
|
||||||
};
|
};
|
||||||
|
|
||||||
const START_PROMPT: &str = "
|
const START_PROMPT: &str = "
|
||||||
@ -11,7 +11,7 @@ AI: I am an AI. How can I help you today?";
|
|||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
let api_token = std::env::var("OPENAI_SK").unwrap();
|
||||||
let client = OpenAIClient::new(&api_token);
|
let client = Client::new(&api_token);
|
||||||
let mut context = String::from(START_PROMPT);
|
let mut context = String::from(START_PROMPT);
|
||||||
let mut args = CompletionArgs::builder();
|
let mut args = CompletionArgs::builder();
|
||||||
args.engine(Engine::Davinci)
|
args.engine(Engine::Davinci)
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
///! Example that prints out a story from a prompt. Used in the readme.
|
///! Example that prints out a story from a prompt. Used in the readme.
|
||||||
use openai_api::OpenAIClient;
|
use openai_api::Client;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
let api_token = std::env::var("OPENAI_SK").unwrap();
|
||||||
let client = OpenAIClient::new(&api_token);
|
let client = Client::new(&api_token);
|
||||||
let prompt = String::from("Once upon a time,");
|
let prompt = String::from("Once upon a time,");
|
||||||
println!(
|
println!(
|
||||||
"{}{}",
|
"{}{}",
|
||||||
prompt,
|
prompt,
|
||||||
client.complete(prompt.as_str()).await.unwrap()
|
client.complete_prompt(prompt.as_str()).await.unwrap()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
540
src/lib.rs
540
src/lib.rs
@ -4,14 +4,14 @@ extern crate derive_builder;
|
|||||||
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
#[allow(clippy::default_trait_access)]
|
#[allow(clippy::default_trait_access)]
|
||||||
pub mod api {
|
pub mod api {
|
||||||
//! Data types corresponding to requests and responses from the API
|
//! Data types corresponding to requests and responses from the API
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, fmt::Display};
|
||||||
|
|
||||||
use super::OpenAIClient;
|
use super::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Container type. Used in the api, but not useful for clients of this library
|
/// Container type. Used in the api, but not useful for clients of this library
|
||||||
@ -115,9 +115,15 @@ pub mod api {
|
|||||||
/// Request a completion from the api
|
/// Request a completion from the api
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
/// `OpenAIError::APIError` if the api returns an error
|
/// `Error::APIError` if the api returns an error
|
||||||
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
|
#[cfg(feature = "async")]
|
||||||
client.complete(self.clone()).await
|
pub async fn complete_prompt(self, client: &Client) -> super::Result<Completion> {
|
||||||
|
client.complete_prompt(self).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
pub fn complete_prompt_sync(self, client: &Client) -> super::Result<Completion> {
|
||||||
|
client.complete_prompt_sync(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,10 +131,16 @@ pub mod api {
|
|||||||
/// Request a completion from the api
|
/// Request a completion from the api
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
/// `OpenAIError::BadArguments` if the arguments to complete are not valid
|
/// `Error::BadArguments` if the arguments to complete are not valid
|
||||||
/// `OpenAIError::APIError` if the api returns an error
|
/// `Error::APIError` if the api returns an error
|
||||||
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
|
#[cfg(feature = "async")]
|
||||||
client.complete(self.build()?).await
|
pub async fn complete_prompt(&self, client: &Client) -> super::Result<Completion> {
|
||||||
|
client.complete_prompt(self.build()?).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
pub fn complete_prompt_sync(&self, client: &Client) -> super::Result<Completion> {
|
||||||
|
client.complete_prompt_sync(self.build()?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,55 +211,70 @@ pub mod api {
|
|||||||
pub error_type: String,
|
pub error_type: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Display for ErrorMessage {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
self.message.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// API-level wrapper used in deserialization
|
/// API-level wrapper used in deserialization
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub(super) struct ErrorWrapper {
|
pub(crate) struct ErrorWrapper {
|
||||||
pub error: ErrorMessage,
|
pub error: ErrorMessage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This library's main `Error` type.
|
/// This library's main `Error` type.
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum OpenAIError {
|
pub enum Error {
|
||||||
/// An error returned by the API itself
|
/// An error returned by the API itself
|
||||||
#[error("API Returned an Error document")]
|
#[error("API returned an Error: {}", .0.message)]
|
||||||
APIError(api::ErrorMessage),
|
APIError(api::ErrorMessage),
|
||||||
/// An error the client discovers before talking to the API
|
/// An error the client discovers before talking to the API
|
||||||
#[error("Bad arguments")]
|
#[error("Bad arguments: {0}")]
|
||||||
BadArguments(String),
|
BadArguments(String),
|
||||||
/// Network / protocol related errors
|
/// Network / protocol related errors
|
||||||
#[error("Error at the protocol level")]
|
#[cfg(feature = "async")]
|
||||||
ProtocolError(surf::Error),
|
#[error("Error at the protocol level: {0}")]
|
||||||
|
AsyncProtocolError(surf::Error),
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
#[error("Error at the protocol level, sync client")]
|
||||||
|
SyncProtocolError(ureq::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<api::ErrorMessage> for OpenAIError {
|
impl From<api::ErrorMessage> for Error {
|
||||||
fn from(e: api::ErrorMessage) -> Self {
|
fn from(e: api::ErrorMessage) -> Self {
|
||||||
OpenAIError::APIError(e)
|
Error::APIError(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<String> for OpenAIError {
|
impl From<String> for Error {
|
||||||
fn from(e: String) -> Self {
|
fn from(e: String) -> Self {
|
||||||
OpenAIError::BadArguments(e)
|
Error::BadArguments(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<surf::Error> for OpenAIError {
|
#[cfg(feature = "async")]
|
||||||
|
impl From<surf::Error> for Error {
|
||||||
fn from(e: surf::Error) -> Self {
|
fn from(e: surf::Error) -> Self {
|
||||||
OpenAIError::ProtocolError(e)
|
Error::AsyncProtocolError(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Client object. Must be constructed to talk to the API.
|
#[cfg(feature = "sync")]
|
||||||
pub struct OpenAIClient {
|
impl From<ureq::Error> for Error {
|
||||||
client: surf::Client,
|
fn from(e: ureq::Error) -> Self {
|
||||||
|
Error::SyncProtocolError(e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Authentication middleware
|
/// Authentication middleware
|
||||||
|
#[cfg(feature = "async")]
|
||||||
struct BearerToken {
|
struct BearerToken {
|
||||||
token: String,
|
token: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "async")]
|
||||||
impl std::fmt::Debug for BearerToken {
|
impl std::fmt::Debug for BearerToken {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
// Get the first few characters to help debug, but not accidentally log key
|
// Get the first few characters to help debug, but not accidentally log key
|
||||||
@ -259,6 +286,7 @@ impl std::fmt::Debug for BearerToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "async")]
|
||||||
impl BearerToken {
|
impl BearerToken {
|
||||||
fn new(token: &str) -> Self {
|
fn new(token: &str) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -267,6 +295,7 @@ impl BearerToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "async")]
|
||||||
#[surf::utils::async_trait]
|
#[surf::utils::async_trait]
|
||||||
impl surf::middleware::Middleware for BearerToken {
|
impl surf::middleware::Middleware for BearerToken {
|
||||||
async fn handle(
|
async fn handle(
|
||||||
@ -277,40 +306,99 @@ impl surf::middleware::Middleware for BearerToken {
|
|||||||
) -> surf::Result<surf::Response> {
|
) -> surf::Result<surf::Response> {
|
||||||
log::debug!("Request: {:?}", req);
|
log::debug!("Request: {:?}", req);
|
||||||
req.insert_header("Authorization", format!("Bearer {}", self.token));
|
req.insert_header("Authorization", format!("Bearer {}", self.token));
|
||||||
let response = next.run(req, client).await?;
|
let response: surf::Response = next.run(req, client).await?;
|
||||||
log::debug!("Response: {:?}", response);
|
log::debug!("Response: {:?}", response);
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIClient {
|
#[cfg(feature = "async")]
|
||||||
/// Creates a new `OpenAIClient` given an api token
|
fn async_client(token: &str, base_url: &str) -> surf::Client {
|
||||||
|
let mut async_client = surf::client();
|
||||||
|
async_client.set_base_url(surf::Url::parse(base_url).expect("Static string should parse"));
|
||||||
|
async_client.with(BearerToken::new(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
fn sync_client(token: &str) -> ureq::Agent {
|
||||||
|
ureq::agent().auth_kind("Bearer", token).build()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Client object. Must be constructed to talk to the API.
|
||||||
|
pub struct Client {
|
||||||
|
#[cfg(feature = "async")]
|
||||||
|
async_client: surf::Client,
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
sync_client: ureq::Agent,
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
base_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
// Creates a new `Client` given an api token
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(token: &str) -> Self {
|
pub fn new(token: &str) -> Self {
|
||||||
let mut client = surf::client();
|
let base_url = String::from("https://api.openai.com/v1/");
|
||||||
client.set_base_url(
|
Self {
|
||||||
surf::Url::parse("https://api.openai.com/v1/").expect("Static string should parse"),
|
#[cfg(feature = "async")]
|
||||||
);
|
async_client: async_client(token, &base_url),
|
||||||
client = client.with(BearerToken::new(token));
|
#[cfg(feature = "sync")]
|
||||||
Self { client }
|
sync_client: sync_client(token),
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
base_url,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allow setting the api root in the tests
|
// Allow setting the api root in the tests
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn set_api_root(&mut self, url: surf::Url) {
|
fn set_api_root(mut self, base_url: &str) -> Self {
|
||||||
self.client.set_base_url(url);
|
#[cfg(feature = "async")]
|
||||||
|
{
|
||||||
|
self.async_client.set_base_url(
|
||||||
|
surf::Url::parse(base_url).expect("static URL expected to parse correctly"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
{
|
||||||
|
self.base_url = String::from(base_url);
|
||||||
|
}
|
||||||
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Private helper for making gets
|
/// Private helper for making gets
|
||||||
async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
|
#[cfg(feature = "async")]
|
||||||
let mut response = self.client.get(endpoint).await?;
|
async fn get<T>(&self, endpoint: &str) -> Result<T>
|
||||||
|
where
|
||||||
|
T: serde::de::DeserializeOwned,
|
||||||
|
{
|
||||||
|
let mut response = self.async_client.get(endpoint).await?;
|
||||||
if let surf::StatusCode::Ok = response.status() {
|
if let surf::StatusCode::Ok = response.status() {
|
||||||
Ok(response.body_json::<T>().await?)
|
Ok(response.body_json::<T>().await?)
|
||||||
} else {
|
} else {
|
||||||
{
|
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
||||||
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
Err(Error::APIError(err))
|
||||||
Err(OpenAIError::APIError(err))
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
fn get_sync<T>(&self, endpoint: &str) -> Result<T>
|
||||||
|
where
|
||||||
|
T: serde::de::DeserializeOwned,
|
||||||
|
{
|
||||||
|
let response = dbg!(self
|
||||||
|
.sync_client
|
||||||
|
.get(&format!("{}{}", self.base_url, endpoint)))
|
||||||
|
.call();
|
||||||
|
if let 200 = response.status() {
|
||||||
|
Ok(response
|
||||||
|
.into_json_deserialize()
|
||||||
|
.expect("Bug: client couldn't deserialize api response"))
|
||||||
|
} else {
|
||||||
|
let err = response
|
||||||
|
.into_json_deserialize::<api::ErrorWrapper>()
|
||||||
|
.expect("Bug: client couldn't deserialize api error response")
|
||||||
|
.error;
|
||||||
|
Err(Error::APIError(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -319,35 +407,55 @@ impl OpenAIClient {
|
|||||||
/// Provides basic information about each one such as the owner and availability.
|
/// Provides basic information about each one such as the owner and availability.
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
/// - `OpenAIError::APIError` if the server returns an error
|
/// - `Error::APIError` if the server returns an error
|
||||||
|
#[cfg(feature = "async")]
|
||||||
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
|
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
|
||||||
self.get("engines").await.map(|r: api::Container<_>| r.data)
|
self.get("engines").await.map(|r: api::Container<_>| r.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Lists the currently available engines.
|
||||||
|
///
|
||||||
|
/// Provides basic information about each one such as the owner and availability.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// - `Error::APIError` if the server returns an error
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
pub fn engines_sync(&self) -> Result<Vec<api::EngineInfo>> {
|
||||||
|
self.get_sync("engines").map(|r: api::Container<_>| r.data)
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieves an engine instance
|
/// Retrieves an engine instance
|
||||||
|
///
|
||||||
/// Provides basic information about the engine such as the owner and availability.
|
/// Provides basic information about the engine such as the owner and availability.
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
/// - `OpenAIError::APIError` if the server returns an error
|
/// - `Error::APIError` if the server returns an error
|
||||||
|
#[cfg(feature = "async")]
|
||||||
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
|
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
|
||||||
self.get(&format!("engines/{}", engine)).await
|
self.get(&format!("engines/{}", engine)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
pub fn engine_sync(&self, engine: api::Engine) -> Result<api::EngineInfo> {
|
||||||
|
self.get_sync(&format!("engines/{}", engine))
|
||||||
|
}
|
||||||
|
|
||||||
// Private helper to generate post requests. Needs to be a bit more flexible than
|
// Private helper to generate post requests. Needs to be a bit more flexible than
|
||||||
// get because it should support SSE eventually
|
// get because it should support SSE eventually
|
||||||
|
#[cfg(feature = "async")]
|
||||||
async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
|
async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
|
||||||
where
|
where
|
||||||
B: serde::ser::Serialize,
|
B: serde::ser::Serialize,
|
||||||
R: serde::de::DeserializeOwned,
|
R: serde::de::DeserializeOwned,
|
||||||
{
|
{
|
||||||
let mut response = self
|
let mut response = self
|
||||||
.client
|
.async_client
|
||||||
.post(endpoint)
|
.post(endpoint)
|
||||||
.body(surf::Body::from_json(&body)?)
|
.body(surf::Body::from_json(&body)?)
|
||||||
.await?;
|
.await?;
|
||||||
match response.status() {
|
match response.status() {
|
||||||
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
||||||
_ => Err(OpenAIError::APIError(
|
_ => Err(Error::APIError(
|
||||||
response
|
response
|
||||||
.body_json::<api::ErrorWrapper>()
|
.body_json::<api::ErrorWrapper>()
|
||||||
.await
|
.await
|
||||||
@ -356,11 +464,38 @@ impl OpenAIClient {
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
fn post_sync<B, R>(&self, endpoint: &str, body: B) -> Result<R>
|
||||||
|
where
|
||||||
|
B: serde::ser::Serialize,
|
||||||
|
R: serde::de::DeserializeOwned,
|
||||||
|
{
|
||||||
|
let response = self
|
||||||
|
.sync_client
|
||||||
|
.post(&format!("{}{}", self.base_url, endpoint))
|
||||||
|
.send_json(
|
||||||
|
serde_json::to_value(body).expect("Bug: client couldn't serialize its own type"),
|
||||||
|
);
|
||||||
|
match response.status() {
|
||||||
|
200 => Ok(response
|
||||||
|
.into_json_deserialize()
|
||||||
|
.expect("Bug: client couldn't deserialize api response")),
|
||||||
|
_ => Err(Error::APIError(
|
||||||
|
response
|
||||||
|
.into_json_deserialize::<api::ErrorWrapper>()
|
||||||
|
.expect("Bug: client couldn't deserialize api error response")
|
||||||
|
.error,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get predicted completion of the prompt
|
/// Get predicted completion of the prompt
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
/// - `OpenAIError::APIError` if the api returns an error
|
/// - `Error::APIError` if the api returns an error
|
||||||
pub async fn complete(
|
#[cfg(feature = "async")]
|
||||||
|
pub async fn complete_prompt(
|
||||||
&self,
|
&self,
|
||||||
prompt: impl Into<api::CompletionArgs>,
|
prompt: impl Into<api::CompletionArgs>,
|
||||||
) -> Result<api::Completion> {
|
) -> Result<api::Completion> {
|
||||||
@ -369,20 +504,60 @@ impl OpenAIClient {
|
|||||||
.post(&format!("engines/{}/completions", args.engine), args)
|
.post(&format!("engines/{}/completions", args.engine), args)
|
||||||
.await?)
|
.await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get predicted completion of the prompt synchronously
|
||||||
|
///
|
||||||
|
/// # Error
|
||||||
|
/// - `Error::APIError` if the api returns an error
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
pub fn complete_prompt_sync(
|
||||||
|
&self,
|
||||||
|
prompt: impl Into<api::CompletionArgs>,
|
||||||
|
) -> Result<api::Completion> {
|
||||||
|
let args = prompt.into();
|
||||||
|
self.post_sync(&format!("engines/{}/completions", args.engine), args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add a macro to de-boilerplate the sync and async tests
|
||||||
|
|
||||||
|
#[allow(unused_macros)]
|
||||||
|
macro_rules! async_test {
|
||||||
|
($test_name: ident, $test_body: block) => {
|
||||||
|
#[cfg(feature = "async")]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn $test_name() -> crate::Result<()> {
|
||||||
|
$test_body;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused_macros)]
|
||||||
|
macro_rules! sync_test {
|
||||||
|
($test_name: ident, $test_body: expr) => {
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
#[test]
|
||||||
|
fn $test_name() -> crate::Result<()> {
|
||||||
|
$test_body;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod unit {
|
mod unit {
|
||||||
|
|
||||||
use crate::{api, OpenAIClient, OpenAIError};
|
use mockito::Mock;
|
||||||
|
|
||||||
fn mocked_client() -> OpenAIClient {
|
use crate::{
|
||||||
|
api::{self, Completion, CompletionArgs, Engine, EngineInfo},
|
||||||
|
Client, Error,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn mocked_client() -> Client {
|
||||||
let _ = env_logger::builder().is_test(true).try_init();
|
let _ = env_logger::builder().is_test(true).try_init();
|
||||||
let mut client = OpenAIClient::new("bogus");
|
Client::new("bogus").set_api_root(&format!("{}/", mockito::server_url()))
|
||||||
client.set_api_root(
|
|
||||||
surf::Url::parse(&mockito::server_url()).expect("mockito url didn't parse"),
|
|
||||||
);
|
|
||||||
client
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -410,10 +585,8 @@ mod unit {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
fn mock_engines() -> (Mock, Vec<EngineInfo>) {
|
||||||
async fn parse_engines() -> crate::Result<()> {
|
let mock = mockito::mock("GET", "/engines")
|
||||||
use api::{Engine, EngineInfo};
|
|
||||||
let _m = mockito::mock("GET", "/engines")
|
|
||||||
.with_status(200)
|
.with_status(200)
|
||||||
.with_header("content-type", "application/json")
|
.with_header("content-type", "application/json")
|
||||||
.with_body(
|
.with_body(
|
||||||
@ -460,6 +633,7 @@ mod unit {
|
|||||||
}"#,
|
}"#,
|
||||||
)
|
)
|
||||||
.create();
|
.create();
|
||||||
|
|
||||||
let expected = vec![
|
let expected = vec![
|
||||||
EngineInfo {
|
EngineInfo {
|
||||||
id: Engine::Ada,
|
id: Engine::Ada,
|
||||||
@ -492,40 +666,59 @@ mod unit {
|
|||||||
ready: true,
|
ready: true,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let response = mocked_client().engines().await?;
|
(mock, expected)
|
||||||
assert_eq!(response, expected);
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
async_test!(parse_engines_async, {
|
||||||
async fn engine_error_response() -> crate::Result<()> {
|
let (_m, expected) = mock_engines();
|
||||||
let _m = mockito::mock("GET", "/engines/davinci")
|
let response = mocked_client().engines().await?;
|
||||||
|
assert_eq!(response, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
sync_test!(parse_engines_sync, {
|
||||||
|
let (_m, expected) = mock_engines();
|
||||||
|
let response = mocked_client().engines_sync()?;
|
||||||
|
assert_eq!(response, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
fn mock_engine() -> (Mock, api::ErrorMessage) {
|
||||||
|
let mock = mockito::mock("GET", "/engines/davinci")
|
||||||
.with_status(404)
|
.with_status(404)
|
||||||
.with_header("content-type", "application/json")
|
.with_header("content-type", "application/json")
|
||||||
.with_body(
|
.with_body(
|
||||||
r#"{
|
r#"{
|
||||||
"error": {
|
"error": {
|
||||||
"code": null,
|
"code": null,
|
||||||
"message": "Some kind of error happened",
|
"message": "Some kind of error happened",
|
||||||
"type": "some_error_type"
|
"type": "some_error_type"
|
||||||
}
|
}
|
||||||
}"#,
|
}"#,
|
||||||
)
|
)
|
||||||
.create();
|
.create();
|
||||||
let expected = api::ErrorMessage {
|
let expected = api::ErrorMessage {
|
||||||
message: "Some kind of error happened".into(),
|
message: "Some kind of error happened".into(),
|
||||||
error_type: "some_error_type".into(),
|
error_type: "some_error_type".into(),
|
||||||
};
|
};
|
||||||
let response = mocked_client().engine(api::Engine::Davinci).await;
|
(mock, expected)
|
||||||
if let Result::Err(OpenAIError::APIError(msg)) = response {
|
|
||||||
assert_eq!(expected, msg);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
async_test!(engine_error_response_async, {
|
||||||
async fn completion_args() -> crate::Result<()> {
|
let (_m, expected) = mock_engine();
|
||||||
let _m = mockito::mock("POST", "/engines/davinci/completions")
|
let response = mocked_client().engine(api::Engine::Davinci).await;
|
||||||
|
if let Result::Err(Error::APIError(msg)) = response {
|
||||||
|
assert_eq!(expected, msg);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
sync_test!(engine_error_response_sync, {
|
||||||
|
let (_m, expected) = mock_engine();
|
||||||
|
let response = mocked_client().engine_sync(api::Engine::Davinci);
|
||||||
|
if let Result::Err(Error::APIError(msg)) = response {
|
||||||
|
assert_eq!(expected, msg);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
fn mock_completion() -> crate::Result<(Mock, CompletionArgs, Completion)> {
|
||||||
|
let mock = mockito::mock("POST", "/engines/davinci/completions")
|
||||||
.with_status(200)
|
.with_status(200)
|
||||||
.with_header("content-type", "application/json")
|
.with_header("content-type", "application/json")
|
||||||
.with_body(
|
.with_body(
|
||||||
@ -544,7 +737,17 @@ mod unit {
|
|||||||
]
|
]
|
||||||
}"#,
|
}"#,
|
||||||
)
|
)
|
||||||
|
.expect(1)
|
||||||
.create();
|
.create();
|
||||||
|
let args = api::CompletionArgs::builder()
|
||||||
|
.engine(api::Engine::Davinci)
|
||||||
|
.prompt("Once upon a time")
|
||||||
|
.max_tokens(5)
|
||||||
|
.temperature(1.0)
|
||||||
|
.top_p(1.0)
|
||||||
|
.n(1)
|
||||||
|
.stop(vec!["\n".into()])
|
||||||
|
.build()?;
|
||||||
let expected = api::Completion {
|
let expected = api::Completion {
|
||||||
id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
|
id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
|
||||||
created: 1589478378,
|
created: 1589478378,
|
||||||
@ -556,54 +759,75 @@ mod unit {
|
|||||||
finish_reason: api::FinishReason::MaxTokensReached,
|
finish_reason: api::FinishReason::MaxTokensReached,
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
let args = api::CompletionArgs::builder()
|
Ok((mock, args, expected))
|
||||||
.engine(api::Engine::Davinci)
|
|
||||||
.prompt("Once upon a time")
|
|
||||||
.max_tokens(5)
|
|
||||||
.temperature(1.0)
|
|
||||||
.top_p(1.0)
|
|
||||||
.n(1)
|
|
||||||
.stop(vec!["\n".into()])
|
|
||||||
.build()?;
|
|
||||||
let response = mocked_client().complete(args).await?;
|
|
||||||
assert_eq!(response.model, expected.model);
|
|
||||||
assert_eq!(response.id, expected.id);
|
|
||||||
assert_eq!(response.created, expected.created);
|
|
||||||
let (resp_choice, expected_choice) = (&response.choices[0], &expected.choices[0]);
|
|
||||||
assert_eq!(resp_choice.text, expected_choice.text);
|
|
||||||
assert_eq!(resp_choice.index, expected_choice.index);
|
|
||||||
assert!(resp_choice.logprobs.is_none());
|
|
||||||
assert_eq!(resp_choice.finish_reason, expected_choice.finish_reason);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Defines boilerplate here. The Completion can't derive Eq since it contains
|
||||||
|
// floats in various places.
|
||||||
|
fn assert_completion_equal(a: Completion, b: Completion) {
|
||||||
|
assert_eq!(a.model, b.model);
|
||||||
|
assert_eq!(a.id, b.id);
|
||||||
|
assert_eq!(a.created, b.created);
|
||||||
|
let (a_choice, b_choice) = (&a.choices[0], &b.choices[0]);
|
||||||
|
assert_eq!(a_choice.text, b_choice.text);
|
||||||
|
assert_eq!(a_choice.index, b_choice.index);
|
||||||
|
assert!(a_choice.logprobs.is_none());
|
||||||
|
assert_eq!(a_choice.finish_reason, b_choice.finish_reason);
|
||||||
|
}
|
||||||
|
|
||||||
|
async_test!(completion_args_async, {
|
||||||
|
let (m, args, expected) = mock_completion()?;
|
||||||
|
let response = mocked_client().complete_prompt(args).await?;
|
||||||
|
assert_completion_equal(response, expected);
|
||||||
|
m.assert();
|
||||||
|
});
|
||||||
|
|
||||||
|
sync_test!(completion_args_sync, {
|
||||||
|
let (m, args, expected) = mock_completion()?;
|
||||||
|
let response = mocked_client().complete_prompt_sync(args)?;
|
||||||
|
assert_completion_equal(response, expected);
|
||||||
|
m.assert();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod integration {
|
mod integration {
|
||||||
|
use crate::{
|
||||||
use api::ErrorMessage;
|
api::{self, Completion},
|
||||||
|
Client, Error,
|
||||||
use crate::{api, OpenAIClient, OpenAIError};
|
};
|
||||||
/// Used by tests to get a client to the actual api
|
/// Used by tests to get a client to the actual api
|
||||||
fn get_client() -> OpenAIClient {
|
fn get_client() -> Client {
|
||||||
let _ = env_logger::builder().is_test(true).try_init();
|
let _ = env_logger::builder().is_test(true).try_init();
|
||||||
let sk = std::env::var("OPENAI_SK").expect(
|
let sk = std::env::var("OPENAI_SK").expect(
|
||||||
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
|
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
|
||||||
);
|
);
|
||||||
OpenAIClient::new(&sk)
|
Client::new(&sk)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
async_test!(can_get_engines_async, {
|
||||||
async fn can_get_engines() {
|
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
client.engines().await.unwrap();
|
client.engines().await?
|
||||||
}
|
});
|
||||||
#[tokio::test]
|
|
||||||
async fn can_get_engine() {
|
sync_test!(can_get_engines_sync, {
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
let result = client.engine(api::Engine::Ada).await;
|
let engines = client
|
||||||
|
.engines_sync()?
|
||||||
|
.into_iter()
|
||||||
|
.map(|ei| ei.id)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert!(engines.contains(&api::Engine::Ada));
|
||||||
|
assert!(engines.contains(&api::Engine::Babbage));
|
||||||
|
assert!(engines.contains(&api::Engine::Curie));
|
||||||
|
assert!(engines.contains(&api::Engine::Davinci));
|
||||||
|
});
|
||||||
|
|
||||||
|
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
|
||||||
|
where
|
||||||
|
T: std::fmt::Debug,
|
||||||
|
{
|
||||||
match result {
|
match result {
|
||||||
Err(OpenAIError::APIError(ErrorMessage {
|
Err(Error::APIError(api::ErrorMessage {
|
||||||
message,
|
message,
|
||||||
error_type,
|
error_type,
|
||||||
})) => {
|
})) => {
|
||||||
@ -615,18 +839,28 @@ mod integration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
async_test!(can_get_engine_async, {
|
||||||
#[tokio::test]
|
|
||||||
async fn complete_string() -> crate::Result<()> {
|
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
client.complete("Hey there").await?;
|
assert_expected_engine_failure(client.engine(api::Engine::Ada).await);
|
||||||
Ok(())
|
});
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
sync_test!(can_get_engine_sync, {
|
||||||
async fn complete_explicit_params() -> crate::Result<()> {
|
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
let args = api::CompletionArgsBuilder::default()
|
assert_expected_engine_failure(client.engine_sync(api::Engine::Ada));
|
||||||
|
});
|
||||||
|
|
||||||
|
async_test!(complete_string_async, {
|
||||||
|
let client = get_client();
|
||||||
|
client.complete_prompt("Hey there").await?;
|
||||||
|
});
|
||||||
|
|
||||||
|
sync_test!(complete_string_sync, {
|
||||||
|
let client = get_client();
|
||||||
|
client.complete_prompt_sync("Hey there")?;
|
||||||
|
});
|
||||||
|
|
||||||
|
fn create_args() -> api::CompletionArgs {
|
||||||
|
api::CompletionArgsBuilder::default()
|
||||||
.prompt("Once upon a time,")
|
.prompt("Once upon a time,")
|
||||||
.max_tokens(10)
|
.max_tokens(10)
|
||||||
.temperature(0.5)
|
.temperature(0.5)
|
||||||
@ -642,28 +876,52 @@ mod integration {
|
|||||||
"23".into() => 0.0,
|
"23".into() => 0.0,
|
||||||
})
|
})
|
||||||
.build()
|
.build()
|
||||||
.expect("Build should have succeeded");
|
.expect("Bug: build should succeed")
|
||||||
client.complete(args).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
async_test!(complete_explicit_params_async, {
|
||||||
#[tokio::test]
|
|
||||||
async fn complete_stop_condition() -> crate::Result<()> {
|
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
|
let args = create_args();
|
||||||
|
client.complete_prompt(args).await?;
|
||||||
|
});
|
||||||
|
|
||||||
|
sync_test!(complete_explicit_params_sync, {
|
||||||
|
let client = get_client();
|
||||||
|
let args = create_args();
|
||||||
|
client.complete_prompt_sync(args)?
|
||||||
|
});
|
||||||
|
|
||||||
|
fn stop_condition_args() -> api::CompletionArgs {
|
||||||
let mut args = api::CompletionArgs::builder();
|
let mut args = api::CompletionArgs::builder();
|
||||||
let completion = args
|
args.prompt(
|
||||||
.prompt(
|
r#"
|
||||||
r#"
|
|
||||||
Q: Please type `#` now
|
Q: Please type `#` now
|
||||||
A:"#,
|
A:"#,
|
||||||
)
|
)
|
||||||
// turn temp & top_p way down to prevent test flakiness
|
// turn temp & top_p way down to prevent test flakiness
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.top_p(0.0)
|
.top_p(0.0)
|
||||||
.max_tokens(100)
|
.max_tokens(100)
|
||||||
.stop(vec!["#".into(), "\n".into()])
|
.stop(vec!["#".into(), "\n".into()])
|
||||||
.complete(&client).await?;
|
.build()
|
||||||
assert_eq!(completion.choices[0].finish_reason, api::FinishReason::StopSequenceReached);
|
.expect("Bug: build should succeed")
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn assert_completion_finish_reason(completion: Completion) {
|
||||||
|
assert_eq!(
|
||||||
|
completion.choices[0].finish_reason,
|
||||||
|
api::FinishReason::StopSequenceReached
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async_test!(complete_stop_condition_async, {
|
||||||
|
let client = get_client();
|
||||||
|
let args = stop_condition_args();
|
||||||
|
assert_completion_finish_reason(client.complete_prompt(args).await?);
|
||||||
|
});
|
||||||
|
|
||||||
|
sync_test!(complete_stop_condition_sync, {
|
||||||
|
let client = get_client();
|
||||||
|
let args = stop_condition_args();
|
||||||
|
assert_completion_finish_reason(client.complete_prompt_sync(args)?);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user