- stop reason changed back to `length` from `max_tokens`
- Fixed the chatloop example
- Added `Clone` and `Debug` to all public types (if possible, `Error` couldn't because ureq & surf don't implement Clone for their `Error` types)
- Fix README to match the current library and not use `.unwrap`
This commit is contained in:
Josh Kuhn 2020-12-21 22:32:16 -08:00 committed by GitHub
parent 3c3553d100
commit 6946acfd6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 19 deletions

View File

@ -16,18 +16,17 @@ $ cargo add openai-api
# Quickstart # Quickstart
```rust ```rust
use openai_api::OpenAIClient;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_token = std::env::var("OPENAI_SK").unwrap(); let api_token = std::env::var("OPENAI_SK")?;
let client = OpenAIClient::new(&api_token); let client = openai_api::Client::new(&api_token);
let prompt = String::from("Once upon a time,"); let prompt = String::from("Once upon a time,");
println!( println!(
"{}{}", "{}{}",
prompt, prompt,
client.complete(prompt.as_str()).await.unwrap() client.complete(prompt.as_str()).await?
); );
Ok(())
} }
``` ```
# Basic Usage # Basic Usage
@ -37,7 +36,7 @@ async fn main() {
For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string: For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string:
```rust ```rust
let response = client.complete("Once upon a time").await?; let response = client.complete_prompt("Once upon a time").await?;
println!("{}", response); println!("{}", response);
``` ```
@ -51,7 +50,7 @@ let args = openai_api::api::CompletionArgs::builder()
.temperature(0.7) .temperature(0.7)
.top_p(0.9) .top_p(0.9)
.stop(vec!["\n".into()]); .stop(vec!["\n".into()]);
let response = args.complete(&client).await?; let completion = client.complete_prompt(args).await?;
println!("Response: {}", response.choices[0].text); println!("Response: {}", response.choices[0].text);
println!("Model used: {}", response.model); println!("Model used: {}", response.model);
``` ```

View File

@ -9,8 +9,8 @@ The assistant is helpful, creative, clever, and very friendly.
Human: Hello, who are you? Human: Hello, who are you?
AI: I am an AI. How can I help you today?"; AI: I am an AI. How can I help you today?";
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_token = std::env::var("OPENAI_SK").unwrap(); let api_token = std::env::var("OPENAI_SK")?;
let client = Client::new(&api_token); let client = Client::new(&api_token);
let mut context = String::from(START_PROMPT); let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder(); let mut args = CompletionArgs::builder();
@ -28,7 +28,7 @@ async fn main() {
break; break;
} }
context.push_str("\nAI:"); context.push_str("\nAI:");
match args.prompt(context.as_str()).complete(&client).await { match args.prompt(context.as_str()).complete_prompt(&client).await {
Ok(completion) => { Ok(completion) => {
println!("\x1b[1;36m{}\x1b[1;0m", completion); println!("\x1b[1;36m{}\x1b[1;0m", completion);
context.push_str(&completion.choices[0].text); context.push_str(&completion.choices[0].text);
@ -40,4 +40,5 @@ async fn main() {
} }
} }
println!("Full conversation:\n{}", context); println!("Full conversation:\n{}", context);
Ok(())
} }

View File

@ -21,7 +21,7 @@ pub mod api {
} }
/// Engine description type /// Engine description type
#[derive(Deserialize, Debug, Eq, PartialEq)] #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct EngineInfo { pub struct EngineInfo {
pub id: Engine, pub id: Engine,
pub owner: String, pub owner: String,
@ -145,7 +145,7 @@ pub mod api {
} }
/// Represents a non-streamed completion response /// Represents a non-streamed completion response
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug, Clone)]
pub struct Completion { pub struct Completion {
/// Completion unique identifier /// Completion unique identifier
pub id: String, pub id: String,
@ -164,7 +164,7 @@ pub mod api {
} }
/// A single completion result /// A single completion result
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug, Clone)]
pub struct Choice { pub struct Choice {
/// The text of the completion. Will contain the prompt if echo is True. /// The text of the completion. Will contain the prompt if echo is True.
pub text: String, pub text: String,
@ -183,7 +183,7 @@ pub mod api {
} }
/// Represents a logprobs subdocument /// Represents a logprobs subdocument
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug, Clone)]
pub struct LogProbs { pub struct LogProbs {
pub tokens: Vec<String>, pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f64>>, pub token_logprobs: Vec<Option<f64>>,
@ -192,11 +192,11 @@ pub mod api {
} }
/// Reason a prompt completion finished. /// Reason a prompt completion finished.
#[derive(Deserialize, Debug, Eq, PartialEq)] #[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
#[non_exhaustive] #[non_exhaustive]
pub enum FinishReason { pub enum FinishReason {
/// The maximum length was reached /// The maximum length was reached
#[serde(rename = "max_tokens")] #[serde(rename = "length")]
MaxTokensReached, MaxTokensReached,
/// The stop token was encountered /// The stop token was encountered
#[serde(rename = "stop")] #[serde(rename = "stop")]
@ -204,7 +204,7 @@ pub mod api {
} }
/// Error response object from the server /// Error response object from the server
#[derive(Deserialize, Debug, Eq, PartialEq)] #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct ErrorMessage { pub struct ErrorMessage {
pub message: String, pub message: String,
#[serde(rename = "type")] #[serde(rename = "type")]
@ -325,6 +325,7 @@ fn sync_client(token: &str) -> ureq::Agent {
} }
/// Client object. Must be constructed to talk to the API. /// Client object. Must be constructed to talk to the API.
#[derive(Debug, Clone)]
pub struct Client { pub struct Client {
#[cfg(feature = "async")] #[cfg(feature = "async")]
async_client: surf::Client, async_client: surf::Client,
@ -732,7 +733,7 @@ mod unit {
"text": " there was a girl who", "text": " there was a girl who",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"finish_reason": "max_tokens" "finish_reason": "length"
} }
] ]
}"#, }"#,