From 9750d5279d0edae3efc61a8da3fd81dd0a371225 Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Sat, 18 Mar 2023 02:56:08 +0100 Subject: [PATCH] Move Context-impls to relevant modules --- src/chat.rs | 9 ++++++++- src/completion.rs | 9 +++++++++ src/context.rs | 28 +--------------------------- src/edits.rs | 9 ++++++++- src/image.rs | 9 +++++++++ src/model.rs | 13 +++++++++++++ 6 files changed, 48 insertions(+), 29 deletions(-) diff --git a/src/chat.rs b/src/chat.rs index baa2498..8650ee5 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; use derive_builder::Builder; +use reqwest::Client; use serde::{Serialize, Deserialize}; -use crate::completion::{Sequence, Usage}; +use crate::{completion::{Sequence, Usage}, context::{API_URL, Context}}; #[derive(Debug, Clone)] pub enum Role { @@ -107,4 +108,10 @@ pub struct ChatCompletionResponse { pub model: String, pub choices: Vec, pub usage: Usage +} + +impl Context { + pub async fn create_chat_completion(&self, chat_completion_request: ChatHistory) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions"))).json(&chat_completion_request).send().await?.json::().await?) + } } \ No newline at end of file diff --git a/src/completion.rs b/src/completion.rs index 805e533..07f9847 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -1,8 +1,11 @@ use std::collections::HashMap; use derive_builder::Builder; +use reqwest::Client; use serde::{Serialize, Deserialize}; +use crate::context::{API_URL, Context}; + #[derive(Debug, Clone)] pub enum Sequence { String(String), @@ -121,4 +124,10 @@ pub struct CompletionResponse { pub model: String, pub choices: Vec, pub usage: Usage, +} + +impl Context { + pub async fn create_completion(&self, completion_request: CompletionRequest) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/completions")).json(&completion_request)).send().await?.json::().await?) + } } \ No newline at end of file diff --git a/src/context.rs b/src/context.rs index 9f4d4fd..b5ec29f 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,6 +1,4 @@ -use reqwest::{Client, RequestBuilder}; - -use crate::{model::{Model, ModelList}, completion::{CompletionRequest, CompletionResponse}, chat::{ChatCompletionResponse, ChatHistory}, edits::{EditRequest, EditResponse}}; +use reqwest::RequestBuilder; pub struct Context { api_key: String, @@ -33,28 +31,4 @@ impl Context { } ).bearer_auth(&self.api_key) } - - pub async fn get_models(&self) -> anyhow::Result> { - Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::().await?.data) - } - - pub async fn get_model(&self, model_id: &str) -> anyhow::Result { - Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models/{model_id}", model_id = model_id))).send().await?.json::().await?) - } - - pub async fn create_completion(&self, completion_request: CompletionRequest) -> anyhow::Result { - Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/completions")).json(&completion_request)).send().await?.json::().await?) - } - - pub async fn create_chat_completion(&self, chat_completion_request: ChatHistory) -> anyhow::Result { - Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/chat/completions")).json(&chat_completion_request)).send().await?.json::().await?) - } - - pub async fn create_edit(&self, edit_request: EditRequest) -> anyhow::Result { - Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/edits")).json(&edit_request)).send().await?.json::().await?) - } - - pub async fn create_image(&self, image_request: crate::image::ImageRequest) -> anyhow::Result { - Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/generations")).json(&image_request)).send().await?.json::().await?) - } } \ No newline at end of file diff --git a/src/edits.rs b/src/edits.rs index a1574d9..74f8d76 100644 --- a/src/edits.rs +++ b/src/edits.rs @@ -1,7 +1,8 @@ use derive_builder::Builder; +use reqwest::Client; use serde::{Serialize, Deserialize}; -use crate::completion::Usage; +use crate::{completion::Usage, context::{API_URL, Context}}; #[derive(Debug, Serialize, Builder)] pub struct EditRequest { @@ -35,4 +36,10 @@ pub struct EditResponse { pub created: u64, pub choices: Vec, pub usage: Usage +} + +impl Context { + pub async fn create_edit(&self, edit_request: EditRequest) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/edits")).json(&edit_request)).send().await?.json::().await?) + } } \ No newline at end of file diff --git a/src/image.rs b/src/image.rs index de1b9de..feaf251 100644 --- a/src/image.rs +++ b/src/image.rs @@ -1,6 +1,9 @@ use derive_builder::Builder; +use reqwest::Client; use serde::{Serialize, Deserialize}; +use crate::context::{API_URL, Context}; + #[derive(Debug, Clone)] pub enum ResponseFormat { URL, @@ -93,4 +96,10 @@ impl<'de> Deserialize<'de> for Image { pub struct ImageResponse { pub created: u64, pub data: Vec, +} + +impl Context { + pub async fn create_image(&self, image_request: crate::image::ImageRequest) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.post(&format!("{API_URL}/v1/images/generations")).json(&image_request)).send().await?.json::().await?) + } } \ No newline at end of file diff --git a/src/model.rs b/src/model.rs index c76e4b0..8ae287b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,5 +1,8 @@ +use reqwest::Client; use serde::Deserialize; +use crate::context::{API_URL, Context}; + #[derive(Debug, Deserialize)] pub struct Permission { pub id: String, @@ -29,4 +32,14 @@ pub struct Model { #[derive(Debug, Deserialize)] pub(crate) struct ModelList { pub data: Vec, +} + +impl Context { + pub async fn get_models(&self) -> anyhow::Result> { + Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models"))).send().await?.json::().await?.data) + } + + pub async fn get_model(&self, model_id: &str) -> anyhow::Result { + Ok(self.with_auth(Client::builder().build()?.get(&format!("{API_URL}/v1/models/{model_id}", model_id = model_id))).send().await?.json::().await?) + } } \ No newline at end of file