mcp.rsโข16.7 kB
use std::{collections::HashMap, sync::Arc};
use crate::{
config::SparqlEndpointConfig,
db::SearchDB,
index::init_endpoint,
sparql_client::{SparqlClient, SparqlResults},
validate::validate_sparql,
void_schema::SchemasMap,
};
use rmcp::{
ErrorData as McpError, RoleServer, ServerHandler,
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::*,
schemars,
service::RequestContext,
tool, tool_handler, tool_router,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
const SEARCH_EXAMPLES_COUNT: usize = 5;
const SEARCH_CLASSES_COUNT: usize = 5;
/// User question to search SPARQL resources
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct InputQuestion {
/// The the search query to do compiled from user's question or query
pub search_query: String,
/// Optional SPARQL endpoint URL to focus the search on, if specified by the user
pub endpoint_url: Option<String>,
// /// Potential topics and classes relevant to the query
// pub topics: vec<String>,
// /// Potential entities relevant to the query
// pub entities: vec<String>,
// /// Potential steps to answer the query (leave empty if 1 is enough)
// pub steps: vec<String>,
}
/// Structured response for SPARQL query search results
#[derive(Debug, Serialize, Deserialize)]
pub struct McpSearchResult {
pub total_found: usize,
pub query: String,
pub results: Vec<SearchDocument>,
}
/// Input structure for executing a SPARQL query against an endpoint
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct InputSparqlQuery {
/// A valid SPARQL query string
pub sparql_query: String,
/// The SPARQL endpoint URL to execute the query against
pub endpoint_url: String,
}
/// Re-export `SearchDocument` from database module for API responses
use crate::db::SearchDocument;
#[derive(Clone)]
pub struct SparqlTools {
tool_router: ToolRouter<SparqlTools>,
db: Arc<SearchDB>,
schemas_map: SchemasMap,
auto_index: bool,
}
#[tool_router]
impl SparqlTools {
pub fn new(db: Arc<SearchDB>, schemas_map: SchemasMap, auto_index: bool) -> Self {
Self {
tool_router: Self::tool_router(),
db,
schemas_map,
auto_index,
}
}
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
RawResource::new(uri, name.to_string()).no_annotation()
}
async fn search_docs(
&self,
// mut search_results: McpSearchResult,
question: &str,
endpoint_url: &Option<String>,
include_examples: bool,
) -> McpSearchResult {
// self.db.search_similar(query, top_k, filters).await
let mut search_results = McpSearchResult {
total_found: 0,
query: question.to_string(),
results: vec![],
};
let mut search_filters = HashMap::new();
if let Some(endpoint) = endpoint_url {
search_filters.insert("endpoint_url", endpoint.as_str());
}
if include_examples {
// Build filters for examples
search_filters.insert("doc_type", "SPARQL endpoints query examples");
match self
.db
.search_similar(question, SEARCH_EXAMPLES_COUNT, Some(&search_filters))
.await
{
Ok(results) => {
search_results.total_found += results.len();
search_results.results.extend(results);
}
Err(e) => tracing::error!("Database search error: {:?}", e),
}
// Build filters for classes
search_filters.remove("doc_type");
}
// Now search for classes
search_filters.insert("doc_type", "SPARQL endpoints classes schema");
match self
.db
.search_similar(question, SEARCH_CLASSES_COUNT, Some(&search_filters))
.await
{
Ok(results) => {
search_results.total_found += results.len();
search_results.results.extend(results);
}
Err(e) => tracing::error!("Database search error: {:?}", e),
}
search_results
}
#[tool(description = "Search for specific classes and their schema in the SPARQL endpoints.")]
async fn get_class_schema(
&self,
Parameters(InputQuestion {
search_query,
endpoint_url,
}): Parameters<InputQuestion>,
) -> Result<CallToolResult, McpError> {
let search_results = self.search_docs(&search_query, &endpoint_url, false).await;
Ok(CallToolResult::structured(
serde_json::to_value(&search_results).map_err(|e| {
McpError::internal_error(
"Failed to serialize search results",
Some(json!({"error": e.to_string()})),
)
})?,
))
}
#[tool(
description = "Assist users in writing SPARQL queries to access SIB biodata resources by retrieving relevant examples and classes schema."
)]
async fn access_sparql_resources(
&self,
Parameters(InputQuestion {
search_query,
endpoint_url,
}): Parameters<InputQuestion>,
) -> Result<CallToolResult, McpError> {
let mut search_results = self.search_docs(&search_query, &endpoint_url, true).await;
// TODO: If no results found then we try to index this endpoint and add its schema/docs to the DB
// How can I get the species present in the SPARQL endpoint with URL https://www.bgee.org/sparql/ ?
if self.auto_index && search_results.results.is_empty() && endpoint_url.is_some() {
tracing::info!(
"๐ No SPARQL query examples or schema found for endpoint {endpoint_url:?}. Indexing it"
);
let new_endpoint_conf = SparqlEndpointConfig {
label: endpoint_url.clone().unwrap(),
description: "".to_string(),
endpoint_url: endpoint_url.clone().unwrap(),
void_file: None,
examples_file: None,
homepage_url: None,
ontology: None,
};
let _ = init_endpoint(&new_endpoint_conf, &self.db, true).await;
// TODO: update schemas_map?
// Perform a new search after indexing
search_results = self.search_docs(&search_query, &endpoint_url, true).await;
}
Ok(CallToolResult::structured(
serde_json::to_value(&search_results).map_err(|e| {
McpError::internal_error(
"Failed to serialize search results",
Some(json!({"error": e.to_string()})),
)
})?,
))
}
#[tool(
description = "Get information about the service and resources available at the SIB Swiss Institute of Bioinformatics."
)]
async fn get_resources_info(
&self,
Parameters(InputQuestion {
search_query,
endpoint_url,
}): Parameters<InputQuestion>,
) -> Result<CallToolResult, McpError> {
// Search for information about the SIB resources
let mut search_filters = HashMap::new();
search_filters.insert("doc_type", "General information");
if let Some(ref endpoint) = endpoint_url {
search_filters.insert("endpoint_url", endpoint.as_str());
}
match self
.db
.search_similar(&search_query, 5, Some(&search_filters))
.await
{
Ok(results) => {
let search_results = McpSearchResult {
total_found: results.len(),
query: search_query.clone(),
results,
};
// Ok(CallToolResult::success(vec![Content::text(json_content)]))
Ok(CallToolResult::structured(
serde_json::to_value(&search_results).map_err(|e| {
McpError::internal_error(
"Failed to serialize search results",
Some(json!({"error": e.to_string()})),
)
})?,
))
// TODO: return JSON Value through structured_content
}
Err(e) => {
tracing::error!("Database search error: {:?}", e);
Err(McpError::internal_error(
"Failed to search for tools".to_string(),
Some(json!({"error": format!("{:?}", e)})),
))
}
}
}
#[tool(description = "Execute a SPARQL query against a given SPARQL endpoint.")]
async fn execute_sparql_query(
&self,
Parameters(InputSparqlQuery {
sparql_query,
endpoint_url,
}): Parameters<InputSparqlQuery>,
) -> Result<CallToolResult, McpError> {
// tracing::debug!("Validated SPARQL query, errors: {errors:?}");
validate_sparql(&endpoint_url, &sparql_query, &self.schemas_map)
.await
.unwrap();
let sparql_client = SparqlClient::builder().build().unwrap();
match sparql_client
.query_endpoint(&endpoint_url, &sparql_query)
.await
{
Ok(res) => {
// Try to parse as SparqlResults first
match serde_json::from_str::<SparqlResults>(&res) {
Ok(sparql_results) => {
// Successfully parsed as SparqlResults, return structured
if sparql_results.results.bindings.is_empty() {
tracing::warn!("SPARQL query returned empty results: {sparql_query}");
let mut warning_msg =
"The query did not return any results.".to_string();
// Run validation and add schema mismatch information if available
match validate_sparql(&endpoint_url, &sparql_query, &self.schemas_map)
.await
{
Ok(errors) if !errors.is_empty() => {
warning_msg += "\nHere are mismatches with the known schema that might help fix the query:\n";
for error in errors.iter() {
warning_msg.push_str(&format!("- {error}\n"));
}
}
Ok(_) => {
// tracing::info!("No schema validation errors found, query may be syntactically correct but returned no data");
}
Err(e) => {
warning_msg +=
&format!("\n\nError during SPARQL validation: {e}");
}
}
Ok(CallToolResult::success(vec![Content::text(warning_msg)]))
// Ok(CallToolResult::error(vec![Content::text(warning_msg)]))
} else {
Ok(CallToolResult::structured(
serde_json::to_value(&sparql_results).map_err(|e| {
McpError::internal_error(
"Failed to serialize SPARQL query results",
Some(json!({"error": e.to_string()})),
)
})?,
))
}
}
Err(_) => {
// Failed to parse as SparqlResults, return raw text content for turtle
Ok(CallToolResult::success(vec![Content::text(res)]))
}
}
}
// Failed to execute the SPARQL query request
Err(e) => {
tracing::error!("Error executing SPARQL query!!! {:?}", e);
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
}
#[tool_handler]
impl ServerHandler for SparqlTools {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder()
.enable_prompts()
.enable_resources()
.enable_tools()
.build(),
server_info: Implementation::from_build_env(),
instructions: Some("This server provides search tools for SPARQL queries using semantic similarity with FastEmbed and LanceDB.".to_string()),
}
}
async fn list_resources(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListResourcesResult, McpError> {
Ok(ListResourcesResult {
resources: vec![self._create_resource_text("meta://sparql-database", "sparql_docs")],
next_cursor: None,
})
}
async fn read_resource(
&self,
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
_: RequestContext<RoleServer>,
) -> Result<ReadResourceResult, McpError> {
match uri.as_str() {
"meta://sparql-database" => Ok(ReadResourceResult {
contents: vec![ResourceContents::text(
"LanceDB with SPARQL query examples using FastEmbed embeddings",
uri,
)],
}),
_ => Err(McpError::resource_not_found(
"resource_not_found",
Some(json!({
"uri": uri
})),
)),
}
}
async fn list_prompts(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListPromptsResult, McpError> {
Ok(ListPromptsResult {
next_cursor: None,
prompts: vec![Prompt::new(
"example_prompt",
Some("This is an example prompt that takes one required argument, search_query"),
Some(vec![PromptArgument {
name: "search_query".to_string(),
description: Some("Search query to find data".to_string()),
required: Some(true),
}]),
)],
})
}
async fn get_prompt(
&self,
GetPromptRequestParam { name, arguments }: GetPromptRequestParam,
_: RequestContext<RoleServer>,
) -> Result<GetPromptResult, McpError> {
match name.as_str() {
"example_prompt" => {
let search_query = arguments
.and_then(|json| json.get("search_query")?.as_str().map(|s| s.to_string()))
.ok_or_else(|| {
McpError::invalid_params("No message provided to example_prompt", None)
})?;
let prompt = format!("I am looking for data about {search_query}");
Ok(GetPromptResult {
description: None,
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::text(prompt),
}],
})
}
_ => Err(McpError::invalid_params("prompt not found", None)),
}
}
async fn list_resource_templates(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListResourceTemplatesResult, McpError> {
Ok(ListResourceTemplatesResult {
next_cursor: None,
resource_templates: Vec::new(),
})
}
async fn initialize(
&self,
_request: InitializeRequestParam,
context: RequestContext<RoleServer>,
) -> Result<InitializeResult, McpError> {
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
let initialize_headers = &http_request_part.headers;
let initialize_uri = &http_request_part.uri;
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
}
Ok(self.get_info())
}
}