Implement image edits
This commit is contained in:
parent
48f07d3201
commit
b04f2ea7d2
39
Cargo.lock
generated
39
Cargo.lock
generated
@ -216,6 +216,23 @@ version = "0.3.27"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
|
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-io"
|
||||||
|
version = "0.3.27"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.27"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.27"
|
version = "0.3.27"
|
||||||
@ -235,9 +252,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
|
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
|
"futures-macro",
|
||||||
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
|
"memchr",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -521,11 +543,13 @@ name = "openai_rs"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"base64",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -677,10 +701,12 @@ dependencies = [
|
|||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-native-tls",
|
"tokio-native-tls",
|
||||||
|
"tokio-util",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"url",
|
"url",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"wasm-bindgen-futures",
|
"wasm-bindgen-futures",
|
||||||
|
"wasm-streams",
|
||||||
"web-sys",
|
"web-sys",
|
||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
@ -1087,6 +1113,19 @@ version = "0.2.84"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d"
|
checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-streams"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"js-sys",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"wasm-bindgen-futures",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "web-sys"
|
name = "web-sys"
|
||||||
version = "0.3.61"
|
version = "0.3.61"
|
||||||
|
@ -7,8 +7,10 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.69"
|
anyhow = "1.0.69"
|
||||||
|
base64 = "0.21.0"
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
reqwest = { version = "0.11.14", features = [ "json", "multipart" ] }
|
reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
|
||||||
serde = { version = "1.0.156", features = ["derive"] }
|
serde = { version = "1.0.156", features = ["derive"] }
|
||||||
serde_json = "1.0.94"
|
serde_json = "1.0.94"
|
||||||
tokio = { version = "1.26.0", features = [ "full" ] }
|
tokio = { version = "1.26.0", features = [ "full" ] }
|
||||||
|
tokio-util = { version = "0.7.7", features = [ "codec" ] }
|
||||||
|
BIN
clown.png
Executable file
BIN
clown.png
Executable file
Binary file not shown.
After Width: | Height: | Size: 1.4 MiB |
@ -7,7 +7,7 @@ pub struct Context {
|
|||||||
org_id: Option<String>
|
org_id: Option<String>
|
||||||
}
|
}
|
||||||
|
|
||||||
const API_URL: &str = "https://api.openai.com";
|
pub(crate) const API_URL: &str = "https://api.openai.com";
|
||||||
|
|
||||||
impl Context {
|
impl Context {
|
||||||
pub fn new(api_key: String) -> Self {
|
pub fn new(api_key: String) -> Self {
|
||||||
@ -24,7 +24,7 @@ impl Context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder {
|
pub(crate) fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder {
|
||||||
(
|
(
|
||||||
if let Some(ref org_id) = self.org_id {
|
if let Some(ref org_id) = self.org_id {
|
||||||
builder.header("OpenAI-Organization", org_id)
|
builder.header("OpenAI-Organization", org_id)
|
||||||
|
17
src/image.rs
17
src/image.rs
@ -7,14 +7,20 @@ pub enum ResponseFormat {
|
|||||||
Base64,
|
Base64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToString for ResponseFormat {
|
||||||
|
fn to_string(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::URL => "url".to_string(),
|
||||||
|
Self::Base64 => "b64_json".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Serialize for ResponseFormat {
|
impl Serialize for ResponseFormat {
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
S: serde::Serializer {
|
S: serde::Serializer {
|
||||||
match self {
|
serializer.serialize_str(&self.to_string())
|
||||||
Self::URL => serializer.serialize_str("url"),
|
|
||||||
Self::Base64 => serializer.serialize_str("b64_json"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,6 +59,9 @@ pub struct ImageRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
#[builder(setter(into, strip_option), default)]
|
#[builder(setter(into, strip_option), default)]
|
||||||
pub user: Option<String>,
|
pub user: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
68
src/image_edit.rs
Normal file
68
src/image_edit.rs
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
use base64::{Engine, prelude::BASE64_STANDARD};
|
||||||
|
use derive_builder::Builder;
|
||||||
|
use reqwest::{Body, multipart::{Part, Form}, Client};
|
||||||
|
use crate::{image::{ResponseFormat, ImageResponse}, context::{API_URL, Context}};
|
||||||
|
use tokio_util::codec::{BytesCodec, FramedRead};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ImageFile {
|
||||||
|
File(tokio::fs::File),
|
||||||
|
Data(Vec<u8>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Builder)]
|
||||||
|
#[builder(pattern = "owned")]
|
||||||
|
pub struct ImageEditRequest {
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub image: ImageFile,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub mask: Option<ImageFile>,
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub prompt: String,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub user: Option<String>,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
|
||||||
|
let name = name.into();
|
||||||
|
match file {
|
||||||
|
ImageFile::File(file) =>
|
||||||
|
form.part(name.clone(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(name)),
|
||||||
|
ImageFile::Data(data) =>
|
||||||
|
form.text(name, BASE64_STANDARD.encode(data.as_slice())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Context {
|
||||||
|
pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result<ImageResponse> {
|
||||||
|
let mut form = Form::new();
|
||||||
|
form = form.text("prompt", req.prompt);
|
||||||
|
form = write_file(form, req.image, "image");
|
||||||
|
|
||||||
|
if let Some(n) = req.n {
|
||||||
|
form = form.text("n", n.to_string());
|
||||||
|
}
|
||||||
|
if let Some(response_format) = req.response_format {
|
||||||
|
form = form.text("response_format", response_format.to_string());
|
||||||
|
}
|
||||||
|
if let Some(user) = req.user {
|
||||||
|
form = form.text("user", user);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(mask) = req.mask {
|
||||||
|
form = write_file(form, mask, "mask");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(temperature) = req.temperature {
|
||||||
|
form = form.text("temperature", temperature.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/edits")).multipart(form)).send().await?.json::<ImageResponse>().await?)
|
||||||
|
}
|
||||||
|
}
|
80
src/lib.rs
80
src/lib.rs
@ -4,13 +4,18 @@ pub mod completion;
|
|||||||
pub mod chat;
|
pub mod chat;
|
||||||
pub mod edits;
|
pub mod edits;
|
||||||
pub mod image;
|
pub mod image;
|
||||||
|
pub mod image_edit;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::chat::ChatMessage;
|
use tokio::fs::File;
|
||||||
|
|
||||||
|
use crate::chat::{ChatHistoryBuilder, ChatMessage, Role};
|
||||||
use crate::context::Context;
|
use crate::context::Context;
|
||||||
use crate::completion::CompletionRequestBuilder;
|
use crate::completion::CompletionRequestBuilder;
|
||||||
use crate::image::{Image, ResponseFormat};
|
use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
|
||||||
|
use crate::edits::EditRequestBuilder;
|
||||||
|
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
|
||||||
|
|
||||||
fn get_api() -> anyhow::Result<Context> {
|
fn get_api() -> anyhow::Result<Context> {
|
||||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||||
@ -34,10 +39,10 @@ mod tests {
|
|||||||
|
|
||||||
let completion = ctx.unwrap().create_completion(
|
let completion = ctx.unwrap().create_completion(
|
||||||
CompletionRequestBuilder::default()
|
CompletionRequestBuilder::default()
|
||||||
.model("text-davinci-003")
|
.model("text-davinci-003")
|
||||||
.prompt("Say 'this is a test'")
|
.prompt("Say 'this is a test'")
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
||||||
@ -50,11 +55,11 @@ mod tests {
|
|||||||
assert!(ctx.is_ok(), "Could not load context");
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
|
||||||
let completion = ctx.unwrap().create_chat_completion(
|
let completion = ctx.unwrap().create_chat_completion(
|
||||||
crate::chat::ChatHistoryBuilder::default()
|
ChatHistoryBuilder::default()
|
||||||
.messages(vec![ChatMessage::new(crate::chat::Role::User, "Respond to this message with 'this is a test'")])
|
.messages(vec![ChatMessage::new(Role::User, "Respond to this message with 'this is a test'")])
|
||||||
.model("gpt-3.5-turbo")
|
.model("gpt-3.5-turbo")
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
||||||
@ -69,12 +74,12 @@ mod tests {
|
|||||||
|
|
||||||
// Autocorrect English spelling errors
|
// Autocorrect English spelling errors
|
||||||
let edit = ctx.create_edit(
|
let edit = ctx.create_edit(
|
||||||
crate::edits::EditRequestBuilder::default()
|
EditRequestBuilder::default()
|
||||||
.model("text-davinci-edit-001")
|
.model("text-davinci-edit-001")
|
||||||
.instruction("Correct all spelling mistakes")
|
.instruction("Correct all spelling mistakes")
|
||||||
.input("What a wnoderful day!")
|
.input("What a wnoderful day!")
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err());
|
assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err());
|
||||||
@ -84,17 +89,17 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_image() {
|
async fn test_image() {
|
||||||
const IMAGE_PROMPT: &str = "In a realistic style, a ginger cat gracefully walking along a thin brick wall";
|
const IMAGE_PROMPT: &str = "A real ginger cat gracefully walking along a real, thin brick wall";
|
||||||
let ctx = get_api();
|
let ctx = get_api();
|
||||||
assert!(ctx.is_ok(), "Could not load context");
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
let ctx = ctx.unwrap();
|
let ctx = ctx.unwrap();
|
||||||
|
|
||||||
let image = ctx.create_image(
|
let image = ctx.create_image(
|
||||||
crate::image::ImageRequestBuilder::default()
|
ImageRequestBuilder::default()
|
||||||
.prompt(IMAGE_PROMPT)
|
.prompt(IMAGE_PROMPT)
|
||||||
.response_format(ResponseFormat::URL)
|
.response_format(ResponseFormat::URL)
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
|
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
|
||||||
@ -110,4 +115,33 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_image_edit() {
|
||||||
|
const IMAGE_PROMPT: &str = "A real ginger cat gracefully walking along a real, thin brick wall";
|
||||||
|
let ctx = get_api();
|
||||||
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
let ctx = ctx.unwrap();
|
||||||
|
|
||||||
|
let image = ctx.create_image_edit(
|
||||||
|
ImageEditRequestBuilder::default()
|
||||||
|
.image(ImageFile::File(File::open("clown.png").await.unwrap()))
|
||||||
|
.prompt("Blue nose")
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
|
||||||
|
assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
|
||||||
|
assert!(matches!(image.as_ref().unwrap().data[0], Image::URL(_)), "No image found");
|
||||||
|
println!("Image prompt: {IMAGE_PROMPT}");
|
||||||
|
match image.unwrap().data[0] {
|
||||||
|
Image::URL(ref url) => {
|
||||||
|
println!("Generated edited image URL: {url}");
|
||||||
|
}
|
||||||
|
Image::Base64(ref b64) => {
|
||||||
|
println!("Generated edited image Base64: {b64}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user