From bcaa49c4c653663200d66a9b53ee8b9a3c0507d6 Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Sat, 18 Mar 2023 03:47:47 +0100 Subject: [PATCH] Implement translation --- src/lib.rs | 3 +++ src/transcription.rs | 4 ++-- src/translation.rs | 45 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 src/translation.rs diff --git a/src/lib.rs b/src/lib.rs index 706bd35..86c4ea9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ pub mod image_edit; pub mod image_variation; pub mod embedding; pub mod transcription; +pub mod translation; #[cfg(test)] mod tests { @@ -216,4 +217,6 @@ mod tests { assert!(transcription.is_ok(), "Could not get transcription: {}", transcription.unwrap_err()); println!("Transcription: {:?}", transcription.unwrap().text); } + + // TODO: Add translation test } \ No newline at end of file diff --git a/src/transcription.rs b/src/transcription.rs index 2383dc6..6bc4eaa 100644 --- a/src/transcription.rs +++ b/src/transcription.rs @@ -26,7 +26,7 @@ pub enum AudioFile { } impl AudioFile { - fn file_name(&self) -> &'static str { + pub(crate) fn file_name(&self) -> &'static str { match self { AudioFile::MP3(_) => "file.mp3", AudioFile::MP4(_) => "file.mp4", @@ -37,7 +37,7 @@ impl AudioFile { } } - fn file(self) -> File { + pub(crate) fn file(self) -> File { match self { AudioFile::MP3(file) => file, AudioFile::MP4(file) => file, diff --git a/src/translation.rs b/src/translation.rs new file mode 100644 index 0000000..489438c --- /dev/null +++ b/src/translation.rs @@ -0,0 +1,45 @@ +use derive_builder::Builder; +use reqwest::{multipart::{Form, Part}, Body, Client}; +use tokio_util::codec::{FramedRead, BytesCodec}; + +use crate::{context::{API_URL, Context}, transcription::TranscriptionResponse}; +use crate::transcription::{AudioFile, AudioResponseFormat}; + +type TranslationResponse = TranscriptionResponse; + +#[derive(Debug, Builder)] +#[builder(pattern = "owned")] +pub struct TranslationRequest { + #[builder(setter(into))] + pub file: AudioFile, + #[builder(setter(into))] + pub model: String, + #[builder(setter(into, strip_option), default)] + pub prompt: Option, + #[builder(setter(into, strip_option), default)] + pub response_format: Option, + #[builder(setter(into, strip_option), default)] + pub temperature: Option, +} + +impl Context { + pub async fn create_translation(&self, req: TranslationRequest) -> anyhow::Result { + let mut form = Form::new(); + let file_name = req.file.file_name(); + form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name)); + form = form.text("model", req.model); + + if let Some(response_format) = req.response_format { + form = form.text("response_format", response_format.to_string()); + } + if let Some(prompt) = req.prompt { + form = form.text("prompt", prompt); + } + + if let Some(temperature) = req.temperature { + form = form.text("temperature", temperature.to_string()); + } + + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/audio/translations")).multipart(form)).send().await?.json::().await?) + } +} \ No newline at end of file