Update docs(#12)

This commit is contained in:
Josh Kuhn 2021-04-28 09:59:06 -07:00 committed by GitHub
parent 68a7b500f5
commit d552f4047b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 35 deletions

View File

@ -10,6 +10,9 @@ repository = "https://github.com/deontologician/openai-api-rust/"
keywords = ["openai", "gpt3"]
categories = ["api-bindings", "asynchronous"]
[build]
rustdocflags = ["--all-features"]
[features]
default = ["sync", "async"]
sync = ["ureq"]

View File

@ -1,4 +1,4 @@
///! `OpenAI` API client library
/// `OpenAI` API client library
#[macro_use]
extern crate derive_builder;
@ -9,35 +9,83 @@ type Result<T> = std::result::Result<T, Error>;
#[allow(clippy::default_trait_access)]
pub mod api {
//! Data types corresponding to requests and responses from the API
use std::{collections::HashMap, fmt::Display};
use std::{collections::HashMap, convert::TryFrom, fmt::Display};
use serde::{Deserialize, Serialize};
/// Container type. Used in the api, but not useful for clients of this library
#[derive(Deserialize, Debug)]
pub(super) struct Container<T> {
pub(crate) struct Container<T> {
/// Items in the page's results
pub data: Vec<T>,
}
/// Engine description type
/// Detailed information on a particular engine.
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct EngineInfo {
/// The name of the engine, e.g. `"davinci"` or `"ada"`
pub id: String,
/// The owner of the model. Usually (always?) `"openai"`
pub owner: String,
/// Whether the model is ready for use. Usually (always?) `true`
pub ready: bool,
}
/// Options that affect the result
/// Options for the query completion
#[derive(Serialize, Debug, Builder, Clone)]
#[builder(pattern = "immutable")]
pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String,
/// The id of the engine to use for this request
///
/// # Example
/// ```
/// # use openai_api::api::CompletionArgs;
/// CompletionArgs::builder().engine("davinci");
/// ```
#[builder(setter(into), default = "\"davinci\".into()")]
#[serde(skip_serializing)]
pub(super) engine: String,
/// The prompt to complete from.
///
/// Defaults to `"<|endoftext|>"` which is a special token seen during training.
///
/// # Example
/// ```
/// # use openai_api::api::CompletionArgs;
/// CompletionArgs::builder().prompt("Once upon a time...");
/// ```
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String,
/// Maximum number of tokens to complete.
///
/// Defaults to 16
/// # Example
/// ```
/// # use openai_api::api::CompletionArgs;
/// CompletionArgs::builder().max_tokens(64);
/// ```
#[builder(default = "16")]
max_tokens: u64,
/// What sampling temperature to use.
///
/// Default is `1.0`
///
/// Higher values means the model will take more risks.
/// Try 0.9 for more creative applications, and 0 (argmax sampling)
/// for ones with a well-defined answer.
///
/// OpenAI recommends altering this or top_p but not both.
///
/// # Example
/// ```
/// # use openai_api::api::{CompletionArgs, CompletionArgsBuilder};
/// # use std::convert::{TryInto, TryFrom};
/// # fn main() -> Result<(), String> {
/// let builder = CompletionArgs::builder().temperature(0.7);
/// let args: CompletionArgs = builder.try_into()?;
/// # Ok::<(), String>(())
/// # }
/// ```
#[builder(default = "1.0")]
temperature: f64,
#[builder(default = "1.0")]
@ -79,6 +127,14 @@ pub mod api {
}
}
impl TryFrom<CompletionArgsBuilder> for CompletionArgs {
type Error = String;
fn try_from(builder: CompletionArgsBuilder) -> Result<Self, Self::Error> {
builder.build()
}
}
/// Represents a non-streamed completion response
#[derive(Deserialize, Debug, Clone)]
pub struct Completion {
@ -88,7 +144,7 @@ pub mod api {
pub created: u64,
/// Exact model type and version used for the completion
pub model: String,
/// Timestamp
/// List of completions generated by the model
pub choices: Vec<Choice>,
}
@ -152,7 +208,7 @@ pub mod api {
pub enum Error {
/// An error returned by the API itself
#[error("API returned an Error: {}", .0.message)]
API(api::ErrorMessage),
Api(api::ErrorMessage),
/// An error the client discovers before talking to the API
#[error("Bad arguments: {0}")]
BadArguments(String),
@ -167,7 +223,7 @@ pub enum Error {
impl From<api::ErrorMessage> for Error {
fn from(e: api::ErrorMessage) -> Self {
Error::API(e)
Error::Api(e)
}
}
@ -300,7 +356,7 @@ impl Client {
Ok(response.body_json::<T>().await?)
} else {
let err = response.body_json::<api::ErrorWrapper>().await?.error;
Err(Error::API(err))
Err(Error::Api(err))
}
}
@ -322,7 +378,7 @@ impl Client {
.into_json_deserialize::<api::ErrorWrapper>()
.expect("Bug: client couldn't deserialize api error response")
.error;
Err(Error::API(err))
Err(Error::Api(err))
}
}
@ -379,7 +435,7 @@ impl Client {
.await?;
match response.status() {
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
_ => Err(Error::API(
_ => Err(Error::Api(
response
.body_json::<api::ErrorWrapper>()
.await
@ -405,7 +461,7 @@ impl Client {
200 => Ok(response
.into_json_deserialize()
.expect("Bug: client couldn't deserialize api response")),
_ => Err(Error::API(
_ => Err(Error::Api(
response
.into_json_deserialize::<api::ErrorWrapper>()
.expect("Bug: client couldn't deserialize api error response")
@ -629,7 +685,7 @@ mod unit {
async_test!(engine_error_response_async, {
let (_m, expected) = mock_engine();
let response = mocked_client().engine("davinci").await;
if let Result::Err(Error::API(msg)) = response {
if let Result::Err(Error::Api(msg)) = response {
assert_eq!(expected, msg);
}
});
@ -637,7 +693,7 @@ mod unit {
sync_test!(engine_error_response_sync, {
let (_m, expected) = mock_engine();
let response = mocked_client().engine_sync("davinci");
if let Result::Err(Error::API(msg)) = response {
if let Result::Err(Error::Api(msg)) = response {
assert_eq!(expected, msg);
}
});
@ -713,11 +769,12 @@ mod unit {
m.assert();
});
}
#[cfg(test)]
mod integration {
use crate::{
api::{self, Completion},
Client, Error,
Client,
};
/// Used by tests to get a client to the actual api
fn get_client() -> Client {
@ -746,31 +803,19 @@ mod integration {
assert!(engines.contains(&"davinci".into()));
});
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
where
T: std::fmt::Debug,
{
match result {
Err(Error::API(api::ErrorMessage {
message,
error_type,
})) => {
assert_eq!(message, "No engine with that ID: ada");
assert_eq!(error_type, "invalid_request_error");
}
_ => {
panic!("Expected an error message, got {:?}", result)
}
}
fn assert_engine_correct(engine_id: &str, info: api::EngineInfo) {
assert_eq!(info.id, engine_id);
assert!(info.ready);
assert_eq!(info.owner, "openai");
}
async_test!(can_get_engine_async, {
let client = get_client();
assert_expected_engine_failure(client.engine("ada").await);
assert_engine_correct("ada", client.engine("ada").await?);
});
sync_test!(can_get_engine_sync, {
let client = get_client();
assert_expected_engine_failure(client.engine_sync("ada"));
assert_engine_correct("ada", client.engine_sync("ada")?);
});
async_test!(complete_string_async, {