Cleanups before publishing (#2)

Several cleanups:
- Flesh out the README
- Make the interface a little easier to use
- Add some examples
- Add some (sparse) documentation for public components
- Mark the `FinishReason` and `Engine` enums as non-exhaustive so new members can be added in the future without breaking backwards compatibility
This commit is contained in:
Josh Kuhn 2020-12-05 16:43:15 -08:00 committed by GitHub
parent 06cde8ae68
commit 62193135c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 224 additions and 1510 deletions

1463
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,14 @@
name = "openai-api" name = "openai-api"
version = "0.1.0" version = "0.1.0"
authors = ["Josh Kuhn <deontologician@gmail.com>"] authors = ["Josh Kuhn <deontologician@gmail.com>"]
license-file = "LICENSE"
edition = "2018" edition = "2018"
description = "OpenAI API library for rust"
homepage = "https://github.com/deontologician/openai-api-rust/"
repository = "https://github.com/deontologician/openai-api-rust/"
keywords = ["openai", "gpt3"]
categories = ["api-bindings", "asynchronous"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@ -1,2 +1,55 @@
# openai-api-rust # openai-api-rust
Rust client for OpenAI API A simple rust client for OpenAI API.
Has a few conveniences, but is mostly at the level of the API itself.
# Installation
```
$ cargo add openai-api-rust
```
# Quickstart
```rust
use openai_api::{api::CompletionArgs, OpenAIClient};
#[tokio::main]
async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token);
let prompt = String::from("Once upon a time,");
println!(
"{}{}",
prompt,
client.complete(prompt.as_str()).await.unwrap()
);
}
```
# Basic Usage
## Creating a completion
For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string:
```rust
let response = client.complete("Once upon a time").await?;
println!("{}", response);
```
To configure the prompt more explicitly, you can use the `CompletionArgs` builder:
```rust
let args = openai_api::api::CompletionArgs::builder()
.prompt("Once upon a time,")
.engine(Engine::Davinci)
.max_tokens(20)
.temperature(0.7)
.top_p(0.9)
.stop(vec!["\n".into()]);
let response = args.complete(&client).await?;
println!("Response: {}", response.choices[0].text);
println!("Model used: {}", response.model);
```
See [examples/](./examples)

43
examples/chatloop.rs Normal file
View File

@ -0,0 +1,43 @@
use openai_api::{
api::{CompletionArgs, Engine},
OpenAIClient,
};
const START_PROMPT: &str = "
The following is a conversation with an AI assistant.
The assistant is helpful, creative, clever, and very friendly.
Human: Hello, who are you?
AI: I am an AI. How can I help you today?";
#[tokio::main]
async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token);
let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder();
args.engine(Engine::Davinci)
.max_tokens(45)
.stop(vec!["\n".into()])
.top_p(0.5)
.temperature(0.9)
.frequency_penalty(0.5);
println!("\x1b[1;36m{}\x1b[1;0m", context.split('\n').last().unwrap());
loop {
context.push_str("\nHuman: ");
if let Err(e) = std::io::stdin().read_line(&mut context) {
eprintln!("Error: {}", e);
break;
}
context.push_str("\nAI:");
match args.prompt(context.as_str()).complete(&client).await {
Ok(completion) => {
println!("\x1b[1;36m{}\x1b[1;0m", completion);
context.push_str(&completion.choices[0].text);
}
Err(e) => {
eprintln!("Error: {}", e);
break;
}
}
}
println!("Full conversation:\n{}", context);
}

14
examples/story.rs Normal file
View File

@ -0,0 +1,14 @@
///! Example that prints out a story from a prompt. Used in the readme.
use openai_api::OpenAIClient;
#[tokio::main]
async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token);
let prompt = String::from("Once upon a time,");
println!(
"{}{}",
prompt,
client.complete(prompt.as_str()).await.unwrap()
);
}

View File

