Implement image variations

This commit is contained in:
Gabriel Tofvesson 2023-03-18 03:04:49 +01:00
parent 9750d5279d
commit 9d8e858c1e
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
5 changed files with 90 additions and 7 deletions

BIN
clown_original.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

View File

@ -34,15 +34,21 @@ pub enum ImageSize {
Size1024, Size1024,
} }
impl ToString for ImageSize {
fn to_string(&self) -> String {
match self {
Self::Size256 => "256x256".to_string(),
Self::Size512 => "512x512".to_string(),
Self::Size1024 => "1024x1024".to_string(),
}
}
}
impl Serialize for ImageSize { impl Serialize for ImageSize {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
S: serde::Serializer { S: serde::Serializer {
match self { serializer.serialize_str(&self.to_string())
Self::Size256 => serializer.serialize_str("256x256"),
Self::Size512 => serializer.serialize_str("512x512"),
Self::Size1024 => serializer.serialize_str("1024x1024"),
}
} }
} }

View File

@ -1,7 +1,7 @@
use base64::{Engine, prelude::BASE64_STANDARD}; use base64::{Engine, prelude::BASE64_STANDARD};
use derive_builder::Builder; use derive_builder::Builder;
use reqwest::{Body, multipart::{Part, Form}, Client}; use reqwest::{Body, multipart::{Part, Form}, Client};
use crate::{image::{ResponseFormat, ImageResponse}, context::{API_URL, Context}}; use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}};
use tokio_util::codec::{BytesCodec, FramedRead}; use tokio_util::codec::{BytesCodec, FramedRead};
#[derive(Debug)] #[derive(Debug)]
@ -27,9 +27,11 @@ pub struct ImageEditRequest {
pub user: Option<String>, pub user: Option<String>,
#[builder(setter(into, strip_option), default)] #[builder(setter(into, strip_option), default)]
pub temperature: Option<f64>, pub temperature: Option<f64>,
#[builder(setter(into, strip_option), default)]
pub size: Option<ImageSize>,
} }
fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form { pub(crate) fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
let name = name.into(); let name = name.into();
match file { match file {
ImageFile::File(file) => ImageFile::File(file) =>
@ -62,6 +64,10 @@ impl Context {
if let Some(temperature) = req.temperature { if let Some(temperature) = req.temperature {
form = form.text("temperature", temperature.to_string()); form = form.text("temperature", temperature.to_string());
} }
if let Some(size) = req.size {
form = form.text("size", size.to_string());
}
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/edits")).multipart(form)).send().await?.json::<ImageResponse>().await?) Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/edits")).multipart(form)).send().await?.json::<ImageResponse>().await?)
} }

43
src/image_variation.rs Normal file
View File

@ -0,0 +1,43 @@
use derive_builder::Builder;
use reqwest::{multipart::Form, Client};
use crate::{image_edit::{ImageFile, write_file}, image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}};
#[derive(Debug, Builder)]
#[builder(pattern = "owned")]
pub struct ImageVariationRequest {
#[builder(setter(into))]
pub image: ImageFile,
#[builder(setter(into, strip_option), default)]
pub n: Option<u32>,
#[builder(setter(into, strip_option), default)]
pub size: Option<ImageSize>,
#[builder(setter(into, strip_option), default)]
pub user: Option<String>,
#[builder(setter(into, strip_option), default)]
pub response_format: Option<ResponseFormat>,
}
impl Context {
pub async fn create_image_variation(&self, req: ImageVariationRequest) -> anyhow::Result<ImageResponse> {
let mut form = Form::new();
form = write_file(form, req.image, "image");
if let Some(n) = req.n {
form = form.text("n", n.to_string());
}
if let Some(response_format) = req.response_format {
form = form.text("response_format", response_format.to_string());
}
if let Some(user) = req.user {
form = form.text("user", user);
}
if let Some(size) = req.size {
form = form.text("size", size.to_string());
}
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/variations")).multipart(form)).send().await?.json::<ImageResponse>().await?)
}
}

View File

@ -5,6 +5,7 @@ pub mod chat;
pub mod edits; pub mod edits;
pub mod image; pub mod image;
pub mod image_edit; pub mod image_edit;
pub mod image_variation;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -16,6 +17,7 @@ mod tests {
use crate::image::{Image, ResponseFormat, ImageRequestBuilder}; use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
use crate::edits::EditRequestBuilder; use crate::edits::EditRequestBuilder;
use crate::image_edit::{ImageEditRequestBuilder, ImageFile}; use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
use crate::image_variation::ImageVariationRequestBuilder;
fn get_api() -> anyhow::Result<Context> { fn get_api() -> anyhow::Result<Context> {
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
@ -144,4 +146,30 @@ mod tests {
} }
} }
} }
#[tokio::test]
async fn test_image_variation() {
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let ctx = ctx.unwrap();
let image = ctx.create_image_variation(
ImageVariationRequestBuilder::default()
.image(ImageFile::File(File::open("clown_original.png").await.unwrap()))
.build()
.unwrap()
).await;
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
assert!(matches!(image.as_ref().unwrap().data[0], Image::URL(_)), "No image found");
match image.unwrap().data[0] {
Image::URL(ref url) => {
println!("Generated image variation URL: {url}");
}
Image::Base64(ref b64) => {
println!("Generated image variation Base64: {b64}");
}
}
}
} }