No more enums (#10)

The strings in the api change enough that enums are more of a hindrance than a help.
This commit is contained in:
Josh Kuhn 2021-01-04 21:58:36 -08:00 committed by GitHub
parent 7d274dfcbd
commit 7b3a2ad5c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 80 deletions

View File

@ -1,7 +1,4 @@
use openai_api::{ use openai_api::{api::CompletionArgs, Client};
api::{CompletionArgs, Engine},
Client,
};
const START_PROMPT: &str = " const START_PROMPT: &str = "
The following is a conversation with an AI assistant. The following is a conversation with an AI assistant.
@ -14,7 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(&api_token); let client = Client::new(&api_token);
let mut context = String::from(START_PROMPT); let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder(); let mut args = CompletionArgs::builder();
args.engine(Engine::Davinci) args.engine("davinci")
.max_tokens(45) .max_tokens(45)
.stop(vec!["\n".into()]) .stop(vec!["\n".into()])
.top_p(0.5) .top_p(0.5)
@ -27,7 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("Error: {}", e); eprintln!("Error: {}", e);
break; break;
} }
context.push_str("\nAI:"); context.push_str("\nAI: ");
match args.prompt(context.as_str()).complete_prompt(&client).await { match args.prompt(context.as_str()).complete_prompt(&client).await {
Ok(completion) => { Ok(completion) => {
println!("\x1b[1;36m{}\x1b[1;0m", completion); println!("\x1b[1;36m{}\x1b[1;0m", completion);

View File

@ -23,48 +23,19 @@ pub mod api {
/// Engine description type /// Engine description type
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)] #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct EngineInfo { pub struct EngineInfo {
pub id: Engine, pub id: String,
pub owner: String, pub owner: String,
pub ready: bool, pub ready: bool,
} }
/// Engine types, known and unknown
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive] // prevent clients from matching on every option
pub enum Engine {
Ada,
Babbage,
Curie,
Davinci,
#[serde(rename = "content-filter-alpha-c4")]
ContentFilter,
#[serde(other)]
Other,
}
// Custom Display to lowercase things
impl std::fmt::Display for Engine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Engine::Ada => f.write_str("ada"),
Engine::Babbage => f.write_str("babbage"),
Engine::Curie => f.write_str("curie"),
Engine::Davinci => f.write_str("davinci"),
Engine::ContentFilter => f.write_str("content-filter-alpha-c4"),
_ => panic!("Can't write out Other engine id"),
}
}
}
/// Options that affect the result /// Options that affect the result
#[derive(Serialize, Debug, Builder, Clone)] #[derive(Serialize, Debug, Builder, Clone)]
pub struct CompletionArgs { pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")] #[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String, prompt: String,
#[builder(default = "Engine::Davinci")] #[builder(setter(into), default = "\"davinci\".into()")]
#[serde(skip_serializing)] #[serde(skip_serializing)]
pub(super) engine: Engine, pub(super) engine: String,
#[builder(default = "16")] #[builder(default = "16")]
max_tokens: u64, max_tokens: u64,
#[builder(default = "1.0")] #[builder(default = "1.0")]
@ -87,11 +58,6 @@ pub mod api {
logit_bias: HashMap<String, f64>, logit_bias: HashMap<String, f64>,
} }
/* {
"stream": false, // SSE streams back results
"best_of": Option<u64>, //cant be used with stream
}
*/
// TODO: add validators for the different arguments // TODO: add validators for the different arguments
impl From<&str> for CompletionArgs { impl From<&str> for CompletionArgs {
@ -173,7 +139,7 @@ pub mod api {
/// If requested, the log probabilities of the completion tokens /// If requested, the log probabilities of the completion tokens
pub logprobs: Option<LogProbs>, pub logprobs: Option<LogProbs>,
/// Why the completion ended when it did /// Why the completion ended when it did
pub finish_reason: FinishReason, pub finish_reason: String,
} }
impl std::fmt::Display for Choice { impl std::fmt::Display for Choice {
@ -191,18 +157,6 @@ pub mod api {
pub text_offset: Vec<u64>, pub text_offset: Vec<u64>,
} }
/// Reason a prompt completion finished.
#[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
#[non_exhaustive]
pub enum FinishReason {
/// The maximum length was reached
#[serde(rename = "length")]
MaxTokensReached,
/// The stop token was encountered
#[serde(rename = "stop")]
StopSequenceReached,
}
/// Error response object from the server /// Error response object from the server
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)] #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct ErrorMessage { pub struct ErrorMessage {
@ -432,12 +386,12 @@ impl Client {
/// # Errors /// # Errors
/// - `Error::APIError` if the server returns an error /// - `Error::APIError` if the server returns an error
#[cfg(feature = "async")] #[cfg(feature = "async")]
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> { pub async fn engine(&self, engine: &str) -> Result<api::EngineInfo> {
self.get(&format!("engines/{}", engine)).await self.get(&format!("engines/{}", engine)).await
} }
#[cfg(feature = "sync")] #[cfg(feature = "sync")]
pub fn engine_sync(&self, engine: api::Engine) -> Result<api::EngineInfo> { pub fn engine_sync(&self, engine: &str) -> Result<api::EngineInfo> {
self.get_sync(&format!("engines/{}", engine)) self.get_sync(&format!("engines/{}", engine))
} }
@ -552,7 +506,7 @@ mod unit {
use mockito::Mock; use mockito::Mock;
use crate::{ use crate::{
api::{self, Completion, CompletionArgs, Engine, EngineInfo}, api::{self, Completion, CompletionArgs, EngineInfo},
Client, Error, Client, Error,
}; };
@ -578,7 +532,7 @@ mod unit {
assert_eq!( assert_eq!(
ei, ei,
api::EngineInfo { api::EngineInfo {
id: api::Engine::Ada, id: "ada".into(),
owner: "openai".into(), owner: "openai".into(),
ready: true, ready: true,
} }
@ -637,32 +591,32 @@ mod unit {
let expected = vec![ let expected = vec![
EngineInfo { EngineInfo {
id: Engine::Ada, id: "ada".into(),
owner: "openai".into(), owner: "openai".into(),
ready: true, ready: true,
}, },
EngineInfo { EngineInfo {
id: Engine::Babbage, id: "babbage".into(),
owner: "openai".into(), owner: "openai".into(),
ready: true, ready: true,
}, },
EngineInfo { EngineInfo {
id: Engine::Other, id: "experimental-engine-v7".into(),
owner: "openai".into(), owner: "openai".into(),
ready: false, ready: false,
}, },
EngineInfo { EngineInfo {
id: Engine::Curie, id: "curie".into(),
owner: "openai".into(), owner: "openai".into(),
ready: true, ready: true,
}, },
EngineInfo { EngineInfo {
id: Engine::Davinci, id: "davinci".into(),
owner: "openai".into(), owner: "openai".into(),
ready: true, ready: true,
}, },
EngineInfo { EngineInfo {
id: Engine::ContentFilter, id: "content-filter-alpha-c4".into(),
owner: "openai".into(), owner: "openai".into(),
ready: true, ready: true,
}, },
@ -705,7 +659,7 @@ mod unit {
async_test!(engine_error_response_async, { async_test!(engine_error_response_async, {
let (_m, expected) = mock_engine(); let (_m, expected) = mock_engine();
let response = mocked_client().engine(api::Engine::Davinci).await; let response = mocked_client().engine("davinci").await;
if let Result::Err(Error::APIError(msg)) = response { if let Result::Err(Error::APIError(msg)) = response {
assert_eq!(expected, msg); assert_eq!(expected, msg);
} }
@ -713,7 +667,7 @@ mod unit {
sync_test!(engine_error_response_sync, { sync_test!(engine_error_response_sync, {
let (_m, expected) = mock_engine(); let (_m, expected) = mock_engine();
let response = mocked_client().engine_sync(api::Engine::Davinci); let response = mocked_client().engine_sync("davinci");
if let Result::Err(Error::APIError(msg)) = response { if let Result::Err(Error::APIError(msg)) = response {
assert_eq!(expected, msg); assert_eq!(expected, msg);
} }
@ -741,7 +695,7 @@ mod unit {
.expect(1) .expect(1)
.create(); .create();
let args = api::CompletionArgs::builder() let args = api::CompletionArgs::builder()
.engine(api::Engine::Davinci) .engine("davinci")
.prompt("Once upon a time") .prompt("Once upon a time")
.max_tokens(5) .max_tokens(5)
.temperature(1.0) .temperature(1.0)
@ -757,7 +711,7 @@ mod unit {
text: " there was a girl who".into(), text: " there was a girl who".into(),
index: 0, index: 0,
logprobs: None, logprobs: None,
finish_reason: api::FinishReason::MaxTokensReached, finish_reason: "length".into(),
}], }],
}; };
Ok((mock, args, expected)) Ok((mock, args, expected))
@ -817,10 +771,10 @@ mod integration {
.into_iter() .into_iter()
.map(|ei| ei.id) .map(|ei| ei.id)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert!(engines.contains(&api::Engine::Ada)); assert!(engines.contains(&"ada".into()));
assert!(engines.contains(&api::Engine::Babbage)); assert!(engines.contains(&"babbage".into()));
assert!(engines.contains(&api::Engine::Curie)); assert!(engines.contains(&"curie".into()));
assert!(engines.contains(&api::Engine::Davinci)); assert!(engines.contains(&"davinci".into()));
}); });
fn assert_expected_engine_failure<T>(result: Result<T, Error>) fn assert_expected_engine_failure<T>(result: Result<T, Error>)
@ -842,12 +796,12 @@ mod integration {
} }
async_test!(can_get_engine_async, { async_test!(can_get_engine_async, {
let client = get_client(); let client = get_client();
assert_expected_engine_failure(client.engine(api::Engine::Ada).await); assert_expected_engine_failure(client.engine("ada").await);
}); });
sync_test!(can_get_engine_sync, { sync_test!(can_get_engine_sync, {
let client = get_client(); let client = get_client();
assert_expected_engine_failure(client.engine_sync(api::Engine::Ada)); assert_expected_engine_failure(client.engine_sync("ada"));
}); });
async_test!(complete_string_async, { async_test!(complete_string_async, {
@ -908,10 +862,7 @@ A:"#,
} }
fn assert_completion_finish_reason(completion: Completion) { fn assert_completion_finish_reason(completion: Completion) {
assert_eq!( assert_eq!(completion.choices[0].finish_reason, "stop",);
completion.choices[0].finish_reason,
api::FinishReason::StopSequenceReached
);
} }
async_test!(complete_stop_condition_async, { async_test!(complete_stop_condition_async, {