Immutable setter (#11)
Immutable builder makes things a bit less awkward
This commit is contained in:
parent
c13b019ee9
commit
68a7b500f5
@ -10,8 +10,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let api_token = std::env::var("OPENAI_SK")?;
|
let api_token = std::env::var("OPENAI_SK")?;
|
||||||
let client = Client::new(&api_token);
|
let client = Client::new(&api_token);
|
||||||
let mut context = String::from(START_PROMPT);
|
let mut context = String::from(START_PROMPT);
|
||||||
let mut args = CompletionArgs::builder();
|
let args = CompletionArgs::builder()
|
||||||
args.engine("davinci")
|
.engine("davinci")
|
||||||
.max_tokens(45)
|
.max_tokens(45)
|
||||||
.stop(vec!["\n".into()])
|
.stop(vec!["\n".into()])
|
||||||
.top_p(0.5)
|
.top_p(0.5)
|
||||||
@ -25,7 +25,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
context.push_str("\nAI: ");
|
context.push_str("\nAI: ");
|
||||||
match args.prompt(context.as_str()).complete_prompt(&client).await {
|
match client
|
||||||
|
.complete_prompt(args.prompt(context.as_str()).build()?)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(completion) => {
|
Ok(completion) => {
|
||||||
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
||||||
context.push_str(&completion.choices[0].text);
|
context.push_str(&completion.choices[0].text);
|
||||||
|
83
src/lib.rs
83
src/lib.rs
@ -11,7 +11,6 @@ pub mod api {
|
|||||||
//! Data types corresponding to requests and responses from the API
|
//! Data types corresponding to requests and responses from the API
|
||||||
use std::{collections::HashMap, fmt::Display};
|
use std::{collections::HashMap, fmt::Display};
|
||||||
|
|
||||||
use super::Client;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Container type. Used in the api, but not useful for clients of this library
|
/// Container type. Used in the api, but not useful for clients of this library
|
||||||
@ -30,6 +29,7 @@ pub mod api {
|
|||||||
|
|
||||||
/// Options that affect the result
|
/// Options that affect the result
|
||||||
#[derive(Serialize, Debug, Builder, Clone)]
|
#[derive(Serialize, Debug, Builder, Clone)]
|
||||||
|
#[builder(pattern = "immutable")]
|
||||||
pub struct CompletionArgs {
|
pub struct CompletionArgs {
|
||||||
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
@ -77,37 +77,6 @@ pub mod api {
|
|||||||
pub fn builder() -> CompletionArgsBuilder {
|
pub fn builder() -> CompletionArgsBuilder {
|
||||||
CompletionArgsBuilder::default()
|
CompletionArgsBuilder::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request a completion from the api
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
/// `Error::APIError` if the api returns an error
|
|
||||||
#[cfg(feature = "async")]
|
|
||||||
pub async fn complete_prompt(self, client: &Client) -> super::Result<Completion> {
|
|
||||||
client.complete_prompt(self).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "sync")]
|
|
||||||
pub fn complete_prompt_sync(self, client: &Client) -> super::Result<Completion> {
|
|
||||||
client.complete_prompt_sync(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionArgsBuilder {
|
|
||||||
/// Request a completion from the api
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
/// `Error::BadArguments` if the arguments to complete are not valid
|
|
||||||
/// `Error::APIError` if the api returns an error
|
|
||||||
#[cfg(feature = "async")]
|
|
||||||
pub async fn complete_prompt(&self, client: &Client) -> super::Result<Completion> {
|
|
||||||
client.complete_prompt(self.build()?).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "sync")]
|
|
||||||
pub fn complete_prompt_sync(&self, client: &Client) -> super::Result<Completion> {
|
|
||||||
client.complete_prompt_sync(self.build()?)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a non-streamed completion response
|
/// Represents a non-streamed completion response
|
||||||
@ -183,22 +152,22 @@ pub mod api {
|
|||||||
pub enum Error {
|
pub enum Error {
|
||||||
/// An error returned by the API itself
|
/// An error returned by the API itself
|
||||||
#[error("API returned an Error: {}", .0.message)]
|
#[error("API returned an Error: {}", .0.message)]
|
||||||
APIError(api::ErrorMessage),
|
API(api::ErrorMessage),
|
||||||
/// An error the client discovers before talking to the API
|
/// An error the client discovers before talking to the API
|
||||||
#[error("Bad arguments: {0}")]
|
#[error("Bad arguments: {0}")]
|
||||||
BadArguments(String),
|
BadArguments(String),
|
||||||
/// Network / protocol related errors
|
/// Network / protocol related errors
|
||||||
#[cfg(feature = "async")]
|
#[cfg(feature = "async")]
|
||||||
#[error("Error at the protocol level: {0}")]
|
#[error("Error at the protocol level: {0}")]
|
||||||
AsyncProtocolError(surf::Error),
|
AsyncProtocol(surf::Error),
|
||||||
#[cfg(feature = "sync")]
|
#[cfg(feature = "sync")]
|
||||||
#[error("Error at the protocol level, sync client")]
|
#[error("Error at the protocol level, sync client")]
|
||||||
SyncProtocolError(ureq::Error),
|
SyncProtocol(ureq::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<api::ErrorMessage> for Error {
|
impl From<api::ErrorMessage> for Error {
|
||||||
fn from(e: api::ErrorMessage) -> Self {
|
fn from(e: api::ErrorMessage) -> Self {
|
||||||
Error::APIError(e)
|
Error::API(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,14 +180,14 @@ impl From<String> for Error {
|
|||||||
#[cfg(feature = "async")]
|
#[cfg(feature = "async")]
|
||||||
impl From<surf::Error> for Error {
|
impl From<surf::Error> for Error {
|
||||||
fn from(e: surf::Error) -> Self {
|
fn from(e: surf::Error) -> Self {
|
||||||
Error::AsyncProtocolError(e)
|
Error::AsyncProtocol(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "sync")]
|
#[cfg(feature = "sync")]
|
||||||
impl From<ureq::Error> for Error {
|
impl From<ureq::Error> for Error {
|
||||||
fn from(e: ureq::Error) -> Self {
|
fn from(e: ureq::Error) -> Self {
|
||||||
Error::SyncProtocolError(e)
|
Error::SyncProtocol(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,7 +262,7 @@ impl Client {
|
|||||||
// Creates a new `Client` given an api token
|
// Creates a new `Client` given an api token
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(token: &str) -> Self {
|
pub fn new(token: &str) -> Self {
|
||||||
let base_url = String::from("https://api.openai.com/v1/");
|
let base_url: String = "https://api.openai.com/v1/".into();
|
||||||
Self {
|
Self {
|
||||||
#[cfg(feature = "async")]
|
#[cfg(feature = "async")]
|
||||||
async_client: async_client(token, &base_url),
|
async_client: async_client(token, &base_url),
|
||||||
@ -331,7 +300,7 @@ impl Client {
|
|||||||
Ok(response.body_json::<T>().await?)
|
Ok(response.body_json::<T>().await?)
|
||||||
} else {
|
} else {
|
||||||
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
||||||
Err(Error::APIError(err))
|
Err(Error::API(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -353,7 +322,7 @@ impl Client {
|
|||||||
.into_json_deserialize::<api::ErrorWrapper>()
|
.into_json_deserialize::<api::ErrorWrapper>()
|
||||||
.expect("Bug: client couldn't deserialize api error response")
|
.expect("Bug: client couldn't deserialize api error response")
|
||||||
.error;
|
.error;
|
||||||
Err(Error::APIError(err))
|
Err(Error::API(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -410,7 +379,7 @@ impl Client {
|
|||||||
.await?;
|
.await?;
|
||||||
match response.status() {
|
match response.status() {
|
||||||
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
||||||
_ => Err(Error::APIError(
|
_ => Err(Error::API(
|
||||||
response
|
response
|
||||||
.body_json::<api::ErrorWrapper>()
|
.body_json::<api::ErrorWrapper>()
|
||||||
.await
|
.await
|
||||||
@ -436,7 +405,7 @@ impl Client {
|
|||||||
200 => Ok(response
|
200 => Ok(response
|
||||||
.into_json_deserialize()
|
.into_json_deserialize()
|
||||||
.expect("Bug: client couldn't deserialize api response")),
|
.expect("Bug: client couldn't deserialize api response")),
|
||||||
_ => Err(Error::APIError(
|
_ => Err(Error::API(
|
||||||
response
|
response
|
||||||
.into_json_deserialize::<api::ErrorWrapper>()
|
.into_json_deserialize::<api::ErrorWrapper>()
|
||||||
.expect("Bug: client couldn't deserialize api error response")
|
.expect("Bug: client couldn't deserialize api error response")
|
||||||
@ -660,7 +629,7 @@ mod unit {
|
|||||||
async_test!(engine_error_response_async, {
|
async_test!(engine_error_response_async, {
|
||||||
let (_m, expected) = mock_engine();
|
let (_m, expected) = mock_engine();
|
||||||
let response = mocked_client().engine("davinci").await;
|
let response = mocked_client().engine("davinci").await;
|
||||||
if let Result::Err(Error::APIError(msg)) = response {
|
if let Result::Err(Error::API(msg)) = response {
|
||||||
assert_eq!(expected, msg);
|
assert_eq!(expected, msg);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -668,7 +637,7 @@ mod unit {
|
|||||||
sync_test!(engine_error_response_sync, {
|
sync_test!(engine_error_response_sync, {
|
||||||
let (_m, expected) = mock_engine();
|
let (_m, expected) = mock_engine();
|
||||||
let response = mocked_client().engine_sync("davinci");
|
let response = mocked_client().engine_sync("davinci");
|
||||||
if let Result::Err(Error::APIError(msg)) = response {
|
if let Result::Err(Error::API(msg)) = response {
|
||||||
assert_eq!(expected, msg);
|
assert_eq!(expected, msg);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -782,7 +751,7 @@ mod integration {
|
|||||||
T: std::fmt::Debug,
|
T: std::fmt::Debug,
|
||||||
{
|
{
|
||||||
match result {
|
match result {
|
||||||
Err(Error::APIError(api::ErrorMessage {
|
Err(Error::API(api::ErrorMessage {
|
||||||
message,
|
message,
|
||||||
error_type,
|
error_type,
|
||||||
})) => {
|
})) => {
|
||||||
@ -846,19 +815,19 @@ mod integration {
|
|||||||
});
|
});
|
||||||
|
|
||||||
fn stop_condition_args() -> api::CompletionArgs {
|
fn stop_condition_args() -> api::CompletionArgs {
|
||||||
let mut args = api::CompletionArgs::builder();
|
api::CompletionArgs::builder()
|
||||||
args.prompt(
|
.prompt(
|
||||||
r#"
|
r#"
|
||||||
Q: Please type `#` now
|
Q: Please type `#` now
|
||||||
A:"#,
|
A:"#,
|
||||||
)
|
)
|
||||||
// turn temp & top_p way down to prevent test flakiness
|
// turn temp & top_p way down to prevent test flakiness
|
||||||
.temperature(0.0)
|
.temperature(0.0)
|
||||||
.top_p(0.0)
|
.top_p(0.0)
|
||||||
.max_tokens(100)
|
.max_tokens(100)
|
||||||
.stop(vec!["#".into(), "\n".into()])
|
.stop(vec!["#".into(), "\n".into()])
|
||||||
.build()
|
.build()
|
||||||
.expect("Bug: build should succeed")
|
.expect("Bug: build should succeed")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assert_completion_finish_reason(completion: Completion) {
|
fn assert_completion_finish_reason(completion: Completion) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user