Implement transcription
This commit is contained in:
parent
703dbc3565
commit
f63547c67a
BIN
sample_audio.mp3
Normal file
BIN
sample_audio.mp3
Normal file
Binary file not shown.
22
src/lib.rs
22
src/lib.rs
@ -7,6 +7,7 @@ pub mod image;
|
|||||||
pub mod image_edit;
|
pub mod image_edit;
|
||||||
pub mod image_variation;
|
pub mod image_variation;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
|
pub mod transcription;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@ -20,6 +21,7 @@ mod tests {
|
|||||||
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
|
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
|
||||||
use crate::image_variation::ImageVariationRequestBuilder;
|
use crate::image_variation::ImageVariationRequestBuilder;
|
||||||
use crate::embedding::EmbeddingRequestBuilder;
|
use crate::embedding::EmbeddingRequestBuilder;
|
||||||
|
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
|
||||||
|
|
||||||
fn get_api() -> anyhow::Result<Context> {
|
fn get_api() -> anyhow::Result<Context> {
|
||||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||||
@ -194,4 +196,24 @@ mod tests {
|
|||||||
assert!(embeddings.as_ref().unwrap().data[0].embedding.len() > 0, "No embeddings found");
|
assert!(embeddings.as_ref().unwrap().data[0].embedding.len() > 0, "No embeddings found");
|
||||||
println!("Embeddings: {:?}", embeddings.unwrap().data[0].embedding);
|
println!("Embeddings: {:?}", embeddings.unwrap().data[0].embedding);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_transcription() {
|
||||||
|
let ctx = get_api();
|
||||||
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
let ctx = ctx.unwrap();
|
||||||
|
|
||||||
|
// Original script: "Hello. This is a sample piece of audio for which the whisper AI will generate a transcript"
|
||||||
|
// Expected result: "Hello, this is a sample piece of audio for which the Whisper AI will generate a transcript."
|
||||||
|
let transcription = ctx.create_transcription(
|
||||||
|
TranscriptionRequestBuilder::default()
|
||||||
|
.model("whisper-1")
|
||||||
|
.file(AudioFile::MP3(File::open("sample_audio.mp3").await.unwrap()))
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert!(transcription.is_ok(), "Could not get transcription: {}", transcription.unwrap_err());
|
||||||
|
println!("Transcription: {:?}", transcription.unwrap().text);
|
||||||
|
}
|
||||||
}
|
}
|
110
src/transcription.rs
Normal file
110
src/transcription.rs
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
use derive_builder::Builder;
|
||||||
|
use reqwest::{multipart::{Form, Part}, Body, Client};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::fs::File;
|
||||||
|
use tokio_util::codec::{FramedRead, BytesCodec};
|
||||||
|
|
||||||
|
use crate::context::{API_URL, Context};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum AudioResponseFormat {
|
||||||
|
Text,
|
||||||
|
Json,
|
||||||
|
Srt,
|
||||||
|
Vtt,
|
||||||
|
VerboseJson,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum AudioFile {
|
||||||
|
MP3(File),
|
||||||
|
MP4(File),
|
||||||
|
MPEG(File),
|
||||||
|
MPGA(File),
|
||||||
|
WAV(File),
|
||||||
|
WEBM(File),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AudioFile {
|
||||||
|
fn file_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
AudioFile::MP3(_) => "file.mp3",
|
||||||
|
AudioFile::MP4(_) => "file.mp4",
|
||||||
|
AudioFile::MPEG(_) => "file.mpeg",
|
||||||
|
AudioFile::MPGA(_) => "file.mpga",
|
||||||
|
AudioFile::WAV(_) => "file.wav",
|
||||||
|
AudioFile::WEBM(_) => "file.webm",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn file(self) -> File {
|
||||||
|
match self {
|
||||||
|
AudioFile::MP3(file) => file,
|
||||||
|
AudioFile::MP4(file) => file,
|
||||||
|
AudioFile::MPEG(file) => file,
|
||||||
|
AudioFile::MPGA(file) => file,
|
||||||
|
AudioFile::WAV(file) => file,
|
||||||
|
AudioFile::WEBM(file) => file,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToString for AudioResponseFormat {
|
||||||
|
fn to_string(&self) -> String {
|
||||||
|
match self {
|
||||||
|
AudioResponseFormat::Text => "text",
|
||||||
|
AudioResponseFormat::Json => "json",
|
||||||
|
AudioResponseFormat::Srt => "srt",
|
||||||
|
AudioResponseFormat::Vtt => "vtt",
|
||||||
|
AudioResponseFormat::VerboseJson => "verbose_json",
|
||||||
|
}.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Builder)]
|
||||||
|
#[builder(pattern = "owned")]
|
||||||
|
pub struct TranscriptionRequest {
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub file: AudioFile,
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub model: String,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub prompt: Option<String>,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub response_format: Option<AudioResponseFormat>,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub language: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct TranscriptionResponse {
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Context {
|
||||||
|
pub async fn create_transcription(&self, req: TranscriptionRequest) -> anyhow::Result<TranscriptionResponse> {
|
||||||
|
let mut form = Form::new();
|
||||||
|
let file_name = req.file.file_name();
|
||||||
|
form = form.part("file", Part::stream(Body::wrap_stream(FramedRead::new(req.file.file(), BytesCodec::new()))).file_name(file_name));
|
||||||
|
form = form.text("model", req.model);
|
||||||
|
|
||||||
|
if let Some(response_format) = req.response_format {
|
||||||
|
form = form.text("response_format", response_format.to_string());
|
||||||
|
}
|
||||||
|
if let Some(prompt) = req.prompt {
|
||||||
|
form = form.text("prompt", prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(temperature) = req.temperature {
|
||||||
|
form = form.text("temperature", temperature.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(language) = req.language {
|
||||||
|
form = form.text("language", language.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/audio/transcriptions")).multipart(form)).send().await?.json::<TranscriptionResponse>().await?)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user