Implement image generation
This commit is contained in:
parent
0be700b0b8
commit
d5a3bc2520
@ -53,4 +53,8 @@ impl Context {
|
|||||||
pub async fn create_edit(&self, edit_request: EditRequest) -> anyhow::Result<crate::edits::EditResponse> {
|
pub async fn create_edit(&self, edit_request: EditRequest) -> anyhow::Result<crate::edits::EditResponse> {
|
||||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/edits")).json(&edit_request)).send().await?.json::<EditResponse>().await?)
|
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/edits")).json(&edit_request)).send().await?.json::<EditResponse>().await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn create_image(&self, image_request: crate::image::ImageRequest) -> anyhow::Result<crate::image::ImageResponse> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/generations")).json(&image_request)).send().await?.json::<crate::image::ImageResponse>().await?)
|
||||||
|
}
|
||||||
}
|
}
|
103
src/image.rs
Normal file
103
src/image.rs
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
use derive_builder::Builder;
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ResponseFormat {
|
||||||
|
URL,
|
||||||
|
Base64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for ResponseFormat {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer {
|
||||||
|
match self {
|
||||||
|
Self::URL => serializer.serialize_str("url"),
|
||||||
|
Self::Base64 => serializer.serialize_str("b64_json"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ImageSize {
|
||||||
|
Size256,
|
||||||
|
Size512,
|
||||||
|
Size1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for ImageSize {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer {
|
||||||
|
match self {
|
||||||
|
Self::Size256 => serializer.serialize_str("256x256"),
|
||||||
|
Self::Size512 => serializer.serialize_str("512x512"),
|
||||||
|
Self::Size1024 => serializer.serialize_str("1024x1024"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Builder)]
|
||||||
|
pub struct ImageRequest {
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub prompt: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub size: Option<ImageSize>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct RawImage {
|
||||||
|
pub(crate) url: Option<String>,
|
||||||
|
pub(crate) b64_json: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Image {
|
||||||
|
URL(String),
|
||||||
|
Base64(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Image {
|
||||||
|
pub fn isURL(&self) -> bool {
|
||||||
|
return match self {
|
||||||
|
Self::URL(_) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isBase64(&self) -> bool {
|
||||||
|
return match self {
|
||||||
|
Self::Base64(_) => true,
|
||||||
|
_ => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for Image {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de> {
|
||||||
|
let raw = RawImage::deserialize(deserializer)?;
|
||||||
|
match (raw.url, raw.b64_json) {
|
||||||
|
(Some(url), None) => Ok(Self::URL(url)),
|
||||||
|
(None, Some(b64)) => Ok(Self::Base64(b64)),
|
||||||
|
_ => Err(serde::de::Error::custom("Invalid image")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ImageResponse {
|
||||||
|
pub created: u64,
|
||||||
|
pub data: Vec<Image>,
|
||||||
|
}
|
23
src/lib.rs
23
src/lib.rs
@ -3,12 +3,14 @@ pub mod model;
|
|||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod chat;
|
pub mod chat;
|
||||||
pub mod edits;
|
pub mod edits;
|
||||||
|
pub mod image;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::chat::ChatMessage;
|
use crate::chat::ChatMessage;
|
||||||
use crate::context::Context;
|
use crate::context::Context;
|
||||||
use crate::completion::CompletionRequestBuilder;
|
use crate::completion::CompletionRequestBuilder;
|
||||||
|
use crate::image::Image;
|
||||||
|
|
||||||
fn get_api() -> anyhow::Result<Context> {
|
fn get_api() -> anyhow::Result<Context> {
|
||||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||||
@ -96,4 +98,25 @@ mod tests {
|
|||||||
// This one might be pushing my luck a bit, but it seems to work
|
// This one might be pushing my luck a bit, but it seems to work
|
||||||
//assert!(edit.unwrap().choices[0].text.replace("\n", "").eq("Ik hou van jouw moeder"));
|
//assert!(edit.unwrap().choices[0].text.replace("\n", "").eq("Ik hou van jouw moeder"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_image() {
|
||||||
|
let ctx = get_api();
|
||||||
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
let ctx = ctx.unwrap();
|
||||||
|
|
||||||
|
let image = ctx.create_image(
|
||||||
|
crate::image::ImageRequestBuilder::default()
|
||||||
|
.prompt("In a realistic style, a ginger cat gracefully walking along a thin brick wall")
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
|
||||||
|
assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
|
||||||
|
assert!(image.as_ref().unwrap().data[0].isURL(), "No image found");
|
||||||
|
if let Image::URL(url) = &image.as_ref().unwrap().data[0] {
|
||||||
|
println!("{}", url);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user