Implement moderation
This commit is contained in:
parent
0409bd1fed
commit
ad68ecd7ab
23
src/lib.rs
23
src/lib.rs
@ -11,6 +11,7 @@ pub mod transcription;
|
|||||||
pub mod translation;
|
pub mod translation;
|
||||||
pub mod file;
|
pub mod file;
|
||||||
pub mod fine_tune;
|
pub mod fine_tune;
|
||||||
|
pub mod moderation;
|
||||||
|
|
||||||
pub mod util;
|
pub mod util;
|
||||||
|
|
||||||
@ -28,6 +29,7 @@ mod tests {
|
|||||||
use crate::embedding::EmbeddingRequestBuilder;
|
use crate::embedding::EmbeddingRequestBuilder;
|
||||||
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
|
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
|
||||||
use crate::translation::TranslationRequestBuilder;
|
use crate::translation::TranslationRequestBuilder;
|
||||||
|
use crate::moderation::ModerationRequestBuilder;
|
||||||
|
|
||||||
fn get_api() -> anyhow::Result<Context> {
|
fn get_api() -> anyhow::Result<Context> {
|
||||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||||
@ -239,4 +241,25 @@ mod tests {
|
|||||||
assert!(translation.is_ok(), "Could not get translation: {}", translation.unwrap_err());
|
assert!(translation.is_ok(), "Could not get translation: {}", translation.unwrap_err());
|
||||||
println!("Translation: {:?}", translation.unwrap().text);
|
println!("Translation: {:?}", translation.unwrap().text);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_moderation() {
|
||||||
|
let ctx = get_api();
|
||||||
|
assert!(ctx.is_ok(), "Could not load context");
|
||||||
|
let ctx = ctx.unwrap();
|
||||||
|
|
||||||
|
let moderation = ctx.create_moderation(
|
||||||
|
ModerationRequestBuilder::default()
|
||||||
|
.model("text-moderation-latest")
|
||||||
|
.input("I want to kill them")
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert!(moderation.is_ok(), "Could not get moderation: {}", moderation.unwrap_err());
|
||||||
|
let moderation = moderation.unwrap();
|
||||||
|
assert!(moderation.results.len() == 1, "No moderation results found");
|
||||||
|
assert!(moderation.results[0].flagged, "Violent language not flagged");
|
||||||
|
println!("Moderation: {:?}", moderation.results[0]);
|
||||||
|
}
|
||||||
}
|
}
|
56
src/moderation.rs
Normal file
56
src/moderation.rs
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
use derive_builder::Builder;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::{completion::Sequence, context::{API_URL, Context}};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Builder)]
|
||||||
|
pub struct ModerationRequest {
|
||||||
|
#[builder(setter(into))]
|
||||||
|
pub input: Sequence,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[builder(setter(into, strip_option), default)]
|
||||||
|
pub model: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Categories<T> {
|
||||||
|
pub hate: T,
|
||||||
|
#[serde(rename = "hate/threatening")]
|
||||||
|
pub hate_threatening: T,
|
||||||
|
#[serde(rename = "self-harm")]
|
||||||
|
pub self_harm: T,
|
||||||
|
pub sexual: T,
|
||||||
|
#[serde(rename = "sexual/minors")]
|
||||||
|
pub sexual_minors: T,
|
||||||
|
pub violence: T,
|
||||||
|
#[serde(rename = "violence/graphic")]
|
||||||
|
pub violence_graphic: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Moderation {
|
||||||
|
pub categories: Categories<bool>,
|
||||||
|
pub category_scores: Categories<f64>,
|
||||||
|
pub flagged: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ModerationResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub model: String,
|
||||||
|
pub results: Vec<Moderation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Context {
|
||||||
|
pub async fn create_moderation(&self, moderation_request: ModerationRequest) -> anyhow::Result<ModerationResponse> {
|
||||||
|
Ok(
|
||||||
|
self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/moderations")).json(&moderation_request))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.error_for_status()?
|
||||||
|
.json::<ModerationResponse>()
|
||||||
|
.await?
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user