Implement voice functions

This commit is contained in:
Gabriel Tofvesson 2023-02-28 03:51:47 +01:00
parent 241b70724e
commit 90c0090947
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
8 changed files with 246 additions and 7 deletions

20
Cargo.lock generated
View File

@ -557,6 +557,16 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]] [[package]]
name = "miniz_oxide" name = "miniz_oxide"
version = "0.6.2" version = "0.6.2"
@ -786,6 +796,7 @@ dependencies = [
"js-sys", "js-sys",
"log", "log",
"mime", "mime",
"mime_guess",
"native-tls", "native-tls",
"once_cell", "once_cell",
"percent-encoding", "percent-encoding",
@ -1114,6 +1125,15 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
[[package]]
name = "unicase"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
dependencies = [
"version_check",
]
[[package]] [[package]]
name = "unicode-bidi" name = "unicode-bidi"
version = "0.3.10" version = "0.3.10"

View File

@ -7,7 +7,7 @@ edition = "2021"
[dependencies] [dependencies]
bytes = "1.4.0" bytes = "1.4.0"
reqwest = { version = "0.11", features = ["json"] } reqwest = { version = "0.11", features = ["json", "multipart"] }
serde = { version = "1.0.152", features = ["std", "derive"] } serde = { version = "1.0.152", features = ["std", "derive"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
zip = "0.6.4" zip = "0.6.4"

View File

@ -1,3 +1,4 @@
pub mod user; pub mod user;
pub mod history; pub mod history;
pub mod tts; pub mod tts;
pub mod voice;

132
src/api/voice.rs Normal file
View File

@ -0,0 +1,132 @@
use reqwest::multipart::Part;
use crate::{elevenlabs_api::ElevenLabsAPI, model::{voice::{Voice, VoiceSettings, VoiceCreation, VoiceId}, error::APIError}};
impl ElevenLabsAPI {
pub async fn get_voices(&self) -> Result<Vec<Voice>, APIError> {
let response = self.get(crate::elevenlabs_api::voice::GET::List)?.send().await?;
if response.status().is_success() {
Ok(response.json().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn get_default_voice_settings(&self) -> Result<VoiceSettings, APIError> {
let response = self.get(crate::elevenlabs_api::voice::GET::DefaultSettings)?.send().await?;
if response.status().is_success() {
Ok(response.json().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn get_voice_settings(&self, voice_id: String) -> Result<VoiceSettings, APIError> {
let response = self.get(crate::elevenlabs_api::voice::GET::Settings { voice_id })?.send().await?;
if response.status().is_success() {
Ok(response.json().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn get_voice(&self, voice_id: String) -> Result<Voice, APIError> {
let response = self.get(crate::elevenlabs_api::voice::GET::Voice { voice_id })?.send().await?;
if response.status().is_success() {
Ok(response.json().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn delete_voice(&self, voice_id: String) -> Result<String, APIError> {
let response = self.delete(crate::elevenlabs_api::voice::DELETE::Voice { voice_id })?.send().await?;
if response.status().is_success() {
Ok(response.text().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn edit_voice_settings(&self, voice_id: String, settings: VoiceSettings) -> Result<String, APIError> {
let response = self.post(crate::elevenlabs_api::voice::POST::EditSettings { voice_id })?.json(&settings).send().await?;
if response.status().is_success() {
Ok(response.text().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn add_voice(&self, voice: VoiceCreation) -> Result<VoiceId, APIError> {
let mut form = reqwest::multipart::Form::new().text("name", voice.name);
for (name, file) in voice.files {
form = form.part("files", Part::bytes(file).file_name(name));
}
let response = self.post(crate::elevenlabs_api::voice::POST::AddVoice)?.multipart(form).send().await?;
if response.status().is_success() {
Ok(response.json().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn edit_voice(&self, voice_id: String, voice: VoiceCreation) -> Result<String, APIError> {
let mut form = reqwest::multipart::Form::new().text("name", voice.name);
for (name, file) in voice.files {
form = form.part("files", Part::bytes(file).file_name(name));
}
let response = self.post(crate::elevenlabs_api::voice::POST::EditVoice { voice_id })?.multipart(form).send().await?;
if response.status().is_success() {
Ok(response.text().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn add_professional_voice(&self, voice: VoiceCreation) -> Result<VoiceId, APIError> {
let mut form = reqwest::multipart::Form::new().text("name", voice.name);
for (name, file) in voice.files {
form = form.part("files", Part::bytes(file).file_name(name));
}
let response = self.post(crate::elevenlabs_api::voice::POST::AddProfessionalVoice)?.multipart(form).send().await?;
if response.status().is_success() {
Ok(response.json().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
pub async fn start_fine_tuning_voice(&self, voice_id: String) -> Result<String, APIError> {
let response = self.post(crate::elevenlabs_api::voice::POST::StartFineTuning { voice_id })?.send().await?;
if response.status().is_success() {
Ok(response.text().await?)
} else {
let error: crate::model::error::HTTPValidationError = response.json().await?;
Err(crate::model::error::APIError::HTTPError(error))
}
}
}

View File

@ -61,7 +61,7 @@ pub mod tts {
} }
} }
pub mod voices { pub mod voice {
use super::Endpoint; use super::Endpoint;
pub enum GET { pub enum GET {
@ -108,6 +108,10 @@ pub mod voices {
EditVoice { EditVoice {
voice_id: String, voice_id: String,
}, },
AddProfessionalVoice,
StartFineTuning {
voice_id: String,
},
} }
impl Endpoint for POST { impl Endpoint for POST {
@ -116,6 +120,8 @@ pub mod voices {
POST::EditSettings { voice_id } => format!("/v1/voices/{}/settings/edit", voice_id), POST::EditSettings { voice_id } => format!("/v1/voices/{}/settings/edit", voice_id),
POST::AddVoice => "/v1/voices".to_string(), POST::AddVoice => "/v1/voices".to_string(),
POST::EditVoice { voice_id } => format!("/v1/voices/{}", voice_id), POST::EditVoice { voice_id } => format!("/v1/voices/{}", voice_id),
POST::AddProfessionalVoice => "/v1/voices/add-professional".to_string(),
POST::StartFineTuning { voice_id } => format!("/v1/voices/{}/start-fine-tuning", voice_id),
} }
} }
} }

View File

@ -37,7 +37,7 @@ mod tests {
let single_audio = api.get_history_audio(item.history_item_id.clone()).await; let single_audio = api.get_history_audio(item.history_item_id.clone()).await;
assert!(single_audio.is_ok()); assert!(single_audio.is_ok());
//std::fs::write("test0.mp3", single_audio.audio(0).unwrap()).unwrap(); std::fs::write("test0.mp3", single_audio.unwrap().to_vec()).unwrap();
if result.history.len() > 1 { if result.history.len() > 1 {
let audio_result = api.download_history(HistoryItems { let audio_result = api.download_history(HistoryItems {
@ -50,8 +50,8 @@ mod tests {
let audio = audio_result.audio(); let audio = audio_result.audio();
assert!(audio.len() == 2); assert!(audio.len() == 2);
//std::fs::write("test1.mp3", audio.audio(0).unwrap()).unwrap(); std::fs::write("test1.mp3", &audio[0]).unwrap();
//std::fs::write("test2.mp3", audio.audio(1).unwrap()).unwrap(); std::fs::write("test2.mp3", &audio[1]).unwrap();
} }
} }
} }

View File

@ -1,3 +1,4 @@
pub mod user; pub mod user;
pub mod error; pub mod error;
pub mod history; pub mod history;
pub mod voice;

79
src/model/voice.rs Normal file
View File

@ -0,0 +1,79 @@
use std::{collections::HashMap, path::Path};
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize)]
pub struct Sample {
pub sample_id: String,
pub file_name: String,
pub mime_type: String,
pub size_bytes: u32,
pub hash: String,
}
#[derive(Debug, Deserialize)]
pub enum FineTuningState {
#[serde(rename = "not_started")]
NotStarted,
#[serde(rename = "is_fine_tuning")]
IsFineTuning,
#[serde(rename = "fine_tuned")]
FineTuned,
}
#[derive(Debug, Deserialize)]
pub struct FineTuning {
pub is_allowed_to_fine_tune: bool,
pub fine_tuning_requesed: bool,
pub finetuning_state: FineTuningState,
pub verification_attempts_count: u32,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct VoiceSettings {
pub stability: f32,
pub similiarity_boost: f32,
}
#[derive(Debug, Deserialize)]
pub struct Voice {
pub voice_id: String,
pub name: String,
pub samples: Vec<Sample>,
pub category: String,
pub fine_tuning: FineTuning,
pub preview_url: String,
pub available_for_tiers: Vec<String>,
pub settings: VoiceSettings,
pub labels: HashMap<String, String>,
}
#[derive(Debug, Serialize)]
pub struct VoiceCreation {
pub(crate) name: String,
pub(crate) files: Vec<(String, Vec<u8>)>,
pub(crate) labels: HashMap<String, String>,
}
impl VoiceCreation {
pub fn new(name: String, files: Vec<(String, Vec<u8>)>, labels: HashMap<String, String>) -> Self {
Self { name, files, labels }
}
pub fn new_files(name: String, files: Vec<&Path>, labels: HashMap<String, String>) -> std::io::Result<Self> {
let mut collect = Vec::new();
for path in files {
collect.push((
path.file_name().unwrap().to_str().unwrap().to_string(),
std::io::Read::bytes(std::fs::File::open(path)?).map_while(|it| it.ok()).collect()
));
}
Ok(Self { name, files: collect, labels })
}
}
#[derive(Debug, Deserialize)]
pub struct VoiceId {
pub voice_id: String,
}