use std::collections::HashMap; use derive_builder::Builder; use reqwest::Client; use serde::{Serialize, Deserialize}; use crate::context::{API_URL, Context}; #[derive(Debug, Clone)] pub enum Sequence { String(String), List(Vec), } impl Serialize for Sequence { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { match self { Sequence::String(s) => serializer.serialize_str(s), Sequence::List(l) => serializer.collect_seq(l), } } } impl From for Sequence { fn from(s: String) -> Self { Sequence::String(s) } } impl From> for Sequence { fn from(v: Vec) -> Self { Sequence::List(v) } } impl From<&str> for Sequence { fn from(s: &str) -> Self { Sequence::String(s.to_string()) } } impl From<&[&str]> for Sequence { fn from(v: &[&str]) -> Self { Sequence::List(v.iter().map(|s| s.to_string()).collect()) } } #[derive(Debug, Serialize, Builder)] pub struct CompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub prompt: Option, #[builder(setter(into))] pub model: String, #[builder(setter(into, strip_option), default)] pub max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub suffix: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub n: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub stream: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub logprobs: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub echo: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub stop: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub presence_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub frequency_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub best_of: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub logit_bias: Option>, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub return_prompt: Option, #[serde(skip_serializing_if = "Option::is_none")] #[builder(setter(into, strip_option), default)] pub user: Option, } #[derive(Debug, Deserialize)] pub struct Choice { pub index: u64, pub text: String, pub logprobs: Option>, pub finish_reason: String, } #[derive(Debug, Deserialize)] pub struct Usage { pub prompt_tokens: u64, pub completion_tokens: u64, pub total_tokens: u64, } #[derive(Debug, Deserialize)] pub struct CompletionResponse { pub id: String, pub object: String, pub created: u64, pub model: String, pub choices: Vec, pub usage: Usage, } impl Context { pub async fn create_completion(&self, completion_request: CompletionRequest) -> anyhow::Result { Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/completions")).json(&completion_request)).send().await?.json::().await?) } }