Update docs(#12)
This commit is contained in:
parent
68a7b500f5
commit
d552f4047b
@ -10,6 +10,9 @@ repository = "https://github.com/deontologician/openai-api-rust/"
|
||||
keywords = ["openai", "gpt3"]
|
||||
categories = ["api-bindings", "asynchronous"]
|
||||
|
||||
[build]
|
||||
rustdocflags = ["--all-features"]
|
||||
|
||||
[features]
|
||||
default = ["sync", "async"]
|
||||
sync = ["ureq"]
|
||||
|
115
src/lib.rs
115
src/lib.rs
@ -1,4 +1,4 @@
|
||||
///! `OpenAI` API client library
|
||||
/// `OpenAI` API client library
|
||||
#[macro_use]
|
||||
extern crate derive_builder;
|
||||
|
||||
@ -9,35 +9,83 @@ type Result<T> = std::result::Result<T, Error>;
|
||||
#[allow(clippy::default_trait_access)]
|
||||
pub mod api {
|
||||
//! Data types corresponding to requests and responses from the API
|
||||
use std::{collections::HashMap, fmt::Display};
|
||||
use std::{collections::HashMap, convert::TryFrom, fmt::Display};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Container type. Used in the api, but not useful for clients of this library
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub(super) struct Container<T> {
|
||||
pub(crate) struct Container<T> {
|
||||
/// Items in the page's results
|
||||
pub data: Vec<T>,
|
||||
}
|
||||
|
||||
/// Engine description type
|
||||
/// Detailed information on a particular engine.
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct EngineInfo {
|
||||
/// The name of the engine, e.g. `"davinci"` or `"ada"`
|
||||
pub id: String,
|
||||
/// The owner of the model. Usually (always?) `"openai"`
|
||||
pub owner: String,
|
||||
/// Whether the model is ready for use. Usually (always?) `true`
|
||||
pub ready: bool,
|
||||
}
|
||||
|
||||
/// Options that affect the result
|
||||
/// Options for the query completion
|
||||
#[derive(Serialize, Debug, Builder, Clone)]
|
||||
#[builder(pattern = "immutable")]
|
||||
pub struct CompletionArgs {
|
||||
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
||||
prompt: String,
|
||||
/// The id of the engine to use for this request
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use openai_api::api::CompletionArgs;
|
||||
/// CompletionArgs::builder().engine("davinci");
|
||||
/// ```
|
||||
#[builder(setter(into), default = "\"davinci\".into()")]
|
||||
#[serde(skip_serializing)]
|
||||
pub(super) engine: String,
|
||||
/// The prompt to complete from.
|
||||
///
|
||||
/// Defaults to `"<|endoftext|>"` which is a special token seen during training.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use openai_api::api::CompletionArgs;
|
||||
/// CompletionArgs::builder().prompt("Once upon a time...");
|
||||
/// ```
|
||||
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
||||
prompt: String,
|
||||
/// Maximum number of tokens to complete.
|
||||
///
|
||||
/// Defaults to 16
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use openai_api::api::CompletionArgs;
|
||||
/// CompletionArgs::builder().max_tokens(64);
|
||||
/// ```
|
||||
#[builder(default = "16")]
|
||||
max_tokens: u64,
|
||||
/// What sampling temperature to use.
|
||||
///
|
||||
/// Default is `1.0`
|
||||
///
|
||||
/// Higher values means the model will take more risks.
|
||||
/// Try 0.9 for more creative applications, and 0 (argmax sampling)
|
||||
/// for ones with a well-defined answer.
|
||||
///
|
||||
/// OpenAI recommends altering this or top_p but not both.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use openai_api::api::{CompletionArgs, CompletionArgsBuilder};
|
||||
/// # use std::convert::{TryInto, TryFrom};
|
||||
/// # fn main() -> Result<(), String> {
|
||||
/// let builder = CompletionArgs::builder().temperature(0.7);
|
||||
/// let args: CompletionArgs = builder.try_into()?;
|
||||
/// # Ok::<(), String>(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[builder(default = "1.0")]
|
||||
temperature: f64,
|
||||
#[builder(default = "1.0")]
|
||||
@ -79,6 +127,14 @@ pub mod api {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<CompletionArgsBuilder> for CompletionArgs {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(builder: CompletionArgsBuilder) -> Result<Self, Self::Error> {
|
||||
builder.build()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a non-streamed completion response
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Completion {
|
||||
@ -88,7 +144,7 @@ pub mod api {
|
||||
pub created: u64,
|
||||
/// Exact model type and version used for the completion
|
||||
pub model: String,
|
||||
/// Timestamp
|
||||
/// List of completions generated by the model
|
||||
pub choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
@ -152,7 +208,7 @@ pub mod api {
|
||||
pub enum Error {
|
||||
/// An error returned by the API itself
|
||||
#[error("API returned an Error: {}", .0.message)]
|
||||
API(api::ErrorMessage),
|
||||
Api(api::ErrorMessage),
|
||||
/// An error the client discovers before talking to the API
|
||||
#[error("Bad arguments: {0}")]
|
||||
BadArguments(String),
|
||||
@ -167,7 +223,7 @@ pub enum Error {
|
||||
|
||||
impl From<api::ErrorMessage> for Error {
|
||||
fn from(e: api::ErrorMessage) -> Self {
|
||||
Error::API(e)
|
||||
Error::Api(e)
|
||||
}
|
||||
}
|
||||
|
||||
@ -300,7 +356,7 @@ impl Client {
|
||||
Ok(response.body_json::<T>().await?)
|
||||
} else {
|
||||
let err = response.body_json::<api::ErrorWrapper>().await?.error;
|
||||
Err(Error::API(err))
|
||||
Err(Error::Api(err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -322,7 +378,7 @@ impl Client {
|
||||
.into_json_deserialize::<api::ErrorWrapper>()
|
||||
.expect("Bug: client couldn't deserialize api error response")
|
||||
.error;
|
||||
Err(Error::API(err))
|
||||
Err(Error::Api(err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -379,7 +435,7 @@ impl Client {
|
||||
.await?;
|
||||
match response.status() {
|
||||
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
|
||||
_ => Err(Error::API(
|
||||
_ => Err(Error::Api(
|
||||
response
|
||||
.body_json::<api::ErrorWrapper>()
|
||||
.await
|
||||
@ -405,7 +461,7 @@ impl Client {
|
||||
200 => Ok(response
|
||||
.into_json_deserialize()
|
||||
.expect("Bug: client couldn't deserialize api response")),
|
||||
_ => Err(Error::API(
|
||||
_ => Err(Error::Api(
|
||||
response
|
||||
.into_json_deserialize::<api::ErrorWrapper>()
|
||||
.expect("Bug: client couldn't deserialize api error response")
|
||||
@ -629,7 +685,7 @@ mod unit {
|
||||
async_test!(engine_error_response_async, {
|
||||
let (_m, expected) = mock_engine();
|
||||
let response = mocked_client().engine("davinci").await;
|
||||
if let Result::Err(Error::API(msg)) = response {
|
||||
if let Result::Err(Error::Api(msg)) = response {
|
||||
assert_eq!(expected, msg);
|
||||
}
|
||||
});
|
||||
@ -637,7 +693,7 @@ mod unit {
|
||||
sync_test!(engine_error_response_sync, {
|
||||
let (_m, expected) = mock_engine();
|
||||
let response = mocked_client().engine_sync("davinci");
|
||||
if let Result::Err(Error::API(msg)) = response {
|
||||
if let Result::Err(Error::Api(msg)) = response {
|
||||
assert_eq!(expected, msg);
|
||||
}
|
||||
});
|
||||
@ -713,11 +769,12 @@ mod unit {
|
||||
m.assert();
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration {
|
||||
use crate::{
|
||||
api::{self, Completion},
|
||||
Client, Error,
|
||||
Client,
|
||||
};
|
||||
/// Used by tests to get a client to the actual api
|
||||
fn get_client() -> Client {
|
||||
@ -746,31 +803,19 @@ mod integration {
|
||||
assert!(engines.contains(&"davinci".into()));
|
||||
});
|
||||
|
||||
fn assert_expected_engine_failure<T>(result: Result<T, Error>)
|
||||
where
|
||||
T: std::fmt::Debug,
|
||||
{
|
||||
match result {
|
||||
Err(Error::API(api::ErrorMessage {
|
||||
message,
|
||||
error_type,
|
||||
})) => {
|
||||
assert_eq!(message, "No engine with that ID: ada");
|
||||
assert_eq!(error_type, "invalid_request_error");
|
||||
}
|
||||
_ => {
|
||||
panic!("Expected an error message, got {:?}", result)
|
||||
}
|
||||
}
|
||||
fn assert_engine_correct(engine_id: &str, info: api::EngineInfo) {
|
||||
assert_eq!(info.id, engine_id);
|
||||
assert!(info.ready);
|
||||
assert_eq!(info.owner, "openai");
|
||||
}
|
||||
async_test!(can_get_engine_async, {
|
||||
let client = get_client();
|
||||
assert_expected_engine_failure(client.engine("ada").await);
|
||||
assert_engine_correct("ada", client.engine("ada").await?);
|
||||
});
|
||||
|
||||
sync_test!(can_get_engine_sync, {
|
||||
let client = get_client();
|
||||
assert_expected_engine_failure(client.engine_sync("ada"));
|
||||
assert_engine_correct("ada", client.engine_sync("ada")?);
|
||||
});
|
||||
|
||||
async_test!(complete_string_async, {
|
||||
|
Loading…
x
Reference in New Issue
Block a user