Implement transcription

This commit is contained in:
Gabriel Tofvesson 2023-03-18 03:43:07 +01:00
parent 703dbc3565
commit f63547c67a
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
3 changed files with 132 additions and 0 deletions

BIN
sample_audio.mp3 Normal file

Binary file not shown.

View File

@ -7,6 +7,7 @@ pub mod image;
pub mod image_edit;
pub mod image_variation;
pub mod embedding;
pub mod transcription;
#[cfg(test)]
mod tests {
@ -20,6 +21,7 @@ mod tests {
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
use crate::image_variation::ImageVariationRequestBuilder;
use crate::embedding::EmbeddingRequestBuilder;
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
fn get_api() -> anyhow::Result<Context> {
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
@ -194,4 +196,24 @@ mod tests {
assert!(embeddings.as_ref().unwrap().data[0].embedding.len() > 0, "No embeddings found");
println!("Embeddings: {:?}", embeddings.unwrap().data[0].embedding);
}
#[tokio::test]
async fn test_transcription() {
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let ctx = ctx.unwrap();
// Original script: "Hello. This is a sample piece of audio for which the whisper AI will generate a transcript"
// Expected result: "Hello, this is a sample piece of audio for which the Whisper AI will generate a transcript."
let transcription = ctx.create_transcription(
TranscriptionRequestBuilder::default()
.model("whisper-1")
.file(AudioFile::MP3(File::open("sample_audio.mp3").await.unwrap()))
.build()
.unwrap()
).await;
assert!(transcription.is_ok(), "Could not get transcription: {}", transcription.unwrap_err());
println!("Transcription: {:?}", transcription.unwrap().text);
}
}

110
src/transcription.rs Normal file
View File

@ -0,0 +1,110 @@
use derive_builder::Builder;
use reqwest::{multipart::{Form, Part}, Body, Client};
use serde::Deserialize;
use tokio::fs::File;
use tokio_util::codec::{FramedRead, BytesCodec};
use crate::context::{API_URL, Context};
#[derive(Debug, Clone)]
pub enum AudioResponseFormat {
Text,
Json,
Srt,
Vtt,
VerboseJson,
}
#[derive(Debug)]
pub enum AudioFile {
MP3(File),
MP4(File),
MPEG(File),
MPGA(File),
WAV(File),
WEBM(File),
}
impl AudioFile {
fn file_name(&self) -> &'static str {
match self {
AudioFile::MP3(_) => "file.mp3",
AudioFile::MP4(_) => "file.mp4",
AudioFile::MPEG(_) => "file.mpeg",
AudioFile::MPGA(_) => "file.mpga",
AudioFile::WAV(_) => "file.wav",
AudioFile::WEBM(_) => "file.webm",
}
}
fn file(self) -> File {
match self {
AudioFile::MP3(file) => file,
AudioFile::MP4(file) => file,
AudioFile::MPEG(file) => file,
AudioFile::MPGA(file) => file,
AudioFile::WAV(file) => file,
AudioFile::WEBM(file) => file,
}
}
}
impl ToString for AudioResponseFormat {
fn to_string(&self) -> String {
match self {
AudioResponseFormat::Text => "text",
AudioResponseFormat::Json => "json",
AudioResponseFormat::Srt => "srt",
AudioResponseFormat::Vtt => "vtt",
AudioResponseFormat::VerboseJson => "verbose_json",
}.to_string()
}
}
#[derive(Debug, Builder)]
#[builder(pattern = "owned")]
pub struct TranscriptionRequest {
#[builder(setter(into))]
pub file: AudioFile,
#[builder(setter(into))]
pub model: String,
#[builder(setter(into, strip_option), default)]
pub prompt: Option<String>,
#[builder(setter(into, strip_option), default)]
pub response_format: Option<AudioResponseFormat>,
#[builder(setter(into, strip_option), default)]
pub temperature: Option<f64>,
#[builder(setter(into, strip_option), default)]
pub language: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct TranscriptionResponse {
pub text: String,
}
impl Context {
pub async fn create_transcription(&self, req: TranscriptionRequest) -> anyhow::Result<TranscriptionResponse> {
let mut form = Form::new();
let file_name = req.file.file_name();
form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name));
form = form.text("model", req.model);
if let Some(response_format) = req.response_format {
form = form.text("response_format", response_format.to_string());
}
if let Some(prompt) = req.prompt {
form = form.text("prompt", prompt);
}
if let Some(temperature) = req.temperature {
form = form.text("temperature", temperature.to_string());
}
if let Some(language) = req.language {
form = form.text("language", language.to_string());
}
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/audio/transcriptions")).multipart(form)).send().await?.json::<TranscriptionResponse>().await?)
}
}