From 7b3a2ad5c189d6e2394e14d3bbdafb1e4f49f217 Mon Sep 17 00:00:00 2001 From: Josh Kuhn Date: Mon, 4 Jan 2021 21:58:36 -0800 Subject: [PATCH] No more enums (#10) The strings in the api change enough that enums are more of a hindrance than a help. --- examples/chatloop.rs | 9 ++-- src/lib.rs | 99 +++++++++++--------------------------------- 2 files changed, 28 insertions(+), 80 deletions(-) diff --git a/examples/chatloop.rs b/examples/chatloop.rs index 26cec0a..124afde 100644 --- a/examples/chatloop.rs +++ b/examples/chatloop.rs @@ -1,7 +1,4 @@ -use openai_api::{ - api::{CompletionArgs, Engine}, - Client, -}; +use openai_api::{api::CompletionArgs, Client}; const START_PROMPT: &str = " The following is a conversation with an AI assistant. @@ -14,7 +11,7 @@ async fn main() -> Result<(), Box> { let client = Client::new(&api_token); let mut context = String::from(START_PROMPT); let mut args = CompletionArgs::builder(); - args.engine(Engine::Davinci) + args.engine("davinci") .max_tokens(45) .stop(vec!["\n".into()]) .top_p(0.5) @@ -27,7 +24,7 @@ async fn main() -> Result<(), Box> { eprintln!("Error: {}", e); break; } - context.push_str("\nAI:"); + context.push_str("\nAI: "); match args.prompt(context.as_str()).complete_prompt(&client).await { Ok(completion) => { println!("\x1b[1;36m{}\x1b[1;0m", completion); diff --git a/src/lib.rs b/src/lib.rs index 4840b64..5e4f63a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,48 +23,19 @@ pub mod api { /// Engine description type #[derive(Deserialize, Debug, Eq, PartialEq, Clone)] pub struct EngineInfo { - pub id: Engine, + pub id: String, pub owner: String, pub ready: bool, } - /// Engine types, known and unknown - #[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)] - #[serde(rename_all = "kebab-case")] - #[non_exhaustive] // prevent clients from matching on every option - pub enum Engine { - Ada, - Babbage, - Curie, - Davinci, - #[serde(rename = "content-filter-alpha-c4")] - ContentFilter, - #[serde(other)] - Other, - } - - // Custom Display to lowercase things - impl std::fmt::Display for Engine { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Engine::Ada => f.write_str("ada"), - Engine::Babbage => f.write_str("babbage"), - Engine::Curie => f.write_str("curie"), - Engine::Davinci => f.write_str("davinci"), - Engine::ContentFilter => f.write_str("content-filter-alpha-c4"), - _ => panic!("Can't write out Other engine id"), - } - } - } - /// Options that affect the result #[derive(Serialize, Debug, Builder, Clone)] pub struct CompletionArgs { #[builder(setter(into), default = "\"<|endoftext|>\".into()")] prompt: String, - #[builder(default = "Engine::Davinci")] + #[builder(setter(into), default = "\"davinci\".into()")] #[serde(skip_serializing)] - pub(super) engine: Engine, + pub(super) engine: String, #[builder(default = "16")] max_tokens: u64, #[builder(default = "1.0")] @@ -87,11 +58,6 @@ pub mod api { logit_bias: HashMap, } - /* { - "stream": false, // SSE streams back results - "best_of": Option, //cant be used with stream - } - */ // TODO: add validators for the different arguments impl From<&str> for CompletionArgs { @@ -173,7 +139,7 @@ pub mod api { /// If requested, the log probabilities of the completion tokens pub logprobs: Option, /// Why the completion ended when it did - pub finish_reason: FinishReason, + pub finish_reason: String, } impl std::fmt::Display for Choice { @@ -191,18 +157,6 @@ pub mod api { pub text_offset: Vec, } - /// Reason a prompt completion finished. - #[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)] - #[non_exhaustive] - pub enum FinishReason { - /// The maximum length was reached - #[serde(rename = "length")] - MaxTokensReached, - /// The stop token was encountered - #[serde(rename = "stop")] - StopSequenceReached, - } - /// Error response object from the server #[derive(Deserialize, Debug, Eq, PartialEq, Clone)] pub struct ErrorMessage { @@ -432,12 +386,12 @@ impl Client { /// # Errors /// - `Error::APIError` if the server returns an error #[cfg(feature = "async")] - pub async fn engine(&self, engine: api::Engine) -> Result { + pub async fn engine(&self, engine: &str) -> Result { self.get(&format!("engines/{}", engine)).await } #[cfg(feature = "sync")] - pub fn engine_sync(&self, engine: api::Engine) -> Result { + pub fn engine_sync(&self, engine: &str) -> Result { self.get_sync(&format!("engines/{}", engine)) } @@ -552,7 +506,7 @@ mod unit { use mockito::Mock; use crate::{ - api::{self, Completion, CompletionArgs, Engine, EngineInfo}, + api::{self, Completion, CompletionArgs, EngineInfo}, Client, Error, }; @@ -578,7 +532,7 @@ mod unit { assert_eq!( ei, api::EngineInfo { - id: api::Engine::Ada, + id: "ada".into(), owner: "openai".into(), ready: true, } @@ -637,32 +591,32 @@ mod unit { let expected = vec![ EngineInfo { - id: Engine::Ada, + id: "ada".into(), owner: "openai".into(), ready: true, }, EngineInfo { - id: Engine::Babbage, + id: "babbage".into(), owner: "openai".into(), ready: true, }, EngineInfo { - id: Engine::Other, + id: "experimental-engine-v7".into(), owner: "openai".into(), ready: false, }, EngineInfo { - id: Engine::Curie, + id: "curie".into(), owner: "openai".into(), ready: true, }, EngineInfo { - id: Engine::Davinci, + id: "davinci".into(), owner: "openai".into(), ready: true, }, EngineInfo { - id: Engine::ContentFilter, + id: "content-filter-alpha-c4".into(), owner: "openai".into(), ready: true, }, @@ -705,7 +659,7 @@ mod unit { async_test!(engine_error_response_async, { let (_m, expected) = mock_engine(); - let response = mocked_client().engine(api::Engine::Davinci).await; + let response = mocked_client().engine("davinci").await; if let Result::Err(Error::APIError(msg)) = response { assert_eq!(expected, msg); } @@ -713,7 +667,7 @@ mod unit { sync_test!(engine_error_response_sync, { let (_m, expected) = mock_engine(); - let response = mocked_client().engine_sync(api::Engine::Davinci); + let response = mocked_client().engine_sync("davinci"); if let Result::Err(Error::APIError(msg)) = response { assert_eq!(expected, msg); } @@ -741,7 +695,7 @@ mod unit { .expect(1) .create(); let args = api::CompletionArgs::builder() - .engine(api::Engine::Davinci) + .engine("davinci") .prompt("Once upon a time") .max_tokens(5) .temperature(1.0) @@ -757,7 +711,7 @@ mod unit { text: " there was a girl who".into(), index: 0, logprobs: None, - finish_reason: api::FinishReason::MaxTokensReached, + finish_reason: "length".into(), }], }; Ok((mock, args, expected)) @@ -817,10 +771,10 @@ mod integration { .into_iter() .map(|ei| ei.id) .collect::>(); - assert!(engines.contains(&api::Engine::Ada)); - assert!(engines.contains(&api::Engine::Babbage)); - assert!(engines.contains(&api::Engine::Curie)); - assert!(engines.contains(&api::Engine::Davinci)); + assert!(engines.contains(&"ada".into())); + assert!(engines.contains(&"babbage".into())); + assert!(engines.contains(&"curie".into())); + assert!(engines.contains(&"davinci".into())); }); fn assert_expected_engine_failure(result: Result) @@ -842,12 +796,12 @@ mod integration { } async_test!(can_get_engine_async, { let client = get_client(); - assert_expected_engine_failure(client.engine(api::Engine::Ada).await); + assert_expected_engine_failure(client.engine("ada").await); }); sync_test!(can_get_engine_sync, { let client = get_client(); - assert_expected_engine_failure(client.engine_sync(api::Engine::Ada)); + assert_expected_engine_failure(client.engine_sync("ada")); }); async_test!(complete_string_async, { @@ -908,10 +862,7 @@ A:"#, } fn assert_completion_finish_reason(completion: Completion) { - assert_eq!( - completion.choices[0].finish_reason, - api::FinishReason::StopSequenceReached - ); + assert_eq!(completion.choices[0].finish_reason, "stop",); } async_test!(complete_stop_condition_async, {