No more enums (#10)
The strings in the api change enough that enums are more of a hindrance than a help.
This commit is contained in:
parent
7d274dfcbd
commit
7b3a2ad5c1
@ -1,7 +1,4 @@
|
|||||||
use openai_api::{
|
use openai_api::{api::CompletionArgs, Client};
|
||||||
api::{CompletionArgs, Engine},
|
|
||||||
Client,
|
|
||||||
};
|
|
||||||
|
|
||||||
const START_PROMPT: &str = "
|
const START_PROMPT: &str = "
|
||||||
The following is a conversation with an AI assistant.
|
The following is a conversation with an AI assistant.
|
||||||
@ -14,7 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let client = Client::new(&api_token);
|
let client = Client::new(&api_token);
|
||||||
let mut context = String::from(START_PROMPT);
|
let mut context = String::from(START_PROMPT);
|
||||||
let mut args = CompletionArgs::builder();
|
let mut args = CompletionArgs::builder();
|
||||||
args.engine(Engine::Davinci)
|
args.engine("davinci")
|
||||||
.max_tokens(45)
|
.max_tokens(45)
|
||||||
.stop(vec!["\n".into()])
|
.stop(vec!["\n".into()])
|
||||||
.top_p(0.5)
|
.top_p(0.5)
|
||||||
@ -27,7 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
eprintln!("Error: {}", e);
|
eprintln!("Error: {}", e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
context.push_str("\nAI:");
|
context.push_str("\nAI: ");
|
||||||
match args.prompt(context.as_str()).complete_prompt(&client).await {
|
match args.prompt(context.as_str()).complete_prompt(&client).await {
|
||||||
Ok(completion) => {
|
Ok(completion) => {
|
||||||
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
||||||
|
99
src/lib.rs
99
src/lib.rs
@ -23,48 +23,19 @@ pub mod api {
|
|||||||
/// Engine description type
|
/// Engine description type
|
||||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||||
pub struct EngineInfo {
|
pub struct EngineInfo {
|
||||||
pub id: Engine,
|
pub id: String,
|
||||||
pub owner: String,
|
pub owner: String,
|
||||||
pub ready: bool,
|
pub ready: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Engine types, known and unknown
|
|
||||||
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
|
|
||||||
#[serde(rename_all = "kebab-case")]
|
|
||||||
#[non_exhaustive] // prevent clients from matching on every option
|
|
||||||
pub enum Engine {
|
|
||||||
Ada,
|
|
||||||
Babbage,
|
|
||||||
Curie,
|
|
||||||
Davinci,
|
|
||||||
#[serde(rename = "content-filter-alpha-c4")]
|
|
||||||
ContentFilter,
|
|
||||||
#[serde(other)]
|
|
||||||
Other,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Custom Display to lowercase things
|
|
||||||
impl std::fmt::Display for Engine {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Engine::Ada => f.write_str("ada"),
|
|
||||||
Engine::Babbage => f.write_str("babbage"),
|
|
||||||
Engine::Curie => f.write_str("curie"),
|
|
||||||
Engine::Davinci => f.write_str("davinci"),
|
|
||||||
Engine::ContentFilter => f.write_str("content-filter-alpha-c4"),
|
|
||||||
_ => panic!("Can't write out Other engine id"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Options that affect the result
|
/// Options that affect the result
|
||||||
#[derive(Serialize, Debug, Builder, Clone)]
|
#[derive(Serialize, Debug, Builder, Clone)]
|
||||||
pub struct CompletionArgs {
|
pub struct CompletionArgs {
|
||||||
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
#[builder(default = "Engine::Davinci")]
|
#[builder(setter(into), default = "\"davinci\".into()")]
|
||||||
#[serde(skip_serializing)]
|
#[serde(skip_serializing)]
|
||||||
pub(super) engine: Engine,
|
pub(super) engine: String,
|
||||||
#[builder(default = "16")]
|
#[builder(default = "16")]
|
||||||
max_tokens: u64,
|
max_tokens: u64,
|
||||||
#[builder(default = "1.0")]
|
#[builder(default = "1.0")]
|
||||||
@ -87,11 +58,6 @@ pub mod api {
|
|||||||
logit_bias: HashMap<String, f64>,
|
logit_bias: HashMap<String, f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* {
|
|
||||||
"stream": false, // SSE streams back results
|
|
||||||
"best_of": Option<u64>, //cant be used with stream
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
// TODO: add validators for the different arguments
|
// TODO: add validators for the different arguments
|
||||||
|
|
||||||
impl From<&str> for CompletionArgs {
|
impl From<&str> for CompletionArgs {
|
||||||
@ -173,7 +139,7 @@ pub mod api {
|
|||||||
/// If requested, the log probabilities of the completion tokens
|
/// If requested, the log probabilities of the completion tokens
|
||||||
pub logprobs: Option<LogProbs>,
|
pub logprobs: Option<LogProbs>,
|
||||||
/// Why the completion ended when it did
|
/// Why the completion ended when it did
|
||||||
pub finish_reason: FinishReason,
|
pub finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Choice {
|
impl std::fmt::Display for Choice {
|
||||||
@ -191,18 +157,6 @@ pub mod api {
|
|||||||
pub text_offset: Vec<u64>,
|
pub text_offset: Vec<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reason a prompt completion finished.
|
|
||||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub enum FinishReason {
|
|
||||||
/// The maximum length was reached
|
|
||||||
#[serde(rename = "length")]
|
|
||||||
MaxTokensReached,
|
|
||||||
/// The stop token was encountered
|
|
||||||
#[serde(rename = "stop")]
|
|
||||||
StopSequenceReached,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Error response object from the server
|
/// Error response object from the server
|
||||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||||
pub struct ErrorMessage {
|
pub struct ErrorMessage {
|
||||||
@ -432,12 +386,12 @@ impl Client {
|
|||||||
/// # Errors
|
/// # Errors
|
||||||
/// - `Error::APIError` if the server returns an error
|
/// - `Error::APIError` if the server returns an error
|
||||||
#[cfg(feature = "async")]
|
#[cfg(feature = "async")]
|
||||||
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
|
pub async fn engine(&self, engine: &str) -> Result<api::EngineInfo> {
|
||||||
self.get(&format!("engines/{}", engine)).await
|
self.get(&format!("engines/{}", engine)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "sync")]
|
#[cfg(feature = "sync")]
|
||||||
pub fn engine_sync(&self, engine: api::Engine) -> Result<api::EngineInfo> {
|
pub fn engine_sync(&self, engine: &str) -> Result<api::EngineInfo> {
|
||||||
self.get_sync(&format!("engines/{}", engine))
|
self.get_sync(&format!("engines/{}", engine))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -552,7 +506,7 @@ mod unit {
|
|||||||
use mockito::Mock;
|
use mockito::Mock;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{self, Completion, CompletionArgs, Engine, EngineInfo},
|
api::{self, Completion, CompletionArgs, EngineInfo},
|
||||||
Client, Error,
|
Client, Error,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -578,7 +532,7 @@ mod unit {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
ei,
|
ei,
|
||||||
api::EngineInfo {
|
api::EngineInfo {
|
||||||
id: api::Engine::Ada,
|
id: "ada".into(),
|
||||||
owner: "openai".into(),
|
owner: "openai".into(),
|
||||||
ready: true,
|
ready: true,
|
||||||
}
|
}
|
||||||
@ -637,32 +591,32 @@ mod unit {
|
|||||||
|
|
||||||
let expected = vec![
|
let expected = vec![
|
||||||
EngineInfo {
|
EngineInfo {
|
||||||
id: Engine::Ada,
|
id: "ada".into(),
|
||||||
owner: "openai".into(),
|
owner: "openai".into(),
|
||||||
ready: true,
|
ready: true,
|
||||||
},
|
},
|
||||||
EngineInfo {
|
EngineInfo {
|
||||||
id: Engine::Babbage,
|
id: "babbage".into(),
|
||||||
owner: "openai".into(),
|
owner: "openai".into(),
|
||||||
ready: true,
|
ready: true,
|
||||||
},
|
},
|
||||||
EngineInfo {
|
EngineInfo {
|
||||||
id: Engine::Other,
|
id: "experimental-engine-v7".into(),
|
||||||
owner: "openai".into(),
|
owner: "openai".into(),
|
||||||
ready: false,
|
ready: false,
|
||||||
},
|
},
|
||||||
EngineInfo {
|
EngineInfo {
|
||||||
id: Engine::Curie,
|
id: "curie".into(),
|
||||||
owner: "openai".into(),
|
owner: "openai".into(),
|
||||||
ready: true,
|
ready: true,
|
||||||
},
|
},
|
||||||
EngineInfo {
|
EngineInfo {
|
||||||
id: Engine::Davinci,
|
id: "davinci".into(),
|
||||||
owner: "openai".into(),
|
owner: "openai".into(),
|
||||||
ready: true,
|
ready: true,
|
||||||
},
|
},
|
||||||
EngineInfo {
|
EngineInfo {
|
||||||
id: Engine::ContentFilter,
|
id: "content-filter-alpha-c4".into(),
|
||||||
owner: "openai".into(),
|
owner: "openai".into(),
|
||||||
ready: true,
|
ready: true,
|
||||||
},
|
},
|
||||||
@ -705,7 +659,7 @@ mod unit {
|
|||||||
|
|
||||||
async_test!(engine_error_response_async, {
|
async_test!(engine_error_response_async, {
|
||||||
let (_m, expected) = mock_engine();
|
let (_m, expected) = mock_engine();
|
||||||
let response = mocked_client().engine(api::Engine::Davinci).await;
|
let response = mocked_client().engine("davinci").await;
|
||||||
if let Result::Err(Error::APIError(msg)) = response {
|
if let Result::Err(Error::APIError(msg)) = response {
|
||||||
assert_eq!(expected, msg);
|
assert_eq!(expected, msg);
|
||||||
}
|
}
|
||||||
@ -713,7 +667,7 @@ mod unit {
|
|||||||
|
|
||||||
sync_test!(engine_error_response_sync, {
|
sync_test!(engine_error_response_sync, {
|
||||||
let (_m, expected) = mock_engine();
|
let (_m, expected) = mock_engine();
|
||||||
let response = mocked_client().engine_sync(api::Engine::Davinci);
|
let response = mocked_client().engine_sync("davinci");
|
||||||
if let Result::Err(Error::APIError(msg)) = response {
|
if let Result::Err(Error::APIError(msg)) = response {
|
||||||
assert_eq!(expected, msg);
|
assert_eq!(expected, msg);
|
||||||
}
|
}
|
||||||
@ -741,7 +695,7 @@ mod unit {
|
|||||||
.expect(1)
|
.expect(1)
|
||||||
.create();
|
.create();
|
||||||
let args = api::CompletionArgs::builder()
|
let args = api::CompletionArgs::builder()
|
||||||
.engine(api::Engine::Davinci)
|
.engine("davinci")
|
||||||
.prompt("Once upon a time")
|
.prompt("Once upon a time")
|
||||||
.max_tokens(5)
|
.max_tokens(5)
|
||||||
.temperature(1.0)
|
.temperature(1.0)
|
||||||
@ -757,7 +711,7 @@ mod unit {
|
|||||||
text: " there was a girl who".into(),
|
text: " there was a girl who".into(),
|
||||||
index: 0,
|
index: 0,
|
||||||
logprobs: None,
|
logprobs: None,
|
||||||
finish_reason: api::FinishReason::MaxTokensReached,
|
finish_reason: "length".into(),
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
Ok((mock, args, expected))
|
Ok((mock, args, expected))
|
||||||
@ -817,10 +771,10 @@ mod integration {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|ei| ei.id)
|
.map(|ei| ei.id)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
assert!(engines.contains(&api::Engine::Ada));
|
assert!(engines.contains(&"ada".into()));
|
||||||
assert!(engines.contains(&api::Engine::Babbage));
|
assert!(engines.contains(&"babbage".into()));
|
||||||
assert!(engines.contains(&api::Engine::Curie));
|
assert!(engines.contains(&"curie".into()));
|
||||||
assert!(engines.contains(&api::Engine::Davinci));
|
assert!(engines.contains(&"davinci".into()));
|
||||||
});
|
});
|
||||||
|
|
||||||
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
|
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
|
||||||
@ -842,12 +796,12 @@ mod integration {
|
|||||||
}
|
}
|
||||||
async_test!(can_get_engine_async, {
|
async_test!(can_get_engine_async, {
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
assert_expected_engine_failure(client.engine(api::Engine::Ada).await);
|
assert_expected_engine_failure(client.engine("ada").await);
|
||||||
});
|
});
|
||||||
|
|
||||||
sync_test!(can_get_engine_sync, {
|
sync_test!(can_get_engine_sync, {
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
assert_expected_engine_failure(client.engine_sync(api::Engine::Ada));
|
assert_expected_engine_failure(client.engine_sync("ada"));
|
||||||
});
|
});
|
||||||
|
|
||||||
async_test!(complete_string_async, {
|
async_test!(complete_string_async, {
|
||||||
@ -908,10 +862,7 @@ A:"#,
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn assert_completion_finish_reason(completion: Completion) {
|
fn assert_completion_finish_reason(completion: Completion) {
|
||||||
assert_eq!(
|
assert_eq!(completion.choices[0].finish_reason, "stop",);
|
||||||
completion.choices[0].finish_reason,
|
|
||||||
api::FinishReason::StopSequenceReached
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async_test!(complete_stop_condition_async, {
|
async_test!(complete_stop_condition_async, {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user