From 703dbc35654394295d9951958fe662cf5a47f325 Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Sat, 18 Mar 2023 03:17:03 +0100 Subject: [PATCH] Implement embeddings --- src/completion.rs | 6 ++++++ src/embedding.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 22 ++++++++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 src/embedding.rs diff --git a/src/completion.rs b/src/completion.rs index 07f9847..bf87ddb 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -36,6 +36,12 @@ impl From> for Sequence { } } +impl From> for Sequence { + fn from(v: Vec<&str>) -> Self { + Sequence::List(v.iter().map(|s| s.to_string()).collect()) + } +} + impl From<&str> for Sequence { fn from(s: &str) -> Self { Sequence::String(s.to_string()) diff --git a/src/embedding.rs b/src/embedding.rs new file mode 100644 index 0000000..3a7a821 --- /dev/null +++ b/src/embedding.rs @@ -0,0 +1,43 @@ +use derive_builder::Builder; +use reqwest::Client; +use serde::{Serialize, Deserialize}; + +use crate::{completion::Sequence, context::{API_URL, Context}}; + +#[derive(Debug, Serialize, Builder)] +pub struct EmbeddingRequest { + #[builder(setter(into))] + pub model: String, + #[builder(setter(into))] + pub input: Sequence, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(setter(into, strip_option), default)] + pub user: Option, +} + +#[derive(Debug, Deserialize)] +pub struct Embedding { + /* pub object: "embedding", */ + pub embedding: Vec, + pub index: u32, +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingUsage { + pub prompt_tokens: u64, + pub total_tokens: u64, +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponse { + /* pub object: "list", */ + pub data: Vec, + pub model: String, + pub usage: EmbeddingUsage, +} + +impl Context { + pub async fn create_embedding(&self, embedding_request: EmbeddingRequest) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/embeddings")).json(&embedding_request)).send().await?.json::().await?) + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 2d53cb0..1827bb3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub mod edits; pub mod image; pub mod image_edit; pub mod image_variation; +pub mod embedding; #[cfg(test)] mod tests { @@ -18,6 +19,7 @@ mod tests { use crate::edits::EditRequestBuilder; use crate::image_edit::{ImageEditRequestBuilder, ImageFile}; use crate::image_variation::ImageVariationRequestBuilder; + use crate::embedding::EmbeddingRequestBuilder; fn get_api() -> anyhow::Result { Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) @@ -172,4 +174,24 @@ mod tests { } } } + + #[tokio::test] + async fn test_embedding() { + let ctx = get_api(); + assert!(ctx.is_ok(), "Could not load context"); + let ctx = ctx.unwrap(); + + let embeddings = ctx.create_embedding( + EmbeddingRequestBuilder::default() + .model("text-embedding-ada-002") + .input("word sentence paragraph lorem ipsum dolor sit amet") + .build() + .unwrap() + ).await; + + assert!(embeddings.is_ok(), "Could not get embeddings: {}", embeddings.unwrap_err()); + assert!(embeddings.as_ref().unwrap().data.len() == 1, "No embeddings found"); + assert!(embeddings.as_ref().unwrap().data[0].embedding.len() > 0, "No embeddings found"); + println!("Embeddings: {:?}", embeddings.unwrap().data[0].embedding); + } } \ No newline at end of file