Implement embeddings

This commit is contained in:
Gabriel Tofvesson 2023-03-18 03:17:03 +01:00
parent 9d8e858c1e
commit 703dbc3565
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
3 changed files with 71 additions and 0 deletions

View File

@ -36,6 +36,12 @@ impl From<Vec<String>> for Sequence {
}
}
impl From<Vec<&str>> for Sequence {
fn from(v: Vec<&str>) -> Self {
Sequence::List(v.iter().map(|s| s.to_string()).collect())
}
}
impl From<&str> for Sequence {
fn from(s: &str) -> Self {
Sequence::String(s.to_string())

43
src/embedding.rs Normal file
View File

@ -0,0 +1,43 @@
use derive_builder::Builder;
use reqwest::Client;
use serde::{Serialize, Deserialize};
use crate::{completion::Sequence, context::{API_URL, Context}};
#[derive(Debug, Serialize, Builder)]
pub struct EmbeddingRequest {
#[builder(setter(into))]
pub model: String,
#[builder(setter(into))]
pub input: Sequence,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(setter(into, strip_option), default)]
pub user: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct Embedding {
/* pub object: "embedding", */
pub embedding: Vec<f64>,
pub index: u32,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u64,
pub total_tokens: u64,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
/* pub object: "list", */
pub data: Vec<Embedding>,
pub model: String,
pub usage: EmbeddingUsage,
}
impl Context {
pub async fn create_embedding(&self, embedding_request: EmbeddingRequest) -> anyhow::Result<EmbeddingResponse> {
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/embeddings")).json(&embedding_request)).send().await?.json::<EmbeddingResponse>().await?)
}
}

View File

@ -6,6 +6,7 @@ pub mod edits;
pub mod image;
pub mod image_edit;
pub mod image_variation;
pub mod embedding;
#[cfg(test)]
mod tests {
@ -18,6 +19,7 @@ mod tests {
use crate::edits::EditRequestBuilder;
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
use crate::image_variation::ImageVariationRequestBuilder;
use crate::embedding::EmbeddingRequestBuilder;
fn get_api() -> anyhow::Result<Context> {
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
@ -172,4 +174,24 @@ mod tests {
}
}
}
#[tokio::test]
async fn test_embedding() {
let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context");
let ctx = ctx.unwrap();
let embeddings = ctx.create_embedding(
EmbeddingRequestBuilder::default()
.model("text-embedding-ada-002")
.input("word sentence paragraph lorem ipsum dolor sit amet")
.build()
.unwrap()
).await;
assert!(embeddings.is_ok(), "Could not get embeddings: {}", embeddings.unwrap_err());
assert!(embeddings.as_ref().unwrap().data.len() == 1, "No embeddings found");
assert!(embeddings.as_ref().unwrap().data[0].embedding.len() > 0, "No embeddings found");
println!("Embeddings: {:?}", embeddings.unwrap().data[0].embedding);
}
}