Implement file upload
This commit is contained in:
parent
bcaa49c4c6
commit
82bfadc2fe
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -544,7 +544,9 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"base64",
|
"base64",
|
||||||
|
"bytes",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
|
"futures-core",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
@ -8,7 +8,9 @@ edition = "2021"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.69"
|
anyhow = "1.0.69"
|
||||||
base64 = "0.21.0"
|
base64 = "0.21.0"
|
||||||
|
bytes = "1.4.0"
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
|
futures-core = "0.3.27"
|
||||||
reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
|
reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
|
||||||
serde = { version = "1.0.156", features = ["derive"] }
|
serde = { version = "1.0.156", features = ["derive"] }
|
||||||
serde_json = "1.0.94"
|
serde_json = "1.0.94"
|
||||||
|
52
src/file.rs
Normal file
52
src/file.rs
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
use bytes::Bytes;
|
||||||
|
use reqwest::{Client, multipart::Form};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use crate::{context::{API_URL, Context}, util::{DataList, FileResource}};
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct FileInfo {
|
||||||
|
pub id: String,
|
||||||
|
/* pub object: "file", */
|
||||||
|
pub bytes: u64,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub filename: String,
|
||||||
|
pub purpose: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct FileDeleteResponse {
|
||||||
|
pub id: String,
|
||||||
|
/* pub object: "file", */
|
||||||
|
pub deleted: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Context {
|
||||||
|
pub async fn get_files(&self) -> anyhow::Result<Vec<FileInfo>> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files"))).send().await?.json::<DataList<FileInfo>>().await?.data)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn upload_file(&self, file: FileResource, file_name: String, purpose: String) -> anyhow::Result<FileInfo> {
|
||||||
|
Ok(
|
||||||
|
self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/files")))
|
||||||
|
.multipart(file.write_file_named(Form::new().text("purpose", purpose), "file", file_name))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json::<FileInfo>()
|
||||||
|
.await?
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_file(&self, file_id: &str) -> anyhow::Result<FileDeleteResponse> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.delete(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.json::<FileDeleteResponse>().await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_file(&self, file_id: &str) -> anyhow::Result<impl futures_core::Stream<Item = reqwest::Result<Bytes>>> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.bytes_stream())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_file_direct(&self, file_id: &str) -> anyhow::Result<Bytes> {
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.bytes().await?)
|
||||||
|
}
|
||||||
|
}
|
@ -1,22 +1,14 @@
|
|||||||
use base64::{Engine, prelude::BASE64_STANDARD};
|
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use reqwest::{Body, multipart::{Part, Form}, Client};
|
use reqwest::{multipart::Form, Client};
|
||||||
use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}};
|
use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}, util::FileResource};
|
||||||
use tokio_util::codec::{BytesCodec, FramedRead};
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum ImageFile {
|
|
||||||
File(tokio::fs::File),
|
|
||||||
Data(Vec<u8>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Builder)]
|
#[derive(Debug, Builder)]
|
||||||
#[builder(pattern = "owned")]
|
#[builder(pattern = "owned")]
|
||||||
pub struct ImageEditRequest {
|
pub struct ImageEditRequest {
|
||||||
#[builder(setter(into))]
|
#[builder(setter(into))]
|
||||||
pub image: ImageFile,
|
pub image: FileResource,
|
||||||
#[builder(setter(into, strip_option), default)]
|
#[builder(setter(into, strip_option), default)]
|
||||||
pub mask: Option<ImageFile>,
|
pub mask: Option<FileResource>,
|
||||||
#[builder(setter(into))]
|
#[builder(setter(into))]
|
||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
#[builder(setter(into, strip_option), default)]
|
#[builder(setter(into, strip_option), default)]
|
||||||
@ -31,21 +23,11 @@ pub struct ImageEditRequest {
|
|||||||
pub size: Option<ImageSize>,
|
pub size: Option<ImageSize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
|
|
||||||
let name = name.into();
|
|
||||||
match file {
|
|
||||||
ImageFile::File(file) =>
|
|
||||||
form.part(name.clone(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(name)),
|
|
||||||
ImageFile::Data(data) =>
|
|
||||||
form.text(name, BASE64_STANDARD.encode(data.as_slice())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Context {
|
impl Context {
|
||||||
pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result<ImageResponse> {
|
pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result<ImageResponse> {
|
||||||
let mut form = Form::new();
|
let mut form = Form::new();
|
||||||
form = form.text("prompt", req.prompt);
|
form = form.text("prompt", req.prompt);
|
||||||
form = write_file(form, req.image, "image");
|
form = req.image.write_file(form, "image");
|
||||||
|
|
||||||
if let Some(n) = req.n {
|
if let Some(n) = req.n {
|
||||||
form = form.text("n", n.to_string());
|
form = form.text("n", n.to_string());
|
||||||
@ -58,7 +40,7 @@ impl Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(mask) = req.mask {
|
if let Some(mask) = req.mask {
|
||||||
form = write_file(form, mask, "mask");
|
form = mask.write_file(form, "mask");
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(temperature) = req.temperature {
|
if let Some(temperature) = req.temperature {
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use reqwest::{multipart::Form, Client};
|
use reqwest::{multipart::Form, Client};
|
||||||
|
|
||||||
use crate::{image_edit::{ImageFile, write_file}, image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}};
|
use crate::{image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}, util::FileResource};
|
||||||
|
|
||||||
#[derive(Debug, Builder)]
|
#[derive(Debug, Builder)]
|
||||||
#[builder(pattern = "owned")]
|
#[builder(pattern = "owned")]
|
||||||
pub struct ImageVariationRequest {
|
pub struct ImageVariationRequest {
|
||||||
#[builder(setter(into))]
|
#[builder(setter(into))]
|
||||||
pub image: ImageFile,
|
pub image: FileResource,
|
||||||
#[builder(setter(into, strip_option), default)]
|
#[builder(setter(into, strip_option), default)]
|
||||||
pub n: Option<u32>,
|
pub n: Option<u32>,
|
||||||
#[builder(setter(into, strip_option), default)]
|
#[builder(setter(into, strip_option), default)]
|
||||||
@ -22,7 +22,7 @@ pub struct ImageVariationRequest {
|
|||||||
impl Context {
|
impl Context {
|
||||||
pub async fn create_image_variation(&self, req: ImageVariationRequest) -> anyhow::Result<ImageResponse> {
|
pub async fn create_image_variation(&self, req: ImageVariationRequest) -> anyhow::Result<ImageResponse> {
|
||||||
let mut form = Form::new();
|
let mut form = Form::new();
|
||||||
form = write_file(form, req.image, "image");
|
form = req.image.write_file(form, "image");
|
||||||
|
|
||||||
if let Some(n) = req.n {
|
if let Some(n) = req.n {
|
||||||
form = form.text("n", n.to_string());
|
form = form.text("n", n.to_string());
|
||||||
|
@ -9,6 +9,9 @@ pub mod image_variation;
|
|||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod transcription;
|
pub mod transcription;
|
||||||
pub mod translation;
|
pub mod translation;
|
||||||
|
pub mod file;
|
||||||
|
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@ -19,7 +22,7 @@ mod tests {
|
|||||||
use crate::completion::CompletionRequestBuilder;
|
use crate::completion::CompletionRequestBuilder;
|
||||||
use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
|
use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
|
||||||
use crate::edits::EditRequestBuilder;
|
use crate::edits::EditRequestBuilder;
|
||||||
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
|
use crate::image_edit::ImageEditRequestBuilder;
|
||||||
use crate::image_variation::ImageVariationRequestBuilder;
|
use crate::image_variation::ImageVariationRequestBuilder;
|
||||||
use crate::embedding::EmbeddingRequestBuilder;
|
use crate::embedding::EmbeddingRequestBuilder;
|
||||||
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
|
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
|
||||||
@ -132,7 +135,7 @@ mod tests {
|
|||||||
|
|
||||||
let image = ctx.create_image_edit(
|
let image = ctx.create_image_edit(
|
||||||
ImageEditRequestBuilder::default()
|
ImageEditRequestBuilder::default()
|
||||||
.image(ImageFile::File(File::open("clown.png").await.unwrap()))
|
.image(File::open("clown.png").await.unwrap())
|
||||||
.prompt("Blue nose")
|
.prompt("Blue nose")
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -160,7 +163,7 @@ mod tests {
|
|||||||
|
|
||||||
let image = ctx.create_image_variation(
|
let image = ctx.create_image_variation(
|
||||||
ImageVariationRequestBuilder::default()
|
ImageVariationRequestBuilder::default()
|
||||||
.image(ImageFile::File(File::open("clown_original.png").await.unwrap()))
|
.image(File::open("clown_original.png").await.unwrap())
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
).await;
|
).await;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::context::{API_URL, Context};
|
use crate::{context::{API_URL, Context}, util::DataList};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct Permission {
|
pub struct Permission {
|
||||||
@ -29,14 +29,9 @@ pub struct Model {
|
|||||||
pub parent: Option<String>,
|
pub parent: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub(crate) struct ModelList {
|
|
||||||
pub data: Vec<Model>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Context {
|
impl Context {
|
||||||
pub async fn get_models(&self) -> anyhow::Result<Vec<Model>> {
|
pub async fn get_models(&self) -> anyhow::Result<Vec<Model>> {
|
||||||
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<ModelList>().await?.data)
|
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<DataList<Model>>().await?.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_model(&self, model_id: &str) -> anyhow::Result<Model> {
|
pub async fn get_model(&self, model_id: &str) -> anyhow::Result<Model> {
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use reqwest::{multipart::{Form, Part}, Body, Client};
|
use reqwest::{multipart::Form, Client};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use tokio::fs::File;
|
use tokio::fs::File;
|
||||||
use tokio_util::codec::{FramedRead, BytesCodec};
|
|
||||||
|
|
||||||
use crate::context::{API_URL, Context};
|
use crate::{context::{API_URL, Context}, util::FileResource};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum AudioResponseFormat {
|
pub enum AudioResponseFormat {
|
||||||
@ -87,7 +86,7 @@ impl Context {
|
|||||||
pub async fn create_transcription(&self, req: TranscriptionRequest) -> anyhow::Result<TranscriptionResponse> {
|
pub async fn create_transcription(&self, req: TranscriptionRequest) -> anyhow::Result<TranscriptionResponse> {
|
||||||
let mut form = Form::new();
|
let mut form = Form::new();
|
||||||
let file_name = req.file.file_name();
|
let file_name = req.file.file_name();
|
||||||
form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name));
|
form = FileResource::from(req.file.file()).write_file_named(form, "file", file_name);
|
||||||
form = form.text("model", req.model);
|
form = form.text("model", req.model);
|
||||||
|
|
||||||
if let Some(response_format) = req.response_format {
|
if let Some(response_format) = req.response_format {
|
||||||
|
56
src/util.rs
Normal file
56
src/util.rs
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||||
|
use reqwest::{multipart::{Form, Part}, Body};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio_util::codec::{FramedRead, BytesCodec};
|
||||||
|
|
||||||
|
pub struct DataList<T: for<'d> Deserialize<'d>> {
|
||||||
|
pub data: Vec<T>,
|
||||||
|
/* pub object: "list", */
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de, T: for<'d> Deserialize<'d>> Deserialize<'de> for DataList<T> {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
Ok(Self {
|
||||||
|
data: Vec::<T>::deserialize(deserializer)?
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum FileResource {
|
||||||
|
File(tokio::fs::File),
|
||||||
|
Data(Vec<u8>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FileResource {
|
||||||
|
pub(crate) fn write_file_named(self, form: Form, part_name: impl Into<String>, file_name: impl Into<String>) -> Form {
|
||||||
|
match self {
|
||||||
|
FileResource::File(file) =>
|
||||||
|
form.part(part_name.into(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(file_name.into())),
|
||||||
|
FileResource::Data(data) =>
|
||||||
|
form.part(part_name.into(), Part::bytes(BASE64_STANDARD.encode(data.as_slice()).as_bytes().to_owned()).file_name(file_name.into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn write_file(self, form: Form, name: impl Into<String>) -> Form {
|
||||||
|
let name = name.into();
|
||||||
|
self.write_file_named(form, name.clone(), name)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<tokio::fs::File> for FileResource {
|
||||||
|
fn from(file: tokio::fs::File) -> Self {
|
||||||
|
Self::File(file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Vec<u8>> for FileResource {
|
||||||
|
fn from(data: Vec<u8>) -> Self {
|
||||||
|
Self::Data(data)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user