Implement moderation

This commit is contained in:
Gabriel Tofvesson 2023-03-18 15:48:29 +01:00
parent 0409bd1fed
commit ad68ecd7ab
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
2 changed files with 79 additions and 0 deletions

View File

@ -11,6 +11,7 @@ pub mod transcription;
pub mod translation;
pub mod file;
pub mod fine_tune;
pub mod moderation;
pub mod util;
@ -28,6 +29,7 @@ mod tests {
use crate::embedding::EmbeddingRequestBuilder;
use crate::transcription::{TranscriptionRequestBuilder, AudioFile};
use crate::translation::TranslationRequestBuilder;
use crate::moderation::ModerationRequestBuilder;
fn get_api() -> anyhow::Result<Context> {
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
@ -239,4 +241,25 @@ mod tests {
assert!(translation.is_ok(), "Could not get translation: {}", translation.unwrap_err());
println!("Translation: {:?}", translation.unwrap().text);
}
#[tokio::test]
async fn test_moderation() {
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let ctx = ctx.unwrap();
let moderation = ctx.create_moderation(
ModerationRequestBuilder::default()
.model("text-moderation-latest")
.input("I want to kill them")
.build()
.unwrap()
).await;
assert!(moderation.is_ok(), "Could not get moderation: {}", moderation.unwrap_err());
let moderation = moderation.unwrap();
assert!(moderation.results.len() == 1, "No moderation results found");
assert!(moderation.results[0].flagged, "Violent language not flagged");
println!("Moderation: {:?}", moderation.results[0]);
}
}

56
src/moderation.rs Normal file
View File

@ -0,0 +1,56 @@
use derive_builder::Builder;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::{completion::Sequence, context::{API_URL, Context}};
#[derive(Debug, Serialize, Builder)]
pub struct ModerationRequest {
#[builder(setter(into))]
pub input: Sequence,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub model: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct Categories<T> {
pub hate: T,
#[serde(rename = "hate/threatening")]
pub hate_threatening: T,
#[serde(rename = "self-harm")]
pub self_harm: T,
pub sexual: T,
#[serde(rename = "sexual/minors")]
pub sexual_minors: T,
pub violence: T,
#[serde(rename = "violence/graphic")]
pub violence_graphic: T,
}
#[derive(Debug, Deserialize)]
pub struct Moderation {
pub categories: Categories<bool>,
pub category_scores: Categories<f64>,
pub flagged: bool,
}
#[derive(Debug, Deserialize)]
pub struct ModerationResponse {
pub id: String,
pub model: String,
pub results: Vec<Moderation>,
}
impl Context {
pub async fn create_moderation(&self, moderation_request: ModerationRequest) -> anyhow::Result<ModerationResponse> {
Ok(
self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/moderations")).json(&moderation_request))
.send()
.await?
.error_for_status()?
.json::<ModerationResponse>()
.await?
)
}
}