From ad68ecd7ab4ef17776b2d05c23399cd2dae31fcd Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Sat, 18 Mar 2023 15:48:29 +0100 Subject: [PATCH] Implement moderation --- src/lib.rs | 23 +++++++++++++++++++ src/moderation.rs | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 src/moderation.rs diff --git a/src/lib.rs b/src/lib.rs index e4f23eb..948b095 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ pub mod transcription; pub mod translation; pub mod file; pub mod fine_tune; +pub mod moderation; pub mod util; @@ -28,6 +29,7 @@ mod tests { use crate::embedding::EmbeddingRequestBuilder; use crate::transcription::{TranscriptionRequestBuilder, AudioFile}; use crate::translation::TranslationRequestBuilder; + use crate::moderation::ModerationRequestBuilder; fn get_api() -> anyhow::Result { Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) @@ -239,4 +241,25 @@ mod tests { assert!(translation.is_ok(), "Could not get translation: {}", translation.unwrap_err()); println!("Translation: {:?}", translation.unwrap().text); } + + #[tokio::test] + async fn test_moderation() { + let ctx = get_api(); + assert!(ctx.is_ok(), "Could not load context"); + let ctx = ctx.unwrap(); + + let moderation = ctx.create_moderation( + ModerationRequestBuilder::default() + .model("text-moderation-latest") + .input("I want to kill them") + .build() + .unwrap() + ).await; + + assert!(moderation.is_ok(), "Could not get moderation: {}", moderation.unwrap_err()); + let moderation = moderation.unwrap(); + assert!(moderation.results.len() == 1, "No moderation results found"); + assert!(moderation.results[0].flagged, "Violent language not flagged"); + println!("Moderation: {:?}", moderation.results[0]); + } } \ No newline at end of file diff --git a/src/moderation.rs b/src/moderation.rs new file mode 100644 index 0000000..1f6fd24 --- /dev/null +++ b/src/moderation.rs @@ -0,0 +1,56 @@ +use derive_builder::Builder; +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +use crate::{completion::Sequence, context::{API_URL, Context}}; + +#[derive(Debug, Serialize, Builder)] +pub struct ModerationRequest { + #[builder(setter(into))] + pub input: Sequence, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(setter(into, strip_option), default)] + pub model: Option, +} + +#[derive(Debug, Deserialize)] +pub struct Categories { + pub hate: T, + #[serde(rename = "hate/threatening")] + pub hate_threatening: T, + #[serde(rename = "self-harm")] + pub self_harm: T, + pub sexual: T, + #[serde(rename = "sexual/minors")] + pub sexual_minors: T, + pub violence: T, + #[serde(rename = "violence/graphic")] + pub violence_graphic: T, +} + +#[derive(Debug, Deserialize)] +pub struct Moderation { + pub categories: Categories, + pub category_scores: Categories, + pub flagged: bool, +} + +#[derive(Debug, Deserialize)] +pub struct ModerationResponse { + pub id: String, + pub model: String, + pub results: Vec, +} + +impl Context { + pub async fn create_moderation(&self, moderation_request: ModerationRequest) -> anyhow::Result { + Ok( + self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/moderations")).json(&moderation_request)) + .send() + .await? + .error_for_status()? + .json::() + .await? + ) + } +} \ No newline at end of file