Tweaks (#9)
- stop reason changed back to `length` from `max_tokens` - Fixed the chatloop example - Added `Clone` and `Debug` to all public types (if possible, `Error` couldn't because ureq & surf don't implement Clone for their `Error` types) - Fix README to match the current library and not use `.unwrap`
This commit is contained in:
parent
3c3553d100
commit
6946acfd6c
15
README.md
15
README.md
@ -16,18 +16,17 @@ $ cargo add openai-api
|
||||
# Quickstart
|
||||
|
||||
```rust
|
||||
use openai_api::OpenAIClient;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
||||
let client = OpenAIClient::new(&api_token);
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let api_token = std::env::var("OPENAI_SK")?;
|
||||
let client = openai_api::Client::new(&api_token);
|
||||
let prompt = String::from("Once upon a time,");
|
||||
println!(
|
||||
"{}{}",
|
||||
prompt,
|
||||
client.complete(prompt.as_str()).await.unwrap()
|
||||
client.complete(prompt.as_str()).await?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
# Basic Usage
|
||||
@ -37,7 +36,7 @@ async fn main() {
|
||||
For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string:
|
||||
|
||||
```rust
|
||||
let response = client.complete("Once upon a time").await?;
|
||||
let response = client.complete_prompt("Once upon a time").await?;
|
||||
println!("{}", response);
|
||||
```
|
||||
|
||||
@ -51,7 +50,7 @@ let args = openai_api::api::CompletionArgs::builder()
|
||||
.temperature(0.7)
|
||||
.top_p(0.9)
|
||||
.stop(vec!["\n".into()]);
|
||||
let response = args.complete(&client).await?;
|
||||
let completion = client.complete_prompt(args).await?;
|
||||
println!("Response: {}", response.choices[0].text);
|
||||
println!("Model used: {}", response.model);
|
||||
```
|
||||
|
@ -9,8 +9,8 @@ The assistant is helpful, creative, clever, and very friendly.
|
||||
Human: Hello, who are you?
|
||||
AI: I am an AI. How can I help you today?";
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let api_token = std::env::var("OPENAI_SK").unwrap();
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let api_token = std::env::var("OPENAI_SK")?;
|
||||
let client = Client::new(&api_token);
|
||||
let mut context = String::from(START_PROMPT);
|
||||
let mut args = CompletionArgs::builder();
|
||||
@ -28,7 +28,7 @@ async fn main() {
|
||||
break;
|
||||
}
|
||||
context.push_str("\nAI:");
|
||||
match args.prompt(context.as_str()).complete(&client).await {
|
||||
match args.prompt(context.as_str()).complete_prompt(&client).await {
|
||||
Ok(completion) => {
|
||||
println!("\x1b[1;36m{}\x1b[1;0m", completion);
|
||||
context.push_str(&completion.choices[0].text);
|
||||
@ -40,4 +40,5 @@ async fn main() {
|
||||
}
|
||||
}
|
||||
println!("Full conversation:\n{}", context);
|
||||
Ok(())
|
||||
}
|
||||
|
17
src/lib.rs
17
src/lib.rs
@ -21,7 +21,7 @@ pub mod api {
|
||||
}
|
||||
|
||||
/// Engine description type
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct EngineInfo {
|
||||
pub id: Engine,
|
||||
pub owner: String,
|
||||
@ -145,7 +145,7 @@ pub mod api {
|
||||
}
|
||||
|
||||
/// Represents a non-streamed completion response
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Completion {
|
||||
/// Completion unique identifier
|
||||
pub id: String,
|
||||
@ -164,7 +164,7 @@ pub mod api {
|
||||
}
|
||||
|
||||
/// A single completion result
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Choice {
|
||||
/// The text of the completion. Will contain the prompt if echo is True.
|
||||
pub text: String,
|
||||
@ -183,7 +183,7 @@ pub mod api {
|
||||
}
|
||||
|
||||
/// Represents a logprobs subdocument
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct LogProbs {
|
||||
pub tokens: Vec<String>,
|
||||
pub token_logprobs: Vec<Option<f64>>,
|
||||
@ -192,11 +192,11 @@ pub mod api {
|
||||
}
|
||||
|
||||
/// Reason a prompt completion finished.
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
|
||||
#[non_exhaustive]
|
||||
pub enum FinishReason {
|
||||
/// The maximum length was reached
|
||||
#[serde(rename = "max_tokens")]
|
||||
#[serde(rename = "length")]
|
||||
MaxTokensReached,
|
||||
/// The stop token was encountered
|
||||
#[serde(rename = "stop")]
|
||||
@ -204,7 +204,7 @@ pub mod api {
|
||||
}
|
||||
|
||||
/// Error response object from the server
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ErrorMessage {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
@ -325,6 +325,7 @@ fn sync_client(token: &str) -> ureq::Agent {
|
||||
}
|
||||
|
||||
/// Client object. Must be constructed to talk to the API.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
#[cfg(feature = "async")]
|
||||
async_client: surf::Client,
|
||||
@ -732,7 +733,7 @@ mod unit {
|
||||
"text": " there was a girl who",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": "max_tokens"
|
||||
"finish_reason": "length"
|
||||
}
|
||||
]
|
||||
}"#,
|
||||
|
Loading…
x
Reference in New Issue
Block a user