diff --git a/README.md b/README.md index efe2bfa..f599e7c 100644 --- a/README.md +++ b/README.md @@ -16,18 +16,17 @@ $ cargo add openai-api # Quickstart ```rust -use openai_api::OpenAIClient; - #[tokio::main] -async fn main() { - let api_token = std::env::var("OPENAI_SK").unwrap(); - let client = OpenAIClient::new(&api_token); +async fn main() -> Result<(), Box> { + let api_token = std::env::var("OPENAI_SK")?; + let client = openai_api::Client::new(&api_token); let prompt = String::from("Once upon a time,"); println!( "{}{}", prompt, - client.complete(prompt.as_str()).await.unwrap() + client.complete(prompt.as_str()).await? ); + Ok(()) } ``` # Basic Usage @@ -37,7 +36,7 @@ async fn main() { For simple demos and debugging, you can do a completion and use the `Display` instance of a `Completion` object to convert it to a string: ```rust -let response = client.complete("Once upon a time").await?; +let response = client.complete_prompt("Once upon a time").await?; println!("{}", response); ``` @@ -51,7 +50,7 @@ let args = openai_api::api::CompletionArgs::builder() .temperature(0.7) .top_p(0.9) .stop(vec!["\n".into()]); -let response = args.complete(&client).await?; +let completion = client.complete_prompt(args).await?; println!("Response: {}", response.choices[0].text); println!("Model used: {}", response.model); ``` diff --git a/examples/chatloop.rs b/examples/chatloop.rs index 94a7752..26cec0a 100644 --- a/examples/chatloop.rs +++ b/examples/chatloop.rs @@ -9,8 +9,8 @@ The assistant is helpful, creative, clever, and very friendly. Human: Hello, who are you? AI: I am an AI. How can I help you today?"; #[tokio::main] -async fn main() { - let api_token = std::env::var("OPENAI_SK").unwrap(); +async fn main() -> Result<(), Box> { + let api_token = std::env::var("OPENAI_SK")?; let client = Client::new(&api_token); let mut context = String::from(START_PROMPT); let mut args = CompletionArgs::builder(); @@ -28,7 +28,7 @@ async fn main() { break; } context.push_str("\nAI:"); - match args.prompt(context.as_str()).complete(&client).await { + match args.prompt(context.as_str()).complete_prompt(&client).await { Ok(completion) => { println!("\x1b[1;36m{}\x1b[1;0m", completion); context.push_str(&completion.choices[0].text); @@ -40,4 +40,5 @@ async fn main() { } } println!("Full conversation:\n{}", context); + Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index b4d16bb..4840b64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ pub mod api { } /// Engine description type - #[derive(Deserialize, Debug, Eq, PartialEq)] + #[derive(Deserialize, Debug, Eq, PartialEq, Clone)] pub struct EngineInfo { pub id: Engine, pub owner: String, @@ -145,7 +145,7 @@ pub mod api { } /// Represents a non-streamed completion response - #[derive(Deserialize, Debug)] + #[derive(Deserialize, Debug, Clone)] pub struct Completion { /// Completion unique identifier pub id: String, @@ -164,7 +164,7 @@ pub mod api { } /// A single completion result - #[derive(Deserialize, Debug)] + #[derive(Deserialize, Debug, Clone)] pub struct Choice { /// The text of the completion. Will contain the prompt if echo is True. pub text: String, @@ -183,7 +183,7 @@ pub mod api { } /// Represents a logprobs subdocument - #[derive(Deserialize, Debug)] + #[derive(Deserialize, Debug, Clone)] pub struct LogProbs { pub tokens: Vec, pub token_logprobs: Vec>, @@ -192,11 +192,11 @@ pub mod api { } /// Reason a prompt completion finished. - #[derive(Deserialize, Debug, Eq, PartialEq)] + #[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)] #[non_exhaustive] pub enum FinishReason { /// The maximum length was reached - #[serde(rename = "max_tokens")] + #[serde(rename = "length")] MaxTokensReached, /// The stop token was encountered #[serde(rename = "stop")] @@ -204,7 +204,7 @@ pub mod api { } /// Error response object from the server - #[derive(Deserialize, Debug, Eq, PartialEq)] + #[derive(Deserialize, Debug, Eq, PartialEq, Clone)] pub struct ErrorMessage { pub message: String, #[serde(rename = "type")] @@ -325,6 +325,7 @@ fn sync_client(token: &str) -> ureq::Agent { } /// Client object. Must be constructed to talk to the API. +#[derive(Debug, Clone)] pub struct Client { #[cfg(feature = "async")] async_client: surf::Client, @@ -732,7 +733,7 @@ mod unit { "text": " there was a girl who", "index": 0, "logprobs": null, - "finish_reason": "max_tokens" + "finish_reason": "length" } ] }"#,