- stop reason changed back to `length` from `max_tokens`
- Fixed the chatloop example
- Added `Clone` and `Debug` to all public types (if possible, `Error` couldn't because ureq & surf don't implement Clone for their `Error` types)
- Fix README to match the current library and not use `.unwrap`
This commit is contained in:
Josh Kuhn 2020-12-21 22:32:16 -08:00 committed by GitHub
parent 3c3553d100
commit 6946acfd6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 19 deletions

View File

@ -16,18 +16,17 @@ $ cargo add openai-api
# Quickstart
```rust
use openai_api::OpenAIClient;
#[tokio::main]
async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap();
let client = OpenAIClient::new(&api_token);
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_token = std::env::var("OPENAI_SK")?;
let client = openai_api::Client::new(&api_token);
let prompt = String::from("Once upon a time,");
println!(
"{}{}",
prompt,
client.complete(prompt.as_str()).await.unwrap()
client.complete(prompt.as_str()).await?
);
Ok(())
}
```
# Basic Usage
@ -37,7 +36,7 @@ async fn main() {
For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string:
```rust
let response = client.complete("Once upon a time").await?;
let response = client.complete_prompt("Once upon a time").await?;
println!("{}", response);
```
@ -51,7 +50,7 @@ let args = openai_api::api::CompletionArgs::builder()
.temperature(0.7)
.top_p(0.9)
.stop(vec!["\n".into()]);
let response = args.complete(&client).await?;
let completion = client.complete_prompt(args).await?;
println!("Response: {}", response.choices[0].text);
println!("Model used: {}", response.model);
```

View File

@ -9,8 +9,8 @@ The assistant is helpful, creative, clever, and very friendly.
Human: Hello, who are you?
AI: I am an AI. How can I help you today?";
#[tokio::main]
async fn main() {
let api_token = std::env::var("OPENAI_SK").unwrap();
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_token = std::env::var("OPENAI_SK")?;
let client = Client::new(&api_token);
let mut context = String::from(START_PROMPT);
let mut args = CompletionArgs::builder();
@ -28,7 +28,7 @@ async fn main() {
break;
}
context.push_str("\nAI:");
match args.prompt(context.as_str()).complete(&client).await {
match args.prompt(context.as_str()).complete_prompt(&client).await {
Ok(completion) => {
println!("\x1b[1;36m{}\x1b[1;0m", completion);
context.push_str(&completion.choices[0].text);
@ -40,4 +40,5 @@ async fn main() {
}
}
println!("Full conversation:\n{}", context);
Ok(())
}

View File

@ -21,7 +21,7 @@ pub mod api {
}
/// Engine description type
#[derive(Deserialize, Debug, Eq, PartialEq)]
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct EngineInfo {
pub id: Engine,
pub owner: String,
@ -145,7 +145,7 @@ pub mod api {
}
/// Represents a non-streamed completion response
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
pub struct Completion {
/// Completion unique identifier
pub id: String,
@ -164,7 +164,7 @@ pub mod api {
}
/// A single completion result
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
pub struct Choice {
/// The text of the completion. Will contain the prompt if echo is True.
pub text: String,
@ -183,7 +183,7 @@ pub mod api {
}
/// Represents a logprobs subdocument
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f64>>,
@ -192,11 +192,11 @@ pub mod api {
}
/// Reason a prompt completion finished.
#[derive(Deserialize, Debug, Eq, PartialEq)]
#[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
#[non_exhaustive]
pub enum FinishReason {
/// The maximum length was reached
#[serde(rename = "max_tokens")]
#[serde(rename = "length")]
MaxTokensReached,
/// The stop token was encountered
#[serde(rename = "stop")]
@ -204,7 +204,7 @@ pub mod api {
}
/// Error response object from the server
#[derive(Deserialize, Debug, Eq, PartialEq)]
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct ErrorMessage {
pub message: String,
#[serde(rename = "type")]
@ -325,6 +325,7 @@ fn sync_client(token: &str) -> ureq::Agent {
}
/// Client object. Must be constructed to talk to the API.
#[derive(Debug, Clone)]
pub struct Client {
#[cfg(feature = "async")]
async_client: surf::Client,
@ -732,7 +733,7 @@ mod unit {
"text": " there was a girl who",
"index": 0,
"logprobs": null,
"finish_reason": "max_tokens"
"finish_reason": "length"
}
]
}"#,