From 68a7b500f5a34162a4dfcacaaeb81ece09e4d82a Mon Sep 17 00:00:00 2001 From: Josh Kuhn Date: Tue, 5 Jan 2021 22:33:49 -0800 Subject: [PATCH] Immutable setter (#11) Immutable builder makes things a bit less awkward --- examples/chatloop.rs | 9 +++-- src/lib.rs | 83 ++++++++++++++------------------------------ 2 files changed, 32 insertions(+), 60 deletions(-) diff --git a/examples/chatloop.rs b/examples/chatloop.rs index 124afde..89920f4 100644 --- a/examples/chatloop.rs +++ b/examples/chatloop.rs @@ -10,8 +10,8 @@ async fn main() -> Result<(), Box> { let api_token = std::env::var("OPENAI_SK")?; let client = Client::new(&api_token); let mut context = String::from(START_PROMPT); - let mut args = CompletionArgs::builder(); - args.engine("davinci") + let args = CompletionArgs::builder() + .engine("davinci") .max_tokens(45) .stop(vec!["\n".into()]) .top_p(0.5) @@ -25,7 +25,10 @@ async fn main() -> Result<(), Box> { break; } context.push_str("\nAI: "); - match args.prompt(context.as_str()).complete_prompt(&client).await { + match client + .complete_prompt(args.prompt(context.as_str()).build()?) + .await + { Ok(completion) => { println!("\x1b[1;36m{}\x1b[1;0m", completion); context.push_str(&completion.choices[0].text); diff --git a/src/lib.rs b/src/lib.rs index 5e4f63a..6c6fbef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,6 @@ pub mod api { //! Data types corresponding to requests and responses from the API use std::{collections::HashMap, fmt::Display}; - use super::Client; use serde::{Deserialize, Serialize}; /// Container type. Used in the api, but not useful for clients of this library @@ -30,6 +29,7 @@ pub mod api { /// Options that affect the result #[derive(Serialize, Debug, Builder, Clone)] + #[builder(pattern = "immutable")] pub struct CompletionArgs { #[builder(setter(into), default = "\"<|endoftext|>\".into()")] prompt: String, @@ -77,37 +77,6 @@ pub mod api { pub fn builder() -> CompletionArgsBuilder { CompletionArgsBuilder::default() } - - /// Request a completion from the api - /// - /// # Errors - /// `Error::APIError` if the api returns an error - #[cfg(feature = "async")] - pub async fn complete_prompt(self, client: &Client) -> super::Result { - client.complete_prompt(self).await - } - - #[cfg(feature = "sync")] - pub fn complete_prompt_sync(self, client: &Client) -> super::Result { - client.complete_prompt_sync(self) - } - } - - impl CompletionArgsBuilder { - /// Request a completion from the api - /// - /// # Errors - /// `Error::BadArguments` if the arguments to complete are not valid - /// `Error::APIError` if the api returns an error - #[cfg(feature = "async")] - pub async fn complete_prompt(&self, client: &Client) -> super::Result { - client.complete_prompt(self.build()?).await - } - - #[cfg(feature = "sync")] - pub fn complete_prompt_sync(&self, client: &Client) -> super::Result { - client.complete_prompt_sync(self.build()?) - } } /// Represents a non-streamed completion response @@ -183,22 +152,22 @@ pub mod api { pub enum Error { /// An error returned by the API itself #[error("API returned an Error: {}", .0.message)] - APIError(api::ErrorMessage), + API(api::ErrorMessage), /// An error the client discovers before talking to the API #[error("Bad arguments: {0}")] BadArguments(String), /// Network / protocol related errors #[cfg(feature = "async")] #[error("Error at the protocol level: {0}")] - AsyncProtocolError(surf::Error), + AsyncProtocol(surf::Error), #[cfg(feature = "sync")] #[error("Error at the protocol level, sync client")] - SyncProtocolError(ureq::Error), + SyncProtocol(ureq::Error), } impl From for Error { fn from(e: api::ErrorMessage) -> Self { - Error::APIError(e) + Error::API(e) } } @@ -211,14 +180,14 @@ impl From for Error { #[cfg(feature = "async")] impl From for Error { fn from(e: surf::Error) -> Self { - Error::AsyncProtocolError(e) + Error::AsyncProtocol(e) } } #[cfg(feature = "sync")] impl From for Error { fn from(e: ureq::Error) -> Self { - Error::SyncProtocolError(e) + Error::SyncProtocol(e) } } @@ -293,7 +262,7 @@ impl Client { // Creates a new `Client` given an api token #[must_use] pub fn new(token: &str) -> Self { - let base_url = String::from("https://api.openai.com/v1/"); + let base_url: String = "https://api.openai.com/v1/".into(); Self { #[cfg(feature = "async")] async_client: async_client(token, &base_url), @@ -331,7 +300,7 @@ impl Client { Ok(response.body_json::().await?) } else { let err = response.body_json::().await?.error; - Err(Error::APIError(err)) + Err(Error::API(err)) } } @@ -353,7 +322,7 @@ impl Client { .into_json_deserialize::() .expect("Bug: client couldn't deserialize api error response") .error; - Err(Error::APIError(err)) + Err(Error::API(err)) } } @@ -410,7 +379,7 @@ impl Client { .await?; match response.status() { surf::StatusCode::Ok => Ok(response.body_json::().await?), - _ => Err(Error::APIError( + _ => Err(Error::API( response .body_json::() .await @@ -436,7 +405,7 @@ impl Client { 200 => Ok(response .into_json_deserialize() .expect("Bug: client couldn't deserialize api response")), - _ => Err(Error::APIError( + _ => Err(Error::API( response .into_json_deserialize::() .expect("Bug: client couldn't deserialize api error response") @@ -660,7 +629,7 @@ mod unit { async_test!(engine_error_response_async, { let (_m, expected) = mock_engine(); let response = mocked_client().engine("davinci").await; - if let Result::Err(Error::APIError(msg)) = response { + if let Result::Err(Error::API(msg)) = response { assert_eq!(expected, msg); } }); @@ -668,7 +637,7 @@ mod unit { sync_test!(engine_error_response_sync, { let (_m, expected) = mock_engine(); let response = mocked_client().engine_sync("davinci"); - if let Result::Err(Error::APIError(msg)) = response { + if let Result::Err(Error::API(msg)) = response { assert_eq!(expected, msg); } }); @@ -782,7 +751,7 @@ mod integration { T: std::fmt::Debug, { match result { - Err(Error::APIError(api::ErrorMessage { + Err(Error::API(api::ErrorMessage { message, error_type, })) => { @@ -846,19 +815,19 @@ mod integration { }); fn stop_condition_args() -> api::CompletionArgs { - let mut args = api::CompletionArgs::builder(); - args.prompt( - r#" + api::CompletionArgs::builder() + .prompt( + r#" Q: Please type `#` now A:"#, - ) - // turn temp & top_p way down to prevent test flakiness - .temperature(0.0) - .top_p(0.0) - .max_tokens(100) - .stop(vec!["#".into(), "\n".into()]) - .build() - .expect("Bug: build should succeed") + ) + // turn temp & top_p way down to prevent test flakiness + .temperature(0.0) + .top_p(0.0) + .max_tokens(100) + .stop(vec!["#".into(), "\n".into()]) + .build() + .expect("Bug: build should succeed") } fn assert_completion_finish_reason(completion: Completion) {