Skip to main content
Glama
chat.rsโ€ข34.7 kB
use axum::extract::State; use axum::response::sse; use axum::{ extract::Json, http::{HeaderMap, StatusCode}, response::{IntoResponse, Response, Sse}, }; use futures_util::StreamExt as _; use futures_util::stream::Stream; use llm::ToolCall; use serde::{Deserialize, Serialize}; use std::time::{SystemTime, UNIX_EPOCH}; use utoipa::ToSchema; use llm::{ builder::{FunctionBuilder, LLMBackend, LLMBuilder}, chat::ChatMessage, }; use rmcp::{ ServiceExt, model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation}, transport::StreamableHttpClientTransport, }; use crate::AppState; use crate::db::SearchDocument; use crate::error::{AppError, AppResult}; use crate::mcp::McpSearchResult; use crate::utils::get_llm_config; // Map a selection of UniProtKB accession numbers (ACs) to HGNC identifiers and symbols: // openrouter/moonshotai/kimi-k2:free // openrouter/anthropic/claude-sonnet-4 /// Maximum number of iterations for the search workflow const MAX_ITERATIONS: usize = 5; // If the conversation already contains enough information to answer the user's question, do not call any tools. Only call tools if the answer cannot be derived from the provided context. const SYSTEM_PROMPT_TOOLS: &str = r#" You are Expasy, an assistant that helps users to navigate the resources and databases from the Swiss Institute of Bioinformatics. Do not answer general knowledge or personal questions, only answer questions about life science, bioinformatics or the SIB. Call tools to gather context to answer the user's question. If the conversation already contains information to answer the user's question, do not call any tools and answers the question. If there is already some relevant context, do not call the same tool multiple times with almost the same arguments and try to answer the question. Depending on the user request and provided context, you may provide general information about the resources available at the SIB, help the user to formulate a query to run on a SPARQL endpoint, or execute a previously formulated SPARQL query and communicate its results. Always derive your answer from the context provided in the conversation, do not use information that is not in the context. If answering with a query: - Put the SPARQL query inside a markdown codeblock with the `sparql` language tag, and always add the URL of the endpoint on which the query should be executed in a comment at the start of the query inside the codeblock starting with `#+ endpoint: ` (always only 1 endpoint). - Always answer with one query; if the answer lies in different endpoints, provide a federated query. Do not add more codeblocks than necessary. - Use DISTINCT as much as possible, and consider using LIMIT 100 to avoid timeout and oversized responses. - Communicate the query when you communicate the results of the query to the user. - Briefly explain the query. "#; // If a previous search query does not yield relevant results, next time try without providing exact entities ID or endpoint URL. /// API-compatible search hit structure #[derive(Debug, Deserialize, Serialize, ToSchema, Clone)] pub struct SearchHit { pub question: String, pub answer: String, pub endpoint_url: String, pub doc_type: String, #[serde(skip_serializing_if = "Option::is_none")] pub score: Option<f64>, } impl From<SearchDocument> for SearchHit { fn from(record: SearchDocument) -> Self { Self { question: record.question, answer: record.answer, endpoint_url: record.endpoint_url, doc_type: record.doc_type, score: None, // Score will be set later by the LLM } } } #[derive(Debug, Deserialize, Serialize, ToSchema)] pub struct ChatInput { pub messages: Vec<ApiChatMessage>, // #[schema(example = "openai/gpt-4.1-nano")] // openrouter/moonshotai/kimi-k2:free // openrouter/anthropic/claude-sonnet-4 #[schema(example = "mistralai/mistral-small-latest")] pub model: String, #[serde(default)] pub stream: bool, } // /// Token usage and cost information for LLM calls // #[derive(Debug, Deserialize, Serialize, ToSchema, Clone)] // pub struct UsageInfo { // pub prompt_tokens: u32, // pub completion_tokens: u32, // pub total_tokens: u32, // pub total_cost_usd: f64, // } /// Wrapper around `llm::ChatMessage` that implements `ToSchema` for API documentation #[derive(Debug, Deserialize, Serialize, ToSchema, Clone)] pub struct ApiChatMessage { #[schema(example = "user")] pub role: String, #[schema(example = "What is the HGNC symbol for the P68871 protein?")] pub content: String, } impl ApiChatMessage { /// Convert to `llm::ChatMessage` for use with the LLM client pub fn to_chat_message(&self) -> ChatMessage { match self.role.as_str() { "user" => ChatMessage::user().content(&self.content).build(), "assistant" => ChatMessage::assistant().content(&self.content).build(), _ => ChatMessage::assistant().content(&self.content).build(), // Default to assistant } } // /// Create from role and content // pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self { // Self { // role: role.into(), // content: content.into(), // } // } // /// Create a user message // pub fn user(content: impl Into<String>) -> Self { // Self::new("user", content) // } } /// OpenAI-compatible streaming response chunk #[derive(Debug, Serialize, ToSchema)] struct StreamChunk { id: String, object: String, created: u64, model: String, choices: Vec<StreamChoice>, } #[derive(Debug, Serialize, ToSchema)] struct StreamChoice { index: u32, delta: StreamDelta, finish_reason: Option<String>, } #[derive(Debug, Serialize, ToSchema)] struct StreamDelta { role: Option<String>, content: Option<String>, function_call: Option<serde_json::Value>, } /// LLM response format #[derive(Debug, Deserialize, Serialize, ToSchema)] struct LLMStructuredOutput { summary: String, queries: Vec<LLMQuery>, } #[derive(Debug, Deserialize, Serialize, ToSchema)] struct LLMQuery { question: String, score: f64, } /// Workflow manager for handling search operations with fragmented steps pub struct SearchWorkflow { pub mcp_client: rmcp::service::RunningService<rmcp::RoleClient, rmcp::model::InitializeRequestParam>, pub llm_backend: LLMBackend, pub llm_api_key: String, pub llm_model: String, pub msg_id: String, pub created: u64, } impl SearchWorkflow { /// Initialize a new search workflow pub async fn new(model: String, bind_address: String) -> AppResult<Self> { let created = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); let msg_id = format!("chatcmpl-{}", uuid::Uuid::new_v4()); // Connect to MCP server let transport = StreamableHttpClientTransport::from_uri(format!("http://{bind_address}/mcp")); let client_info = ClientInfo { protocol_version: Default::default(), capabilities: ClientCapabilities::default(), client_info: Implementation { name: "MCP streamable HTTP client".to_string(), version: "0.0.1".to_string(), }, }; let client = match client_info.serve(transport).await { Ok(client) => client, Err(e) => { tracing::error!("client error: {:?}", e); return Err(AppError::Llm(format!( "MCP client initialization failed: {e}" ))); } }; // Get LLM configuration let (llm_backend, llm_api_key, llm_model) = get_llm_config(&model).map_err(AppError::Llm)?; Ok(Self { mcp_client: client, llm_backend, llm_api_key, llm_model, msg_id, created, }) } /// Streaming version: yields SSE events for each step (tool call args, results, or message tokens) pub async fn execute_tool_calls( &self, messages: &[ApiChatMessage], ) -> impl Stream<Item = AppResult<sse::Event>> { use async_stream::stream; stream! { // Convert messages to LLM `ChatMessage` format let mut chat_messages: Vec<ChatMessage> = messages.iter().map(|msg| msg.to_chat_message()).collect(); let question = messages .last() .map(|msg| msg.content.as_str()) .unwrap_or(""); // Configure LLM client with dynamic tools from MCP let mut llm_builder = LLMBuilder::new() .backend(self.llm_backend.clone()) .api_key(&self.llm_api_key) .model(&self.llm_model) .max_tokens(1024) .normalize_response(true) .function(FunctionBuilder::new("stop".to_string()) .description(format!("If a satisfactory response has been given to the question \"{question}\", stop the tool calls")) // .json_schema(serde_json::json!({ // "type": "object", // "properties": {}, // "additionalProperties": false, // })) ) .system(SYSTEM_PROMPT_TOOLS); // Convert MCP tools to LLM functions and add them to the llm builder let tools_list = match self.mcp_client.list_tools(Default::default()).await { Ok(t) => t, Err(e) => { yield Err(AppError::Llm(format!("MCP list_tools failed: {e}"))); return; } }; for tool in &tools_list.tools { let schema_value = serde_json::Value::Object(tool.input_schema.as_ref().clone()); // tracing::debug!("Adding tool as function: {:?}", schema_value); let function = FunctionBuilder::new(tool.name.to_string()) .description(tool.description.as_deref().unwrap_or("")) .json_schema(schema_value); llm_builder = llm_builder.function(function); } let llm = match llm_builder.build() { Ok(l) => l, Err(e) => { yield Err(AppError::Llm(format!("Failed to build LLM client: {e}"))); return; } }; let mut iteration_count = 0; let mut resp_msg = String::new(); let mut stop_requested = false; // Iterate over tool calls, until no more tool call (getting a regular LLM message) loop { iteration_count += 1; if iteration_count > MAX_ITERATIONS { tracing::warn!("Maximum tool call iterations ({}) reached, breaking loop", MAX_ITERATIONS); break; } // tracing::debug!("Iteration {iteration_count}: {chat_messages:?}"); tracing::debug!("Calling LLM for iteration {iteration_count}, {} messages so far {chat_messages:?}", chat_messages.len()); match llm.chat_stream_struct(&chat_messages).await { Ok(mut stream) => { let mut tool_msg = String::new(); // NOTE: groq and cohere do not return usage in stream responses // let mut usage_data = None; while let Some(chunk_result) = stream.next().await { match chunk_result { Ok(stream_response) => { // tracing::debug!("Stream response: {:?}", stream_response); if let Some(choice) = stream_response.choices.first() { // If tool calls are present, handle them if let Some(tc) = &choice.delta.tool_calls { yield self.create_sse_event("tool_call_requested", tc); for call in tc { if call.function.name == "stop" { tracing::debug!("Stop tool calls requested by LLM"); stop_requested = true; break; } tracing::debug!("Calling tool {:?}", call.function.name); // let arguments = match serde_json::from_str::<serde_json::Value>(&call.function.arguments) { // Ok(value) => value.as_object().cloned(), // Err(_) => None, // }; // Call MCP tools let tool_results = match self.mcp_client.call_tool(CallToolRequestParam { name: call.function.name.clone().into(), arguments: match serde_json::from_str::<serde_json::Value>(&call.function.arguments) { Ok(value) => value.as_object().cloned(), Err(_) => None, }, }).await { Ok(res) => res, Err(e) => { // Send error to LLM so it can handle the failure and continue yield self.create_sse_event("tool_call_results", e.to_string()); tool_msg = format!( "Tool `{}` failed with error: {e}", call.function.name ); continue; } }; if let Some(structured) = &tool_results.structured_content { // Handle structured content if present yield self.create_sse_event("tool_call_results", structured); if let Ok(search_results) = serde_json::from_value::<McpSearchResult>(structured.clone()) { tool_msg = format_context(search_results, question, call); }; tracing::debug!("tool call results {:?}", structured); if tool_msg.is_empty() { tool_msg = format!( "Tool `{}` output to answer question '{question}':\n\n{}", call.function.name, serde_json::to_string_pretty(structured).unwrap_or_else(|_| structured.to_string()) ); } } else { // Fallback: plain content let plain_content = tool_results .content .iter() .filter_map(|annotated| match &annotated.raw { rmcp::model::RawContent::Text(text_content) => { Some(text_content.text.as_str()) } _ => None, }) .collect::<Vec<_>>() .join(" "); yield self.create_sse_event("tool_call_results", &plain_content); tool_msg = format!( "Tool `{}` output to answer question '{question}':\n\n{}", call.function.name, plain_content ); } } } } // tracing::debug!("Stream response chunk: {:?}", stream_response); if let Some(choice) = stream_response.choices.first() { // If direct message content, stream it token by token if let Some(content) = &choice.delta.content { // if content.trim().is_empty() { // continue; // } // tracing::debug!("Streaming content chunk: {}", content); resp_msg.push_str(content); yield self.create_stream_chunk(Some(content.to_string()), None); } } // if let Some(usage) = stream_response.usage { // usage_data = Some(usage); // } } Err(e) => { tracing::error!("LLM chat_stream_struct chunk error: {e}"); yield Err(AppError::Llm(format!("Stream error: {e}"))); return; } } } tracing::debug!("Streaming tool call done, {} message so far", chat_messages.len()); if stop_requested { tracing::debug!("Question solved, breaking loop before validation"); break; } if !resp_msg.is_empty() { chat_messages.push(ChatMessage::assistant().content(&resp_msg).build()); resp_msg.clear(); } if !tool_msg.is_empty() { chat_messages.push(ChatMessage::user().content(&tool_msg).build()); // chat_messages.push(ChatMessage::tool().content(&tool_msg).build()); tool_msg.clear(); } } Err(e) => { tracing::error!("LLM chat_stream_struct error: {e}"); yield Err(AppError::Llm(format!("Stream error: {e}"))); return; } } } } } /// Create SSE event for streaming responses pub fn create_sse_event( &self, event_type: &str, data: impl Serialize, ) -> AppResult<sse::Event> { Ok(sse::Event::default() .event(event_type) .data(serde_json::to_string(&data)?)) } /// Create streaming chunk relatively OpenAI-compatible pub fn create_stream_chunk( &self, content: Option<String>, finish_reason: Option<String>, ) -> AppResult<sse::Event> { let chunk = StreamChunk { id: self.msg_id.clone(), object: "chat.completion.chunk".to_string(), created: self.created, model: self.llm_model.clone(), choices: vec![StreamChoice { index: 0, delta: StreamDelta { role: if content.is_some() { Some("assistant".to_string()) } else { None }, content, function_call: None, }, finish_reason, }], }; Ok(sse::Event::default() .event("message") .data(serde_json::to_string(&chunk)?)) } } /// Chat with the MCP server tools to help build and execute SPARQL queries #[utoipa::path( post, path = "/chat", request_body(content = ChatInput, description = "List of messages in the conversation"), responses( (status = 200, description = "Search results (SSE stream)", content_type = "text/event-stream", body = StreamChunk), (status = 401, description = "Unauthorized"), (status = 500, description = "Internal server error") ), )] pub async fn chat_handler( State(state): State<AppState>, headers: HeaderMap, Json(resp): Json<ChatInput>, ) -> Response { if resp.stream { streaming_chat_handler(headers, resp, state) .await .into_response() } else { // regular_search_handler(headers, resp, state) // .await // .into_response() println!("Regular search handler not implemented yet"); ( StatusCode::NOT_IMPLEMENTED, Json(serde_json::json!({"error": "Regular search handler not implemented"})), ) .into_response() } } /// Format context as string for the search results pub fn format_context(results: McpSearchResult, question: &str, call: &ToolCall) -> String { let mut context = format!( "\n\nTool `{}` called and found {} items relevant to the user question '{question}':\n\n", call.function.name, results.total_found, ); for item in results.results.iter() { if item.doc_type.to_lowercase() == "sparql endpoints query examples" { context.push_str(&format!( "{}\n\n```sparql\n#+ endpoint: {}\n{}\n```\n\n", item.question, item.endpoint_url, item.answer )); } else if item.doc_type.to_lowercase() == "sparql endpoints classes schema" { context.push_str(&format!( "{} ({})\n\n```shex\n{}\n```\n\n", item.question, item.endpoint_url, item.answer )); } else { context.push_str(&format!( "{} ({})\n\n```{}\n{}\n```\n\n", item.question, item.endpoint_url, item.doc_type, item.answer )); } } context } /// Streaming search handler for SSE responses async fn streaming_chat_handler( headers: HeaderMap, resp: ChatInput, state: AppState, ) -> impl IntoResponse { // Validate API key only if SEARCH_API_KEY is set if let Ok(secret_key) = std::env::var("SEARCH_API_KEY") { let auth_header = headers.get("authorization"); let auth_value = auth_header.and_then(|h| h.to_str().ok()).unwrap_or(""); if auth_value != secret_key { return ( StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "unauthorized"})), ) .into_response(); } } let stream = create_stream(resp, state); Sse::new(stream) // .keep_alive( // sse::KeepAlive::new() // .interval(std::time::Duration::from_secs(10)), // // .text("keep-alive-text"), // ) .into_response() } /// Helper function to create error SSE events fn send_error_event(error_message: &str) -> AppResult<sse::Event> { Ok(sse::Event::default() .event("error") .data(serde_json::to_string(&serde_json::json!({ "error": error_message.to_string(), }))?)) // sse::Event::default().data( // serde_json::to_string(&StreamChunk { // id: msg_id.to_string(), // object: "chat.completion.chunk".to_string(), // created, // model: model.to_string(), // choices: vec![StreamChoice { // index: 0, // delta: StreamDelta { // role: Some("assistant".to_string()), // content: Some(error_message.to_string()), // function_call: None, // }, // finish_reason: Some(finish_reason.to_string()), // }], // })?, // ) } /// Create a streaming search response fn create_stream(resp: ChatInput, state: AppState) -> impl Stream<Item = AppResult<sse::Event>> { async_stream::stream! { // Initialize the workflow let workflow = match SearchWorkflow::new(resp.model, state.bind_address).await { Ok(workflow) => workflow, Err(e) => { tracing::error!("Failed to initialize workflow: {:?}", e); yield Ok(send_error_event("Error connecting to search service")?); return; } }; // Stream all events from execute_tool_calls_streaming let conversation_messages = resp.messages; let mut event_stream = Box::pin(workflow.execute_tool_calls(&conversation_messages).await); while let Some(event_result) = event_stream.next().await { match event_result { Ok(event) => { // TODO: check if regular reponse or tool call yield Ok(event) }, Err(e) => { // tracing::error!("Tool call failed: {:?}", e); yield Ok(send_error_event(&e.to_string())?); return; } } } } } // /// List endpoints supported by the server and update the schemas_map state // #[utoipa::path( // post, // path = "/endpoints", // // request_body(content = ChatInput, description = "List of messages in the conversation"), // responses( // (status = 200, description = "List of endpoints supported", content_type = "application/json"), // // , body = SchemasMap // (status = 401, description = "Unauthorized"), // (status = 500, description = "Internal server error") // ), // )] // pub async fn list_endpoints( // State(mut state): State<AppState>, // headers: HeaderMap, // ) -> SchemasMap { // let schemas_map: SchemasMap = state.db.clone().get_schemas().await; // state.schemas_map = schemas_map.clone(); // schemas_map // } // /// Search handler for non-streaming responses // async fn regular_search_handler( // headers: HeaderMap, // resp: SearchInput, // state: AppState, // ) -> impl IntoResponse { // // Validate API key only if SEARCH_API_KEY is set // if let Ok(secret_key) = std::env::var("SEARCH_API_KEY") { // let auth_header = headers.get("authorization"); // let auth_value = auth_header.and_then(|h| h.to_str().ok()).unwrap_or(""); // if auth_value != secret_key { // return ( // StatusCode::UNAUTHORIZED, // Json(serde_json::json!({"error": "unauthorized"})), // ); // } // } // // Initialize the workflow // // Get MCP endpoint from env or fallback // let workflow = match SearchWorkflow::new(resp.model, state.bind_address).await { // Ok(workflow) => workflow, // Err(e) => { // tracing::error!("Failed to initialize workflow: {:?}", e); // return ( // StatusCode::INTERNAL_SERVER_ERROR, // Json(serde_json::json!({"error": "Error connecting to search service"})), // ); // } // }; // // Tool calling loop: continue until no more tools are needed or we get a regular response // let mut conversation_messages = resp.messages.clone(); // let mut accumulated_search_results = SearchResult { // total_found: 0, // hits: vec![], // }; // let mut iteration_count = 0; // const MAX_ITERATIONS: usize = 5; // loop { // iteration_count += 1; // if iteration_count > MAX_ITERATIONS { // tracing::warn!( // "Maximum tool call iterations ({}) reached, breaking loop", // MAX_ITERATIONS // ); // break; // } // // Step 1: Execute tool calls if needed // let (resp_txt, tool_calls, search_results) = // match workflow.execute_tool_calls(&conversation_messages).await { // Ok(result) => result, // Err(e) => { // tracing::error!("Tool call execution failed: {:?}", e); // return ( // StatusCode::INTERNAL_SERVER_ERROR, // Json(serde_json::json!({"error": "Error searching for datasets"})), // ); // } // }; // // If no tool calls were made, return regular response // if tool_calls.is_none() { // return ( // StatusCode::OK, // Json(serde_json::json!({ // "hits": [], // "summary": resp_txt, // })), // ); // } // // If no results found from tools, return error // if search_results.total_found == 0 { // return ( // StatusCode::OK, // Json(serde_json::json!({ // "hits": [], // "summary": "No datasets found for your query." // })), // ); // } // // Accumulate search results from this iteration // accumulated_search_results // .hits // .extend(search_results.hits.clone()); // accumulated_search_results.total_found += search_results.total_found; // // Add the tool results to conversation for potential follow-up tool calls // if let Some(tc) = &tool_calls { // for call in tc { // // Format the context for the LLM based on tool call results // let last_message_content = conversation_messages // .last() // .map(|msg| msg.content.as_str()) // .unwrap_or(""); // let mut formatted_context = format!( // "\n\nTool {} called and found {} SPARQL query examples relevant to the query '{}':\n\n", // call.function.name, search_results.total_found, last_message_content // ); // for (i, dataset) in search_results.hits.iter().enumerate() { // formatted_context.push_str(&format!( // "{}. **Question**: {}\n", // i + 1, // dataset.question // )); // formatted_context.push_str(&format!(" **Answer**: {}\n", dataset.answer)); // formatted_context // .push_str(&format!(" **Endpoint**: {}\n", dataset.endpoint_url)); // formatted_context.push_str(&format!(" **Type**: {}\n\n", dataset.doc_type)); // } // // Modify the last message (user message) to include the context // if let Some(last_message) = conversation_messages.last_mut() { // last_message.content.push_str(&formatted_context); // } // } // } // // // Check if we should continue with more tool calls by making another LLM request // // let continue_check = match workflow.execute_tool_calls(&conversation_messages).await { // // Ok((next_tool_calls, _)) => next_tool_calls.is_some(), // // Err(_) => false, // If error, assume no more tools needed // // }; // // if !continue_check { // // break; // // } // } // // If no datasets were found after all tool calls, return early // if accumulated_search_results.total_found == 0 || accumulated_search_results.hits.is_empty() { // return ( // StatusCode::OK, // Json(serde_json::json!({ // "hits": [], // "summary": "No datasets found for your query." // })), // ); // } // // Step 2: Generate summary and scores using LLM with accumulated results // let final_response = match workflow // .generate_summary_and_scores(&resp.messages, accumulated_search_results) // .await // { // Ok(response) => response, // Err(e) => { // tracing::error!("LLM processing failed: {:?}", e); // // Fallback response without scoring // return ( // StatusCode::OK, // Json(serde_json::json!({ // "hits": [], // "summary": "Found datasets for your query, but could not process relevance scores." // })), // ); // } // }; // (StatusCode::OK, Json(serde_json::json!(final_response))) // }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/sib-swiss/sparql-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server