//! Utilities, such as executing a SPARQL query on a remote endpoint
use llm::builder::LLMBackend;
/// Determines the LLM backend and API key based on available environment variables
/// OpenAI takes priority if both are available
pub fn get_llm_config(model: &str) -> Result<(LLMBackend, String, String), String> {
// Parse provider/model_name from input
let parts: Vec<&str> = model.splitn(2, '/').collect();
let (provider, model_name) = match parts.as_slice() {
[provider, model_name] => (provider.to_string(), model_name.to_string()),
[single] => (single.to_string(), single.to_string()),
_ => ("openai".to_string(), model.to_string()), // fallback
};
let backend_result = match provider.as_str() {
"openai" => {
if std::env::var("OPENAI_API_KEY").is_ok() {
match std::env::var("OPENAI_API_KEY") {
Ok(key) => Ok((LLMBackend::OpenAI, key)),
Err(_) => Err("OPENAI_API_KEY environment variable not set".to_string()),
}
} else {
Err("OPENAI_API_KEY environment variable not set".to_string())
}
}
"mistralai" => {
if std::env::var("MISTRAL_API_KEY").is_ok() {
match std::env::var("MISTRAL_API_KEY") {
Ok(key) => Ok((LLMBackend::Mistral, key)),
Err(_) => Err("MISTRAL_API_KEY environment variable not set".to_string()),
}
} else {
Err("MISTRAL_API_KEY environment variable not set".to_string())
}
}
"openrouter" => {
if std::env::var("OPENROUTER_API_KEY").is_ok() {
match std::env::var("OPENROUTER_API_KEY") {
Ok(key) => Ok((LLMBackend::OpenRouter, key)),
Err(_) => Err("OPENROUTER_API_KEY environment variable not set".to_string()),
}
} else {
Err("OPENROUTER_API_KEY environment variable not set".to_string())
}
}
"groq" => {
if std::env::var("GROQ_API_KEY").is_ok() {
match std::env::var("GROQ_API_KEY") {
Ok(key) => Ok((LLMBackend::Groq, key)),
Err(_) => Err("GROQ_API_KEY environment variable not set".to_string()),
}
} else {
Err("GROQ_API_KEY environment variable not set".to_string())
}
}
_ => Err(format!("Unknown provider: {provider}")),
};
match backend_result {
Ok((backend, api_key)) => Ok((backend, api_key, model_name)),
Err(e) => Err(e),
}
}