From d5a3bc25207883482359f5f622ceedfa65742b46 Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Fri, 17 Mar 2023 21:53:26 +0100 Subject: [PATCH] Implement image generation --- src/context.rs | 4 ++ src/image.rs | 103 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 25 +++++++++++- 3 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 src/image.rs diff --git a/src/context.rs b/src/context.rs index f891cc0..fb4468f 100644 --- a/src/context.rs +++ b/src/context.rs @@ -53,4 +53,8 @@ impl Context { pub async fn create_edit(&self, edit_request: EditRequest) -> anyhow::Result { Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/edits")).json(&edit_request)).send().await?.json::().await?) } + + pub async fn create_image(&self, image_request: crate::image::ImageRequest) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/generations")).json(&image_request)).send().await?.json::().await?) + } } \ No newline at end of file diff --git a/src/image.rs b/src/image.rs new file mode 100644 index 0000000..45416cc --- /dev/null +++ b/src/image.rs @@ -0,0 +1,103 @@ +use derive_builder::Builder; +use serde::{Serialize, Deserialize}; + +#[derive(Debug, Clone)] +pub enum ResponseFormat { + URL, + Base64, +} + +impl Serialize for ResponseFormat { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer { + match self { + Self::URL => serializer.serialize_str("url"), + Self::Base64 => serializer.serialize_str("b64_json"), + } + } +} + +#[derive(Debug, Clone)] +pub enum ImageSize { + Size256, + Size512, + Size1024, +} + +impl Serialize for ImageSize { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer { + match self { + Self::Size256 => serializer.serialize_str("256x256"), + Self::Size512 => serializer.serialize_str("512x512"), + Self::Size1024 => serializer.serialize_str("1024x1024"), + } + } +} + +#[derive(Debug, Serialize, Builder)] +pub struct ImageRequest { + #[builder(setter(into))] + pub prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(setter(into, strip_option), default)] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(setter(into, strip_option), default)] + pub size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(setter(into, strip_option), default)] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(setter(into, strip_option), default)] + pub user: Option, +} + +#[derive(Debug, Deserialize)] +struct RawImage { + pub(crate) url: Option, + pub(crate) b64_json: Option, +} + +#[derive(Debug)] +pub enum Image { + URL(String), + Base64(String), +} + +impl Image { + pub fn isURL(&self) -> bool { + return match self { + Self::URL(_) => true, + _ => false, + } + } + + pub fn isBase64(&self) -> bool { + return match self { + Self::Base64(_) => true, + _ => false + } + } +} + +impl<'de> Deserialize<'de> for Image { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + let raw = RawImage::deserialize(deserializer)?; + match (raw.url, raw.b64_json) { + (Some(url), None) => Ok(Self::URL(url)), + (None, Some(b64)) => Ok(Self::Base64(b64)), + _ => Err(serde::de::Error::custom("Invalid image")), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ImageResponse { + pub created: u64, + pub data: Vec, +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 9dd3d9f..03f68e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,12 +3,14 @@ pub mod model; pub mod completion; pub mod chat; pub mod edits; +pub mod image; #[cfg(test)] mod tests { use crate::chat::ChatMessage; use crate::context::Context; use crate::completion::CompletionRequestBuilder; + use crate::image::Image; fn get_api() -> anyhow::Result { Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) @@ -92,8 +94,29 @@ mod tests { assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err()); assert!(edit.as_ref().unwrap().choices.len() == 1, "No edit found"); - + // This one might be pushing my luck a bit, but it seems to work //assert!(edit.unwrap().choices[0].text.replace("\n", "").eq("Ik hou van jouw moeder")); } + + #[tokio::test] + async fn test_image() { + let ctx = get_api(); + assert!(ctx.is_ok(), "Could not load context"); + let ctx = ctx.unwrap(); + + let image = ctx.create_image( + crate::image::ImageRequestBuilder::default() + .prompt("In a realistic style, a ginger cat gracefully walking along a thin brick wall") + .build() + .unwrap() + ).await; + + assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err()); + assert!(image.as_ref().unwrap().data.len() == 1, "No image found"); + assert!(image.as_ref().unwrap().data[0].isURL(), "No image found"); + if let Image::URL(url) = &image.as_ref().unwrap().data[0] { + println!("{}", url); + } + } } \ No newline at end of file