diff --git a/sample_audio.mp3 b/sample_audio.mp3 new file mode 100644 index 0000000..b405f8d Binary files /dev/null and b/sample_audio.mp3 differ diff --git a/src/lib.rs b/src/lib.rs index 1827bb3..706bd35 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ pub mod image; pub mod image_edit; pub mod image_variation; pub mod embedding; +pub mod transcription; #[cfg(test)] mod tests { @@ -20,6 +21,7 @@ mod tests { use crate::image_edit::{ImageEditRequestBuilder, ImageFile}; use crate::image_variation::ImageVariationRequestBuilder; use crate::embedding::EmbeddingRequestBuilder; + use crate::transcription::{TranscriptionRequestBuilder, AudioFile}; fn get_api() -> anyhow::Result { Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) @@ -194,4 +196,24 @@ mod tests { assert!(embeddings.as_ref().unwrap().data[0].embedding.len() > 0, "No embeddings found"); println!("Embeddings: {:?}", embeddings.unwrap().data[0].embedding); } + + #[tokio::test] + async fn test_transcription() { + let ctx = get_api(); + assert!(ctx.is_ok(), "Could not load context"); + let ctx = ctx.unwrap(); + + // Original script: "Hello. This is a sample piece of audio for which the whisper AI will generate a transcript" + // Expected result: "Hello, this is a sample piece of audio for which the Whisper AI will generate a transcript." + let transcription = ctx.create_transcription( + TranscriptionRequestBuilder::default() + .model("whisper-1") + .file(AudioFile::MP3(File::open("sample_audio.mp3").await.unwrap())) + .build() + .unwrap() + ).await; + + assert!(transcription.is_ok(), "Could not get transcription: {}", transcription.unwrap_err()); + println!("Transcription: {:?}", transcription.unwrap().text); + } } \ No newline at end of file diff --git a/src/transcription.rs b/src/transcription.rs new file mode 100644 index 0000000..2383dc6 --- /dev/null +++ b/src/transcription.rs @@ -0,0 +1,110 @@ +use derive_builder::Builder; +use reqwest::{multipart::{Form, Part}, Body, Client}; +use serde::Deserialize; +use tokio::fs::File; +use tokio_util::codec::{FramedRead, BytesCodec}; + +use crate::context::{API_URL, Context}; + +#[derive(Debug, Clone)] +pub enum AudioResponseFormat { + Text, + Json, + Srt, + Vtt, + VerboseJson, +} + +#[derive(Debug)] +pub enum AudioFile { + MP3(File), + MP4(File), + MPEG(File), + MPGA(File), + WAV(File), + WEBM(File), +} + +impl AudioFile { + fn file_name(&self) -> &'static str { + match self { + AudioFile::MP3(_) => "file.mp3", + AudioFile::MP4(_) => "file.mp4", + AudioFile::MPEG(_) => "file.mpeg", + AudioFile::MPGA(_) => "file.mpga", + AudioFile::WAV(_) => "file.wav", + AudioFile::WEBM(_) => "file.webm", + } + } + + fn file(self) -> File { + match self { + AudioFile::MP3(file) => file, + AudioFile::MP4(file) => file, + AudioFile::MPEG(file) => file, + AudioFile::MPGA(file) => file, + AudioFile::WAV(file) => file, + AudioFile::WEBM(file) => file, + } + } +} + +impl ToString for AudioResponseFormat { + fn to_string(&self) -> String { + match self { + AudioResponseFormat::Text => "text", + AudioResponseFormat::Json => "json", + AudioResponseFormat::Srt => "srt", + AudioResponseFormat::Vtt => "vtt", + AudioResponseFormat::VerboseJson => "verbose_json", + }.to_string() + } +} + +#[derive(Debug, Builder)] +#[builder(pattern = "owned")] +pub struct TranscriptionRequest { + #[builder(setter(into))] + pub file: AudioFile, + #[builder(setter(into))] + pub model: String, + #[builder(setter(into, strip_option), default)] + pub prompt: Option, + #[builder(setter(into, strip_option), default)] + pub response_format: Option, + #[builder(setter(into, strip_option), default)] + pub temperature: Option, + #[builder(setter(into, strip_option), default)] + pub language: Option, +} + +#[derive(Debug, Deserialize)] +pub struct TranscriptionResponse { + pub text: String, +} + +impl Context { + pub async fn create_transcription(&self, req: TranscriptionRequest) -> anyhow::Result { + let mut form = Form::new(); + let file_name = req.file.file_name(); + form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name)); + form = form.text("model", req.model); + + if let Some(response_format) = req.response_format { + form = form.text("response_format", response_format.to_string()); + } + if let Some(prompt) = req.prompt { + form = form.text("prompt", prompt); + } + + if let Some(temperature) = req.temperature { + form = form.text("temperature", temperature.to_string()); + } + + if let Some(language) = req.language { + form = form.text("language", language.to_string()); + } + + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/audio/transcriptions")).multipart(form)).send().await?.json::().await?) + } +} \ No newline at end of file