Implement image variations
This commit is contained in:
parent
9750d5279d
commit
9d8e858c1e
BIN
clown_original.png
Normal file
BIN
clown_original.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.0 MiB |
16
src/image.rs
16
src/image.rs
@ -34,15 +34,21 @@ pub enum ImageSize {
|
||||
Size1024,
|
||||
}
|
||||
|
||||
impl ToString for ImageSize {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Self::Size256 => "256x256".to_string(),
|
||||
Self::Size512 => "512x512".to_string(),
|
||||
Self::Size1024 => "1024x1024".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for ImageSize {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer {
|
||||
match self {
|
||||
Self::Size256 => serializer.serialize_str("256x256"),
|
||||
Self::Size512 => serializer.serialize_str("512x512"),
|
||||
Self::Size1024 => serializer.serialize_str("1024x1024"),
|
||||
}
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use base64::{Engine, prelude::BASE64_STANDARD};
|
||||
use derive_builder::Builder;
|
||||
use reqwest::{Body, multipart::{Part, Form}, Client};
|
||||
use crate::{image::{ResponseFormat, ImageResponse}, context::{API_URL, Context}};
|
||||
use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}};
|
||||
use tokio_util::codec::{BytesCodec, FramedRead};
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -27,9 +27,11 @@ pub struct ImageEditRequest {
|
||||
pub user: Option<String>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub temperature: Option<f64>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub size: Option<ImageSize>,
|
||||
}
|
||||
|
||||
fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
|
||||
pub(crate) fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
|
||||
let name = name.into();
|
||||
match file {
|
||||
ImageFile::File(file) =>
|
||||
@ -62,6 +64,10 @@ impl Context {
|
||||
if let Some(temperature) = req.temperature {
|
||||
form = form.text("temperature", temperature.to_string());
|
||||
}
|
||||
|
||||
if let Some(size) = req.size {
|
||||
form = form.text("size", size.to_string());
|
||||
}
|
||||
|
||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/edits")).multipart(form)).send().await?.json::<ImageResponse>().await?)
|
||||
}
|
||||
|
43
src/image_variation.rs
Normal file
43
src/image_variation.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use derive_builder::Builder;
|
||||
use reqwest::{multipart::Form, Client};
|
||||
|
||||
use crate::{image_edit::{ImageFile, write_file}, image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}};
|
||||
|
||||
#[derive(Debug, Builder)]
|
||||
#[builder(pattern = "owned")]
|
||||
pub struct ImageVariationRequest {
|
||||
#[builder(setter(into))]
|
||||
pub image: ImageFile,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub n: Option<u32>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub size: Option<ImageSize>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub user: Option<String>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
}
|
||||
|
||||
|
||||
impl Context {
|
||||
pub async fn create_image_variation(&self, req: ImageVariationRequest) -> anyhow::Result<ImageResponse> {
|
||||
let mut form = Form::new();
|
||||
form = write_file(form, req.image, "image");
|
||||
|
||||
if let Some(n) = req.n {
|
||||
form = form.text("n", n.to_string());
|
||||
}
|
||||
if let Some(response_format) = req.response_format {
|
||||
form = form.text("response_format", response_format.to_string());
|
||||
}
|
||||
if let Some(user) = req.user {
|
||||
form = form.text("user", user);
|
||||
}
|
||||
|
||||
if let Some(size) = req.size {
|
||||
form = form.text("size", size.to_string());
|
||||
}
|
||||
|
||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/variations")).multipart(form)).send().await?.json::<ImageResponse>().await?)
|
||||
}
|
||||
}
|
28
src/lib.rs
28
src/lib.rs
@ -5,6 +5,7 @@ pub mod chat;
|
||||
pub mod edits;
|
||||
pub mod image;
|
||||
pub mod image_edit;
|
||||
pub mod image_variation;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@ -16,6 +17,7 @@ mod tests {
|
||||
use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
|
||||
use crate::edits::EditRequestBuilder;
|
||||
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
|
||||
use crate::image_variation::ImageVariationRequestBuilder;
|
||||
|
||||
fn get_api() -> anyhow::Result<Context> {
|
||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||
@ -144,4 +146,30 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_image_variation() {
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
let ctx = ctx.unwrap();
|
||||
|
||||
let image = ctx.create_image_variation(
|
||||
ImageVariationRequestBuilder::default()
|
||||
.image(ImageFile::File(File::open("clown_original.png").await.unwrap()))
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
|
||||
assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
|
||||
assert!(matches!(image.as_ref().unwrap().data[0], Image::URL(_)), "No image found");
|
||||
match image.unwrap().data[0] {
|
||||
Image::URL(ref url) => {
|
||||
println!("Generated image variation URL: {url}");
|
||||
}
|
||||
Image::Base64(ref b64) => {
|
||||
println!("Generated image variation Base64: {b64}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user