Fix finish reason breaking
This commit is contained in:
parent
1b6d7f69be
commit
d416fb4d20
98
src/lib.rs
98
src/lib.rs
@ -166,28 +166,29 @@ pub mod api {
|
||||
|
||||
impl std::fmt::Display for Choice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.text)
|
||||
self.text.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a logprobs subdocument
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct LogProbs {
|
||||
tokens: Vec<String>,
|
||||
token_logprobs: Vec<Option<f64>>,
|
||||
top_logprobs: Vec<Option<HashMap<String, f64>>>,
|
||||
text_offset: Vec<u64>,
|
||||
pub tokens: Vec<String>,
|
||||
pub token_logprobs: Vec<Option<f64>>,
|
||||
pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
|
||||
pub text_offset: Vec<u64>,
|
||||
}
|
||||
|
||||
/// Reason a prompt completion finished.
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub enum FinishReason {
|
||||
/// The maximum length was reached
|
||||
Length,
|
||||
#[serde(rename = "max_tokens")]
|
||||
MaxTokensReached,
|
||||
/// The stop token was encountered
|
||||
Stop,
|
||||
#[serde(rename = "stop")]
|
||||
StopSequenceReached,
|
||||
}
|
||||
|
||||
/// Error response object from the server
|
||||
@ -414,7 +415,7 @@ mod unit {
|
||||
use api::{Engine, EngineInfo};
|
||||
let _m = mockito::mock("GET", "/engines")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "text/json")
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(
|
||||
r#"{
|
||||
"object": "list",
|
||||
@ -500,7 +501,7 @@ mod unit {
|
||||
async fn engine_error_response() -> crate::Result<()> {
|
||||
let _m = mockito::mock("GET", "/engines/davinci")
|
||||
.with_status(404)
|
||||
.with_header("content-type", "text/json")
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(
|
||||
r#"{
|
||||
"error": {
|
||||
@ -521,6 +522,61 @@ mod unit {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn completion_args() -> crate::Result<()> {
|
||||
let _m = mockito::mock("POST", "/engines/davinci/completions")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(
|
||||
r#"{
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "davinci:2020-05-03",
|
||||
"choices": [
|
||||
{
|
||||
"text": " there was a girl who",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": "max_tokens"
|
||||
}
|
||||
]
|
||||
}"#,
|
||||
)
|
||||
.create();
|
||||
let expected = api::Completion {
|
||||
id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
|
||||
created: 1589478378,
|
||||
model: "davinci:2020-05-03".into(),
|
||||
choices: vec![api::Choice {
|
||||
text: " there was a girl who".into(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
finish_reason: api::FinishReason::MaxTokensReached,
|
||||
}],
|
||||
};
|
||||
let args = api::CompletionArgs::builder()
|
||||
.engine(api::Engine::Davinci)
|
||||
.prompt("Once upon a time")
|
||||
.max_tokens(5)
|
||||
.temperature(1.0)
|
||||
.top_p(1.0)
|
||||
.n(1)
|
||||
.stop(vec!["\n".into()])
|
||||
.build()?;
|
||||
let response = mocked_client().complete(args).await?;
|
||||
assert_eq!(response.model, expected.model);
|
||||
assert_eq!(response.id, expected.id);
|
||||
assert_eq!(response.created, expected.created);
|
||||
let (resp_choice, expected_choice) = (&response.choices[0], &expected.choices[0]);
|
||||
assert_eq!(resp_choice.text, expected_choice.text);
|
||||
assert_eq!(resp_choice.index, expected_choice.index);
|
||||
assert!(resp_choice.logprobs.is_none());
|
||||
assert_eq!(resp_choice.finish_reason, expected_choice.finish_reason);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod integration {
|
||||
@ -590,4 +646,24 @@ mod integration {
|
||||
client.complete(args).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn complete_stop_condition() -> crate::Result<()> {
|
||||
let client = get_client();
|
||||
let mut args = api::CompletionArgs::builder();
|
||||
let completion = args
|
||||
.prompt(
|
||||
r#"
|
||||
Q: Please type `#` now
|
||||
A:"#,
|
||||
)
|
||||
// turn temp & top_p way down to prevent test flakiness
|
||||
.temperature(0.0)
|
||||
.top_p(0.0)
|
||||
.max_tokens(100)
|
||||
.stop(vec!["#".into(), "\n".into()])
|
||||
.complete(&client).await?;
|
||||
assert_eq!(completion.choices[0].finish_reason, api::FinishReason::StopSequenceReached);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user