From 82bfadc2fe7e9e26e3d27bd5ff098535a7e7a086 Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Sat, 18 Mar 2023 12:16:30 +0100 Subject: [PATCH] Implement file upload --- Cargo.lock | 2 ++ Cargo.toml | 2 ++ src/file.rs | 52 +++++++++++++++++++++++++++++++++++++++ src/image_edit.rs | 30 +++++----------------- src/image_variation.rs | 6 ++--- src/lib.rs | 9 ++++--- src/model.rs | 9 ++----- src/transcription.rs | 7 +++--- src/util.rs | 56 ++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 132 insertions(+), 41 deletions(-) create mode 100644 src/file.rs create mode 100644 src/util.rs diff --git a/Cargo.lock b/Cargo.lock index dfeda87..e53ce38 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -544,7 +544,9 @@ version = "0.1.0" dependencies = [ "anyhow", "base64", + "bytes", "derive_builder", + "futures-core", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index d83279c..1687f16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,9 @@ edition = "2021" [dependencies] anyhow = "1.0.69" base64 = "0.21.0" +bytes = "1.4.0" derive_builder = "0.12.0" +futures-core = "0.3.27" reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] } serde = { version = "1.0.156", features = ["derive"] } serde_json = "1.0.94" diff --git a/src/file.rs b/src/file.rs new file mode 100644 index 0000000..1bc2de2 --- /dev/null +++ b/src/file.rs @@ -0,0 +1,52 @@ +use bytes::Bytes; +use reqwest::{Client, multipart::Form}; +use serde::Deserialize; + +use crate::{context::{API_URL, Context}, util::{DataList, FileResource}}; + +#[derive(Debug, Deserialize)] +pub struct FileInfo { + pub id: String, + /* pub object: "file", */ + pub bytes: u64, + pub created_at: u64, + pub filename: String, + pub purpose: String, +} + +#[derive(Debug, Deserialize)] +pub struct FileDeleteResponse { + pub id: String, + /* pub object: "file", */ + pub deleted: bool, +} + +impl Context { + pub async fn get_files(&self) -> anyhow::Result> { + Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files"))).send().await?.json::>().await?.data) + + } + + pub async fn upload_file(&self, file: FileResource, file_name: String, purpose: String) -> anyhow::Result { + Ok( + self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/files"))) + .multipart(file.write_file_named(Form::new().text("purpose", purpose), "file", file_name)) + .send() + .await? + .json::() + .await? + ) + } + + pub async fn delete_file(&self, file_id: &str) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.delete(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.json::().await?) + } + + pub async fn get_file(&self, file_id: &str) -> anyhow::Result>> { + Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.bytes_stream()) + } + + pub async fn get_file_direct(&self, file_id: &str) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.bytes().await?) + } +} \ No newline at end of file diff --git a/src/image_edit.rs b/src/image_edit.rs index 1166a5b..4491ee4 100644 --- a/src/image_edit.rs +++ b/src/image_edit.rs @@ -1,22 +1,14 @@ -use base64::{Engine, prelude::BASE64_STANDARD}; use derive_builder::Builder; -use reqwest::{Body, multipart::{Part, Form}, Client}; -use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}}; -use tokio_util::codec::{BytesCodec, FramedRead}; - -#[derive(Debug)] -pub enum ImageFile { - File(tokio::fs::File), - Data(Vec), -} +use reqwest::{multipart::Form, Client}; +use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}, util::FileResource}; #[derive(Debug, Builder)] #[builder(pattern = "owned")] pub struct ImageEditRequest { #[builder(setter(into))] - pub image: ImageFile, + pub image: FileResource, #[builder(setter(into, strip_option), default)] - pub mask: Option, + pub mask: Option, #[builder(setter(into))] pub prompt: String, #[builder(setter(into, strip_option), default)] @@ -31,21 +23,11 @@ pub struct ImageEditRequest { pub size: Option, } -pub(crate) fn write_file(form: Form, file: ImageFile, name: impl Into) -> Form { - let name = name.into(); - match file { - ImageFile::File(file) => - form.part(name.clone(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(name)), - ImageFile::Data(data) => - form.text(name, BASE64_STANDARD.encode(data.as_slice())), - } -} - impl Context { pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result { let mut form = Form::new(); form = form.text("prompt", req.prompt); - form = write_file(form, req.image, "image"); + form = req.image.write_file(form, "image"); if let Some(n) = req.n { form = form.text("n", n.to_string()); @@ -58,7 +40,7 @@ impl Context { } if let Some(mask) = req.mask { - form = write_file(form, mask, "mask"); + form = mask.write_file(form, "mask"); } if let Some(temperature) = req.temperature { diff --git a/src/image_variation.rs b/src/image_variation.rs index 6addda8..a84d48d 100644 --- a/src/image_variation.rs +++ b/src/image_variation.rs @@ -1,13 +1,13 @@ use derive_builder::Builder; use reqwest::{multipart::Form, Client}; -use crate::{image_edit::{ImageFile, write_file}, image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}}; +use crate::{image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}, util::FileResource}; #[derive(Debug, Builder)] #[builder(pattern = "owned")] pub struct ImageVariationRequest { #[builder(setter(into))] - pub image: ImageFile, + pub image: FileResource, #[builder(setter(into, strip_option), default)] pub n: Option, #[builder(setter(into, strip_option), default)] @@ -22,7 +22,7 @@ pub struct ImageVariationRequest { impl Context { pub async fn create_image_variation(&self, req: ImageVariationRequest) -> anyhow::Result { let mut form = Form::new(); - form = write_file(form, req.image, "image"); + form = req.image.write_file(form, "image"); if let Some(n) = req.n { form = form.text("n", n.to_string()); diff --git a/src/lib.rs b/src/lib.rs index 86c4ea9..c7b2090 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,9 @@ pub mod image_variation; pub mod embedding; pub mod transcription; pub mod translation; +pub mod file; + +pub mod util; #[cfg(test)] mod tests { @@ -19,7 +22,7 @@ mod tests { use crate::completion::CompletionRequestBuilder; use crate::image::{Image, ResponseFormat, ImageRequestBuilder}; use crate::edits::EditRequestBuilder; - use crate::image_edit::{ImageEditRequestBuilder, ImageFile}; + use crate::image_edit::ImageEditRequestBuilder; use crate::image_variation::ImageVariationRequestBuilder; use crate::embedding::EmbeddingRequestBuilder; use crate::transcription::{TranscriptionRequestBuilder, AudioFile}; @@ -132,7 +135,7 @@ mod tests { let image = ctx.create_image_edit( ImageEditRequestBuilder::default() - .image(ImageFile::File(File::open("clown.png").await.unwrap())) + .image(File::open("clown.png").await.unwrap()) .prompt("Blue nose") .build() .unwrap() @@ -160,7 +163,7 @@ mod tests { let image = ctx.create_image_variation( ImageVariationRequestBuilder::default() - .image(ImageFile::File(File::open("clown_original.png").await.unwrap())) + .image(File::open("clown_original.png").await.unwrap()) .build() .unwrap() ).await; diff --git a/src/model.rs b/src/model.rs index 8ae287b..4b79fe5 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,7 +1,7 @@ use reqwest::Client; use serde::Deserialize; -use crate::context::{API_URL, Context}; +use crate::{context::{API_URL, Context}, util::DataList}; #[derive(Debug, Deserialize)] pub struct Permission { @@ -29,14 +29,9 @@ pub struct Model { pub parent: Option, } -#[derive(Debug, Deserialize)] -pub(crate) struct ModelList { - pub data: Vec, -} - impl Context { pub async fn get_models(&self) -> anyhow::Result> { - Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::().await?.data) + Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::>().await?.data) } pub async fn get_model(&self, model_id: &str) -> anyhow::Result { diff --git a/src/transcription.rs b/src/transcription.rs index 6bc4eaa..8490a73 100644 --- a/src/transcription.rs +++ b/src/transcription.rs @@ -1,10 +1,9 @@ use derive_builder::Builder; -use reqwest::{multipart::{Form, Part}, Body, Client}; +use reqwest::{multipart::Form, Client}; use serde::Deserialize; use tokio::fs::File; -use tokio_util::codec::{FramedRead, BytesCodec}; -use crate::context::{API_URL, Context}; +use crate::{context::{API_URL, Context}, util::FileResource}; #[derive(Debug, Clone)] pub enum AudioResponseFormat { @@ -87,7 +86,7 @@ impl Context { pub async fn create_transcription(&self, req: TranscriptionRequest) -> anyhow::Result { let mut form = Form::new(); let file_name = req.file.file_name(); - form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name)); + form = FileResource::from(req.file.file()).write_file_named(form, "file", file_name); form = form.text("model", req.model); if let Some(response_format) = req.response_format { diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..a30bae7 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,56 @@ +use base64::{prelude::BASE64_STANDARD, Engine}; +use reqwest::{multipart::{Form, Part}, Body}; +use serde::Deserialize; +use tokio_util::codec::{FramedRead, BytesCodec}; + +pub struct DataList Deserialize<'d>> { + pub data: Vec, + /* pub object: "list", */ +} + +impl<'de, T: for<'d> Deserialize<'d>> Deserialize<'de> for DataList { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(Self { + data: Vec::::deserialize(deserializer)? + }) + } +} + + +#[derive(Debug)] +pub enum FileResource { + File(tokio::fs::File), + Data(Vec), +} + +impl FileResource { + pub(crate) fn write_file_named(self, form: Form, part_name: impl Into, file_name: impl Into) -> Form { + match self { + FileResource::File(file) => + form.part(part_name.into(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(file_name.into())), + FileResource::Data(data) => + form.part(part_name.into(), Part::bytes(BASE64_STANDARD.encode(data.as_slice()).as_bytes().to_owned()).file_name(file_name.into())), + } + } + + pub(crate) fn write_file(self, form: Form, name: impl Into) -> Form { + let name = name.into(); + self.write_file_named(form, name.clone(), name) + } + +} + +impl From for FileResource { + fn from(file: tokio::fs::File) -> Self { + Self::File(file) + } +} + +impl From> for FileResource { + fn from(data: Vec) -> Self { + Self::Data(data) + } +} \ No newline at end of file