Implement translation

This commit is contained in:
Gabriel Tofvesson 2023-03-18 03:47:47 +01:00
parent f63547c67a
commit bcaa49c4c6
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
3 changed files with 50 additions and 2 deletions

View File

@ -8,6 +8,7 @@ pub mod image_edit;
pub mod image_variation;
pub mod embedding;
pub mod transcription;
pub mod translation;
#[cfg(test)]
mod tests {
@ -216,4 +217,6 @@ mod tests {
assert!(transcription.is_ok(), "Could not get transcription: {}", transcription.unwrap_err());
println!("Transcription: {:?}", transcription.unwrap().text);
}
// TODO: Add translation test
}

View File

@ -26,7 +26,7 @@ pub enum AudioFile {
}
impl AudioFile {
fn file_name(&self) -> &'static str {
pub(crate) fn file_name(&self) -> &'static str {
match self {
AudioFile::MP3(_) => "file.mp3",
AudioFile::MP4(_) => "file.mp4",
@ -37,7 +37,7 @@ impl AudioFile {
}
}
fn file(self) -> File {
pub(crate) fn file(self) -> File {
match self {
AudioFile::MP3(file) => file,
AudioFile::MP4(file) => file,

45
src/translation.rs Normal file
View File

@ -0,0 +1,45 @@
use derive_builder::Builder;
use reqwest::{multipart::{Form, Part}, Body, Client};
use tokio_util::codec::{FramedRead, BytesCodec};
use crate::{context::{API_URL, Context}, transcription::TranscriptionResponse};
use crate::transcription::{AudioFile, AudioResponseFormat};
type TranslationResponse = TranscriptionResponse;
#[derive(Debug, Builder)]
#[builder(pattern = "owned")]
pub struct TranslationRequest {
#[builder(setter(into))]
pub file: AudioFile,
#[builder(setter(into))]
pub model: String,
#[builder(setter(into, strip_option), default)]
pub prompt: Option<String>,
#[builder(setter(into, strip_option), default)]
pub response_format: Option<AudioResponseFormat>,
#[builder(setter(into, strip_option), default)]
pub temperature: Option<f64>,
}
impl Context {
pub async fn create_translation(&self, req: TranslationRequest) -> anyhow::Result<TranslationResponse> {
let mut form = Form::new();
let file_name = req.file.file_name();
form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name));
form = form.text("model", req.model);
if let Some(response_format) = req.response_format {
form = form.text("response_format", response_format.to_string());
}
if let Some(prompt) = req.prompt {
form = form.text("prompt", prompt);
}
if let Some(temperature) = req.temperature {
form = form.text("temperature", temperature.to_string());
}
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/audio/translations")).multipart(form)).send().await?.json::<TranslationResponse>().await?)
}
}