Cleanups before publishing (#2)

Several cleanups:
- Flesh out the README
- Make the interface a little easier to use
- Add some examples
- Add some (sparse) documentation for public components
- Mark the `FinishReason` and `Engine` enums as non-exhaustive so new members can be added in the future without breaking backwards compatibility
This commit is contained in:
Josh Kuhn 2020-12-05 16:43:15 -08:00 committed by GitHub
parent 06cde8ae68
commit 62193135c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 224 additions and 1510 deletions

1463
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,14 @@
name = "openai-api"
version = "0.1.0"
authors = ["Josh Kuhn <deontologician@gmail.com>"]
license-file = "LICENSE"
edition = "2018"
description = "OpenAI API library for rust"
homepage = "https://github.com/deontologician/openai-api-rust/"
repository = "https://github.com/deontologician/openai-api-rust/"
keywords = ["openai", "gpt3"]
categories = ["api-bindings", "asynchronous"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@ -1,2 +1,55 @@
# openai-api-rust
Rust client for OpenAI API
A simple rust client for OpenAI API.
Has a few conveniences, but is mostly at the level of the API itself.
# Installation
```
$ cargo add openai-api-rust
```
# Quickstart
```rust
use openai_api::{api::CompletionArgs, OpenAIClient};
#[tokio::main]
async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token);
let prompt = String::from("Once upon a time,");
println!(
"{}{}",
prompt,
client.complete(prompt.as_str()).await.unwrap()
);
}
```
# Basic Usage
## Creating a completion
For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string:
```rust
let response = client.complete("Once upon a time").await?;
println!("{}", response);
```
To configure the prompt more explicitly, you can use the `CompletionArgs` builder:
```rust
let args = openai_api::api::CompletionArgs::builder()
.prompt("Once upon a time,")
.engine(Engine::Davinci)
.max_tokens(20)
.temperature(0.7)
.top_p(0.9)
.stop(vec!["\n".into()]);
let response = args.complete(&client).await?;
println!("Response: {}", response.choices[0].text);
println!("Model used: {}", response.model);
```
See [examples/](./examples)

43
examples/chatloop.rs Normal file
View File

@ -0,0 +1,43 @@
use openai_api::{
api::{CompletionArgs, Engine},
OpenAIClient,
};
const START_PROMPT: &str = "
The following is a conversation with an AI assistant.
The assistant is helpful, creative, clever, and very friendly.
Human: Hello, who are you?
AI: I am an AI. How can I help you today?";
#[tokio::main]
async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token);
let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder();
args.engine(Engine::Davinci)
.max_tokens(45)
.stop(vec!["\n".into()])
.top_p(0.5)
.temperature(0.9)
.frequency_penalty(0.5);
println!("\x1b[1;36m{}\x1b[1;0m", context.split('\n').last().unwrap());
loop {
context.push_str("\nHuman: ");
if let Err(e) = std::io::stdin().read_line(&mut context) {
eprintln!("Error: {}", e);
break;
}
context.push_str("\nAI:");
match args.prompt(context.as_str()).complete(&client).await {
Ok(completion) => {
println!("\x1b[1;36m{}\x1b[1;0m", completion);
context.push_str(&completion.choices[0].text);
}
Err(e) => {
eprintln!("Error: {}", e);
break;
}
}
}
println!("Full conversation:\n{}", context);
}

14
examples/story.rs Normal file
View File

@ -0,0 +1,14 @@
///! Example that prints out a story from a prompt. Used in the readme.
use openai_api::OpenAIClient;
#[tokio::main]
async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token);
let prompt = String::from("Once upon a time,");
println!(
"{}{}",
prompt,
client.complete(prompt.as_str()).await.unwrap()
);
}

View File

