Tweaks (#9)
- stop reason changed back to `length` from `max_tokens` - Fixed the chatloop example - Added `Clone` and `Debug` to all public types (if possible, `Error` couldn't because ureq & surf don't implement Clone for their `Error` types) - Fix README to match the current library and not use `.unwrap`
This commit is contained in:
parent
3c3553d100
commit
6946acfd6c
15
README.md
15
README.md
@ -16,18 +16,17 @@ $ cargo add openai-api
|
|||||||
# Quickstart
|
# Quickstart
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
use openai_api::OpenAIClient;
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
let api_token = std::env::var("OPENAI_SK")?;
|
||||||
let client = OpenAIClient::new(&api_token);
|
let client = openai_api::Client::new(&api_token);
|
||||||
let prompt = String::from("Once upon a time,");
|
let prompt = String::from("Once upon a time,");
|
||||||
println!(
|
println!(
|
||||||
"{}{}",
|
"{}{}",
|
||||||
prompt,
|
prompt,
|
||||||
client.complete(prompt.as_str()).await.unwrap()
|
client.complete(prompt.as_str()).await?
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
# Basic Usage
|
# Basic Usage
|
||||||
@ -37,7 +36,7 @@ async fn main() {
|
|||||||
For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string:
|
For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string:
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
let response = client.complete("Once upon a time").await?;
|
let response = client.complete_prompt("Once upon a time").await?;
|
||||||
println!("{}", response);
|
println!("{}", response);
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -51,7 +50,7 @@ let args = openai_api::api::CompletionArgs::builder()
|
|||||||
.temperature(0.7)
|
.temperature(0.7)
|
||||||
.top_p(0.9)
|
.top_p(0.9)
|
||||||
.stop(vec!["\n".into()]);
|
.stop(vec!["\n".into()]);
|
||||||
let response = args.complete(&client).await?;
|
let completion = client.complete_prompt(args).await?;
|
||||||
println!("Response: {}", response.choices[0].text);
|
println!("Response: {}", response.choices[0].text);
|
||||||
println!("Model used: {}", response.model);
|
println!("Model used: {}", response.model);
|
||||||
```
|
```
|
||||||
|
@ -9,8 +9,8 @@ The assistant is helpful, creative, clever, and very friendly.
|
|||||||
Human: Hello, who are you?
|
Human: Hello, who are you?
|
||||||
AI: I am an AI. How can I help you today?";
|
AI: I am an AI. How can I help you today?";
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
let api_token = std::env::var("OPENAI_SK")?;
|
||||||
let client = Client::new(&api_token);
|
let client = Client::new(&api_token);
|
||||||
let mut context = String::from(START_PROMPT);
|
let mut context = String::from(START_PROMPT);
|
||||||
let mut args = CompletionArgs::builder();
|
let mut args = CompletionArgs::builder();
|
||||||
@ -28,7 +28,7 @@ async fn main() {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
context.push_str("\nAI:");
|
context.push_str("\nAI:");
|
||||||
match args.prompt(context.as_str()).complete(&client).await {
|
match args.prompt(context.as_str()).complete_prompt(&client).await {
|
||||||
Ok(completion) => {
|
Ok(completion) => {
|
||||||
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
||||||
context.push_str(&completion.choices[0].text);
|
context.push_str(&completion.choices[0].text);
|
||||||
@ -40,4 +40,5 @@ async fn main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("Full conversation:\n{}", context);
|
println!("Full conversation:\n{}", context);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
17
src/lib.rs
17
src/lib.rs
@ -21,7 +21,7 @@ pub mod api {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Engine description type
|
/// Engine description type
|
||||||
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||||
pub struct EngineInfo {
|
pub struct EngineInfo {
|
||||||
pub id: Engine,
|
pub id: Engine,
|
||||||
pub owner: String,
|
pub owner: String,
|
||||||
@ -145,7 +145,7 @@ pub mod api {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a non-streamed completion response
|
/// Represents a non-streamed completion response
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
pub struct Completion {
|
pub struct Completion {
|
||||||
/// Completion unique identifier
|
/// Completion unique identifier
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -164,7 +164,7 @@ pub mod api {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A single completion result
|
/// A single completion result
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
pub struct Choice {
|
pub struct Choice {
|
||||||
/// The text of the completion. Will contain the prompt if echo is True.
|
/// The text of the completion. Will contain the prompt if echo is True.
|
||||||
pub text: String,
|
pub text: String,
|
||||||
@ -183,7 +183,7 @@ pub mod api {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a logprobs subdocument
|
/// Represents a logprobs subdocument
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
pub struct LogProbs {
|
pub struct LogProbs {
|
||||||
pub tokens: Vec<String>,
|
pub tokens: Vec<String>,
|
||||||
pub token_logprobs: Vec<Option<f64>>,
|
pub token_logprobs: Vec<Option<f64>>,
|
||||||
@ -192,11 +192,11 @@ pub mod api {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Reason a prompt completion finished.
|
/// Reason a prompt completion finished.
|
||||||
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub enum FinishReason {
|
pub enum FinishReason {
|
||||||
/// The maximum length was reached
|
/// The maximum length was reached
|
||||||
#[serde(rename = "max_tokens")]
|
#[serde(rename = "length")]
|
||||||
MaxTokensReached,
|
MaxTokensReached,
|
||||||
/// The stop token was encountered
|
/// The stop token was encountered
|
||||||
#[serde(rename = "stop")]
|
#[serde(rename = "stop")]
|
||||||
@ -204,7 +204,7 @@ pub mod api {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Error response object from the server
|
/// Error response object from the server
|
||||||
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||||
pub struct ErrorMessage {
|
pub struct ErrorMessage {
|
||||||
pub message: String,
|
pub message: String,
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
@ -325,6 +325,7 @@ fn sync_client(token: &str) -> ureq::Agent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Client object. Must be constructed to talk to the API.
|
/// Client object. Must be constructed to talk to the API.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
#[cfg(feature = "async")]
|
#[cfg(feature = "async")]
|
||||||
async_client: surf::Client,
|
async_client: surf::Client,
|
||||||
@ -732,7 +733,7 @@ mod unit {
|
|||||||
"text": " there was a girl who",
|
"text": " there was a girl who",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"finish_reason": "max_tokens"
|
"finish_reason": "length"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}"#,
|
}"#,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user