Immutable setter (#11)
Immutable builder makes things a bit less awkward
This commit is contained in:
parent
c13b019ee9
commit
68a7b500f5
@ -10,8 +10,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let api_token = std::env::var("OPENAI_SK")?;
|
||||
let client = Client::new(&api_token);
|
||||
let mut context = String::from(START_PROMPT);
|
||||
let mut args = CompletionArgs::builder();
|
||||
args.engine("davinci")
|
||||
let args = CompletionArgs::builder()
|
||||
.engine("davinci")
|
||||
.max_tokens(45)
|
||||
.stop(vec!["\n".into()])
|
||||
.top_p(0.5)
|
||||
@ -25,7 +25,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
break;
|
||||
}
|
||||
context.push_str("\nAI: ");
|
||||
match args.prompt(context.as_str()).complete_prompt(&client).await {
|
||||
match client
|
||||
.complete_prompt(args.prompt(context.as_str()).build()?)
|
||||
.await
|
||||
{
|
||||
Ok(completion) => {
|
||||
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
||||
context.push_str(&completion.choices[0].text);
|
||||
|
65
src/lib.rs
65
src/lib.rs
@ -11,7 +11,6 @@ pub mod api {
|
||||
//! Data types corresponding to requests and responses from the API
|
||||
use std::{collections::HashMap, fmt::Display};
|
||||
|
||||
use super::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Container type. Used in the api, but not useful for clients of this library
|
||||
@ -30,6 +29,7 @@ pub mod api {
|
||||
|
||||
/// Options that affect the result
|
||||
#[derive(Serialize, Debug, Builder, Clone)]
|
||||
#[builder(pattern = "immutable")]
|
||||
pub struct CompletionArgs {
|
||||
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
||||
prompt: String,
|
||||
@ -77,37 +77,6 @@ pub mod api {
|
||||
pub fn builder() -> CompletionArgsBuilder {
|
||||
CompletionArgsBuilder::default()
|
||||
}
|
||||
|
||||
/// Request a completion from the api
|
||||
///
|
||||
/// # Errors
|
||||
/// `Error::APIError` if the api returns an error
|
||||
#[cfg(feature = "async")]
|
||||
pub async fn complete_prompt(self, client: &Client) -> super::Result<Completion> {
|
||||
client.complete_prompt(self).await
|
||||
}
|
||||
|
||||
#[cfg(feature = "sync")]
|
||||
pub fn complete_prompt_sync(self, client: &Client) -> super::Result<Completion> {
|
||||
client.complete_prompt_sync(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionArgsBuilder {
|
||||
/// Request a completion from the api
|
||||
///
|
||||
/// # Errors
|
||||
/// `Error::BadArguments` if the arguments to complete are not valid
|
||||
/// `Error::APIError` if the api returns an error
|
||||
#[cfg(feature = "async")]
|
||||
pub async fn complete_prompt(&self, client: &Client) -> super::Result<Completion> {
|
||||
client.complete_prompt(self.build()?).await
|
||||
}
|
||||
|
||||
#[cfg(feature = "sync")]
|
||||
pub fn complete_prompt_sync(&self, client: &Client) -> super::Result<Completion> {
|
||||
client.complete_prompt_sync(self.build()?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a non-streamed completion response
|
||||
@ -183,22 +152,22 @@ pub mod api {
|
||||
pub enum Error {
|
||||
/// An error returned by the API itself
|
||||
#[error("API returned an Error: {}", .0.message)]
|
||||
APIError(api::ErrorMessage),
|
||||
API(api::ErrorMessage),
|
||||
/// An error the client discovers before talking to the API
|
||||
#[error("Bad arguments: {0}")]
|
||||
BadArguments(String),
|
||||
/// Network / protocol related errors
|
||||
#[cfg(feature = "async")]
|
||||
#[error("Error at the protocol level: {0}")]
|
||||
AsyncProtocolError(surf::Error),
|
||||
AsyncProtocol(surf::Error),
|
||||
#[cfg(feature = "sync")]
|
||||
#[error("Error at the protocol level, sync client")]
|
||||
SyncProtocolError(ureq::Error),
|
||||
SyncProtocol(ureq::Error),
|
||||
}
|
||||
|
||||
impl From<api::ErrorMessage> for Error {
|
||||
fn from(e: api::ErrorMessage) -> Self {
|
||||
Error::APIError(e)
|
||||
Error::API(e)
|
||||
}
|
||||
}
|
||||
|
||||
@ -211,14 +180,14 @@ impl From<String> for Error {
|
||||
#[cfg(feature = "async")]
|
||||
impl From<surf::Error> for Error {
|
||||
fn from(e: surf::Error) -> Self {
|
||||
Error::AsyncProtocolError(e)
|
||||
Error::AsyncProtocol(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "sync")]
|
||||
impl From<ureq::Error> for Error {
|
||||
fn from(e: ureq::Error) -> Self {
|
||||
Error::SyncProtocolError(e)
|
||||
Error::SyncProtocol(e)
|
||||
}
|
||||
}
|
||||
|
||||
@ -293,7 +262,7 @@ impl Client {
|
||||
// Creates a new `Client` given an api token
|
||||
#[must_use]
|
||||
pub fn new(token: &str) -> Self {
|
||||
let base_url = String::from("https://api.openai.com/v1/");
|
||||
let base_url: String = "https://api.openai.com/v1/".into();
|
||||
Self {
|
||||
#[cfg(feature = "async")]
|
||||
async_client: async_client(token, &base_url),
|
||||
@ -331,7 +300,7 @@ impl Client {
|
||||
Ok(response.body_json::<T>().await?)
|
||||
} else {
|
||||
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
||||
Err(Error::APIError(err))
|
||||
Err(Error::API(err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -353,7 +322,7 @@ impl Client {
|
||||
.into_json_deserialize::<api::ErrorWrapper>()
|
||||
.expect("Bug: client couldn't deserialize api error response")
|
||||
.error;
|
||||
Err(Error::APIError(err))
|
||||
Err(Error::API(err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -410,7 +379,7 @@ impl Client {
|
||||
.await?;
|
||||
match response.status() {
|
||||
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
||||
_ => Err(Error::APIError(
|
||||
_ => Err(Error::API(
|
||||
response
|
||||
.body_json::<api::ErrorWrapper>()
|
||||
.await
|
||||
@ -436,7 +405,7 @@ impl Client {
|
||||
200 => Ok(response
|
||||
.into_json_deserialize()
|
||||
.expect("Bug: client couldn't deserialize api response")),
|
||||
_ => Err(Error::APIError(
|
||||
_ => Err(Error::API(
|
||||
response
|
||||
.into_json_deserialize::<api::ErrorWrapper>()
|
||||
.expect("Bug: client couldn't deserialize api error response")
|
||||
@ -660,7 +629,7 @@ mod unit {
|
||||
async_test!(engine_error_response_async, {
|
||||
let (_m, expected) = mock_engine();
|
||||
let response = mocked_client().engine("davinci").await;
|
||||
if let Result::Err(Error::APIError(msg)) = response {
|
||||
if let Result::Err(Error::API(msg)) = response {
|
||||
assert_eq!(expected, msg);
|
||||
}
|
||||
});
|
||||
@ -668,7 +637,7 @@ mod unit {
|
||||
sync_test!(engine_error_response_sync, {
|
||||
let (_m, expected) = mock_engine();
|
||||
let response = mocked_client().engine_sync("davinci");
|
||||
if let Result::Err(Error::APIError(msg)) = response {
|
||||
if let Result::Err(Error::API(msg)) = response {
|
||||
assert_eq!(expected, msg);
|
||||
}
|
||||
});
|
||||
@ -782,7 +751,7 @@ mod integration {
|
||||
T: std::fmt::Debug,
|
||||
{
|
||||
match result {
|
||||
Err(Error::APIError(api::ErrorMessage {
|
||||
Err(Error::API(api::ErrorMessage {
|
||||
message,
|
||||
error_type,
|
||||
})) => {
|
||||
@ -846,8 +815,8 @@ mod integration {
|
||||
});
|
||||
|
||||
fn stop_condition_args() -> api::CompletionArgs {
|
||||
let mut args = api::CompletionArgs::builder();
|
||||
args.prompt(
|
||||
api::CompletionArgs::builder()
|
||||
.prompt(
|
||||
r#"
|
||||
Q: Please type `#` now
|
||||
A:"#,
|
||||
|
Loading…
x
Reference in New Issue
Block a user