"""Database tools for MCP server"""
import json
import logging
from typing import Any, Dict, List, Optional, Sequence
from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource
from .database import PostgreSQLManager, DatabaseError
from .config import ServerConfig
from .query_optimizer import QueryValidator
logger = logging.getLogger(__name__)
class DatabaseTools:
"""Database tools for MCP server"""
def __init__(self, db_manager: PostgreSQLManager, config: ServerConfig):
self.db_manager = db_manager
self.config = config
self.query_validator = QueryValidator(db_manager)
def get_tools(self) -> List[Tool]:
"""Get list of available tools"""
return [
Tool(
name="query",
description="Execute a SQL query on the PostgreSQL database",
inputSchema={
"type": "object",
"properties": {
"sql": {
"type": "string",
"description": "SQL query to execute"
},
"params": {
"type": "array",
"description": "Parameters for the SQL query",
"items": {"type": "string"},
"default": []
}
},
"required": ["sql"]
}
),
Tool(
name="list_tables",
description="List all tables in a database schema",
inputSchema={
"type": "object",
"properties": {
"schema": {
"type": "string",
"description": "Database schema name",
"default": "public"
}
}
}
),
Tool(
name="describe_table",
description="Get detailed information about a table's columns and structure",
inputSchema={
"type": "object",
"properties": {
"table_name": {
"type": "string",
"description": "Name of the table to describe"
},
"schema": {
"type": "string",
"description": "Database schema name",
"default": "public"
}
},
"required": ["table_name"]
}
),
Tool(
name="list_schemas",
description="List all available database schemas",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="test_connection",
description="Test database connection and get server information",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="validate_query",
description="Validate and analyze a SQL query for security, performance, and optimization opportunities",
inputSchema={
"type": "object",
"properties": {
"sql": {
"type": "string",
"description": "SQL query to validate and analyze"
},
"schema": {
"type": "string",
"description": "Database schema name for table/column validation",
"default": "public"
}
},
"required": ["sql"]
}
)
]
async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle a tool call and return results"""
try:
if name == "query":
return await self._handle_query(arguments)
elif name == "list_tables":
return await self._handle_list_tables(arguments)
elif name == "describe_table":
return await self._handle_describe_table(arguments)
elif name == "list_schemas":
return await self._handle_list_schemas(arguments)
elif name == "test_connection":
return await self._handle_test_connection(arguments)
elif name == "validate_query":
return await self._handle_validate_query(arguments)
else:
return [TextContent(type="text", text=f"Unknown tool: {name}")]
except DatabaseError as e:
logger.error(f"Database error in tool {name}: {e}")
return [TextContent(type="text", text=f"Database error: {e}")]
except Exception as e:
logger.error(f"Unexpected error in tool {name}: {e}")
return [TextContent(type="text", text=f"Error: {e}")]
async def _handle_query(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle SQL query execution"""
sql = arguments.get("sql", "").strip()
params = arguments.get("params", [])
if not sql:
return [TextContent(type="text", text="Error: SQL query is required")]
# Basic validation - prevent obviously dangerous operations
sql_upper = sql.upper().strip()
dangerous_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"]
# For safety, only allow SELECT statements by default
if not sql_upper.startswith("SELECT"):
return [TextContent(
type="text",
text="Error: Only SELECT queries are allowed for safety. Use a dedicated database admin tool for data modifications."
)]
try:
# Validate query first (optional analysis)
schema = arguments.get("schema", "public")
analysis = await self.query_validator.validate_query(sql, schema)
validation_feedback = []
# Add security warnings if any
if analysis.security_issues:
critical_issues = [issue for issue in analysis.security_issues if issue.level.value == "critical"]
if critical_issues:
return [TextContent(
type="text",
text=f"❌ Query blocked due to security issues:\n\n" +
"\n".join([f"• {issue.message}" for issue in critical_issues])
)]
# Add non-critical security warnings
validation_feedback.append("🔒 Security warnings detected - review query carefully")
# Add performance warnings if complexity is high
if analysis.estimated_complexity > 7:
validation_feedback.append(f"⚡ High complexity query (complexity: {analysis.estimated_complexity}/10)")
# Apply row limit
if "LIMIT" not in sql_upper:
sql += f" LIMIT {self.config.max_rows}"
if self.config.log_queries:
logger.info(f"Executing query: {sql[:200]}...")
results = await self.db_manager.execute_query(sql, params)
if not results:
response_text = "Query executed successfully. No results returned."
if validation_feedback:
response_text += "\n\n" + "\n".join(validation_feedback)
return [TextContent(type="text", text=response_text)]
# Format results as JSON for better readability
formatted_results = json.dumps(results, indent=2, default=str)
response_text = f"Query Results ({len(results)} rows):\n\n```json\n{formatted_results}\n```"
# Add validation feedback if any
if validation_feedback:
response_text += "\n\n📋 Query Analysis:\n" + "\n".join(validation_feedback)
if analysis.optimization_suggestions:
response_text += "\n\n💡 Quick optimization tips:\n• " + "\n• ".join(analysis.optimization_suggestions[:3])
return [TextContent(type="text", text=response_text)]
except Exception as e:
return [TextContent(type="text", text=f"Query error: {e}")]
async def _handle_list_tables(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle listing tables"""
schema = arguments.get("schema", "public")
# Check if schema is allowed
if self.config.allowed_schemas and schema not in self.config.allowed_schemas:
return [TextContent(
type="text",
text=f"Error: Access to schema '{schema}' is not allowed"
)]
try:
tables = await self.db_manager.get_table_info(schema)
if not tables:
return [TextContent(type="text", text=f"No tables found in schema '{schema}'")]
# Format table information
table_info = []
for table in tables:
table_info.append(f"- {table['table_name']} ({table['table_type']})")
result = f"Tables in schema '{schema}':\n\n" + "\n".join(table_info)
return [TextContent(type="text", text=result)]
except Exception as e:
return [TextContent(type="text", text=f"Error listing tables: {e}")]
async def _handle_describe_table(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle table description"""
table_name = arguments.get("table_name")
schema = arguments.get("schema", "public")
if not table_name:
return [TextContent(type="text", text="Error: table_name is required")]
# Check if schema is allowed
if self.config.allowed_schemas and schema not in self.config.allowed_schemas:
return [TextContent(
type="text",
text=f"Error: Access to schema '{schema}' is not allowed"
)]
try:
columns = await self.db_manager.get_column_info(table_name, schema)
if not columns:
return [TextContent(
type="text",
text=f"Table '{schema}.{table_name}' not found or has no columns"
)]
# Format column information
column_info = []
for col in columns:
nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL"
default = f" DEFAULT {col['column_default']}" if col['column_default'] else ""
column_line = f"- {col['column_name']}: {col['data_type']} {nullable}{default}"
column_info.append(column_line)
result = f"Table '{schema}.{table_name}' structure:\n\n" + "\n".join(column_info)
return [TextContent(type="text", text=result)]
except Exception as e:
return [TextContent(type="text", text=f"Error describing table: {e}")]
async def _handle_list_schemas(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle listing schemas"""
try:
schemas = await self.db_manager.get_schemas()
if not schemas:
return [TextContent(type="text", text="No schemas found")]
# Filter allowed schemas if configured
if self.config.allowed_schemas:
schemas = [s for s in schemas if s in self.config.allowed_schemas]
result = "Available schemas:\n\n" + "\n".join(f"- {schema}" for schema in schemas)
return [TextContent(type="text", text=result)]
except Exception as e:
return [TextContent(type="text", text=f"Error listing schemas: {e}")]
async def _handle_test_connection(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle connection test"""
try:
info = await self.db_manager.test_connection()
result = f"""Database Connection Test:
✅ Connection successful
Server Information:
- Database: {info['database']}
- User: {info['user']}
- PostgreSQL Version: {info['version']}
"""
return [TextContent(type="text", text=result)]
except Exception as e:
return [TextContent(type="text", text=f"❌ Connection test failed: {e}")]
async def _handle_validate_query(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle query validation and analysis"""
sql = arguments.get("sql", "").strip()
schema = arguments.get("schema", "public")
if not sql:
return [TextContent(type="text", text="Error: SQL query is required")]
try:
# Validate and analyze the query
analysis = await self.query_validator.validate_query(sql, schema)
# Generate comprehensive report
report = self.query_validator.format_analysis_report(analysis)
return [TextContent(type="text", text=report)]
except Exception as e:
return [TextContent(type="text", text=f"Validation error: {e}")]