Fix finish reason breaking

This commit is contained in:
Josh Kuhn 2020-12-12 22:43:26 -08:00
parent 1b6d7f69be
commit d416fb4d20

View File

@ -166,28 +166,29 @@ pub mod api {
impl std::fmt::Display for Choice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.text)
self.text.fmt(f)
}
}
/// Represents a logprobs subdocument
#[derive(Deserialize, Debug)]
pub struct LogProbs {
tokens: Vec<String>,
token_logprobs: Vec<Option<f64>>,
top_logprobs: Vec<Option<HashMap<String, f64>>>,
text_offset: Vec<u64>,
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f64>>,
pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
pub text_offset: Vec<u64>,
}
/// Reason a prompt completion finished.
#[derive(Deserialize, Debug)]
#[serde(rename_all = "kebab-case")]
#[derive(Deserialize, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum FinishReason {
/// The maximum length was reached
Length,
#[serde(rename = "max_tokens")]
MaxTokensReached,
/// The stop token was encountered
Stop,
#[serde(rename = "stop")]
StopSequenceReached,
}
/// Error response object from the server
@ -414,7 +415,7 @@ mod unit {
use api::{Engine, EngineInfo};
let _m = mockito::mock("GET", "/engines")
.with_status(200)
.with_header("content-type", "text/json")
.with_header("content-type", "application/json")
.with_body(
r#"{
"object": "list",
@ -500,7 +501,7 @@ mod unit {
async fn engine_error_response() -> crate::Result<()> {
let _m = mockito::mock("GET", "/engines/davinci")
.with_status(404)
.with_header("content-type", "text/json")
.with_header("content-type", "application/json")
.with_body(
r#"{
"error": {
@ -521,6 +522,61 @@ mod unit {
}
Ok(())
}
#[tokio::test]
async fn completion_args() -> crate::Result<()> {
let _m = mockito::mock("POST", "/engines/davinci/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "davinci:2020-05-03",
"choices": [
{
"text": " there was a girl who",
"index": 0,
"logprobs": null,
"finish_reason": "max_tokens"
}
]
}"#,
)
.create();
let expected = api::Completion {
id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
created: 1589478378,
model: "davinci:2020-05-03".into(),
choices: vec![api::Choice {
text: " there was a girl who".into(),
index: 0,
logprobs: None,
finish_reason: api::FinishReason::MaxTokensReached,
}],
};
let args = api::CompletionArgs::builder()
.engine(api::Engine::Davinci)
.prompt("Once upon a time")
.max_tokens(5)
.temperature(1.0)
.top_p(1.0)
.n(1)
.stop(vec!["\n".into()])
.build()?;
let response = mocked_client().complete(args).await?;
assert_eq!(response.model, expected.model);
assert_eq!(response.id, expected.id);
assert_eq!(response.created, expected.created);
let (resp_choice, expected_choice) = (&response.choices[0], &expected.choices[0]);
assert_eq!(resp_choice.text, expected_choice.text);
assert_eq!(resp_choice.index, expected_choice.index);
assert!(resp_choice.logprobs.is_none());
assert_eq!(resp_choice.finish_reason, expected_choice.finish_reason);
Ok(())
}
}
#[cfg(test)]
mod integration {
@ -590,4 +646,24 @@ mod integration {
client.complete(args).await?;
Ok(())
}
#[tokio::test]
async fn complete_stop_condition() -> crate::Result<()> {
let client = get_client();
let mut args = api::CompletionArgs::builder();
let completion = args
.prompt(
r#"
Q: Please type `#` now
A:"#,
)
// turn temp & top_p way down to prevent test flakiness
.temperature(0.0)
.top_p(0.0)
.max_tokens(100)
.stop(vec!["#".into(), "\n".into()])
.complete(&client).await?;
assert_eq!(completion.choices[0].finish_reason, api::FinishReason::StopSequenceReached);
Ok(())
}
}