diff --git a/Cargo.lock b/Cargo.lock index acc4040..dfeda87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -216,6 +216,23 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" +[[package]] +name = "futures-io" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91" + +[[package]] +name = "futures-macro" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.27" @@ -235,9 +252,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" dependencies = [ "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -521,11 +543,13 @@ name = "openai_rs" version = "0.1.0" dependencies = [ "anyhow", + "base64", "derive_builder", "reqwest", "serde", "serde_json", "tokio", + "tokio-util", ] [[package]] @@ -677,10 +701,12 @@ dependencies = [ "serde_urlencoded", "tokio", "tokio-native-tls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg", ] @@ -1087,6 +1113,19 @@ version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" +[[package]] +name = "wasm-streams" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.61" diff --git a/Cargo.toml b/Cargo.toml index 0917de6..d83279c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,10 @@ edition = "2021" [dependencies] anyhow = "1.0.69" +base64 = "0.21.0" derive_builder = "0.12.0" -reqwest = { version = "0.11.14", features = [ "json", "multipart" ] } +reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] } serde = { version = "1.0.156", features = ["derive"] } serde_json = "1.0.94" tokio = { version = "1.26.0", features = [ "full" ] } +tokio-util = { version = "0.7.7", features = [ "codec" ] } diff --git a/clown.png b/clown.png new file mode 100755 index 0000000..30b5168 Binary files /dev/null and b/clown.png differ diff --git a/src/context.rs b/src/context.rs index fb4468f..9f4d4fd 100644 --- a/src/context.rs +++ b/src/context.rs @@ -7,7 +7,7 @@ pub struct Context { org_id: Option } -const API_URL: &str = "https://api.openai.com"; +pub(crate) const API_URL: &str = "https://api.openai.com"; impl Context { pub fn new(api_key: String) -> Self { @@ -24,7 +24,7 @@ impl Context { } } - fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder { + pub(crate) fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder { ( if let Some(ref org_id) = self.org_id { builder.header("OpenAI-Organization", org_id) diff --git a/src/image.rs b/src/image.rs index 7850c3f..de1b9de 100644 --- a/src/image.rs +++ b/src/image.rs @@ -7,14 +7,20 @@ pub enum ResponseFormat { Base64, } +impl ToString for ResponseFormat { + fn to_string(&self) -> String { + match self { + Self::URL => "url".to_string(), + Self::Base64 => "b64_json".to_string(), + } + } +} + impl Serialize for ResponseFormat { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { - match self { - Self::URL => serializer.serialize_str("url"), - Self::Base64 => serializer.serialize_str("b64_json"), - } + serializer.serialize_str(&self.to_string()) } } @@ -53,6 +59,9 @@ pub struct ImageRequest { #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(setter(into, strip_option), default)] + pub temperature: Option, } #[derive(Debug, Deserialize)] diff --git a/src/image_edit.rs b/src/image_edit.rs new file mode 100644 index 0000000..ab59a22 --- /dev/null +++ b/src/image_edit.rs @@ -0,0 +1,68 @@ +use base64::{Engine, prelude::BASE64_STANDARD}; +use derive_builder::Builder; +use reqwest::{Body, multipart::{Part, Form}, Client}; +use crate::{image::{ResponseFormat, ImageResponse}, context::{API_URL, Context}}; +use tokio_util::codec::{BytesCodec, FramedRead}; + +#[derive(Debug)] +pub enum ImageFile { + File(tokio::fs::File), + Data(Vec), +} + +#[derive(Debug, Builder)] +#[builder(pattern = "owned")] +pub struct ImageEditRequest { + #[builder(setter(into))] + pub image: ImageFile, + #[builder(setter(into, strip_option), default)] + pub mask: Option, + #[builder(setter(into))] + pub prompt: String, + #[builder(setter(into, strip_option), default)] + pub n: Option, + #[builder(setter(into, strip_option), default)] + pub response_format: Option, + #[builder(setter(into, strip_option), default)] + pub user: Option, + #[builder(setter(into, strip_option), default)] + pub temperature: Option, +} + +fn write_file(form: Form, file: ImageFile, name: impl Into) -> Form { + let name = name.into(); + match file { + ImageFile::File(file) => + form.part(name.clone(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(name)), + ImageFile::Data(data) => + form.text(name, BASE64_STANDARD.encode(data.as_slice())), + } +} + +impl Context { + pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result { + let mut form = Form::new(); + form = form.text("prompt", req.prompt); + form = write_file(form, req.image, "image"); + + if let Some(n) = req.n { + form = form.text("n", n.to_string()); + } + if let Some(response_format) = req.response_format { + form = form.text("response_format", response_format.to_string()); + } + if let Some(user) = req.user { + form = form.text("user", user); + } + + if let Some(mask) = req.mask { + form = write_file(form, mask, "mask"); + } + + if let Some(temperature) = req.temperature { + form = form.text("temperature", temperature.to_string()); + } + + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/edits")).multipart(form)).send().await?.json::().await?) + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index ac92544..db026a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,18 @@ pub mod completion; pub mod chat; pub mod edits; pub mod image; +pub mod image_edit; #[cfg(test)] mod tests { - use crate::chat::ChatMessage; + use tokio::fs::File; + + use crate::chat::{ChatHistoryBuilder, ChatMessage, Role}; use crate::context::Context; use crate::completion::CompletionRequestBuilder; - use crate::image::{Image, ResponseFormat}; + use crate::image::{Image, ResponseFormat, ImageRequestBuilder}; + use crate::edits::EditRequestBuilder; + use crate::image_edit::{ImageEditRequestBuilder, ImageFile}; fn get_api() -> anyhow::Result { Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) @@ -34,10 +39,10 @@ mod tests { let completion = ctx.unwrap().create_completion( CompletionRequestBuilder::default() - .model("text-davinci-003") - .prompt("Say 'this is a test'") - .build() - .unwrap() + .model("text-davinci-003") + .prompt("Say 'this is a test'") + .build() + .unwrap() ).await; assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err()); @@ -50,11 +55,11 @@ mod tests { assert!(ctx.is_ok(), "Could not load context"); let completion = ctx.unwrap().create_chat_completion( - crate::chat::ChatHistoryBuilder::default() - .messages(vec![ChatMessage::new(crate::chat::Role::User, "Respond to this message with 'this is a test'")]) - .model("gpt-3.5-turbo") - .build() - .unwrap() + ChatHistoryBuilder::default() + .messages(vec![ChatMessage::new(Role::User, "Respond to this message with 'this is a test'")]) + .model("gpt-3.5-turbo") + .build() + .unwrap() ).await; assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err()); @@ -69,12 +74,12 @@ mod tests { // Autocorrect English spelling errors let edit = ctx.create_edit( - crate::edits::EditRequestBuilder::default() - .model("text-davinci-edit-001") - .instruction("Correct all spelling mistakes") - .input("What a wnoderful day!") - .build() - .unwrap() + EditRequestBuilder::default() + .model("text-davinci-edit-001") + .instruction("Correct all spelling mistakes") + .input("What a wnoderful day!") + .build() + .unwrap() ).await; assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err()); @@ -84,17 +89,17 @@ mod tests { #[tokio::test] async fn test_image() { - const IMAGE_PROMPT: &str = "In a realistic style, a ginger cat gracefully walking along a thin brick wall"; + const IMAGE_PROMPT: &str = "A real ginger cat gracefully walking along a real, thin brick wall"; let ctx = get_api(); assert!(ctx.is_ok(), "Could not load context"); let ctx = ctx.unwrap(); let image = ctx.create_image( - crate::image::ImageRequestBuilder::default() - .prompt(IMAGE_PROMPT) - .response_format(ResponseFormat::URL) - .build() - .unwrap() + ImageRequestBuilder::default() + .prompt(IMAGE_PROMPT) + .response_format(ResponseFormat::URL) + .build() + .unwrap() ).await; assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err()); @@ -110,4 +115,33 @@ mod tests { } } } + + #[tokio::test] + async fn test_image_edit() { + const IMAGE_PROMPT: &str = "A real ginger cat gracefully walking along a real, thin brick wall"; + let ctx = get_api(); + assert!(ctx.is_ok(), "Could not load context"); + let ctx = ctx.unwrap(); + + let image = ctx.create_image_edit( + ImageEditRequestBuilder::default() + .image(ImageFile::File(File::open("clown.png").await.unwrap())) + .prompt("Blue nose") + .build() + .unwrap() + ).await; + + assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err()); + assert!(image.as_ref().unwrap().data.len() == 1, "No image found"); + assert!(matches!(image.as_ref().unwrap().data[0], Image::URL(_)), "No image found"); + println!("Image prompt: {IMAGE_PROMPT}"); + match image.unwrap().data[0] { + Image::URL(ref url) => { + println!("Generated edited image URL: {url}"); + } + Image::Base64(ref b64) => { + println!("Generated edited image Base64: {b64}"); + } + } + } } \ No newline at end of file