Implement file upload

This commit is contained in:
Gabriel Tofvesson 2023-03-18 12:16:30 +01:00
parent bcaa49c4c6
commit 82bfadc2fe
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
9 changed files with 132 additions and 41 deletions

2
Cargo.lock generated
View File

@ -544,7 +544,9 @@ version = "0.1.0"
dependencies = [
"anyhow",
"base64",
"bytes",
"derive_builder",
"futures-core",
"reqwest",
"serde",
"serde_json",

View File

@ -8,7 +8,9 @@ edition = "2021"
[dependencies]
anyhow = "1.0.69"
base64 = "0.21.0"
bytes = "1.4.0"
derive_builder = "0.12.0"
futures-core = "0.3.27"
reqwest = { version = "0.11.14", features = [ "json", "multipart", "stream" ] }
serde = { version = "1.0.156", features = ["derive"] }
serde_json = "1.0.94"

52
src/file.rs Normal file
View File

@ -0,0 +1,52 @@
use bytes::Bytes;
use reqwest::{Client, multipart::Form};
use serde::Deserialize;
use crate::{context::{API_URL, Context}, util::{DataList, FileResource}};
#[derive(Debug, Deserialize)]
pub struct FileInfo {
pub id: String,
/* pub object: "file", */
pub bytes: u64,
pub created_at: u64,
pub filename: String,
pub purpose: String,
}
#[derive(Debug, Deserialize)]
pub struct FileDeleteResponse {
pub id: String,
/* pub object: "file", */
pub deleted: bool,
}
impl Context {
pub async fn get_files(&self) -> anyhow::Result<Vec<FileInfo>> {
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files"))).send().await?.json::<DataList<FileInfo>>().await?.data)
}
pub async fn upload_file(&self, file: FileResource, file_name: String, purpose: String) -> anyhow::Result<FileInfo> {
Ok(
self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/files")))
.multipart(file.write_file_named(Form::new().text("purpose", purpose), "file", file_name))
.send()
.await?
.json::<FileInfo>()
.await?
)
}
pub async fn delete_file(&self, file_id: &str) -> anyhow::Result<FileDeleteResponse> {
Ok(self.with_auth(Client::builder().build()?.delete(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.json::<FileDeleteResponse>().await?)
}
pub async fn get_file(&self, file_id: &str) -> anyhow::Result<impl futures_core::Stream<Item = reqwest::Result<Bytes>>> {
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.bytes_stream())
}
pub async fn get_file_direct(&self, file_id: &str) -> anyhow::Result<Bytes> {
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/files/{file_id}"))).send().await?.bytes().await?)
}
}

View File

@ -1,22 +1,14 @@
use base64::{Engine, prelude::BASE64_STANDARD};
use derive_builder::Builder;
use reqwest::{Body, multipart::{Part, Form}, Client};
use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}};
use tokio_util::codec::{BytesCodec, FramedRead};
#[derive(Debug)]
pub enum ImageFile {
File(tokio::fs::File),
Data(Vec<u8>),
}
use reqwest::{multipart::Form, Client};
use crate::{image::{ResponseFormat, ImageResponse, ImageSize}, context::{API_URL, Context}, util::FileResource};
#[derive(Debug, Builder)]
#[builder(pattern = "owned")]
pub struct ImageEditRequest {
#[builder(setter(into))]
pub image: ImageFile,
pub image: FileResource,
#[builder(setter(into, strip_option), default)]
pub mask: Option<ImageFile>,
pub mask: Option<FileResource>,
#[builder(setter(into))]
pub prompt: String,
#[builder(setter(into, strip_option), default)]
@ -31,21 +23,11 @@ pub struct ImageEditRequest {
pub size: Option<ImageSize>,
}
pub(crate) fn write_file(form: Form, file: ImageFile, name: impl Into<String>) -> Form {
let name = name.into();
match file {
ImageFile::File(file) =>
form.part(name.clone(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(name)),
ImageFile::Data(data) =>
form.text(name, BASE64_STANDARD.encode(data.as_slice())),
}
}
impl Context {
pub async fn create_image_edit(&self, req: ImageEditRequest) -> anyhow::Result<ImageResponse> {
let mut form = Form::new();
form = form.text("prompt", req.prompt);
form = write_file(form, req.image, "image");
form = req.image.write_file(form, "image");
if let Some(n) = req.n {
form = form.text("n", n.to_string());
@ -58,7 +40,7 @@ impl Context {
}
if let Some(mask) = req.mask {
form = write_file(form, mask, "mask");
form = mask.write_file(form, "mask");
}
if let Some(temperature) = req.temperature {

View File

@ -1,13 +1,13 @@
use derive_builder::Builder;
use reqwest::{multipart::Form, Client};
use crate::{image_edit::{ImageFile, write_file}, image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}};
use crate::{image::{ImageSize, ResponseFormat, ImageResponse}, context::{API_URL, Context}, util::FileResource};
#[derive(Debug, Builder)]
#[builder(pattern = "owned")]
pub struct ImageVariationRequest {
#[builder(setter(into))]
pub image: ImageFile,
pub image: FileResource,
#[builder(setter(into, strip_option), default)]
pub n: Option<u32>,
#[builder(setter(into, strip_option), default)]
@ -22,7 +22,7 @@ pub struct ImageVariationRequest {
impl Context {
pub async fn create_image_variation(&self, req: ImageVariationRequest) -> anyhow::Result<ImageResponse> {
let mut form = Form::new();
form = write_file(form, req.image, "image");
form = req.image.write_file(form, "image");
if let Some(n) = req.n {
form = form.text("n", n.to_string());

View File

@ -9,6 +9,9 @@ pub mod image_variation;
pub mod embedding;
pub mod transcription;
pub mod translation;
pub mod file;
pub mod util;
#[cfg(test)]
mod tests {
@ -19,7 +22,7 @@ mod tests {
use crate::completion::CompletionRequestBuilder;
use crate::image::{Image, ResponseFormat, ImageRequestBuilder};
use crate::edits::EditRequestBuilder;
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
use crate::image_edit::ImageEditRequestBuilder;
use crate::image_variation::ImageVariationRequestBuilder;
use crate::embedding::EmbeddingRequestBuilder;
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
@ -132,7 +135,7 @@ mod tests {
let image = ctx.create_image_edit(
ImageEditRequestBuilder::default()
.image(ImageFile::File(File::open("clown.png").await.unwrap()))
.image(File::open("clown.png").await.unwrap())
.prompt("Blue nose")
.build()
.unwrap()
@ -160,7 +163,7 @@ mod tests {
let image = ctx.create_image_variation(
ImageVariationRequestBuilder::default()
.image(ImageFile::File(File::open("clown_original.png").await.unwrap()))
.image(File::open("clown_original.png").await.unwrap())
.build()
.unwrap()
).await;

View File

@ -1,7 +1,7 @@
use reqwest::Client;
use serde::Deserialize;
use crate::context::{API_URL, Context};
use crate::{context::{API_URL, Context}, util::DataList};
#[derive(Debug, Deserialize)]
pub struct Permission {
@ -29,14 +29,9 @@ pub struct Model {
pub parent: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ModelList {
pub data: Vec<Model>,
}
impl Context {
pub async fn get_models(&self) -> anyhow::Result<Vec<Model>> {
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<ModelList>().await?.data)
Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::<DataList<Model>>().await?.data)
}
pub async fn get_model(&self, model_id: &str) -> anyhow::Result<Model> {

View File

@ -1,10 +1,9 @@
use derive_builder::Builder;
use reqwest::{multipart::{Form, Part}, Body, Client};
use reqwest::{multipart::Form, Client};
use serde::Deserialize;
use tokio::fs::File;
use tokio_util::codec::{FramedRead, BytesCodec};
use crate::context::{API_URL, Context};
use crate::{context::{API_URL, Context}, util::FileResource};
#[derive(Debug, Clone)]
pub enum AudioResponseFormat {
@ -87,7 +86,7 @@ impl Context {
pub async fn create_transcription(&self, req: TranscriptionRequest) -> anyhow::Result<TranscriptionResponse> {
let mut form = Form::new();
let file_name = req.file.file_name();
form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name));
form = FileResource::from(req.file.file()).write_file_named(form, "file", file_name);
form = form.text("model", req.model);
if let Some(response_format) = req.response_format {

56
src/util.rs Normal file
View File

@ -0,0 +1,56 @@
use base64::{prelude::BASE64_STANDARD, Engine};
use reqwest::{multipart::{Form, Part}, Body};
use serde::Deserialize;
use tokio_util::codec::{FramedRead, BytesCodec};
pub struct DataList<T: for<'d> Deserialize<'d>> {
pub data: Vec<T>,
/* pub object: "list", */
}
impl<'de, T: for<'d> Deserialize<'d>> Deserialize<'de> for DataList<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Self {
data: Vec::<T>::deserialize(deserializer)?
})
}
}
#[derive(Debug)]
pub enum FileResource {
File(tokio::fs::File),
Data(Vec<u8>),
}
impl FileResource {
pub(crate) fn write_file_named(self, form: Form, part_name: impl Into<String>, file_name: impl Into<String>) -> Form {
match self {
FileResource::File(file) =>
form.part(part_name.into(), Part::stream(Body::wrap_stream(FramedRead::new(file, BytesCodec::new()))).file_name(file_name.into())),
FileResource::Data(data) =>
form.part(part_name.into(), Part::bytes(BASE64_STANDARD.encode(data.as_slice()).as_bytes().to_owned()).file_name(file_name.into())),
}
}
pub(crate) fn write_file(self, form: Form, name: impl Into<String>) -> Form {
let name = name.into();
self.write_file_named(form, name.clone(), name)
}
}
impl From<tokio::fs::File> for FileResource {
fn from(file: tokio::fs::File) -> Self {
Self::File(file)
}
}
impl From<Vec<u8>> for FileResource {
fn from(data: Vec<u8>) -> Self {
Self::Data(data)
}
}