Implement moderation
This commit is contained in:
parent
0409bd1fed
commit
ad68ecd7ab
23
src/lib.rs
23
src/lib.rs
@ -11,6 +11,7 @@ pub mod transcription;
|
||||
pub mod translation;
|
||||
pub mod file;
|
||||
pub mod fine_tune;
|
||||
pub mod moderation;
|
||||
|
||||
pub mod util;
|
||||
|
||||
@ -28,6 +29,7 @@ mod tests {
|
||||
use crate::embedding::EmbeddingRequestBuilder;
|
||||
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
|
||||
use crate::translation::TranslationRequestBuilder;
|
||||
use crate::moderation::ModerationRequestBuilder;
|
||||
|
||||
fn get_api() -> anyhow::Result<Context> {
|
||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||
@ -239,4 +241,25 @@ mod tests {
|
||||
assert!(translation.is_ok(), "Could not get translation: {}", translation.unwrap_err());
|
||||
println!("Translation: {:?}", translation.unwrap().text);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_moderation() {
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
let ctx = ctx.unwrap();
|
||||
|
||||
let moderation = ctx.create_moderation(
|
||||
ModerationRequestBuilder::default()
|
||||
.model("text-moderation-latest")
|
||||
.input("I want to kill them")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(moderation.is_ok(), "Could not get moderation: {}", moderation.unwrap_err());
|
||||
let moderation = moderation.unwrap();
|
||||
assert!(moderation.results.len() == 1, "No moderation results found");
|
||||
assert!(moderation.results[0].flagged, "Violent language not flagged");
|
||||
println!("Moderation: {:?}", moderation.results[0]);
|
||||
}
|
||||
}
|
56
src/moderation.rs
Normal file
56
src/moderation.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use derive_builder::Builder;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{completion::Sequence, context::{API_URL, Context}};
|
||||
|
||||
#[derive(Debug, Serialize, Builder)]
|
||||
pub struct ModerationRequest {
|
||||
#[builder(setter(into))]
|
||||
pub input: Sequence,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Categories<T> {
|
||||
pub hate: T,
|
||||
#[serde(rename = "hate/threatening")]
|
||||
pub hate_threatening: T,
|
||||
#[serde(rename = "self-harm")]
|
||||
pub self_harm: T,
|
||||
pub sexual: T,
|
||||
#[serde(rename = "sexual/minors")]
|
||||
pub sexual_minors: T,
|
||||
pub violence: T,
|
||||
#[serde(rename = "violence/graphic")]
|
||||
pub violence_graphic: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Moderation {
|
||||
pub categories: Categories<bool>,
|
||||
pub category_scores: Categories<f64>,
|
||||
pub flagged: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ModerationResponse {
|
||||
pub id: String,
|
||||
pub model: String,
|
||||
pub results: Vec<Moderation>,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub async fn create_moderation(&self, moderation_request: ModerationRequest) -> anyhow::Result<ModerationResponse> {
|
||||
Ok(
|
||||
self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/moderations")).json(&moderation_request))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json::<ModerationResponse>()
|
||||
.await?
|
||||
)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user