From d552f4047b233016ae47d1c72f616f5eed719958 Mon Sep 17 00:00:00 2001 From: Josh Kuhn Date: Wed, 28 Apr 2021 09:59:06 -0700 Subject: [PATCH] Update docs(#12) --- Cargo.toml | 3 ++ src/lib.rs | 115 +++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 83 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index aac46f3..fdd4325 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ repository = "https://github.com/deontologician/openai-api-rust/" keywords = ["openai", "gpt3"] categories = ["api-bindings", "asynchronous"] +[build] +rustdocflags = ["--all-features"] + [features] default = ["sync", "async"] sync = ["ureq"] diff --git a/src/lib.rs b/src/lib.rs index 6c6fbef..9d36603 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -///! `OpenAI` API client library +/// `OpenAI` API client library #[macro_use] extern crate derive_builder; @@ -9,35 +9,83 @@ type Result = std::result::Result; #[allow(clippy::default_trait_access)] pub mod api { //! Data types corresponding to requests and responses from the API - use std::{collections::HashMap, fmt::Display}; + use std::{collections::HashMap, convert::TryFrom, fmt::Display}; use serde::{Deserialize, Serialize}; /// Container type. Used in the api, but not useful for clients of this library #[derive(Deserialize, Debug)] - pub(super) struct Container { + pub(crate) struct Container { + /// Items in the page's results pub data: Vec, } - /// Engine description type + /// Detailed information on a particular engine. #[derive(Deserialize, Debug, Eq, PartialEq, Clone)] pub struct EngineInfo { + /// The name of the engine, e.g. `"davinci"` or `"ada"` pub id: String, + /// The owner of the model. Usually (always?) `"openai"` pub owner: String, + /// Whether the model is ready for use. Usually (always?) `true` pub ready: bool, } - /// Options that affect the result + /// Options for the query completion #[derive(Serialize, Debug, Builder, Clone)] #[builder(pattern = "immutable")] pub struct CompletionArgs { - #[builder(setter(into), default = "\"<|endoftext|>\".into()")] - prompt: String, + /// The id of the engine to use for this request + /// + /// # Example + /// ``` + /// # use openai_api::api::CompletionArgs; + /// CompletionArgs::builder().engine("davinci"); + /// ``` #[builder(setter(into), default = "\"davinci\".into()")] #[serde(skip_serializing)] pub(super) engine: String, + /// The prompt to complete from. + /// + /// Defaults to `"<|endoftext|>"` which is a special token seen during training. + /// + /// # Example + /// ``` + /// # use openai_api::api::CompletionArgs; + /// CompletionArgs::builder().prompt("Once upon a time..."); + /// ``` + #[builder(setter(into), default = "\"<|endoftext|>\".into()")] + prompt: String, + /// Maximum number of tokens to complete. + /// + /// Defaults to 16 + /// # Example + /// ``` + /// # use openai_api::api::CompletionArgs; + /// CompletionArgs::builder().max_tokens(64); + /// ``` #[builder(default = "16")] max_tokens: u64, + /// What sampling temperature to use. + /// + /// Default is `1.0` + /// + /// Higher values means the model will take more risks. + /// Try 0.9 for more creative applications, and 0 (argmax sampling) + /// for ones with a well-defined answer. + /// + /// OpenAI recommends altering this or top_p but not both. + /// + /// # Example + /// ``` + /// # use openai_api::api::{CompletionArgs, CompletionArgsBuilder}; + /// # use std::convert::{TryInto, TryFrom}; + /// # fn main() -> Result<(), String> { + /// let builder = CompletionArgs::builder().temperature(0.7); + /// let args: CompletionArgs = builder.try_into()?; + /// # Ok::<(), String>(()) + /// # } + /// ``` #[builder(default = "1.0")] temperature: f64, #[builder(default = "1.0")] @@ -79,6 +127,14 @@ pub mod api { } } + impl TryFrom for CompletionArgs { + type Error = String; + + fn try_from(builder: CompletionArgsBuilder) -> Result { + builder.build() + } + } + /// Represents a non-streamed completion response #[derive(Deserialize, Debug, Clone)] pub struct Completion { @@ -88,7 +144,7 @@ pub mod api { pub created: u64, /// Exact model type and version used for the completion pub model: String, - /// Timestamp + /// List of completions generated by the model pub choices: Vec, } @@ -152,7 +208,7 @@ pub mod api { pub enum Error { /// An error returned by the API itself #[error("API returned an Error: {}", .0.message)] - API(api::ErrorMessage), + Api(api::ErrorMessage), /// An error the client discovers before talking to the API #[error("Bad arguments: {0}")] BadArguments(String), @@ -167,7 +223,7 @@ pub enum Error { impl From for Error { fn from(e: api::ErrorMessage) -> Self { - Error::API(e) + Error::Api(e) } } @@ -300,7 +356,7 @@ impl Client { Ok(response.body_json::().await?) } else { let err = response.body_json::().await?.error; - Err(Error::API(err)) + Err(Error::Api(err)) } } @@ -322,7 +378,7 @@ impl Client { .into_json_deserialize::() .expect("Bug: client couldn't deserialize api error response") .error; - Err(Error::API(err)) + Err(Error::Api(err)) } } @@ -379,7 +435,7 @@ impl Client { .await?; match response.status() { surf::StatusCode::Ok => Ok(response.body_json::().await?), - _ => Err(Error::API( + _ => Err(Error::Api( response .body_json::() .await @@ -405,7 +461,7 @@ impl Client { 200 => Ok(response .into_json_deserialize() .expect("Bug: client couldn't deserialize api response")), - _ => Err(Error::API( + _ => Err(Error::Api( response .into_json_deserialize::() .expect("Bug: client couldn't deserialize api error response") @@ -629,7 +685,7 @@ mod unit { async_test!(engine_error_response_async, { let (_m, expected) = mock_engine(); let response = mocked_client().engine("davinci").await; - if let Result::Err(Error::API(msg)) = response { + if let Result::Err(Error::Api(msg)) = response { assert_eq!(expected, msg); } }); @@ -637,7 +693,7 @@ mod unit { sync_test!(engine_error_response_sync, { let (_m, expected) = mock_engine(); let response = mocked_client().engine_sync("davinci"); - if let Result::Err(Error::API(msg)) = response { + if let Result::Err(Error::Api(msg)) = response { assert_eq!(expected, msg); } }); @@ -713,11 +769,12 @@ mod unit { m.assert(); }); } + #[cfg(test)] mod integration { use crate::{ api::{self, Completion}, - Client, Error, + Client, }; /// Used by tests to get a client to the actual api fn get_client() -> Client { @@ -746,31 +803,19 @@ mod integration { assert!(engines.contains(&"davinci".into())); }); - fn assert_expected_engine_failure(result: Result) - where - T: std::fmt::Debug, - { - match result { - Err(Error::API(api::ErrorMessage { - message, - error_type, - })) => { - assert_eq!(message, "No engine with that ID: ada"); - assert_eq!(error_type, "invalid_request_error"); - } - _ => { - panic!("Expected an error message, got {:?}", result) - } - } + fn assert_engine_correct(engine_id: &str, info: api::EngineInfo) { + assert_eq!(info.id, engine_id); + assert!(info.ready); + assert_eq!(info.owner, "openai"); } async_test!(can_get_engine_async, { let client = get_client(); - assert_expected_engine_failure(client.engine("ada").await); + assert_engine_correct("ada", client.engine("ada").await?); }); sync_test!(can_get_engine_sync, { let client = get_client(); - assert_expected_engine_failure(client.engine_sync("ada")); + assert_engine_correct("ada", client.engine_sync("ada")?); }); async_test!(complete_string_async, {