Implement image generation
This commit is contained in:
parent
0be700b0b8
commit
d5a3bc2520
@ -53,4 +53,8 @@ impl Context {
|
||||
pub async fn create_edit(&self, edit_request: EditRequest) -> anyhow::Result<crate::edits::EditResponse> {
|
||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/edits")).json(&edit_request)).send().await?.json::<EditResponse>().await?)
|
||||
}
|
||||
|
||||
pub async fn create_image(&self, image_request: crate::image::ImageRequest) -> anyhow::Result<crate::image::ImageResponse> {
|
||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/generations")).json(&image_request)).send().await?.json::<crate::image::ImageResponse>().await?)
|
||||
}
|
||||
}
|
103
src/image.rs
Normal file
103
src/image.rs
Normal file
@ -0,0 +1,103 @@
|
||||
use derive_builder::Builder;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ResponseFormat {
|
||||
URL,
|
||||
Base64,
|
||||
}
|
||||
|
||||
impl Serialize for ResponseFormat {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer {
|
||||
match self {
|
||||
Self::URL => serializer.serialize_str("url"),
|
||||
Self::Base64 => serializer.serialize_str("b64_json"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ImageSize {
|
||||
Size256,
|
||||
Size512,
|
||||
Size1024,
|
||||
}
|
||||
|
||||
impl Serialize for ImageSize {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer {
|
||||
match self {
|
||||
Self::Size256 => serializer.serialize_str("256x256"),
|
||||
Self::Size512 => serializer.serialize_str("512x512"),
|
||||
Self::Size1024 => serializer.serialize_str("1024x1024"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Builder)]
|
||||
pub struct ImageRequest {
|
||||
#[builder(setter(into))]
|
||||
pub prompt: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub n: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub size: Option<ImageSize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RawImage {
|
||||
pub(crate) url: Option<String>,
|
||||
pub(crate) b64_json: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Image {
|
||||
URL(String),
|
||||
Base64(String),
|
||||
}
|
||||
|
||||
impl Image {
|
||||
pub fn isURL(&self) -> bool {
|
||||
return match self {
|
||||
Self::URL(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn isBase64(&self) -> bool {
|
||||
return match self {
|
||||
Self::Base64(_) => true,
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Image {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de> {
|
||||
let raw = RawImage::deserialize(deserializer)?;
|
||||
match (raw.url, raw.b64_json) {
|
||||
(Some(url), None) => Ok(Self::URL(url)),
|
||||
(None, Some(b64)) => Ok(Self::Base64(b64)),
|
||||
_ => Err(serde::de::Error::custom("Invalid image")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ImageResponse {
|
||||
pub created: u64,
|
||||
pub data: Vec<Image>,
|
||||
}
|
25
src/lib.rs
25
src/lib.rs
@ -3,12 +3,14 @@ pub mod model;
|
||||
pub mod completion;
|
||||
pub mod chat;
|
||||
pub mod edits;
|
||||
pub mod image;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::chat::ChatMessage;
|
||||
use crate::context::Context;
|
||||
use crate::completion::CompletionRequestBuilder;
|
||||
use crate::image::Image;
|
||||
|
||||
fn get_api() -> anyhow::Result<Context> {
|
||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||
@ -92,8 +94,29 @@ mod tests {
|
||||
|
||||
assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err());
|
||||
assert!(edit.as_ref().unwrap().choices.len() == 1, "No edit found");
|
||||
|
||||
|
||||
// This one might be pushing my luck a bit, but it seems to work
|
||||
//assert!(edit.unwrap().choices[0].text.replace("\n", "").eq("Ik hou van jouw moeder"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_image() {
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
let ctx = ctx.unwrap();
|
||||
|
||||
let image = ctx.create_image(
|
||||
crate::image::ImageRequestBuilder::default()
|
||||
.prompt("In a realistic style, a ginger cat gracefully walking along a thin brick wall")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
|
||||
assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
|
||||
assert!(image.as_ref().unwrap().data[0].isURL(), "No image found");
|
||||
if let Image::URL(url) = &image.as_ref().unwrap().data[0] {
|
||||
println!("{}", url);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user