Remove unnecessary type check functions

This commit is contained in:
Gabriel Tofvesson 2023-03-18 01:01:20 +01:00
parent 5393ecbff5
commit 412a79ac56
No known key found for this signature in database
GPG Key ID: 6F1345DF28EDA13E
2 changed files with 13 additions and 21 deletions

View File

@ -67,22 +67,6 @@ pub enum Image {
Base64(String), Base64(String),
} }
impl Image {
pub fn isURL(&self) -> bool {
return match self {
Self::URL(_) => true,
_ => false,
}
}
pub fn isBase64(&self) -> bool {
return match self {
Self::Base64(_) => true,
_ => false
}
}
}
impl<'de> Deserialize<'de> for Image { impl<'de> Deserialize<'de> for Image {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where

View File

@ -10,7 +10,7 @@ mod tests {
use crate::chat::ChatMessage; use crate::chat::ChatMessage;
use crate::context::Context; use crate::context::Context;
use crate::completion::CompletionRequestBuilder; use crate::completion::CompletionRequestBuilder;
use crate::image::Image; use crate::image::{Image, ResponseFormat};
fn get_api() -> anyhow::Result<Context> { fn get_api() -> anyhow::Result<Context> {
Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string())) Ok(Context::new(std::fs::read_to_string(std::path::Path::new("apikey.txt"))?.trim().to_string()))
@ -84,22 +84,30 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_image() { async fn test_image() {
const IMAGE_PROMPT: &str = "In a realistic style, a ginger cat gracefully walking along a thin brick wall";
let ctx = get_api(); let ctx = get_api();
assert!(ctx.is_ok(), "Could not load context"); assert!(ctx.is_ok(), "Could not load context");
let ctx = ctx.unwrap(); let ctx = ctx.unwrap();
let image = ctx.create_image( let image = ctx.create_image(
crate::image::ImageRequestBuilder::default() crate::image::ImageRequestBuilder::default()
.prompt("In a realistic style, a ginger cat gracefully walking along a thin brick wall") .prompt(IMAGE_PROMPT)
.response_format(ResponseFormat::URL)
.build() .build()
.unwrap() .unwrap()
).await; ).await;
assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err()); assert!(image.is_ok(), "Could not get image: {}", image.unwrap_err());
assert!(image.as_ref().unwrap().data.len() == 1, "No image found"); assert!(image.as_ref().unwrap().data.len() == 1, "No image found");
assert!(image.as_ref().unwrap().data[0].isURL(), "No image found"); assert!(matches!(image.as_ref().unwrap().data[0], Image::URL(_)), "No image found");
if let Image::URL(url) = &image.as_ref().unwrap().data[0] { println!("Image prompt: {IMAGE_PROMPT}");
println!("{}", url); match image.unwrap().data[0] {
Image::URL(ref url) => {
println!("Generated test image URL: {url}");
}
Image::Base64(ref b64) => {
println!("Generated test image Base64: {b64}");
}
} }
} }
} }