Immutable setter (#11)

Immutable builder makes things a bit less awkward
This commit is contained in:
Josh Kuhn 2021-01-05 22:33:49 -08:00 committed by GitHub
parent c13b019ee9
commit 68a7b500f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 60 deletions

View File

@ -10,8 +10,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_token = std::env::var("OPENAI_SK")?;
let client = Client::new(&api_token);
let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder();
args.engine("davinci")
let args = CompletionArgs::builder()
.engine("davinci")
.max_tokens(45)
.stop(vec!["\n".into()])
.top_p(0.5)
@ -25,7 +25,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
break;
}
context.push_str("\nAI: ");
match args.prompt(context.as_str()).complete_prompt(&client).await {
match client
.complete_prompt(args.prompt(context.as_str()).build()?)
.await
{
Ok(completion) => {
println!("\x1b[1;36m{}\x1b[1;0m", completion);
context.push_str(&completion.choices[0].text);

View File

@ -11,7 +11,6 @@ pub mod api {
//! Data types corresponding to requests and responses from the API
use std::{collections::HashMap, fmt::Display};
use super::Client;
use serde::{Deserialize, Serialize};
/// Container type. Used in the api, but not useful for clients of this library
@ -30,6 +29,7 @@ pub mod api {
/// Options that affect the result
#[derive(Serialize, Debug, Builder, Clone)]
#[builder(pattern = "immutable")]
pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String,
@ -77,37 +77,6 @@ pub mod api {
pub fn builder() -> CompletionArgsBuilder {
CompletionArgsBuilder::default()
}
/// Request a completion from the api
///
/// # Errors
/// `Error::APIError` if the api returns an error
#[cfg(feature = "async")]
pub async fn complete_prompt(self, client: &Client) -> super::Result<Completion> {
client.complete_prompt(self).await
}
#[cfg(feature = "sync")]
pub fn complete_prompt_sync(self, client: &Client) -> super::Result<Completion> {
client.complete_prompt_sync(self)
}
}
impl CompletionArgsBuilder {
/// Request a completion from the api
///
/// # Errors
/// `Error::BadArguments` if the arguments to complete are not valid
/// `Error::APIError` if the api returns an error
#[cfg(feature = "async")]
pub async fn complete_prompt(&self, client: &Client) -> super::Result<Completion> {
client.complete_prompt(self.build()?).await
}
#[cfg(feature = "sync")]
pub fn complete_prompt_sync(&self, client: &Client) -> super::Result<Completion> {
client.complete_prompt_sync(self.build()?)
}
}
/// Represents a non-streamed completion response
@ -183,22 +152,22 @@ pub mod api {
pub enum Error {
/// An error returned by the API itself
#[error("API returned an Error: {}", .0.message)]
APIError(api::ErrorMessage),
API(api::ErrorMessage),
/// An error the client discovers before talking to the API
#[error("Bad arguments: {0}")]
BadArguments(String),
/// Network / protocol related errors
#[cfg(feature = "async")]
#[error("Error at the protocol level: {0}")]
AsyncProtocolError(surf::Error),
AsyncProtocol(surf::Error),
#[cfg(feature = "sync")]
#[error("Error at the protocol level, sync client")]
SyncProtocolError(ureq::Error),
SyncProtocol(ureq::Error),
}
impl From<api::ErrorMessage> for Error {
fn from(e: api::ErrorMessage) -> Self {
Error::APIError(e)
Error::API(e)
}
}
@ -211,14 +180,14 @@ impl From<String> for Error {
#[cfg(feature = "async")]
impl From<surf::Error> for Error {
fn from(e: surf::Error) -> Self {
Error::AsyncProtocolError(e)
Error::AsyncProtocol(e)
}
}
#[cfg(feature = "sync")]
impl From<ureq::Error> for Error {
fn from(e: ureq::Error) -> Self {
Error::SyncProtocolError(e)
Error::SyncProtocol(e)
}
}
@ -293,7 +262,7 @@ impl Client {
// Creates a new `Client` given an api token
#[must_use]
pub fn new(token: &str) -> Self {
let base_url = String::from("https://api.openai.com/v1/");
let base_url: String = "https://api.openai.com/v1/".into();
Self {
#[cfg(feature = "async")]
async_client: async_client(token, &base_url),
@ -331,7 +300,7 @@ impl Client {
Ok(response.body_json::<T>().await?)
} else {
let err = response.body_json::<api::ErrorWrapper>().await?.error;
Err(Error::APIError(err))
Err(Error::API(err))
}
}
@ -353,7 +322,7 @@ impl Client {
.into_json_deserialize::<api::ErrorWrapper>()
.expect("Bug: client couldn't deserialize api error response")
.error;
Err(Error::APIError(err))
Err(Error::API(err))
}
}
@ -410,7 +379,7 @@ impl Client {
.await?;
match response.status() {
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
_ => Err(Error::APIError(
_ => Err(Error::API(
response
.body_json::<api::ErrorWrapper>()
.await
@ -436,7 +405,7 @@ impl Client {
200 => Ok(response
.into_json_deserialize()
.expect("Bug: client couldn't deserialize api response")),
_ => Err(Error::APIError(
_ => Err(Error::API(
response
.into_json_deserialize::<api::ErrorWrapper>()
.expect("Bug: client couldn't deserialize api error response")
@ -660,7 +629,7 @@ mod unit {
async_test!(engine_error_response_async, {
let (_m, expected) = mock_engine();
let response = mocked_client().engine("davinci").await;
if let Result::Err(Error::APIError(msg)) = response {
if let Result::Err(Error::API(msg)) = response {
assert_eq!(expected, msg);
}
});
@ -668,7 +637,7 @@ mod unit {
sync_test!(engine_error_response_sync, {
let (_m, expected) = mock_engine();
let response = mocked_client().engine_sync("davinci");
if let Result::Err(Error::APIError(msg)) = response {
if let Result::Err(Error::API(msg)) = response {
assert_eq!(expected, msg);
}
});
@ -782,7 +751,7 @@ mod integration {
T: std::fmt::Debug,
{
match result {
Err(Error::APIError(api::ErrorMessage {
Err(Error::API(api::ErrorMessage {
message,
error_type,
})) => {
@ -846,8 +815,8 @@ mod integration {
});
fn stop_condition_args() -> api::CompletionArgs {
let mut args = api::CompletionArgs::builder();
args.prompt(
api::CompletionArgs::builder()
.prompt(
r#"
Q: Please type `#` now
A:"#,