@ -1,3 +1,4 @@
///! OpenAI API client library
#[macro_use]
extern crate derive_builder;
@ -6,11 +7,12 @@ use thiserror::Error;
type Result<T> = std::result::Result<T, OpenAIError>;
#[allow(clippy::clippy::default_trait_access)]
#[allow(clippy::default_trait_access)]
pub mod api {
//! Data types corresponding to requests and responses from the API
use std::collections::HashMap;
use super::OpenAIClient;
use serde::{Deserialize, Serialize};
/// Container type. Used in the api, but not useful for clients of this library
@ -30,6 +32,7 @@ pub mod api {
/// Engine types, known and unknown
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive] // prevent clients from matching on every option
pub enum Engine {
Ada,
Babbage,
@ -56,10 +59,13 @@ pub mod api {
}
/// Options that affect the result
#[derive(Serialize, Debug, Builder)]
#[derive(Serialize, Debug, Builder, Clone)]
pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String,
#[builder(default = "Engine::Davinci")]
#[serde(skip_serializing)]
pub(super) engine: Engine,
#[builder(default = "16")]
max_tokens: u64,
#[builder(default = "1.0")]
@ -81,6 +87,7 @@ pub mod api {
#[builder(default)]
logit_bias: HashMap<String, f64>,
}
/* {
"stream": false, // SSE streams back results
"best_of": Option<u64>, //cant be used with stream
@ -88,40 +95,80 @@ pub mod api {
*/
// TODO: add validators for the different arguments
impl Default for CompletionArgs {
fn default() -> Self {
CompletionArgsBuilder::default()
.build()
.expect("Client error, invalid defaults")
}
}
impl From<&str> for CompletionArgs {
fn from(prompt_string: &str) -> Self {
Self {
prompt: prompt_string.into(),
..CompletionArgs::default()
..CompletionArgsBuilder::default()
.build()
.expect("default should build")
}
}
}
impl CompletionArgs {
/// Build a `CompletionArgs` from the defaults
#[must_use]
pub fn builder() -> CompletionArgsBuilder {
CompletionArgsBuilder::default()
}
/// Request a completion from the api
///
/// # Errors
/// `OpenAIError::APIError` if the api returns an error
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
client.complete(self.clone()).await
}
}
impl CompletionArgsBuilder {
/// Request a completion from the api
///
/// # Errors
/// `OpenAIError::BadArguments` if the arguments to complete are not valid
/// `OpenAIError::APIError` if the api returns an error
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
client.complete(self.build()?).await
}
}
/// Represents a non-streamed completion response
#[derive(Deserialize, Debug)]
pub struct Completion {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
/// Completion unique identifier
pub id: String,
/// Unix timestamp when the completion was generated
pub created: u64,
/// Exact model type and version used for the completion
pub model: String,
/// Timestamp
pub choices: Vec<Choice>,
}
/// Represents a single choice
impl std::fmt::Display for Completion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.choices[0])
}
}
/// A single completion result
#[derive(Deserialize, Debug)]
pub struct Choice {
text: String,
index: u64,
logprobs: Option<LogProbs>,
finish_reason: FinishReason,
/// The text of the completion. Will contain the prompt if echo is True.
pub text: String,
/// Offset in the result where the completion began. Useful if using echo.
pub index: u64,
/// If requested, the log probabilities of the completion tokens
pub logprobs: Option<LogProbs>,
/// Why the completion ended when it did
pub finish_reason: FinishReason,
}
impl std::fmt::Display for Choice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.text)
}
}
/// Represents a logprobs subdocument
@ -133,13 +180,18 @@ pub mod api {
text_offset: Vec<u64>,
}
/// Reason a prompt completion finished.
#[derive(Deserialize, Debug)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub enum FinishReason {
/// The maximum length was reached
Length,
/// The stop token was encountered
Stop,
}
/// Error response object from the server
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct ErrorMessage {
pub message: String,
@ -147,46 +199,58 @@ pub mod api {
pub error_type: String,
}
/// API-level wrapper used in deserialization
#[derive(Deserialize, Debug)]
pub struct ErrorWrapper {
pub(super) struct ErrorWrapper {
pub error: ErrorMessage,
}
}
/// This library's main `Error` type.
#[derive(Error, Debug)]
pub enum OpenAIError {
#[error("Invalid secret key")]
InvalidAPIKey {
#[from]
source: reqwest::header::InvalidHeaderValue,
},
/// An error returned by the API itself
#[error("API Returned an Error document")]
APIError(api::ErrorMessage),
/// An error the client discovers before talking to the API
#[error("Bad arguments")]
BadArguments(String),
}
impl From<api::ErrorMessage> for OpenAIError {
fn from(e: api::ErrorMessage) -> Self {
OpenAIError::APIError(e)
}
}
impl From<String> for OpenAIError {
fn from(e: String) -> Self {
OpenAIError::BadArguments(e)
}
}
/// Client object. Must be constructed to talk to the API.
pub struct OpenAIClient {
client: reqwest::Client,
root: String,
}
impl OpenAIClient {
/// Creates a new `OpenAIClient`
///
/// # Errors
/// `OpenAIError::InvalidAPIKey` if the api token has invalid characters
pub fn new(token: &str) -> Result<Self> {
/// Creates a new `OpenAIClient` given an api token
pub fn new(token: &str) -> Self {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))?,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
.expect("Client library error. Header value badly formatted"),
);
Ok(Self {
Self {
client: reqwest::Client::builder()
.default_headers(headers)
.build()
.expect("Client library error. Should have constructed a valid http client."),
root: "https://api.openai.com/v1".into(),
})
}
}
/// Private helper for making gets
@ -216,8 +280,6 @@ impl OpenAIClient {
///
/// # Errors
/// - `OpenAIError::APIError` if the server returns an error
/// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most
/// likely a bug in this client, please report it)
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
self.get("engines").await.map(|r: api::Container<_>| r.data)
}
@ -227,8 +289,6 @@ impl OpenAIClient {
///
/// # Errors
/// - `OpenAIError::APIError` if the server returns an error
/// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most
/// likely a bug in this client, please report it)
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
self.get(&format!("engines/{}", engine)).await
}
@ -265,11 +325,11 @@ impl OpenAIClient {
/// - `OpenAIError::APIError` if the api returns an error
pub async fn complete(
&self,
engine: api::Engine,
prompt: impl Into<api::CompletionArgs>,
) -> Result<api::Completion> {
let args = prompt.into();
Ok(self
.post(&format!("engines/{}/completions", engine), prompt.into())
.post(&format!("engines/{}/completions", args.engine), args)
.await?
//.text()
.json()
@ -284,7 +344,7 @@ mod unit {
use crate::{api, OpenAIClient, OpenAIError};
fn mocked_client() -> OpenAIClient {
let mut client = OpenAIClient::new("bogus").unwrap();
let mut client = OpenAIClient::new("bogus");
client.root = mockito::server_url();
client
}
@ -439,7 +499,7 @@ mod integration {
let sk = std::env::var("OPENAI_SK").expect(
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
);
OpenAIClient::new(&sk).expect("client build failed")
OpenAIClient::new(&sk)
}
#[tokio::test]
@ -468,7 +528,7 @@ mod integration {
#[tokio::test]
async fn complete_string() -> crate::Result<()> {
let client = get_client();
client.complete(api::Engine::Ada, "Hey there").await?;
client.complete("Hey there").await?;
Ok(())
}
@ -492,7 +552,7 @@ mod integration {
})
.build()
.expect("Build should have succeeded");
client.complete(api::Engine::Ada, args).await?;
client.complete(args).await?;
Ok(())
}
}