Implement file upload
This commit is contained in:
parent
bcaa49c4c6
commit
82bfadc2fe
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -544,7 +544,9 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64",
|
||||
"bytes",
|
||||
"derive_builder",
|
||||
"futures-core",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -8,7 +8,9 @@ edition = "2021"
|
||||
[dependencies]
|
||||
anyhow = "1.0.69"
|
||||
base64 = "0.21.0"
|
||||
bytes = "1.4.0"
|
||||
derive_builder = "0.12.0"
|
||||
futures-core = "0.3.27"
|
||||
reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
|
||||
serde = { version = "1.0.156", features = ["derive"] }
|
||||
serde_json = "1.0.94"
|
||||
|
52
src/file.rs
Normal file
52
src/file.rs
Normal file
@ -0,0 +1,52 @@
|
||||
use bytes::Bytes;
|
||||
use reqwest::{Client, multipart::Form};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{context::{API_URL, Context}, util::{DataList, FileResource}};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FileInfo {
|
||||
pub id: String,
|
||||
/* pub object: "file", */
|
||||
pub bytes: u64,
|
||||
pub created_at: u64,
|
||||
pub filename: String,
|
||||
pub purpose: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FileDeleteResponse {
|
||||
pub id: String,
|
||||
/* pub object: "file", */
|
||||
pub deleted: bool,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub async fn get_files(&self) -> anyhow::Result<Vec<FileInfo>> {
|
||||
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files"))).send().await?.json::<DataList<FileInfo>>().await?.data)
|
||||
|
||||
}
|
||||
|
||||
pub async fn upload_file(&self, file: FileResource, file_name: String, purpose: String) -> anyhow::Result<FileInfo> {
|
||||
Ok(
|
||||
self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/files")))
|
||||
.multipart(file.write_file_named(Form::new().text("purpose", purpose), "file", file_name))
|
||||
.send()
|
||||
.await?
|
||||
.json::<FileInfo>()
|
||||
.await?
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn delete_file(&self, file_id: &str) -> anyhow::Result<FileDeleteResponse> {
|
||||
Ok(self.with_auth(Client::builder().build()?.delete(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.json::<FileDeleteResponse>().await?)
|
||||
}
|
||||
|
||||
pub async fn get_file(&self, file_id: &str) -> anyhow::Result<impl futures_core::Stream<Item = reqwest::Result<Bytes>>> {
|
||||
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.bytes_stream())
|
||||
}
|
||||
|
||||
pub async fn get_file_direct(&self, file_id: &str) -> anyhow::Result<Bytes> {
|
||||
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.bytes().await?)
|
||||
}
|
||||
}
|
@ -1,22 +1,14 @@
|
||||
use base64::{Engine, prelude::BASE64_STANDARD};
|
||||
use derive_builder::Builder;
|
||||
use reqwest::{Body, multipart::{Part, Form}, Client};
|
||||
use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}};
|
||||
use tokio_util::codec::{BytesCodec, FramedRead};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ImageFile {
|
||||
File(tokio::fs::File),
|
||||
Data(Vec<u8>),
|
||||
}
|
||||
use reqwest::{multipart::Form, Client};
|
||||
use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}, util::FileResource};
|
||||
|
||||
#[derive(Debug, Builder)]
|
||||
#[builder(pattern = "owned")]
|
||||
pub struct ImageEditRequest {
|
||||
#[builder(setter(into))]
|
||||
pub image: ImageFile,
|
||||
pub image: FileResource,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub mask: Option<ImageFile>,
|
||||
pub mask: Option<FileResource>,
|
||||
#[builder(setter(into))]
|
||||
pub prompt: String,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
@ -31,21 +23,11 @@ pub struct ImageEditRequest {
|
||||
pub size: Option<ImageSize>,
|
||||
}
|
||||
|
||||
pub(crate) fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
|
||||
let name = name.into();
|
||||
match file {
|
||||
ImageFile::File(file) =>
|
||||
form.part(name.clone(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(name)),
|
||||
ImageFile::Data(data) =>
|
||||
form.text(name, BASE64_STANDARD.encode(data.as_slice())),
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result<ImageResponse> {
|
||||
let mut form = Form::new();
|
||||
form = form.text("prompt", req.prompt);
|
||||
form = write_file(form, req.image, "image");
|
||||
form = req.image.write_file(form, "image");
|
||||
|
||||
if let Some(n) = req.n {
|
||||
form = form.text("n", n.to_string());
|
||||
@ -58,7 +40,7 @@ impl Context {
|
||||
}
|
||||
|
||||
if let Some(mask) = req.mask {
|
||||
form = write_file(form, mask, "mask");
|
||||
form = mask.write_file(form, "mask");
|
||||
}
|
||||
|
||||
if let Some(temperature) = req.temperature {
|
||||
|
@ -1,13 +1,13 @@
|
||||
use derive_builder::Builder;
|
||||
use reqwest::{multipart::Form, Client};
|
||||
|
||||
use crate::{image_edit::{ImageFile, write_file}, image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}};
|
||||
use crate::{image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}, util::FileResource};
|
||||
|
||||
#[derive(Debug, Builder)]
|
||||
#[builder(pattern = "owned")]
|
||||
pub struct ImageVariationRequest {
|
||||
#[builder(setter(into))]
|
||||
pub image: ImageFile,
|
||||
pub image: FileResource,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub n: Option<u32>,
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
@ -22,7 +22,7 @@ pub struct ImageVariationRequest {
|
||||
impl Context {
|
||||
pub async fn create_image_variation(&self, req: ImageVariationRequest) -> anyhow::Result<ImageResponse> {
|
||||
let mut form = Form::new();
|
||||
form = write_file(form, req.image, "image");
|
||||
form = req.image.write_file(form, "image");
|
||||
|
||||
if let Some(n) = req.n {
|
||||
form = form.text("n", n.to_string());
|
||||
|
@ -9,6 +9,9 @@ pub mod image_variation;
|
||||
pub mod embedding;
|
||||
pub mod transcription;
|
||||
pub mod translation;
|
||||
pub mod file;
|
||||
|
||||
pub mod util;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@ -19,7 +22,7 @@ mod tests {
|
||||
use crate::completion::CompletionRequestBuilder;
|
||||
use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
|
||||
use crate::edits::EditRequestBuilder;
|
||||
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
|
||||
use crate::image_edit::ImageEditRequestBuilder;
|
||||
use crate::image_variation::ImageVariationRequestBuilder;
|
||||
use crate::embedding::EmbeddingRequestBuilder;
|
||||
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
|
||||
@ -132,7 +135,7 @@ mod tests {
|
||||
|
||||
let image = ctx.create_image_edit(
|
||||
ImageEditRequestBuilder::default()
|
||||
.image(ImageFile::File(File::open("clown.png").await.unwrap()))
|
||||
.image(File::open("clown.png").await.unwrap())
|
||||
.prompt("Blue nose")
|
||||
.build()
|
||||
.unwrap()
|
||||
@ -160,7 +163,7 @@ mod tests {
|
||||
|
||||
let image = ctx.create_image_variation(
|
||||
ImageVariationRequestBuilder::default()
|
||||
.image(ImageFile::File(File::open("clown_original.png").await.unwrap()))
|
||||
.image(File::open("clown_original.png").await.unwrap())
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
@ -1,7 +1,7 @@
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::context::{API_URL, Context};
|
||||
use crate::{context::{API_URL, Context}, util::DataList};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Permission {
|
||||
@ -29,14 +29,9 @@ pub struct Model {
|
||||
pub parent: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ModelList {
|
||||
pub data: Vec<Model>,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub async fn get_models(&self) -> anyhow::Result<Vec<Model>> {
|
||||
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<ModelList>().await?.data)
|
||||
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<DataList<Model>>().await?.data)
|
||||
}
|
||||
|
||||
pub async fn get_model(&self, model_id: &str) -> anyhow::Result<Model> {
|
||||
|
@ -1,10 +1,9 @@
|
||||
use derive_builder::Builder;
|
||||
use reqwest::{multipart::{Form, Part}, Body, Client};
|
||||
use reqwest::{multipart::Form, Client};
|
||||
use serde::Deserialize;
|
||||
use tokio::fs::File;
|
||||
use tokio_util::codec::{FramedRead, BytesCodec};
|
||||
|
||||
use crate::context::{API_URL, Context};
|
||||
use crate::{context::{API_URL, Context}, util::FileResource};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AudioResponseFormat {
|
||||
@ -87,7 +86,7 @@ impl Context {
|
||||
pub async fn create_transcription(&self, req: TranscriptionRequest) -> anyhow::Result<TranscriptionResponse> {
|
||||
let mut form = Form::new();
|
||||
let file_name = req.file.file_name();
|
||||
form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name));
|
||||
form = FileResource::from(req.file.file()).write_file_named(form, "file", file_name);
|
||||
form = form.text("model", req.model);
|
||||
|
||||
if let Some(response_format) = req.response_format {
|
||||
|
56
src/util.rs
Normal file
56
src/util.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use reqwest::{multipart::{Form, Part}, Body};
|
||||
use serde::Deserialize;
|
||||
use tokio_util::codec::{FramedRead, BytesCodec};
|
||||
|
||||
pub struct DataList<T: for<'d> Deserialize<'d>> {
|
||||
pub data: Vec<T>,
|
||||
/* pub object: "list", */
|
||||
}
|
||||
|
||||
impl<'de, T: for<'d> Deserialize<'d>> Deserialize<'de> for DataList<T> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
Ok(Self {
|
||||
data: Vec::<T>::deserialize(deserializer)?
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FileResource {
|
||||
File(tokio::fs::File),
|
||||
Data(Vec<u8>),
|
||||
}
|
||||
|
||||
impl FileResource {
|
||||
pub(crate) fn write_file_named(self, form: Form, part_name: impl Into<String>, file_name: impl Into<String>) -> Form {
|
||||
match self {
|
||||
FileResource::File(file) =>
|
||||
form.part(part_name.into(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(file_name.into())),
|
||||
FileResource::Data(data) =>
|
||||
form.part(part_name.into(), Part::bytes(BASE64_STANDARD.encode(data.as_slice()).as_bytes().to_owned()).file_name(file_name.into())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn write_file(self, form: Form, name: impl Into<String>) -> Form {
|
||||
let name = name.into();
|
||||
self.write_file_named(form, name.clone(), name)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl From<tokio::fs::File> for FileResource {
|
||||
fn from(file: tokio::fs::File) -> Self {
|
||||
Self::File(file)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for FileResource {
|
||||
fn from(data: Vec<u8>) -> Self {
|
||||
Self::Data(data)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user