No more enums (#10)

The strings in the api change enough that enums are more of a hindrance than a help.
This commit is contained in:
Josh Kuhn 2021-01-04 21:58:36 -08:00 committed by GitHub
parent 7d274dfcbd
commit 7b3a2ad5c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 80 deletions

View File

@ -1,7 +1,4 @@
use openai_api::{
api::{CompletionArgs, Engine},
Client,
};
use openai_api::{api::CompletionArgs, Client};
const START_PROMPT: &str = "
The following is a conversation with an AI assistant.
@ -14,7 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(&api_token);
let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder();
args.engine(Engine::Davinci)
args.engine("davinci")
.max_tokens(45)
.stop(vec!["\n".into()])
.top_p(0.5)
@ -27,7 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("Error: {}", e);
break;
}
context.push_str("\nAI:");
context.push_str("\nAI: ");
match args.prompt(context.as_str()).complete_prompt(&client).await {
Ok(completion) => {
println!("\x1b[1;36m{}\x1b[1;0m", completion);

View File

@ -23,48 +23,19 @@ pub mod api {
/// Engine description type
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct EngineInfo {
pub id: Engine,
pub id: String,
pub owner: String,
pub ready: bool,
}
/// Engine types, known and unknown
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive] // prevent clients from matching on every option
pub enum Engine {
Ada,
Babbage,
Curie,
Davinci,
#[serde(rename = "content-filter-alpha-c4")]
ContentFilter,
#[serde(other)]
Other,
}
// Custom Display to lowercase things
impl std::fmt::Display for Engine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Engine::Ada => f.write_str("ada"),
Engine::Babbage => f.write_str("babbage"),
Engine::Curie => f.write_str("curie"),
Engine::Davinci => f.write_str("davinci"),
Engine::ContentFilter => f.write_str("content-filter-alpha-c4"),
_ => panic!("Can't write out Other engine id"),
}
}
}
/// Options that affect the result
#[derive(Serialize, Debug, Builder, Clone)]
pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String,
#[builder(default = "Engine::Davinci")]
#[builder(setter(into), default = "\"davinci\".into()")]
#[serde(skip_serializing)]
pub(super) engine: Engine,
pub(super) engine: String,
#[builder(default = "16")]
max_tokens: u64,
#[builder(default = "1.0")]
@ -87,11 +58,6 @@ pub mod api {
logit_bias: HashMap<String, f64>,
}
/* {
"stream": false, // SSE streams back results
"best_of": Option<u64>, //cant be used with stream
}
*/
// TODO: add validators for the different arguments
impl From<&str> for CompletionArgs {
@ -173,7 +139,7 @@ pub mod api {
/// If requested, the log probabilities of the completion tokens
pub logprobs: Option<LogProbs>,
/// Why the completion ended when it did
pub finish_reason: FinishReason,
pub finish_reason: String,
}
impl std::fmt::Display for Choice {
@ -191,18 +157,6 @@ pub mod api {
pub text_offset: Vec<u64>,
}
/// Reason a prompt completion finished.
#[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
#[non_exhaustive]
pub enum FinishReason {
/// The maximum length was reached
#[serde(rename = "length")]
MaxTokensReached,
/// The stop token was encountered
#[serde(rename = "stop")]
StopSequenceReached,
}
/// Error response object from the server
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct ErrorMessage {
@ -432,12 +386,12 @@ impl Client {
/// # Errors
/// - `Error::APIError` if the server returns an error
#[cfg(feature = "async")]
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
pub async fn engine(&self, engine: &str) -> Result<api::EngineInfo> {
self.get(&format!("engines/{}", engine)).await
}
#[cfg(feature = "sync")]
pub fn engine_sync(&self, engine: api::Engine) -> Result<api::EngineInfo> {
pub fn engine_sync(&self, engine: &str) -> Result<api::EngineInfo> {
self.get_sync(&format!("engines/{}", engine))
}
@ -552,7 +506,7 @@ mod unit {
use mockito::Mock;
use crate::{
api::{self, Completion, CompletionArgs, Engine, EngineInfo},
api::{self, Completion, CompletionArgs, EngineInfo},
Client, Error,
};
@ -578,7 +532,7 @@ mod unit {
assert_eq!(
ei,
api::EngineInfo {
id: api::Engine::Ada,
id: "ada".into(),
owner: "openai".into(),
ready: true,
}
@ -637,32 +591,32 @@ mod unit {
let expected = vec![
EngineInfo {
id: Engine::Ada,
id: "ada".into(),
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Babbage,
id: "babbage".into(),
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Other,
id: "experimental-engine-v7".into(),
owner: "openai".into(),
ready: false,
},
EngineInfo {
id: Engine::Curie,
id: "curie".into(),
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Davinci,
id: "davinci".into(),
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::ContentFilter,
id: "content-filter-alpha-c4".into(),
owner: "openai".into(),
ready: true,
},
@ -705,7 +659,7 @@ mod unit {
async_test!(engine_error_response_async, {
let (_m, expected) = mock_engine();
let response = mocked_client().engine(api::Engine::Davinci).await;
let response = mocked_client().engine("davinci").await;
if let Result::Err(Error::APIError(msg)) = response {
assert_eq!(expected, msg);
}
@ -713,7 +667,7 @@ mod unit {
sync_test!(engine_error_response_sync, {
let (_m, expected) = mock_engine();
let response = mocked_client().engine_sync(api::Engine::Davinci);
let response = mocked_client().engine_sync("davinci");
if let Result::Err(Error::APIError(msg)) = response {
assert_eq!(expected, msg);
}
@ -741,7 +695,7 @@ mod unit {
.expect(1)
.create();
let args = api::CompletionArgs::builder()
.engine(api::Engine::Davinci)
.engine("davinci")
.prompt("Once upon a time")
.max_tokens(5)
.temperature(1.0)
@ -757,7 +711,7 @@ mod unit {
text: " there was a girl who".into(),
index: 0,
logprobs: None,
finish_reason: api::FinishReason::MaxTokensReached,
finish_reason: "length".into(),
}],
};
Ok((mock, args, expected))
@ -817,10 +771,10 @@ mod integration {
.into_iter()
.map(|ei| ei.id)
.collect::<Vec<_>>();
assert!(engines.contains(&api::Engine::Ada));
assert!(engines.contains(&api::Engine::Babbage));
assert!(engines.contains(&api::Engine::Curie));
assert!(engines.contains(&api::Engine::Davinci));
assert!(engines.contains(&"ada".into()));
assert!(engines.contains(&"babbage".into()));
assert!(engines.contains(&"curie".into()));
assert!(engines.contains(&"davinci".into()));
});
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
@ -842,12 +796,12 @@ mod integration {
}
async_test!(can_get_engine_async, {
let client = get_client();
assert_expected_engine_failure(client.engine(api::Engine::Ada).await);
assert_expected_engine_failure(client.engine("ada").await);
});
sync_test!(can_get_engine_sync, {
let client = get_client();
assert_expected_engine_failure(client.engine_sync(api::Engine::Ada));
assert_expected_engine_failure(client.engine_sync("ada"));
});
async_test!(complete_string_async, {
@ -908,10 +862,7 @@ A:"#,
}
fn assert_completion_finish_reason(completion: Completion) {
assert_eq!(
completion.choices[0].finish_reason,
api::FinishReason::StopSequenceReached
);
assert_eq!(completion.choices[0].finish_reason, "stop",);
}
async_test!(complete_stop_condition_async, {