diff --git a/src/lib.rs b/src/lib.rs index 690c48f..2a04392 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -166,28 +166,29 @@ pub mod api { impl std::fmt::Display for Choice { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.text) + self.text.fmt(f) } } /// Represents a logprobs subdocument #[derive(Deserialize, Debug)] pub struct LogProbs { - tokens: Vec, - token_logprobs: Vec>, - top_logprobs: Vec>>, - text_offset: Vec, + pub tokens: Vec, + pub token_logprobs: Vec>, + pub top_logprobs: Vec>>, + pub text_offset: Vec, } /// Reason a prompt completion finished. - #[derive(Deserialize, Debug)] - #[serde(rename_all = "kebab-case")] + #[derive(Deserialize, Debug, Eq, PartialEq)] #[non_exhaustive] pub enum FinishReason { /// The maximum length was reached - Length, + #[serde(rename = "max_tokens")] + MaxTokensReached, /// The stop token was encountered - Stop, + #[serde(rename = "stop")] + StopSequenceReached, } /// Error response object from the server @@ -414,7 +415,7 @@ mod unit { use api::{Engine, EngineInfo}; let _m = mockito::mock("GET", "/engines") .with_status(200) - .with_header("content-type", "text/json") + .with_header("content-type", "application/json") .with_body( r#"{ "object": "list", @@ -500,7 +501,7 @@ mod unit { async fn engine_error_response() -> crate::Result<()> { let _m = mockito::mock("GET", "/engines/davinci") .with_status(404) - .with_header("content-type", "text/json") + .with_header("content-type", "application/json") .with_body( r#"{ "error": { @@ -521,6 +522,61 @@ mod unit { } Ok(()) } + + #[tokio::test] + async fn completion_args() -> crate::Result<()> { + let _m = mockito::mock("POST", "/engines/davinci/completions") + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{ + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "davinci:2020-05-03", + "choices": [ + { + "text": " there was a girl who", + "index": 0, + "logprobs": null, + "finish_reason": "max_tokens" + } + ] + }"#, + ) + .create(); + let expected = api::Completion { + id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(), + created: 1589478378, + model: "davinci:2020-05-03".into(), + choices: vec![api::Choice { + text: " there was a girl who".into(), + index: 0, + logprobs: None, + finish_reason: api::FinishReason::MaxTokensReached, + }], + }; + let args = api::CompletionArgs::builder() + .engine(api::Engine::Davinci) + .prompt("Once upon a time") + .max_tokens(5) + .temperature(1.0) + .top_p(1.0) + .n(1) + .stop(vec!["\n".into()]) + .build()?; + let response = mocked_client().complete(args).await?; + assert_eq!(response.model, expected.model); + assert_eq!(response.id, expected.id); + assert_eq!(response.created, expected.created); + let (resp_choice, expected_choice) = (&response.choices[0], &expected.choices[0]); + assert_eq!(resp_choice.text, expected_choice.text); + assert_eq!(resp_choice.index, expected_choice.index); + assert!(resp_choice.logprobs.is_none()); + assert_eq!(resp_choice.finish_reason, expected_choice.finish_reason); + + Ok(()) + } } #[cfg(test)] mod integration { @@ -590,4 +646,24 @@ mod integration { client.complete(args).await?; Ok(()) } + + #[tokio::test] + async fn complete_stop_condition() -> crate::Result<()> { + let client = get_client(); + let mut args = api::CompletionArgs::builder(); + let completion = args + .prompt( + r#" +Q: Please type `#` now +A:"#, + ) + // turn temp & top_p way down to prevent test flakiness + .temperature(0.0) + .top_p(0.0) + .max_tokens(100) + .stop(vec!["#".into(), "\n".into()]) + .complete(&client).await?; + assert_eq!(completion.choices[0].finish_reason, api::FinishReason::StopSequenceReached); + Ok(()) + } }