Implement image generation

This commit is contained in:
Gabriel Tofvesson 2023-03-17 21:53:26 +01:00
parent 0be700b0b8
commit d5a3bc2520
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
3 changed files with 131 additions and 1 deletions

View File

@ -53,4 +53,8 @@ impl Context {
pub async fn create_edit(&self, edit_request: EditRequest) -> anyhow::Result<crate::edits::EditResponse> {
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/edits")).json(&edit_request)).send().await?.json::<EditResponse>().await?)
}
pub async fn create_image(&self, image_request: crate::image::ImageRequest) -> anyhow::Result<crate::image::ImageResponse> {
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/generations")).json(&image_request)).send().await?.json::<crate::image::ImageResponse>().await?)
}
}

103
src/image.rs Normal file
View File

@ -0,0 +1,103 @@
use derive_builder::Builder;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone)]
pub enum ResponseFormat {
URL,
Base64,
}
impl Serialize for ResponseFormat {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer {
match self {
Self::URL => serializer.serialize_str("url"),
Self::Base64 => serializer.serialize_str("b64_json"),
}
}
}
#[derive(Debug, Clone)]
pub enum ImageSize {
Size256,
Size512,
Size1024,
}
impl Serialize for ImageSize {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer {
match self {
Self::Size256 => serializer.serialize_str("256x256"),
Self::Size512 => serializer.serialize_str("512x512"),
Self::Size1024 => serializer.serialize_str("1024x1024"),
}
}
}
#[derive(Debug, Serialize, Builder)]
pub struct ImageRequest {
#[builder(setter(into))]
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub size: Option<ImageSize>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub user: Option<String>,
}
#[derive(Debug, Deserialize)]
struct RawImage {
pub(crate) url: Option<String>,
pub(crate) b64_json: Option<String>,
}
#[derive(Debug)]
pub enum Image {
URL(String),
Base64(String),
}
impl Image {
pub fn isURL(&self) -> bool {
return match self {
Self::URL(_) => true,
_ => false,
}
}
pub fn isBase64(&self) -> bool {
return match self {
Self::Base64(_) => true,
_ => false
}
}
}
impl<'de> Deserialize<'de> for Image {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de> {
let raw = RawImage::deserialize(deserializer)?;
match (raw.url, raw.b64_json) {
(Some(url), None) => Ok(Self::URL(url)),
(None, Some(b64)) => Ok(Self::Base64(b64)),
_ => Err(serde::de::Error::custom("Invalid image")),
}
}
}
#[derive(Debug, Deserialize)]
pub struct ImageResponse {
pub created: u64,
pub data: Vec<Image>,
}

View File

@ -3,12 +3,14 @@ pub mod model;
pub mod completion;
pub mod chat;
pub mod edits;
pub mod image;
#[cfg(test)]
mod tests {
use crate::chat::ChatMessage;
use crate::context::Context;
use crate::completion::CompletionRequestBuilder;
use crate::image::Image;
fn get_api() -> anyhow::Result<Context> {
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
@ -92,8 +94,29 @@ mod tests {
assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err());
assert!(edit.as_ref().unwrap().choices.len() == 1, "No edit found");
// This one might be pushing my luck a bit, but it seems to work
//assert!(edit.unwrap().choices[0].text.replace("\n", "").eq("Ik hou van jouw moeder"));
}
#[tokio::test]
async fn test_image() {
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let ctx = ctx.unwrap();
let image = ctx.create_image(
crate::image::ImageRequestBuilder::default()
.prompt("In a realistic style, a ginger cat gracefully walking along a thin brick wall")
.build()
.unwrap()
).await;
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
assert!(image.as_ref().unwrap().data[0].isURL(), "No image found");
if let Image::URL(url) = &image.as_ref().unwrap().data[0] {
println!("{}", url);
}
}
}