Immutable setter (#11)

Immutable builder makes things a bit less awkward
This commit is contained in:
Josh Kuhn 2021-01-05 22:33:49 -08:00 committed by GitHub
parent c13b019ee9
commit 68a7b500f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 60 deletions

View File

@ -10,8 +10,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_token = std::env::var("OPENAI_SK")?; let api_token = std::env::var("OPENAI_SK")?;
let client = Client::new(&api_token); let client = Client::new(&api_token);
let mut context = String::from(START_PROMPT); let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder(); let args = CompletionArgs::builder()
args.engine("davinci") .engine("davinci")
.max_tokens(45) .max_tokens(45)
.stop(vec!["\n".into()]) .stop(vec!["\n".into()])
.top_p(0.5) .top_p(0.5)
@ -25,7 +25,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
break; break;
} }
context.push_str("\nAI: "); context.push_str("\nAI: ");
match args.prompt(context.as_str()).complete_prompt(&client).await { match client
.complete_prompt(args.prompt(context.as_str()).build()?)
.await
{
Ok(completion) => { Ok(completion) => {
println!("\x1b[1;36m{}\x1b[1;0m", completion); println!("\x1b[1;36m{}\x1b[1;0m", completion);
context.push_str(&completion.choices[0].text); context.push_str(&completion.choices[0].text);

View File

@ -11,7 +11,6 @@ pub mod api {
//! Data types corresponding to requests and responses from the API //! Data types corresponding to requests and responses from the API
use std::{collections::HashMap, fmt::Display}; use std::{collections::HashMap, fmt::Display};
use super::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Container type. Used in the api, but not useful for clients of this library /// Container type. Used in the api, but not useful for clients of this library
@ -30,6 +29,7 @@ pub mod api {
/// Options that affect the result /// Options that affect the result
#[derive(Serialize, Debug, Builder, Clone)] #[derive(Serialize, Debug, Builder, Clone)]
#[builder(pattern = "immutable")]
pub struct CompletionArgs { pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")] #[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String, prompt: String,
@ -77,37 +77,6 @@ pub mod api {
pub fn builder() -> CompletionArgsBuilder { pub fn builder() -> CompletionArgsBuilder {
CompletionArgsBuilder::default() CompletionArgsBuilder::default()
} }
/// Request a completion from the api
///
/// # Errors
/// `Error::APIError` if the api returns an error
#[cfg(feature = "async")]
pub async fn complete_prompt(self, client: &Client) -> super::Result<Completion> {
client.complete_prompt(self).await
}
#[cfg(feature = "sync")]
pub fn complete_prompt_sync(self, client: &Client) -> super::Result<Completion> {
client.complete_prompt_sync(self)
}
}
impl CompletionArgsBuilder {
/// Request a completion from the api
///
/// # Errors
/// `Error::BadArguments` if the arguments to complete are not valid
/// `Error::APIError` if the api returns an error
#[cfg(feature = "async")]
pub async fn complete_prompt(&self, client: &Client) -> super::Result<Completion> {
client.complete_prompt(self.build()?).await
}
#[cfg(feature = "sync")]
pub fn complete_prompt_sync(&self, client: &Client) -> super::Result<Completion> {
client.complete_prompt_sync(self.build()?)
}
} }
/// Represents a non-streamed completion response /// Represents a non-streamed completion response
@ -183,22 +152,22 @@ pub mod api {
pub enum Error { pub enum Error {
/// An error returned by the API itself /// An error returned by the API itself
#[error("API returned an Error: {}", .0.message)] #[error("API returned an Error: {}", .0.message)]
APIError(api::ErrorMessage), API(api::ErrorMessage),
/// An error the client discovers before talking to the API /// An error the client discovers before talking to the API
#[error("Bad arguments: {0}")] #[error("Bad arguments: {0}")]
BadArguments(String), BadArguments(String),
/// Network / protocol related errors /// Network / protocol related errors
#[cfg(feature = "async")] #[cfg(feature = "async")]
#[error("Error at the protocol level: {0}")] #[error("Error at the protocol level: {0}")]
AsyncProtocolError(surf::Error), AsyncProtocol(surf::Error),
#[cfg(feature = "sync")] #[cfg(feature = "sync")]
#[error("Error at the protocol level, sync client")] #[error("Error at the protocol level, sync client")]
SyncProtocolError(ureq::Error), SyncProtocol(ureq::Error),
} }
impl From<api::ErrorMessage> for Error { impl From<api::ErrorMessage> for Error {
fn from(e: api::ErrorMessage) -> Self { fn from(e: api::ErrorMessage) -> Self {
Error::APIError(e) Error::API(e)
} }
} }
@ -211,14 +180,14 @@ impl From<String> for Error {
#[cfg(feature = "async")] #[cfg(feature = "async")]
impl From<surf::Error> for Error { impl From<surf::Error> for Error {
fn from(e: surf::Error) -> Self { fn from(e: surf::Error) -> Self {
Error::AsyncProtocolError(e) Error::AsyncProtocol(e)
} }
} }
#[cfg(feature = "sync")] #[cfg(feature = "sync")]
impl From<ureq::Error> for Error { impl From<ureq::Error> for Error {
fn from(e: ureq::Error) -> Self { fn from(e: ureq::Error) -> Self {
Error::SyncProtocolError(e) Error::SyncProtocol(e)
} }
} }
@ -293,7 +262,7 @@ impl Client {
// Creates a new `Client` given an api token // Creates a new `Client` given an api token
#[must_use] #[must_use]
pub fn new(token: &str) -> Self { pub fn new(token: &str) -> Self {
let base_url = String::from("https://api.openai.com/v1/"); let base_url: String = "https://api.openai.com/v1/".into();
Self { Self {
#[cfg(feature = "async")] #[cfg(feature = "async")]
async_client: async_client(token, &base_url), async_client: async_client(token, &base_url),
@ -331,7 +300,7 @@ impl Client {
Ok(response.body_json::<T>().await?) Ok(response.body_json::<T>().await?)
} else { } else {
let err = response.body_json::<api::ErrorWrapper>().await?.error; let err = response.body_json::<api::ErrorWrapper>().await?.error;
Err(Error::APIError(err)) Err(Error::API(err))
} }
} }
@ -353,7 +322,7 @@ impl Client {
.into_json_deserialize::<api::ErrorWrapper>() .into_json_deserialize::<api::ErrorWrapper>()
.expect("Bug: client couldn't deserialize api error response") .expect("Bug: client couldn't deserialize api error response")
.error; .error;
Err(Error::APIError(err)) Err(Error::API(err))
} }
} }
@ -410,7 +379,7 @@ impl Client {
.await?; .await?;
match response.status() { match response.status() {
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?), surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
_ => Err(Error::APIError( _ => Err(Error::API(
response response
.body_json::<api::ErrorWrapper>() .body_json::<api::ErrorWrapper>()
.await .await
@ -436,7 +405,7 @@ impl Client {
200 => Ok(response 200 => Ok(response
.into_json_deserialize() .into_json_deserialize()
.expect("Bug: client couldn't deserialize api response")), .expect("Bug: client couldn't deserialize api response")),
_ => Err(Error::APIError( _ => Err(Error::API(
response response
.into_json_deserialize::<api::ErrorWrapper>() .into_json_deserialize::<api::ErrorWrapper>()
.expect("Bug: client couldn't deserialize api error response") .expect("Bug: client couldn't deserialize api error response")
@ -660,7 +629,7 @@ mod unit {
async_test!(engine_error_response_async, { async_test!(engine_error_response_async, {
let (_m, expected) = mock_engine(); let (_m, expected) = mock_engine();
let response = mocked_client().engine("davinci").await; let response = mocked_client().engine("davinci").await;
if let Result::Err(Error::APIError(msg)) = response { if let Result::Err(Error::API(msg)) = response {
assert_eq!(expected, msg); assert_eq!(expected, msg);
} }
}); });
@ -668,7 +637,7 @@ mod unit {
sync_test!(engine_error_response_sync, { sync_test!(engine_error_response_sync, {
let (_m, expected) = mock_engine(); let (_m, expected) = mock_engine();
let response = mocked_client().engine_sync("davinci"); let response = mocked_client().engine_sync("davinci");
if let Result::Err(Error::APIError(msg)) = response { if let Result::Err(Error::API(msg)) = response {
assert_eq!(expected, msg); assert_eq!(expected, msg);
} }
}); });
@ -782,7 +751,7 @@ mod integration {
T: std::fmt::Debug, T: std::fmt::Debug,
{ {
match result { match result {
Err(Error::APIError(api::ErrorMessage { Err(Error::API(api::ErrorMessage {
message, message,
error_type, error_type,
})) => { })) => {
@ -846,19 +815,19 @@ mod integration {
}); });
fn stop_condition_args() -> api::CompletionArgs { fn stop_condition_args() -> api::CompletionArgs {
let mut args = api::CompletionArgs::builder(); api::CompletionArgs::builder()
args.prompt( .prompt(
r#" r#"
Q: Please type `#` now Q: Please type `#` now
A:"#, A:"#,
) )
// turn temp & top_p way down to prevent test flakiness // turn temp & top_p way down to prevent test flakiness
.temperature(0.0) .temperature(0.0)
.top_p(0.0) .top_p(0.0)
.max_tokens(100) .max_tokens(100)
.stop(vec!["#".into(), "\n".into()]) .stop(vec!["#".into(), "\n".into()])
.build() .build()
.expect("Bug: build should succeed") .expect("Bug: build should succeed")
} }
fn assert_completion_finish_reason(completion: Completion) { fn assert_completion_finish_reason(completion: Completion) {