Implement image edits

This commit is contained in:
Gabriel Tofvesson 2023-03-18 02:28:03 +01:00
parent 48f07d3201
commit b04f2ea7d2
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
7 changed files with 182 additions and 30 deletions

39
Cargo.lock generated
View File

@ -216,6 +216,23 @@ version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
[[package]]
name = "futures-io"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
[[package]]
name = "futures-macro"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.27" version = "0.3.27"
@ -235,9 +252,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task", "futures-task",
"memchr",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
"slab",
] ]
[[package]] [[package]]
@ -521,11 +543,13 @@ name = "openai_rs"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"base64",
"derive_builder", "derive_builder",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"tokio", "tokio",
"tokio-util",
] ]
[[package]] [[package]]
@ -677,10 +701,12 @@ dependencies = [
"serde_urlencoded", "serde_urlencoded",
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
"tokio-util",
"tower-service", "tower-service",
"url", "url",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"wasm-streams",
"web-sys", "web-sys",
"winreg", "winreg",
] ]
@ -1087,6 +1113,19 @@ version = "0.2.84"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d"
[[package]]
name = "wasm-streams"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]] [[package]]
name = "web-sys" name = "web-sys"
version = "0.3.61" version = "0.3.61"

View File

@ -7,8 +7,10 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1.0.69" anyhow = "1.0.69"
base64 = "0.21.0"
derive_builder = "0.12.0" derive_builder = "0.12.0"
reqwest = { version = "0.11.14", features = [ "json", "multipart" ] } reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
serde = { version = "1.0.156", features = ["derive"] } serde = { version = "1.0.156", features = ["derive"] }
serde_json = "1.0.94" serde_json = "1.0.94"
tokio = { version = "1.26.0", features = [ "full" ] } tokio = { version = "1.26.0", features = [ "full" ] }
tokio-util = { version = "0.7.7", features = [ "codec" ] }

