diff --git a/clown_original.png b/clown_original.png new file mode 100644 index 0000000..a4b663c Binary files /dev/null and b/clown_original.png differ diff --git a/src/image.rs b/src/image.rs index feaf251..867a37e 100644 --- a/src/image.rs +++ b/src/image.rs @@ -34,15 +34,21 @@ pub enum ImageSize { Size1024, } +impl ToString for ImageSize { + fn to_string(&self) -> String { + match self { + Self::Size256 => "256x256".to_string(), + Self::Size512 => "512x512".to_string(), + Self::Size1024 => "1024x1024".to_string(), + } + } +} + impl Serialize for ImageSize { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { - match self { - Self::Size256 => serializer.serialize_str("256x256"), - Self::Size512 => serializer.serialize_str("512x512"), - Self::Size1024 => serializer.serialize_str("1024x1024"), - } + serializer.serialize_str(&self.to_string()) } } diff --git a/src/image_edit.rs b/src/image_edit.rs index ab59a22..1166a5b 100644 --- a/src/image_edit.rs +++ b/src/image_edit.rs @@ -1,7 +1,7 @@ use base64::{Engine, prelude::BASE64_STANDARD}; use derive_builder::Builder; use reqwest::{Body, multipart::{Part, Form}, Client}; -use crate::{image::{ResponseFormat, ImageResponse}, context::{API_URL, Context}}; +use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}}; use tokio_util::codec::{BytesCodec, FramedRead}; #[derive(Debug)] @@ -27,9 +27,11 @@ pub struct ImageEditRequest { pub user: Option, #[builder(setter(into, strip_option), default)] pub temperature: Option, + #[builder(setter(into, strip_option), default)] + pub size: Option, } -fn write_file(form: Form, file: ImageFile, name: impl Into) -> Form { +pub(crate) fn write_file(form: Form, file: ImageFile, name: impl Into) -> Form { let name = name.into(); match file { ImageFile::File(file) => @@ -62,6 +64,10 @@ impl Context { if let Some(temperature) = req.temperature { form = form.text("temperature", temperature.to_string()); } + + if let Some(size) = req.size { + form = form.text("size", size.to_string()); + } Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/edits")).multipart(form)).send().await?.json::().await?) } diff --git a/src/image_variation.rs b/src/image_variation.rs new file mode 100644 index 0000000..6addda8 --- /dev/null +++ b/src/image_variation.rs @@ -0,0 +1,43 @@ +use derive_builder::Builder; +use reqwest::{multipart::Form, Client}; + +use crate::{image_edit::{ImageFile, write_file}, image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}}; + +#[derive(Debug, Builder)] +#[builder(pattern = "owned")] +pub struct ImageVariationRequest { + #[builder(setter(into))] + pub image: ImageFile, + #[builder(setter(into, strip_option), default)] + pub n: Option, + #[builder(setter(into, strip_option), default)] + pub size: Option, + #[builder(setter(into, strip_option), default)] + pub user: Option, + #[builder(setter(into, strip_option), default)] + pub response_format: Option, +} + + +impl Context { + pub async fn create_image_variation(&self, req: ImageVariationRequest) -> anyhow::Result { + let mut form = Form::new(); + form = write_file(form, req.image, "image"); + + if let Some(n) = req.n { + form = form.text("n", n.to_string()); + } + if let Some(response_format) = req.response_format { + form = form.text("response_format", response_format.to_string()); + } + if let Some(user) = req.user { + form = form.text("user", user); + } + + if let Some(size) = req.size { + form = form.text("size", size.to_string()); + } + + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/variations")).multipart(form)).send().await?.json::().await?) + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index db026a0..2d53cb0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub mod chat; pub mod edits; pub mod image; pub mod image_edit; +pub mod image_variation; #[cfg(test)] mod tests { @@ -16,6 +17,7 @@ mod tests { use crate::image::{Image, ResponseFormat, ImageRequestBuilder}; use crate::edits::EditRequestBuilder; use crate::image_edit::{ImageEditRequestBuilder, ImageFile}; + use crate::image_variation::ImageVariationRequestBuilder; fn get_api() -> anyhow::Result { Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) @@ -144,4 +146,30 @@ mod tests { } } } + + #[tokio::test] + async fn test_image_variation() { + let ctx = get_api(); + assert!(ctx.is_ok(), "Could not load context"); + let ctx = ctx.unwrap(); + + let image = ctx.create_image_variation( + ImageVariationRequestBuilder::default() + .image(ImageFile::File(File::open("clown_original.png").await.unwrap())) + .build() + .unwrap() + ).await; + + assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err()); + assert!(image.as_ref().unwrap().data.len() == 1, "No image found"); + assert!(matches!(image.as_ref().unwrap().data[0], Image::URL(_)), "No image found"); + match image.unwrap().data[0] { + Image::URL(ref url) => { + println!("Generated image variation URL: {url}"); + } + Image::Base64(ref b64) => { + println!("Generated image variation Base64: {b64}"); + } + } + } } \ No newline at end of file