Cleanups before publishing (#2)
Several cleanups: - Flesh out the README - Make the interface a little easier to use - Add some examples - Add some (sparse) documentation for public components - Mark the `FinishReason` and `Engine` enums as non-exhaustive so new members can be added in the future without breaking backwards compatibility
This commit is contained in:
parent
06cde8ae68
commit
62193135c7
1463
Cargo.lock
generated
1463
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,14 @@
|
||||
name = "openai-api"
|
||||
version = "0.1.0"
|
||||
authors = ["Josh Kuhn <deontologician@gmail.com>"]
|
||||
license-file = "LICENSE"
|
||||
edition = "2018"
|
||||
description = "OpenAI API library for rust"
|
||||
homepage = "https://github.com/deontologician/openai-api-rust/"
|
||||
repository = "https://github.com/deontologician/openai-api-rust/"
|
||||
keywords = ["openai", "gpt3"]
|
||||
categories = ["api-bindings", "asynchronous"]
|
||||
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
|
55
README.md
55
README.md
@ -1,2 +1,55 @@
|
||||
# openai-api-rust
|
||||
Rust client for OpenAI API
|
||||
A simple rust client for OpenAI API.
|
||||
|
||||
Has a few conveniences, but is mostly at the level of the API itself.
|
||||
|
||||
# Installation
|
||||
|
||||
```
|
||||
$ cargo add openai-api-rust
|
||||
```
|
||||
|
||||
# Quickstart
|
||||
|
||||
```rust
|
||||
use openai_api::{api::CompletionArgs, OpenAIClient};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
||||
let client = OpenAIClient::new(&api_token);
|
||||
let prompt = String::from("Once upon a time,");
|
||||
println!(
|
||||
"{}{}",
|
||||
prompt,
|
||||
client.complete(prompt.as_str()).await.unwrap()
|
||||
);
|
||||
}
|
||||
```
|
||||
# Basic Usage
|
||||
|
||||
## Creating a completion
|
||||
|
||||
For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string:
|
||||
|
||||
```rust
|
||||
let response = client.complete("Once upon a time").await?;
|
||||
println!("{}", response);
|
||||
```
|
||||
|
||||
To configure the prompt more explicitly, you can use the `CompletionArgs` builder:
|
||||
|
||||
```rust
|
||||
let args = openai_api::api::CompletionArgs::builder()
|
||||
.prompt("Once upon a time,")
|
||||
.engine(Engine::Davinci)
|
||||
.max_tokens(20)
|
||||
.temperature(0.7)
|
||||
.top_p(0.9)
|
||||
.stop(vec!["\n".into()]);
|
||||
let response = args.complete(&client).await?;
|
||||
println!("Response: {}", response.choices[0].text);
|
||||
println!("Model used: {}", response.model);
|
||||
```
|
||||
|
||||
See [examples/](./examples)
|
43
examples/chatloop.rs
Normal file
43
examples/chatloop.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use openai_api::{
|
||||
api::{CompletionArgs, Engine},
|
||||
OpenAIClient,
|
||||
};
|
||||
|
||||
const START_PROMPT: &str = "
|
||||
The following is a conversation with an AI assistant.
|
||||
The assistant is helpful, creative, clever, and very friendly.
|
||||
Human: Hello, who are you?
|
||||
AI: I am an AI. How can I help you today?";
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
||||
let client = OpenAIClient::new(&api_token);
|
||||
let mut context = String::from(START_PROMPT);
|
||||
let mut args = CompletionArgs::builder();
|
||||
args.engine(Engine::Davinci)
|
||||
.max_tokens(45)
|
||||
.stop(vec!["\n".into()])
|
||||
.top_p(0.5)
|
||||
.temperature(0.9)
|
||||
.frequency_penalty(0.5);
|
||||
println!("\x1b[1;36m{}\x1b[1;0m", context.split('\n').last().unwrap());
|
||||
loop {
|
||||
context.push_str("\nHuman: ");
|
||||
if let Err(e) = std::io::stdin().read_line(&mut context) {
|
||||
eprintln!("Error: {}", e);
|
||||
break;
|
||||
}
|
||||
context.push_str("\nAI:");
|
||||
match args.prompt(context.as_str()).complete(&client).await {
|
||||
Ok(completion) => {
|
||||
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
||||
context.push_str(&completion.choices[0].text);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("Full conversation:\n{}", context);
|
||||
}
|
14
examples/story.rs
Normal file
14
examples/story.rs
Normal file
@ -0,0 +1,14 @@
|
||||
///! Example that prints out a story from a prompt. Used in the readme.
|
||||
use openai_api::OpenAIClient;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
||||
let client = OpenAIClient::new(&api_token);
|
||||
let prompt = String::from("Once upon a time,");
|
||||
println!(
|
||||
"{}{}",
|
||||
prompt,
|
||||
client.complete(prompt.as_str()).await.unwrap()
|
||||
);
|
||||
}
|
152
src/lib.rs
152
src/lib.rs
@ -1,3 +1,4 @@
|
||||
///! OpenAI API client library
|
||||
#[macro_use]
|
||||
extern crate derive_builder;
|
||||
|
||||
@ -6,11 +7,12 @@ use thiserror::Error;
|
||||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
|
||||
#[allow(clippy::clippy::default_trait_access)]
|
||||
#[allow(clippy::default_trait_access)]
|
||||
pub mod api {
|
||||
|
||||
//! Data types corresponding to requests and responses from the API
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::OpenAIClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Container type. Used in the api, but not useful for clients of this library
|
||||
@ -30,6 +32,7 @@ pub mod api {
|
||||
/// Engine types, known and unknown
|
||||
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[non_exhaustive] // prevent clients from matching on every option
|
||||
pub enum Engine {
|
||||
Ada,
|
||||
Babbage,
|
||||
@ -56,10 +59,13 @@ pub mod api {
|
||||
}
|
||||
|
||||
/// Options that affect the result
|
||||
#[derive(Serialize, Debug, Builder)]
|
||||
#[derive(Serialize, Debug, Builder, Clone)]
|
||||
pub struct CompletionArgs {
|
||||
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
|
||||
prompt: String,
|
||||
#[builder(default = "Engine::Davinci")]
|
||||
#[serde(skip_serializing)]
|
||||
pub(super) engine: Engine,
|
||||
#[builder(default = "16")]
|
||||
max_tokens: u64,
|
||||
#[builder(default = "1.0")]
|
||||
@ -81,6 +87,7 @@ pub mod api {
|
||||
#[builder(default)]
|
||||
logit_bias: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
/* {
|
||||
"stream": false, // SSE streams back results
|
||||
"best_of": Option<u64>, //cant be used with stream
|
||||
@ -88,40 +95,80 @@ pub mod api {
|
||||
*/
|
||||
// TODO: add validators for the different arguments
|
||||
|
||||
impl Default for CompletionArgs {
|
||||
fn default() -> Self {
|
||||
CompletionArgsBuilder::default()
|
||||
.build()
|
||||
.expect("Client error, invalid defaults")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for CompletionArgs {
|
||||
fn from(prompt_string: &str) -> Self {
|
||||
Self {
|
||||
prompt: prompt_string.into(),
|
||||
..CompletionArgs::default()
|
||||
..CompletionArgsBuilder::default()
|
||||
.build()
|
||||
.expect("default should build")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionArgs {
|
||||
/// Build a `CompletionArgs` from the defaults
|
||||
#[must_use]
|
||||
pub fn builder() -> CompletionArgsBuilder {
|
||||
CompletionArgsBuilder::default()
|
||||
}
|
||||
|
||||
/// Request a completion from the api
|
||||
///
|
||||
/// # Errors
|
||||
/// `OpenAIError::APIError` if the api returns an error
|
||||
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
|
||||
client.complete(self.clone()).await
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionArgsBuilder {
|
||||
/// Request a completion from the api
|
||||
///
|
||||
/// # Errors
|
||||
/// `OpenAIError::BadArguments` if the arguments to complete are not valid
|
||||
/// `OpenAIError::APIError` if the api returns an error
|
||||
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
|
||||
client.complete(self.build()?).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a non-streamed completion response
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Completion {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
/// Completion unique identifier
|
||||
pub id: String,
|
||||
/// Unix timestamp when the completion was generated
|
||||
pub created: u64,
|
||||
/// Exact model type and version used for the completion
|
||||
pub model: String,
|
||||
/// Timestamp
|
||||
pub choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
/// Represents a single choice
|
||||
impl std::fmt::Display for Completion {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.choices[0])
|
||||
}
|
||||
}
|
||||
|
||||
/// A single completion result
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Choice {
|
||||
text: String,
|
||||
index: u64,
|
||||
logprobs: Option<LogProbs>,
|
||||
finish_reason: FinishReason,
|
||||
/// The text of the completion. Will contain the prompt if echo is True.
|
||||
pub text: String,
|
||||
/// Offset in the result where the completion began. Useful if using echo.
|
||||
pub index: u64,
|
||||
/// If requested, the log probabilities of the completion tokens
|
||||
pub logprobs: Option<LogProbs>,
|
||||
/// Why the completion ended when it did
|
||||
pub finish_reason: FinishReason,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Choice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.text)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a logprobs subdocument
|
||||
@ -133,13 +180,18 @@ pub mod api {
|
||||
text_offset: Vec<u64>,
|
||||
}
|
||||
|
||||
/// Reason a prompt completion finished.
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[non_exhaustive]
|
||||
pub enum FinishReason {
|
||||
/// The maximum length was reached
|
||||
Length,
|
||||
/// The stop token was encountered
|
||||
Stop,
|
||||
}
|
||||
|
||||
/// Error response object from the server
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ErrorMessage {
|
||||
pub message: String,
|
||||
@ -147,46 +199,58 @@ pub mod api {
|
||||
pub error_type: String,
|
||||
}
|
||||
|
||||
/// API-level wrapper used in deserialization
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ErrorWrapper {
|
||||
pub(super) struct ErrorWrapper {
|
||||
pub error: ErrorMessage,
|
||||
}
|
||||
}
|
||||
|
||||
/// This library's main `Error` type.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum OpenAIError {
|
||||
#[error("Invalid secret key")]
|
||||
InvalidAPIKey {
|
||||
#[from]
|
||||
source: reqwest::header::InvalidHeaderValue,
|
||||
},
|
||||
/// An error returned by the API itself
|
||||
#[error("API Returned an Error document")]
|
||||
APIError(api::ErrorMessage),
|
||||
/// An error the client discovers before talking to the API
|
||||
#[error("Bad arguments")]
|
||||
BadArguments(String),
|
||||
}
|
||||
|
||||
impl From<api::ErrorMessage> for OpenAIError {
|
||||
fn from(e: api::ErrorMessage) -> Self {
|
||||
OpenAIError::APIError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for OpenAIError {
|
||||
fn from(e: String) -> Self {
|
||||
OpenAIError::BadArguments(e)
|
||||
}
|
||||
}
|
||||
|
||||
/// Client object. Must be constructed to talk to the API.
|
||||
pub struct OpenAIClient {
|
||||
client: reqwest::Client,
|
||||
root: String,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
/// Creates a new `OpenAIClient`
|
||||
///
|
||||
/// # Errors
|
||||
/// `OpenAIError::InvalidAPIKey` if the api token has invalid characters
|
||||
pub fn new(token: &str) -> Result<Self> {
|
||||
/// Creates a new `OpenAIClient` given an api token
|
||||
pub fn new(token: &str) -> Self {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))?,
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
|
||||
.expect("Client library error. Header value badly formatted"),
|
||||
);
|
||||
Ok(Self {
|
||||
Self {
|
||||
client: reqwest::Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.expect("Client library error. Should have constructed a valid http client."),
|
||||
root: "https://api.openai.com/v1".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Private helper for making gets
|
||||
@ -216,8 +280,6 @@ impl OpenAIClient {
|
||||
///
|
||||
/// # Errors
|
||||
/// - `OpenAIError::APIError` if the server returns an error
|
||||
/// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most
|
||||
/// likely a bug in this client, please report it)
|
||||
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
|
||||
self.get("engines").await.map(|r: api::Container<_>| r.data)
|
||||
}
|
||||
@ -227,8 +289,6 @@ impl OpenAIClient {
|
||||
///
|
||||
/// # Errors
|
||||
/// - `OpenAIError::APIError` if the server returns an error
|
||||
/// - `OpenAIError::ServerFormatError` if the json response wasn't parseable (most
|
||||
/// likely a bug in this client, please report it)
|
||||
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
|
||||
self.get(&format!("engines/{}", engine)).await
|
||||
}
|
||||
@ -265,11 +325,11 @@ impl OpenAIClient {
|
||||
/// - `OpenAIError::APIError` if the api returns an error
|
||||
pub async fn complete(
|
||||
&self,
|
||||
engine: api::Engine,
|
||||
prompt: impl Into<api::CompletionArgs>,
|
||||
) -> Result<api::Completion> {
|
||||
let args = prompt.into();
|
||||
Ok(self
|
||||
.post(&format!("engines/{}/completions", engine), prompt.into())
|
||||
.post(&format!("engines/{}/completions", args.engine), args)
|
||||
.await?
|
||||
//.text()
|
||||
.json()
|
||||
@ -284,7 +344,7 @@ mod unit {
|
||||
use crate::{api, OpenAIClient, OpenAIError};
|
||||
|
||||
fn mocked_client() -> OpenAIClient {
|
||||
let mut client = OpenAIClient::new("bogus").unwrap();
|
||||
let mut client = OpenAIClient::new("bogus");
|
||||
client.root = mockito::server_url();
|
||||
client
|
||||
}
|
||||
@ -439,7 +499,7 @@ mod integration {
|
||||
let sk = std::env::var("OPENAI_SK").expect(
|
||||
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
|
||||
);
|
||||
OpenAIClient::new(&sk).expect("client build failed")
|
||||
OpenAIClient::new(&sk)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -468,7 +528,7 @@ mod integration {
|
||||
#[tokio::test]
|
||||
async fn complete_string() -> crate::Result<()> {
|
||||
let client = get_client();
|
||||
client.complete(api::Engine::Ada, "Hey there").await?;
|
||||
client.complete("Hey there").await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -492,7 +552,7 @@ mod integration {
|
||||
})
|
||||
.build()
|
||||
.expect("Build should have succeeded");
|
||||
client.complete(api::Engine::Ada, args).await?;
|
||||
client.complete(args).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user