Implement image edits
This commit is contained in:
parent
48f07d3201
commit
b04f2ea7d2
39
Cargo.lock
generated
39
Cargo.lock
generated
@ -216,6 +216,23 @@ version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.27"
|
||||
@ -235,9 +252,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -521,11 +543,13 @@ name = "openai_rs"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64",
|
||||
"derive_builder",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -677,10 +701,12 @@ dependencies = [
|
||||
"serde_urlencoded",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"winreg",
|
||||
]
|
||||
@ -1087,6 +1113,19 @@ version = "0.2.84"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-streams"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.61"
|
||||
|
@ -7,8 +7,10 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.69"
|
||||
base64 = "0.21.0"
|
||||
derive_builder = "0.12.0"
|
||||
reqwest = { version = "0.11.14", features = [ "json", "multipart" ] }
|
||||
reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
|
||||
serde = { version = "1.0.156", features = ["derive"] }
|
||||
serde_json = "1.0.94"
|
||||
tokio = { version = "1.26.0", features = [ "full" ] }
|
||||
tokio-util = { version = "0.7.7", features = [ "codec" ] }
|
||||
|
BIN
clown.png
Executable file
BIN
clown.png
Executable file
Binary file not shown.
After Width: | Height: | Size: 1.4 MiB |
@ -7,7 +7,7 @@ pub struct Context {
|
||||
org_id: Option<String>
|
||||
}
|
||||
|
||||
const API_URL: &str = "https://api.openai.com";
|
||||
pub(crate) const API_URL: &str = "https://api.openai.com";
|
||||
|
||||
impl Context {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
@ -24,7 +24,7 @@ impl Context {
|
||||
}
|
||||
}
|
||||
|
||||
fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder {
|
||||
pub(crate) fn with_auth(&self, builder: RequestBuilder) -> RequestBuilder {
|
||||
(
|
||||
if let Some(ref org_id) = self.org_id {
|
||||
builder.header("OpenAI-Organization", org_id)
|
||||
|
17
src/image.rs
17
src/image.rs
@ -7,14 +7,20 @@ pub enum ResponseFormat {
|
||||
Base64,
|
||||
}
|
||||
|
||||
impl ToString for ResponseFormat {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Self::URL => "url".to_string(),
|
||||
Self::Base64 => "b64_json".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for ResponseFormat {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer {
|
||||
match self {
|
||||
Self::URL => serializer.serialize_str("url"),
|
||||
Self::Base64 => serializer.serialize_str("b64_json"),
|
||||
}
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,6 +59,9 @@ pub struct ImageRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub user: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub temperature: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
68
src/image_edit.rs
Normal file
68
src/image_edit.rs
Normal file
@ -0,0 +1,68 @@
|
||||
use base64::{Engine, prelude::BASE64_STANDARD};
|
||||
use derive_builder::Builder;
|
||||
use reqwest::{Body, multipart::{Part, Form}, Client};
|
||||
use crate::{image::{ResponseFormat, ImageResponse}, context::{API_URL, Context}};
|
||||
use tokio_util::codec::{BytesCodec, FramedRead};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ImageFile {
|
||||
File(tokio::fs::File),
|
||||
Data(Vec<u8>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Builder)]
|
||||
#[builder(pattern = "owned")]
|
||||
pub struct ImageEditRequest {
|
||||
#[builder(setter(into))]
|
||||
pub image: ImageFile,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub mask: Option<ImageFile>,
|
||||
#[builder(setter(into))]
|
||||
pub prompt: String,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub n: Option<u32>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub user: Option<String>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub temperature: Option<f64>,
|
||||
}
|
||||
|
||||
fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
|
||||
let name = name.into();
|
||||
match file {
|
||||
ImageFile::File(file) =>
|
||||
form.part(name.clone(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(name)),
|
||||
ImageFile::Data(data) =>
|
||||
form.text(name, BASE64_STANDARD.encode(data.as_slice())),
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result<ImageResponse> {
|
||||
let mut form = Form::new();
|
||||
form = form.text("prompt", req.prompt);
|
||||
form = write_file(form, req.image, "image");
|
||||
|
||||
if let Some(n) = req.n {
|
||||
form = form.text("n", n.to_string());
|
||||
}
|
||||
if let Some(response_format) = req.response_format {
|
||||
form = form.text("response_format", response_format.to_string());
|
||||
}
|
||||
if let Some(user) = req.user {
|
||||
form = form.text("user", user);
|
||||
}
|
||||
|
||||
if let Some(mask) = req.mask {
|
||||
form = write_file(form, mask, "mask");
|
||||
}
|
||||
|
||||
if let Some(temperature) = req.temperature {
|
||||
form = form.text("temperature", temperature.to_string());
|
||||
}
|
||||
|
||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/edits")).multipart(form)).send().await?.json::<ImageResponse>().await?)
|
||||
}
|
||||
}
|
80
src/lib.rs
80
src/lib.rs
@ -4,13 +4,18 @@ pub mod completion;
|
||||
pub mod chat;
|
||||
pub mod edits;
|
||||
pub mod image;
|
||||
pub mod image_edit;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::chat::ChatMessage;
|
||||
use tokio::fs::File;
|
||||
|
||||
use crate::chat::{ChatHistoryBuilder, ChatMessage, Role};
|
||||
use crate::context::Context;
|
||||
use crate::completion::CompletionRequestBuilder;
|
||||
use crate::image::{Image, ResponseFormat};
|
||||
use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
|
||||
use crate::edits::EditRequestBuilder;
|
||||
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
|
||||
|
||||
fn get_api() -> anyhow::Result<Context> {
|
||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||
@ -34,10 +39,10 @@ mod tests {
|
||||
|
||||
let completion = ctx.unwrap().create_completion(
|
||||
CompletionRequestBuilder::default()
|
||||
.model("text-davinci-003")
|
||||
.prompt("Say 'this is a test'")
|
||||
.build()
|
||||
.unwrap()
|
||||
.model("text-davinci-003")
|
||||
.prompt("Say 'this is a test'")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
||||
@ -50,11 +55,11 @@ mod tests {
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
|
||||
let completion = ctx.unwrap().create_chat_completion(
|
||||
crate::chat::ChatHistoryBuilder::default()
|
||||
.messages(vec![ChatMessage::new(crate::chat::Role::User, "Respond to this message with 'this is a test'")])
|
||||
.model("gpt-3.5-turbo")
|
||||
.build()
|
||||
.unwrap()
|
||||
ChatHistoryBuilder::default()
|
||||
.messages(vec![ChatMessage::new(Role::User, "Respond to this message with 'this is a test'")])
|
||||
.model("gpt-3.5-turbo")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(completion.is_ok(), "Could not get completion: {}", completion.unwrap_err());
|
||||
@ -69,12 +74,12 @@ mod tests {
|
||||
|
||||
// Autocorrect English spelling errors
|
||||
let edit = ctx.create_edit(
|
||||
crate::edits::EditRequestBuilder::default()
|
||||
.model("text-davinci-edit-001")
|
||||
.instruction("Correct all spelling mistakes")
|
||||
.input("What a wnoderful day!")
|
||||
.build()
|
||||
.unwrap()
|
||||
EditRequestBuilder::default()
|
||||
.model("text-davinci-edit-001")
|
||||
.instruction("Correct all spelling mistakes")
|
||||
.input("What a wnoderful day!")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(edit.is_ok(), "Could not get edit: {}", edit.unwrap_err());
|
||||
@ -84,17 +89,17 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_image() {
|
||||
const IMAGE_PROMPT: &str = "In a realistic style, a ginger cat gracefully walking along a thin brick wall";
|
||||
const IMAGE_PROMPT: &str = "A real ginger cat gracefully walking along a real, thin brick wall";
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
let ctx = ctx.unwrap();
|
||||
|
||||
let image = ctx.create_image(
|
||||
crate::image::ImageRequestBuilder::default()
|
||||
.prompt(IMAGE_PROMPT)
|
||||
.response_format(ResponseFormat::URL)
|
||||
.build()
|
||||
.unwrap()
|
||||
ImageRequestBuilder::default()
|
||||
.prompt(IMAGE_PROMPT)
|
||||
.response_format(ResponseFormat::URL)
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
|
||||
@ -110,4 +115,33 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_image_edit() {
|
||||
const IMAGE_PROMPT: &str = "A real ginger cat gracefully walking along a real, thin brick wall";
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
let ctx = ctx.unwrap();
|
||||
|
||||
let image = ctx.create_image_edit(
|
||||
ImageEditRequestBuilder::default()
|
||||
.image(ImageFile::File(File::open("clown.png").await.unwrap()))
|
||||
.prompt("Blue nose")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
|
||||
assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
|
||||
assert!(matches!(image.as_ref().unwrap().data[0], Image::URL(_)), "No image found");
|
||||
println!("Image prompt: {IMAGE_PROMPT}");
|
||||
match image.unwrap().data[0] {
|
||||
Image::URL(ref url) => {
|
||||
println!("Generated edited image URL: {url}");
|
||||
}
|
||||
Image::Base64(ref b64) => {
|
||||
println!("Generated edited image Base64: {b64}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user