@ -1,3 +1,4 @@
///! OpenAI API client library
#[macro_use] #[macro_use]
extern crate derive_builder; extern crate derive_builder;
@ -6,11 +7,12 @@ use thiserror::Error;
type Result<T> = std::result::Result<T, OpenAIError>; type Result<T> = std::result::Result<T, OpenAIError>;
#[allow(clippy::clippy::default_trait_access)] #[allow(clippy::default_trait_access)]
pub mod api { pub mod api {
//! Data types corresponding to requests and responses from the API
use std::collections::HashMap; use std::collections::HashMap;
use super::OpenAIClient;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Container type. Used in the api, but not useful for clients of this library /// Container type. Used in the api, but not useful for clients of this library
@ -30,6 +32,7 @@ pub mod api {
/// Engine types, known and unknown /// Engine types, known and unknown
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)] #[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
#[non_exhaustive] // prevent clients from matching on every option
pub enum Engine { pub enum Engine {
Ada, Ada,
Babbage, Babbage,
@ -56,10 +59,13 @@ pub mod api {
} }
/// Options that affect the result /// Options that affect the result
#[derive(Serialize, Debug, Builder)] #[derive(Serialize, Debug, Builder, Clone)]
pub struct CompletionArgs { pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")] #[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String, prompt: String,
#[builder(default = "Engine::Davinci")]
#[serde(skip_serializing)]
pub(super) engine: Engine,
#[builder(default = "16")] #[builder(default = "16")]
max_tokens: u64, max_tokens: u64,
#[builder(default = "1.0")] #[builder(default = "1.0")]
@ -81,6 +87,7 @@ pub mod api {
#[builder(default)] #[builder(default)]
logit_bias: HashMap<String, f64>, logit_bias: HashMap<String, f64>,
} }
/* { /* {
"stream": false, // SSE streams back results "stream": false, // SSE streams back results
"best_of": Option<u64>, //cant be used with stream "best_of": Option<u64>, //cant be used with stream
@ -88,40 +95,80 @@ pub mod api {
*/ */
// TODO: add validators for the different arguments // TODO: add validators for the different arguments
impl Default for CompletionArgs {
fn default() -> Self {
CompletionArgsBuilder::default()
.build()
.expect("Client error, invalid defaults")
}
}
impl From<&str> for CompletionArgs { impl From<&str> for CompletionArgs {
fn from(prompt_string: &str) -> Self { fn from(prompt_string: &str) -> Self {
Self { Self {
prompt: prompt_string.into(), prompt: prompt_string.into(),
..CompletionArgs::default() ..CompletionArgsBuilder::default()
.build()
.expect("default should build")
} }
} }
} }
impl CompletionArgs {
/// Build a `CompletionArgs` from the defaults
#[must_use]
pub fn builder() -> CompletionArgsBuilder {
CompletionArgsBuilder::default()
}
/// Request a completion from the api
///
/// # Errors
/// `OpenAIError::APIError` if the api returns an error
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
client.complete(self.clone()).await
}
}
impl CompletionArgsBuilder {
/// Request a completion from the api
///
/// # Errors
/// `OpenAIError::BadArguments` if the arguments to complete are not valid
/// `OpenAIError::APIError` if the api returns an error
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
client.complete(self.build()?).await
}
}
/// Represents a non-streamed completion response /// Represents a non-streamed completion response
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct Completion { pub struct Completion {
id: String, /// Completion unique identifier
object: String, pub id: String,
created: u64, /// Unix timestamp when the completion was generated
model: String, pub created: u64,
choices: Vec<Choice>, /// Exact model type and version used for the completion
pub model: String,
/// Timestamp
pub choices: Vec<Choice>,
} }
/// Represents a single choice impl std::fmt::Display for Completion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.choices[0])
}
}
/// A single completion result
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct Choice { pub struct Choice {
text: String, /// The text of the completion. Will contain the prompt if echo is True.
index: u64, pub text: String,
logprobs: Option<LogProbs>, /// Offset in the result where the completion began. Useful if using echo.
finish_reason: FinishReason, pub index: u64,
/// If requested, the log probabilities of the completion tokens
pub logprobs: Option<LogProbs>,
/// Why the completion ended when it did
pub finish_reason: FinishReason,
}
impl std::fmt::Display for Choice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.text)
}
} }
/// Represents a logprobs subdocument /// Represents a logprobs subdocument
@ -133,13 +180,18 @@ pub mod api {
text_offset: Vec<u64>, text_offset: Vec<u64>,
} }
/// Reason a prompt completion finished.
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub enum FinishReason { pub enum FinishReason {
/// The maximum length was reached
Length, Length,
/// The stop token was encountered
Stop, Stop,
} }
/// Error response object from the server
#[derive(Deserialize, Debug, Eq, PartialEq)] #[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct ErrorMessage { pub struct ErrorMessage {
pub message: String, pub message: String,
@ -147,46 +199,58 @@ pub mod api {
pub error_type: String, pub error_type: String,
} }
/// API-level wrapper used in deserialization
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct ErrorWrapper { pub(super) struct ErrorWrapper {
pub error: ErrorMessage, pub error: ErrorMessage,
} }
} }
/// This library's main `Error` type.
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum OpenAIError { pub enum OpenAIError {
#[error("Invalid secret key")] /// An error returned by the API itself
InvalidAPIKey {
#[from]
source: reqwest::header::InvalidHeaderValue,
},
#[error("API Returned an Error document")] #[error("API Returned an Error document")]
APIError(api::ErrorMessage), APIError(api::ErrorMessage),
/// An error the client discovers before talking to the API
#[error("Bad arguments")]
BadArguments(String),
} }
impl From<api::ErrorMessage> for OpenAIError {
fn from(e: api::ErrorMessage) -> Self {
OpenAIError::APIError(e)
}
}
impl From<String> for OpenAIError {
fn from(e: String) -> Self {
OpenAIError::BadArguments(e)
}
}
/// Client object. Must be constructed to talk to the API.
pub struct OpenAIClient { pub struct OpenAIClient {
client: reqwest::Client, client: reqwest::Client,
root: String, root: String,
} }
impl OpenAIClient { impl OpenAIClient {
/// Creates a new `OpenAIClient` /// Creates a new `OpenAIClient` given an api token
/// pub fn new(token: &str) -> Self {
/// # Errors
/// `OpenAIError::InvalidAPIKey` if the api token has invalid characters
pub fn new(token: &str) -> Result<Self> {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert( headers.insert(
reqwest::header::AUTHORIZATION, reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))?, reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
.expect("Client library error. Header value badly formatted"),
); );
Ok(Self { Self {
client: reqwest::Client::builder() client: reqwest::Client::builder()
.default_headers(headers) .default_headers(headers)
.build() .build()
.expect("Client library error. Should have constructed a valid http client."), .expect("Client library error. Should have constructed a valid http client."),
root: "https://api.openai.com/v1".into(), root: "https://api.openai.com/v1".into(),
}) }
} }
/// Private helper for making gets /// Private helper for making gets
@ -216,8 +280,6 @@ impl OpenAIClient {
/// ///
/// # Errors /// # Errors
/// - `OpenAIError::APIError` if the server returns an error /// - `OpenAIError::APIError` if the server returns an error
/// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most
/// likely a bug in this client, please report it)
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> { pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
self.get("engines").await.map(|r: api::Container<_>| r.data) self.get("engines").await.map(|r: api::Container<_>| r.data)
} }
@ -227,8 +289,6 @@ impl OpenAIClient {
/// ///
/// # Errors /// # Errors
/// - `OpenAIError::APIError` if the server returns an error /// - `OpenAIError::APIError` if the server returns an error
/// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most
/// likely a bug in this client, please report it)
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> { pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
self.get(&format!("engines/{}", engine)).await self.get(&format!("engines/{}", engine)).await
} }
@ -265,11 +325,11 @@ impl OpenAIClient {
/// - `OpenAIError::APIError` if the api returns an error /// - `OpenAIError::APIError` if the api returns an error
pub async fn complete( pub async fn complete(
&self, &self,
engine: api::Engine,
prompt: impl Into<api::CompletionArgs>, prompt: impl Into<api::CompletionArgs>,
) -> Result<api::Completion> { ) -> Result<api::Completion> {
let args = prompt.into();
Ok(self Ok(self
.post(&format!("engines/{}/completions", engine), prompt.into()) .post(&format!("engines/{}/completions", args.engine), args)
.await? .await?
//.text() //.text()
.json() .json()
@ -284,7 +344,7 @@ mod unit {
use crate::{api, OpenAIClient, OpenAIError}; use crate::{api, OpenAIClient, OpenAIError};
fn mocked_client() -> OpenAIClient { fn mocked_client() -> OpenAIClient {
let mut client = OpenAIClient::new("bogus").unwrap(); let mut client = OpenAIClient::new("bogus");
client.root = mockito::server_url(); client.root = mockito::server_url();
client client
} }
@ -439,7 +499,7 @@ mod integration {
let sk = std::env::var("OPENAI_SK").expect( let sk = std::env::var("OPENAI_SK").expect(
"To run integration tests, you must put set the OPENAI_SK env var to your api token", "To run integration tests, you must put set the OPENAI_SK env var to your api token",
); );
OpenAIClient::new(&sk).expect("client build failed") OpenAIClient::new(&sk)
} }
#[tokio::test] #[tokio::test]
@ -468,7 +528,7 @@ mod integration {
#[tokio::test] #[tokio::test]
async fn complete_string() -> crate::Result<()> { async fn complete_string() -> crate::Result<()> {
let client = get_client(); let client = get_client();
client.complete(api::Engine::Ada, "Hey there").await?; client.complete("Hey there").await?;
Ok(()) Ok(())
} }
@ -492,7 +552,7 @@ mod integration {
}) })
.build() .build()
.expect("Build should have succeeded"); .expect("Build should have succeeded");
client.complete(api::Engine::Ada, args).await?; client.complete(args).await?;
Ok(()) Ok(())
} }
} }