Update docs(#12)
This commit is contained in:
parent
68a7b500f5
commit
d552f4047b
@ -10,6 +10,9 @@ repository = "https://github.com/deontologician/openai-api-rust/"
|
|||||||
keywords = ["openai", "gpt3"]
|
keywords = ["openai", "gpt3"]
|
||||||
categories = ["api-bindings", "asynchronous"]
|
categories = ["api-bindings", "asynchronous"]
|
||||||
|
|
||||||
|
[build]
|
||||||
|
rustdocflags = ["--all-features"]
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["sync", "async"]
|
default = ["sync", "async"]
|
||||||
sync = ["ureq"]
|
sync = ["ureq"]
|
||||||
|
115
src/lib.rs
115
src/lib.rs
@ -1,4 +1,4 @@
|
|||||||
///! `OpenAI` API client library
|
/// `OpenAI` API client library
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate derive_builder;
|
extern crate derive_builder;
|
||||||
|
|
||||||
@ -9,35 +9,83 @@ type Result<T> = std::result::Result<T, Error>;
|
|||||||
#[allow(clippy::default_trait_access)]
|
#[allow(clippy::default_trait_access)]
|
||||||
pub mod api {
|
pub mod api {
|
||||||
//! Data types corresponding to requests and responses from the API
|
//! Data types corresponding to requests and responses from the API
|
||||||
use std::{collections::HashMap, fmt::Display};
|
use std::{collections::HashMap, convert::TryFrom, fmt::Display};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Container type. Used in the api, but not useful for clients of this library
|
/// Container type. Used in the api, but not useful for clients of this library
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub(super) struct Container<T> {
|
pub(crate) struct Container<T> {
|
||||||
|
/// Items in the page's results
|
||||||
pub data: Vec<T>,
|
pub data: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Engine description type
|
/// Detailed information on a particular engine.
|
||||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||||
pub struct EngineInfo {
|
pub struct EngineInfo {
|
||||||
|
/// The name of the engine, e.g. `"davinci"` or `"ada"`
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
/// The owner of the model. Usually (always?) `"openai"`
|
||||||
pub owner: String,
|
pub owner: String,
|
||||||
|
/// Whether the model is ready for use. Usually (always?) `true`
|
||||||
pub ready: bool,
|
pub ready: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Options that affect the result
|
/// Options for the query completion
|
||||||
#[derive(Serialize, Debug, Builder, Clone)]
|
#[derive(Serialize, Debug, Builder, Clone)]
|
||||||
#[builder(pattern = "immutable")]
|
#[builder(pattern = "immutable")]
|
||||||
pub struct CompletionArgs {
|
pub struct CompletionArgs {
|
||||||
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
/// The id of the engine to use for this request
|
||||||
prompt: String,
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// # use openai_api::api::CompletionArgs;
|
||||||
|
/// CompletionArgs::builder().engine("davinci");
|
||||||
|
/// ```
|
||||||
#[builder(setter(into), default = "\"davinci\".into()")]
|
#[builder(setter(into), default = "\"davinci\".into()")]
|
||||||
#[serde(skip_serializing)]
|
#[serde(skip_serializing)]
|
||||||
pub(super) engine: String,
|
pub(super) engine: String,
|
||||||
|
/// The prompt to complete from.
|
||||||
|
///
|
||||||
|
/// Defaults to `"<|endoftext|>"` which is a special token seen during training.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// # use openai_api::api::CompletionArgs;
|
||||||
|
/// CompletionArgs::builder().prompt("Once upon a time...");
|
||||||
|
/// ```
|
||||||
|
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
||||||
|
prompt: String,
|
||||||
|
/// Maximum number of tokens to complete.
|
||||||
|
///
|
||||||
|
/// Defaults to 16
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// # use openai_api::api::CompletionArgs;
|
||||||
|
/// CompletionArgs::builder().max_tokens(64);
|
||||||
|
/// ```
|
||||||
#[builder(default = "16")]
|
#[builder(default = "16")]
|
||||||
max_tokens: u64,
|
max_tokens: u64,
|
||||||
|
/// What sampling temperature to use.
|
||||||
|
///
|
||||||
|
/// Default is `1.0`
|
||||||
|
///
|
||||||
|
/// Higher values means the model will take more risks.
|
||||||
|
/// Try 0.9 for more creative applications, and 0 (argmax sampling)
|
||||||
|
/// for ones with a well-defined answer.
|
||||||
|
///
|
||||||
|
/// OpenAI recommends altering this or top_p but not both.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// # use openai_api::api::{CompletionArgs, CompletionArgsBuilder};
|
||||||
|
/// # use std::convert::{TryInto, TryFrom};
|
||||||
|
/// # fn main() -> Result<(), String> {
|
||||||
|
/// let builder = CompletionArgs::builder().temperature(0.7);
|
||||||
|
/// let args: CompletionArgs = builder.try_into()?;
|
||||||
|
/// # Ok::<(), String>(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
#[builder(default = "1.0")]
|
#[builder(default = "1.0")]
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
#[builder(default = "1.0")]
|
#[builder(default = "1.0")]
|
||||||
@ -79,6 +127,14 @@ pub mod api {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TryFrom<CompletionArgsBuilder> for CompletionArgs {
|
||||||
|
type Error = String;
|
||||||
|
|
||||||
|
fn try_from(builder: CompletionArgsBuilder) -> Result<Self, Self::Error> {
|
||||||
|
builder.build()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Represents a non-streamed completion response
|
/// Represents a non-streamed completion response
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
pub struct Completion {
|
pub struct Completion {
|
||||||
@ -88,7 +144,7 @@ pub mod api {
|
|||||||
pub created: u64,
|
pub created: u64,
|
||||||
/// Exact model type and version used for the completion
|
/// Exact model type and version used for the completion
|
||||||
pub model: String,
|
pub model: String,
|
||||||
/// Timestamp
|
/// List of completions generated by the model
|
||||||
pub choices: Vec<Choice>,
|
pub choices: Vec<Choice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,7 +208,7 @@ pub mod api {
|
|||||||
pub enum Error {
|
pub enum Error {
|
||||||
/// An error returned by the API itself
|
/// An error returned by the API itself
|
||||||
#[error("API returned an Error: {}", .0.message)]
|
#[error("API returned an Error: {}", .0.message)]
|
||||||
API(api::ErrorMessage),
|
Api(api::ErrorMessage),
|
||||||
/// An error the client discovers before talking to the API
|
/// An error the client discovers before talking to the API
|
||||||
#[error("Bad arguments: {0}")]
|
#[error("Bad arguments: {0}")]
|
||||||
BadArguments(String),
|
BadArguments(String),
|
||||||
@ -167,7 +223,7 @@ pub enum Error {
|
|||||||
|
|
||||||
impl From<api::ErrorMessage> for Error {
|
impl From<api::ErrorMessage> for Error {
|
||||||
fn from(e: api::ErrorMessage) -> Self {
|
fn from(e: api::ErrorMessage) -> Self {
|
||||||
Error::API(e)
|
Error::Api(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,7 +356,7 @@ impl Client {
|
|||||||
Ok(response.body_json::<T>().await?)
|
Ok(response.body_json::<T>().await?)
|
||||||
} else {
|
} else {
|
||||||
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
||||||
Err(Error::API(err))
|
Err(Error::Api(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,7 +378,7 @@ impl Client {
|
|||||||
.into_json_deserialize::<api::ErrorWrapper>()
|
.into_json_deserialize::<api::ErrorWrapper>()
|
||||||
.expect("Bug: client couldn't deserialize api error response")
|
.expect("Bug: client couldn't deserialize api error response")
|
||||||
.error;
|
.error;
|
||||||
Err(Error::API(err))
|
Err(Error::Api(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -379,7 +435,7 @@ impl Client {
|
|||||||
.await?;
|
.await?;
|
||||||
match response.status() {
|
match response.status() {
|
||||||
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
||||||
_ => Err(Error::API(
|
_ => Err(Error::Api(
|
||||||
response
|
response
|
||||||
.body_json::<api::ErrorWrapper>()
|
.body_json::<api::ErrorWrapper>()
|
||||||
.await
|
.await
|
||||||
@ -405,7 +461,7 @@ impl Client {
|
|||||||
200 => Ok(response
|
200 => Ok(response
|
||||||
.into_json_deserialize()
|
.into_json_deserialize()
|
||||||
.expect("Bug: client couldn't deserialize api response")),
|
.expect("Bug: client couldn't deserialize api response")),
|
||||||
_ => Err(Error::API(
|
_ => Err(Error::Api(
|
||||||
response
|
response
|
||||||
.into_json_deserialize::<api::ErrorWrapper>()
|
.into_json_deserialize::<api::ErrorWrapper>()
|
||||||
.expect("Bug: client couldn't deserialize api error response")
|
.expect("Bug: client couldn't deserialize api error response")
|
||||||
@ -629,7 +685,7 @@ mod unit {
|
|||||||
async_test!(engine_error_response_async, {
|
async_test!(engine_error_response_async, {
|
||||||
let (_m, expected) = mock_engine();
|
let (_m, expected) = mock_engine();
|
||||||
let response = mocked_client().engine("davinci").await;
|
let response = mocked_client().engine("davinci").await;
|
||||||
if let Result::Err(Error::API(msg)) = response {
|
if let Result::Err(Error::Api(msg)) = response {
|
||||||
assert_eq!(expected, msg);
|
assert_eq!(expected, msg);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -637,7 +693,7 @@ mod unit {
|
|||||||
sync_test!(engine_error_response_sync, {
|
sync_test!(engine_error_response_sync, {
|
||||||
let (_m, expected) = mock_engine();
|
let (_m, expected) = mock_engine();
|
||||||
let response = mocked_client().engine_sync("davinci");
|
let response = mocked_client().engine_sync("davinci");
|
||||||
if let Result::Err(Error::API(msg)) = response {
|
if let Result::Err(Error::Api(msg)) = response {
|
||||||
assert_eq!(expected, msg);
|
assert_eq!(expected, msg);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -713,11 +769,12 @@ mod unit {
|
|||||||
m.assert();
|
m.assert();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod integration {
|
mod integration {
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{self, Completion},
|
api::{self, Completion},
|
||||||
Client, Error,
|
Client,
|
||||||
};
|
};
|
||||||
/// Used by tests to get a client to the actual api
|
/// Used by tests to get a client to the actual api
|
||||||
fn get_client() -> Client {
|
fn get_client() -> Client {
|
||||||
@ -746,31 +803,19 @@ mod integration {
|
|||||||
assert!(engines.contains(&"davinci".into()));
|
assert!(engines.contains(&"davinci".into()));
|
||||||
});
|
});
|
||||||
|
|
||||||
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
|
fn assert_engine_correct(engine_id: &str, info: api::EngineInfo) {
|
||||||
where
|
assert_eq!(info.id, engine_id);
|
||||||
T: std::fmt::Debug,
|
assert!(info.ready);
|
||||||
{
|
assert_eq!(info.owner, "openai");
|
||||||
match result {
|
|
||||||
Err(Error::API(api::ErrorMessage {
|
|
||||||
message,
|
|
||||||
error_type,
|
|
||||||
})) => {
|
|
||||||
assert_eq!(message, "No engine with that ID: ada");
|
|
||||||
assert_eq!(error_type, "invalid_request_error");
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
panic!("Expected an error message, got {:?}", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
async_test!(can_get_engine_async, {
|
async_test!(can_get_engine_async, {
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
assert_expected_engine_failure(client.engine("ada").await);
|
assert_engine_correct("ada", client.engine("ada").await?);
|
||||||
});
|
});
|
||||||
|
|
||||||
sync_test!(can_get_engine_sync, {
|
sync_test!(can_get_engine_sync, {
|
||||||
let client = get_client();
|
let client = get_client();
|
||||||
assert_expected_engine_failure(client.engine_sync("ada"));
|
assert_engine_correct("ada", client.engine_sync("ada")?);
|
||||||
});
|
});
|
||||||
|
|
||||||
async_test!(complete_string_async, {
|
async_test!(complete_string_async, {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user