No more enums (#10)
The strings in the api change enough that enums are more of a hindrance than a help.
This commit is contained in:
parent
7d274dfcbd
commit
7b3a2ad5c1
@ -1,7 +1,4 @@
|
||||
use openai_api::{
|
||||
api::{CompletionArgs, Engine},
|
||||
Client,
|
||||
};
|
||||
use openai_api::{api::CompletionArgs, Client};
|
||||
|
||||
const START_PROMPT: &str = "
|
||||
The following is a conversation with an AI assistant.
|
||||
@ -14,7 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = Client::new(&api_token);
|
||||
let mut context = String::from(START_PROMPT);
|
||||
let mut args = CompletionArgs::builder();
|
||||
args.engine(Engine::Davinci)
|
||||
args.engine("davinci")
|
||||
.max_tokens(45)
|
||||
.stop(vec!["\n".into()])
|
||||
.top_p(0.5)
|
||||
@ -27,7 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
eprintln!("Error: {}", e);
|
||||
break;
|
||||
}
|
||||
context.push_str("\nAI:");
|
||||
context.push_str("\nAI: ");
|
||||
match args.prompt(context.as_str()).complete_prompt(&client).await {
|
||||
Ok(completion) => {
|
||||
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
||||
|
99
src/lib.rs
99
src/lib.rs
@ -23,48 +23,19 @@ pub mod api {
|
||||
/// Engine description type
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct EngineInfo {
|
||||
pub id: Engine,
|
||||
pub id: String,
|
||||
pub owner: String,
|
||||
pub ready: bool,
|
||||
}
|
||||
|
||||
/// Engine types, known and unknown
|
||||
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[non_exhaustive] // prevent clients from matching on every option
|
||||
pub enum Engine {
|
||||
Ada,
|
||||
Babbage,
|
||||
Curie,
|
||||
Davinci,
|
||||
#[serde(rename = "content-filter-alpha-c4")]
|
||||
ContentFilter,
|
||||
#[serde(other)]
|
||||
Other,
|
||||
}
|
||||
|
||||
// Custom Display to lowercase things
|
||||
impl std::fmt::Display for Engine {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Engine::Ada => f.write_str("ada"),
|
||||
Engine::Babbage => f.write_str("babbage"),
|
||||
Engine::Curie => f.write_str("curie"),
|
||||
Engine::Davinci => f.write_str("davinci"),
|
||||
Engine::ContentFilter => f.write_str("content-filter-alpha-c4"),
|
||||
_ => panic!("Can't write out Other engine id"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Options that affect the result
|
||||
#[derive(Serialize, Debug, Builder, Clone)]
|
||||
pub struct CompletionArgs {
|
||||
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
||||
prompt: String,
|
||||
#[builder(default = "Engine::Davinci")]
|
||||
#[builder(setter(into), default = "\"davinci\".into()")]
|
||||
#[serde(skip_serializing)]
|
||||
pub(super) engine: Engine,
|
||||
pub(super) engine: String,
|
||||
#[builder(default = "16")]
|
||||
max_tokens: u64,
|
||||
#[builder(default = "1.0")]
|
||||
@ -87,11 +58,6 @@ pub mod api {
|
||||
logit_bias: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
/* {
|
||||
"stream": false, // SSE streams back results
|
||||
"best_of": Option<u64>, //cant be used with stream
|
||||
}
|
||||
*/
|
||||
// TODO: add validators for the different arguments
|
||||
|
||||
impl From<&str> for CompletionArgs {
|
||||
@ -173,7 +139,7 @@ pub mod api {
|
||||
/// If requested, the log probabilities of the completion tokens
|
||||
pub logprobs: Option<LogProbs>,
|
||||
/// Why the completion ended when it did
|
||||
pub finish_reason: FinishReason,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Choice {
|
||||
@ -191,18 +157,6 @@ pub mod api {
|
||||
pub text_offset: Vec<u64>,
|
||||
}
|
||||
|
||||
/// Reason a prompt completion finished.
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
|
||||
#[non_exhaustive]
|
||||
pub enum FinishReason {
|
||||
/// The maximum length was reached
|
||||
#[serde(rename = "length")]
|
||||
MaxTokensReached,
|
||||
/// The stop token was encountered
|
||||
#[serde(rename = "stop")]
|
||||
StopSequenceReached,
|
||||
}
|
||||
|
||||
/// Error response object from the server
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ErrorMessage {
|
||||
@ -432,12 +386,12 @@ impl Client {
|
||||
/// # Errors
|
||||
/// - `Error::APIError` if the server returns an error
|
||||
#[cfg(feature = "async")]
|
||||
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
|
||||
pub async fn engine(&self, engine: &str) -> Result<api::EngineInfo> {
|
||||
self.get(&format!("engines/{}", engine)).await
|
||||
}
|
||||
|
||||
#[cfg(feature = "sync")]
|
||||
pub fn engine_sync(&self, engine: api::Engine) -> Result<api::EngineInfo> {
|
||||
pub fn engine_sync(&self, engine: &str) -> Result<api::EngineInfo> {
|
||||
self.get_sync(&format!("engines/{}", engine))
|
||||
}
|
||||
|
||||
@ -552,7 +506,7 @@ mod unit {
|
||||
use mockito::Mock;
|
||||
|
||||
use crate::{
|
||||
api::{self, Completion, CompletionArgs, Engine, EngineInfo},
|
||||
api::{self, Completion, CompletionArgs, EngineInfo},
|
||||
Client, Error,
|
||||
};
|
||||
|
||||
@ -578,7 +532,7 @@ mod unit {
|
||||
assert_eq!(
|
||||
ei,
|
||||
api::EngineInfo {
|
||||
id: api::Engine::Ada,
|
||||
id: "ada".into(),
|
||||
owner: "openai".into(),
|
||||
ready: true,
|
||||
}
|
||||
@ -637,32 +591,32 @@ mod unit {
|
||||
|
||||
let expected = vec![
|
||||
EngineInfo {
|
||||
id: Engine::Ada,
|
||||
id: "ada".into(),
|
||||
owner: "openai".into(),
|
||||
ready: true,
|
||||
},
|
||||
EngineInfo {
|
||||
id: Engine::Babbage,
|
||||
id: "babbage".into(),
|
||||
owner: "openai".into(),
|
||||
ready: true,
|
||||
},
|
||||
EngineInfo {
|
||||
id: Engine::Other,
|
||||
id: "experimental-engine-v7".into(),
|
||||
owner: "openai".into(),
|
||||
ready: false,
|
||||
},
|
||||
EngineInfo {
|
||||
id: Engine::Curie,
|
||||
id: "curie".into(),
|
||||
owner: "openai".into(),
|
||||
ready: true,
|
||||
},
|
||||
EngineInfo {
|
||||
id: Engine::Davinci,
|
||||
id: "davinci".into(),
|
||||
owner: "openai".into(),
|
||||
ready: true,
|
||||
},
|
||||
EngineInfo {
|
||||
id: Engine::ContentFilter,
|
||||
id: "content-filter-alpha-c4".into(),
|
||||
owner: "openai".into(),
|
||||
ready: true,
|
||||
},
|
||||
@ -705,7 +659,7 @@ mod unit {
|
||||
|
||||
async_test!(engine_error_response_async, {
|
||||
let (_m, expected) = mock_engine();
|
||||
let response = mocked_client().engine(api::Engine::Davinci).await;
|
||||
let response = mocked_client().engine("davinci").await;
|
||||
if let Result::Err(Error::APIError(msg)) = response {
|
||||
assert_eq!(expected, msg);
|
||||
}
|
||||
@ -713,7 +667,7 @@ mod unit {
|
||||
|
||||
sync_test!(engine_error_response_sync, {
|
||||
let (_m, expected) = mock_engine();
|
||||
let response = mocked_client().engine_sync(api::Engine::Davinci);
|
||||
let response = mocked_client().engine_sync("davinci");
|
||||
if let Result::Err(Error::APIError(msg)) = response {
|
||||
assert_eq!(expected, msg);
|
||||
}
|
||||
@ -741,7 +695,7 @@ mod unit {
|
||||
.expect(1)
|
||||
.create();
|
||||
let args = api::CompletionArgs::builder()
|
||||
.engine(api::Engine::Davinci)
|
||||
.engine("davinci")
|
||||
.prompt("Once upon a time")
|
||||
.max_tokens(5)
|
||||
.temperature(1.0)
|
||||
@ -757,7 +711,7 @@ mod unit {
|
||||
text: " there was a girl who".into(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
finish_reason: api::FinishReason::MaxTokensReached,
|
||||
finish_reason: "length".into(),
|
||||
}],
|
||||
};
|
||||
Ok((mock, args, expected))
|
||||
@ -817,10 +771,10 @@ mod integration {
|
||||
.into_iter()
|
||||
.map(|ei| ei.id)
|
||||
.collect::<Vec<_>>();
|
||||
assert!(engines.contains(&api::Engine::Ada));
|
||||
assert!(engines.contains(&api::Engine::Babbage));
|
||||
assert!(engines.contains(&api::Engine::Curie));
|
||||
assert!(engines.contains(&api::Engine::Davinci));
|
||||
assert!(engines.contains(&"ada".into()));
|
||||
assert!(engines.contains(&"babbage".into()));
|
||||
assert!(engines.contains(&"curie".into()));
|
||||
assert!(engines.contains(&"davinci".into()));
|
||||
});
|
||||
|
||||
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
|
||||
@ -842,12 +796,12 @@ mod integration {
|
||||
}
|
||||
async_test!(can_get_engine_async, {
|
||||
let client = get_client();
|
||||
assert_expected_engine_failure(client.engine(api::Engine::Ada).await);
|
||||
assert_expected_engine_failure(client.engine("ada").await);
|
||||
});
|
||||
|
||||
sync_test!(can_get_engine_sync, {
|
||||
let client = get_client();
|
||||
assert_expected_engine_failure(client.engine_sync(api::Engine::Ada));
|
||||
assert_expected_engine_failure(client.engine_sync("ada"));
|
||||
});
|
||||
|
||||
async_test!(complete_string_async, {
|
||||
@ -908,10 +862,7 @@ A:"#,
|
||||
}
|
||||
|
||||
fn assert_completion_finish_reason(completion: Completion) {
|
||||
assert_eq!(
|
||||
completion.choices[0].finish_reason,
|
||||
api::FinishReason::StopSequenceReached
|
||||
);
|
||||
assert_eq!(completion.choices[0].finish_reason, "stop",);
|
||||
}
|
||||
|
||||
async_test!(complete_stop_condition_async, {
|
||||
|
Loading…
x
Reference in New Issue
Block a user