BIN
clown.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@ -7,7 +7,7 @@ pub struct Context {
org_id: Option<String> org_id: Option<String>
} }
const API_URL: &str = "https://api.openai.com"; pub(crate) const API_URL: &str = "https://api.openai.com";
impl Context { impl Context {
pub fn new(api_key: String) -> Self { pub fn new(api_key: String) -> Self {
@ -24,7 +24,7 @@ impl Context {
} }
} }
fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder { pub(crate) fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder {
( (
if let Some(ref org_id) = self.org_id { if let Some(ref org_id) = self.org_id {
builder.header("OpenAI-Organization", org_id) builder.header("OpenAI-Organization", org_id)

View File

@ -7,14 +7,20 @@ pub enum ResponseFormat {
Base64, Base64,
} }
impl ToString for ResponseFormat {
fn to_string(&self) -> String {
match self {
Self::URL => "url".to_string(),
Self::Base64 => "b64_json".to_string(),
}
}
}
impl Serialize for ResponseFormat { impl Serialize for ResponseFormat {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
S: serde::Serializer { S: serde::Serializer {
match self { serializer.serialize_str(&self.to_string())
Self::URL => serializer.serialize_str("url"),
Self::Base64 => serializer.serialize_str("b64_json"),
}
} }
} }
@ -53,6 +59,9 @@ pub struct ImageRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)] #[builder(setter(into, strip_option), default)]
pub user: Option<String>, pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub temperature: Option<f64>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]

68
src/image_edit.rs Normal file
View File

@ -0,0 +1,68 @@
use base64::{Engine, prelude::BASE64_STANDARD};
use derive_builder::Builder;
use reqwest::{Body, multipart::{Part, Form}, Client};
use crate::{image::{ResponseFormat, ImageResponse}, context::{API_URL, Context}};
use tokio_util::codec::{BytesCodec, FramedRead};
#[derive(Debug)]
pub enum ImageFile {
File(tokio::fs::File),
Data(Vec<u8>),
}
#[derive(Debug, Builder)]
#[builder(pattern = "owned")]
pub struct ImageEditRequest {
#[builder(setter(into))]
pub image: ImageFile,
#[builder(setter(into, strip_option), default)]
pub mask: Option<ImageFile>,
#[builder(setter(into))]
pub prompt: String,
#[builder(setter(into, strip_option), default)]
pub n: Option<u32>,
#[builder(setter(into, strip_option), default)]
pub response_format: Option<ResponseFormat>,
#[builder(setter(into, strip_option), default)]
pub user: Option<String>,
#[builder(setter(into, strip_option), default)]
pub temperature: Option<f64>,
}
fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
let name = name.into();
match file {
ImageFile::File(file) =>
form.part(name.clone(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(name)),
ImageFile::Data(data) =>
form.text(name, BASE64_STANDARD.encode(data.as_slice())),
}
}
impl Context {
pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result<ImageResponse> {
let mut form = Form::new();
form = form.text("prompt", req.prompt);
form = write_file(form, req.image, "image");
if let Some(n) = req.n {
form = form.text("n", n.to_string());
}
if let Some(response_format) = req.response_format {
form = form.text("response_format", response_format.to_string());
}
if let Some(user) = req.user {
form = form.text("user", user);
}
if let Some(mask) = req.mask {
form = write_file(form, mask, "mask");
}
if let Some(temperature) = req.temperature {
form = form.text("temperature", temperature.to_string());
}
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/edits")).multipart(form)).send().await?.json::<ImageResponse>().await?)
}
}

View File

@ -4,13 +4,18 @@ pub mod completion;
pub mod chat; pub mod chat;
pub mod edits; pub mod edits;
pub mod image; pub mod image;
pub mod image_edit;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::chat::ChatMessage; use tokio::fs::File;
use crate::chat::{ChatHistoryBuilder, ChatMessage, Role};
use crate::context::Context; use crate::context::Context;
use crate::completion::CompletionRequestBuilder; use crate::completion::CompletionRequestBuilder;
use crate::image::{Image, ResponseFormat}; use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
use crate::edits::EditRequestBuilder;
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
fn get_api() -> anyhow::Result<Context> { fn get_api() -> anyhow::Result<Context> {
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
@ -34,10 +39,10 @@ mod tests {
let completion = ctx.unwrap().create_completion( let completion = ctx.unwrap().create_completion(
CompletionRequestBuilder::default() CompletionRequestBuilder::default()
.model("text-davinci-003") .model("text-davinci-003")
.prompt("Say 'this is a test'") .prompt("Say 'this is a test'")
.build() .build()
.unwrap() .unwrap()
).await; ).await;
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err()); assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
@ -50,11 +55,11 @@ mod tests {
assert!(ctx.is_ok(), "Could not load context"); assert!(ctx.is_ok(), "Could not load context");
let completion = ctx.unwrap().create_chat_completion( let completion = ctx.unwrap().create_chat_completion(
crate::chat::ChatHistoryBuilder::default() ChatHistoryBuilder::default()
.messages(vec![ChatMessage::new(crate::chat::Role::User, "Respond to this message with 'this is a test'")]) .messages(vec![ChatMessage::new(Role::User, "Respond to this message with 'this is a test'")])
.model("gpt-3.5-turbo") .model("gpt-3.5-turbo")
.build() .build()
.unwrap() .unwrap()
).await; ).await;
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err()); assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
@ -69,12 +74,12 @@ mod tests {
// Autocorrect English spelling errors // Autocorrect English spelling errors
let edit = ctx.create_edit( let edit = ctx.create_edit(
crate::edits::EditRequestBuilder::default() EditRequestBuilder::default()
.model("text-davinci-edit-001") .model("text-davinci-edit-001")
.instruction("Correct all spelling mistakes") .instruction("Correct all spelling mistakes")
.input("What a wnoderful day!") .input("What a wnoderful day!")
.build() .build()
.unwrap() .unwrap()
).await; ).await;
assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err()); assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err());
@ -84,17 +89,17 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_image() { async fn test_image() {
const IMAGE_PROMPT: &str = "In a realistic style, a ginger cat gracefully walking along a thin brick wall"; const IMAGE_PROMPT: &str = "A real ginger cat gracefully walking along a real, thin brick wall";
let ctx = get_api(); let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context"); assert!(ctx.is_ok(), "Could not load context");
let ctx = ctx.unwrap(); let ctx = ctx.unwrap();
let image = ctx.create_image( let image = ctx.create_image(
crate::image::ImageRequestBuilder::default() ImageRequestBuilder::default()
.prompt(IMAGE_PROMPT) .prompt(IMAGE_PROMPT)
.response_format(ResponseFormat::URL) .response_format(ResponseFormat::URL)
.build() .build()
.unwrap() .unwrap()
).await; ).await;
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err()); assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
@ -110,4 +115,33 @@ mod tests {
} }
} }
} }
#[tokio::test]
async fn test_image_edit() {
const IMAGE_PROMPT: &str = "A real ginger cat gracefully walking along a real, thin brick wall";
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let ctx = ctx.unwrap();
let image = ctx.create_image_edit(
ImageEditRequestBuilder::default()
.image(ImageFile::File(File::open("clown.png").await.unwrap()))
.prompt("Blue nose")
.build()
.unwrap()
).await;
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
assert!(matches!(image.as_ref().unwrap().data[0], Image::URL(_)), "No image found");
println!("Image prompt: {IMAGE_PROMPT}");
match image.unwrap().data[0] {
Image::URL(ref url) => {
println!("Generated edited image URL: {url}");
}
Image::Base64(ref b64) => {
println!("Generated edited image Base64: {b64}");
}
}
}
} }