Fix finish reason breaking
This commit is contained in:
parent
1b6d7f69be
commit
d416fb4d20
98
src/lib.rs
98
src/lib.rs
@ -166,28 +166,29 @@ pub mod api {
|
|||||||
|
|
||||||
impl std::fmt::Display for Choice {
|
impl std::fmt::Display for Choice {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "{}", self.text)
|
self.text.fmt(f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a logprobs subdocument
|
/// Represents a logprobs subdocument
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub struct LogProbs {
|
pub struct LogProbs {
|
||||||
tokens: Vec<String>,
|
pub tokens: Vec<String>,
|
||||||
token_logprobs: Vec<Option<f64>>,
|
pub token_logprobs: Vec<Option<f64>>,
|
||||||
top_logprobs: Vec<Option<HashMap<String, f64>>>,
|
pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
|
||||||
text_offset: Vec<u64>,
|
pub text_offset: Vec<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reason a prompt completion finished.
|
/// Reason a prompt completion finished.
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
||||||
#[serde(rename_all = "kebab-case")]
|
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub enum FinishReason {
|
pub enum FinishReason {
|
||||||
/// The maximum length was reached
|
/// The maximum length was reached
|
||||||
Length,
|
#[serde(rename = "max_tokens")]
|
||||||
|
MaxTokensReached,
|
||||||
/// The stop token was encountered
|
/// The stop token was encountered
|
||||||
Stop,
|
#[serde(rename = "stop")]
|
||||||
|
StopSequenceReached,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error response object from the server
|
/// Error response object from the server
|
||||||
@ -414,7 +415,7 @@ mod unit {
|
|||||||
use api::{Engine, EngineInfo};
|
use api::{Engine, EngineInfo};
|
||||||
let _m = mockito::mock("GET", "/engines")
|
let _m = mockito::mock("GET", "/engines")
|
||||||
.with_status(200)
|
.with_status(200)
|
||||||
.with_header("content-type", "text/json")
|
.with_header("content-type", "application/json")
|
||||||
.with_body(
|
.with_body(
|
||||||
r#"{
|
r#"{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
@ -500,7 +501,7 @@ mod unit {
|
|||||||
async fn engine_error_response() -> crate::Result<()> {
|
async fn engine_error_response() -> crate::Result<()> {
|
||||||
let _m = mockito::mock("GET", "/engines/davinci")
|
let _m = mockito::mock("GET", "/engines/davinci")
|
||||||
.with_status(404)
|
.with_status(404)
|
||||||
.with_header("content-type", "text/json")
|
.with_header("content-type", "application/json")
|
||||||
.with_body(
|
.with_body(
|
||||||
r#"{
|
r#"{
|
||||||
"error": {
|
"error": {
|
||||||
@ -521,6 +522,61 @@ mod unit {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn completion_args() -> crate::Result<()> {
|
||||||
|
let _m = mockito::mock("POST", "/engines/davinci/completions")
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "application/json")
|
||||||
|
.with_body(
|
||||||
|
r#"{
|
||||||
|
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1589478378,
|
||||||
|
"model": "davinci:2020-05-03",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": " there was a girl who",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "max_tokens"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.create();
|
||||||
|
let expected = api::Completion {
|
||||||
|
id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
|
||||||
|
created: 1589478378,
|
||||||
|
model: "davinci:2020-05-03".into(),
|
||||||
|
choices: vec![api::Choice {
|
||||||
|
text: " there was a girl who".into(),
|
||||||
|
index: 0,
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: api::FinishReason::MaxTokensReached,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
let args = api::CompletionArgs::builder()
|
||||||
|
.engine(api::Engine::Davinci)
|
||||||
|
.prompt("Once upon a time")
|
||||||
|
.max_tokens(5)
|
||||||
|
.temperature(1.0)
|
||||||
|
.top_p(1.0)
|
||||||
|
.n(1)
|
||||||
|
.stop(vec!["\n".into()])
|
||||||
|
.build()?;
|
||||||
|
let response = mocked_client().complete(args).await?;
|
||||||
|
assert_eq!(response.model, expected.model);
|
||||||
|
assert_eq!(response.id, expected.id);
|
||||||
|
assert_eq!(response.created, expected.created);
|
||||||
|
let (resp_choice, expected_choice) = (&response.choices[0], &expected.choices[0]);
|
||||||
|
assert_eq!(resp_choice.text, expected_choice.text);
|
||||||
|
assert_eq!(resp_choice.index, expected_choice.index);
|
||||||
|
assert!(resp_choice.logprobs.is_none());
|
||||||
|
assert_eq!(resp_choice.finish_reason, expected_choice.finish_reason);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod integration {
|
mod integration {
|
||||||
@ -590,4 +646,24 @@ mod integration {
|
|||||||
client.complete(args).await?;
|
client.complete(args).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn complete_stop_condition() -> crate::Result<()> {
|
||||||
|
let client = get_client();
|
||||||
|
let mut args = api::CompletionArgs::builder();
|
||||||
|
let completion = args
|
||||||
|
.prompt(
|
||||||
|
r#"
|
||||||
|
Q: Please type `#` now
|
||||||
|
A:"#,
|
||||||
|
)
|
||||||
|
// turn temp & top_p way down to prevent test flakiness
|
||||||
|
.temperature(0.0)
|
||||||
|
.top_p(0.0)
|
||||||
|
.max_tokens(100)
|
||||||
|
.stop(vec!["#".into(), "\n".into()])
|
||||||
|
.complete(&client).await?;
|
||||||
|
assert_eq!(completion.choices[0].finish_reason, api::FinishReason::StopSequenceReached);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user