Implement embeddings
This commit is contained in:
parent
9d8e858c1e
commit
703dbc3565
@ -36,6 +36,12 @@ impl From<Vec<String>> for Sequence {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<&str>> for Sequence {
|
||||
fn from(v: Vec<&str>) -> Self {
|
||||
Sequence::List(v.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Sequence {
|
||||
fn from(s: &str) -> Self {
|
||||
Sequence::String(s.to_string())
|
||||
|
43
src/embedding.rs
Normal file
43
src/embedding.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use derive_builder::Builder;
|
||||
use reqwest::Client;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::{completion::Sequence, context::{API_URL, Context}};
|
||||
|
||||
#[derive(Debug, Serialize, Builder)]
|
||||
pub struct EmbeddingRequest {
|
||||
#[builder(setter(into))]
|
||||
pub model: String,
|
||||
#[builder(setter(into))]
|
||||
pub input: Sequence,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[builder(setter(into, strip_option), default)]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Embedding {
|
||||
/* pub object: "embedding", */
|
||||
pub embedding: Vec<f64>,
|
||||
pub index: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingUsage {
|
||||
pub prompt_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
/* pub object: "list", */
|
||||
pub data: Vec<Embedding>,
|
||||
pub model: String,
|
||||
pub usage: EmbeddingUsage,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub async fn create_embedding(&self, embedding_request: EmbeddingRequest) -> anyhow::Result<EmbeddingResponse> {
|
||||
Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/embeddings")).json(&embedding_request)).send().await?.json::<EmbeddingResponse>().await?)
|
||||
}
|
||||
}
|
22
src/lib.rs
22
src/lib.rs
@ -6,6 +6,7 @@ pub mod edits;
|
||||
pub mod image;
|
||||
pub mod image_edit;
|
||||
pub mod image_variation;
|
||||
pub mod embedding;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@ -18,6 +19,7 @@ mod tests {
|
||||
use crate::edits::EditRequestBuilder;
|
||||
use crate::image_edit::{ImageEditRequestBuilder, ImageFile};
|
||||
use crate::image_variation::ImageVariationRequestBuilder;
|
||||
use crate::embedding::EmbeddingRequestBuilder;
|
||||
|
||||
fn get_api() -> anyhow::Result<Context> {
|
||||
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
|
||||
@ -172,4 +174,24 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_embedding() {
|
||||
let ctx = get_api();
|
||||
assert!(ctx.is_ok(), "Could not load context");
|
||||
let ctx = ctx.unwrap();
|
||||
|
||||
let embeddings = ctx.create_embedding(
|
||||
EmbeddingRequestBuilder::default()
|
||||
.model("text-embedding-ada-002")
|
||||
.input("word sentence paragraph lorem ipsum dolor sit amet")
|
||||
.build()
|
||||
.unwrap()
|
||||
).await;
|
||||
|
||||
assert!(embeddings.is_ok(), "Could not get embeddings: {}", embeddings.unwrap_err());
|
||||
assert!(embeddings.as_ref().unwrap().data.len() == 1, "No embeddings found");
|
||||
assert!(embeddings.as_ref().unwrap().data[0].embedding.len() > 0, "No embeddings found");
|
||||
println!("Embeddings: {:?}", embeddings.unwrap().data[0].embedding);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user