"""Query validation and optimization for PostgreSQL MCP Server"""
import re
import logging
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from enum import Enum
import sqlparse
from sqlparse import sql, tokens as T
from .database import PostgreSQLManager
logger = logging.getLogger(__name__)
class ValidationLevel(Enum):
"""Validation severity levels"""
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"
@dataclass
class ValidationResult:
"""Result of query validation"""
level: ValidationLevel
category: str
message: str
suggestion: Optional[str] = None
line_number: Optional[int] = None
column: Optional[int] = None
@dataclass
class QueryAnalysis:
"""Complete query analysis result"""
query: str
is_valid: bool
estimated_complexity: int # 1-10 scale
validation_results: List[ValidationResult]
optimization_suggestions: List[str]
security_issues: List[ValidationResult]
performance_warnings: List[ValidationResult]
class QueryValidator:
"""SQL query validator and optimizer"""
def __init__(self, db_manager: PostgreSQLManager):
self.db_manager = db_manager
self._table_cache: Dict[str, List[str]] = {}
self._index_cache: Dict[str, List[Dict[str, Any]]] = {}
async def validate_query(self, query: str, schema: str = "public") -> QueryAnalysis:
"""Validate and analyze a SQL query"""
logger.debug(f"Validating query: {query[:100]}...")
# Parse the query
try:
parsed = sqlparse.parse(query)[0]
except Exception as e:
return QueryAnalysis(
query=query,
is_valid=False,
estimated_complexity=10,
validation_results=[
ValidationResult(
level=ValidationLevel.ERROR,
category="syntax",
message=f"SQL parsing failed: {e}"
)
],
optimization_suggestions=[],
security_issues=[],
performance_warnings=[]
)
# Initialize result lists
validation_results = []
optimization_suggestions = []
security_issues = []
performance_warnings = []
# Perform various validations
validation_results.extend(await self._validate_syntax(parsed))
validation_results.extend(await self._validate_security(parsed))
validation_results.extend(await self._validate_performance(parsed, schema))
# Categorize results
for result in validation_results:
if result.category in ["sql_injection", "dangerous_operation"]:
security_issues.append(result)
elif result.category in ["performance", "indexing", "optimization"]:
performance_warnings.append(result)
# Generate optimization suggestions
optimization_suggestions.extend(await self._generate_optimizations(parsed, schema))
# Calculate complexity
complexity = self._calculate_complexity(parsed)
# Determine if query is valid (no critical errors)
is_valid = not any(r.level == ValidationLevel.CRITICAL for r in validation_results)
return QueryAnalysis(
query=query,
is_valid=is_valid,
estimated_complexity=complexity,
validation_results=validation_results,
optimization_suggestions=optimization_suggestions,
security_issues=security_issues,
performance_warnings=performance_warnings
)
async def _validate_syntax(self, parsed: sql.Statement) -> List[ValidationResult]:
"""Validate SQL syntax and structure"""
results = []
# Check for basic SQL structure
if not parsed.tokens:
results.append(ValidationResult(
level=ValidationLevel.ERROR,
category="syntax",
message="Empty or invalid SQL statement"
))
return results
# Check for SELECT statement (security requirement)
first_token = None
for token in parsed.flatten():
if token.ttype is T.Keyword and token.value.upper().strip():
first_token = token.value.upper().strip()
break
if first_token != "SELECT":
results.append(ValidationResult(
level=ValidationLevel.CRITICAL,
category="security",
message="Only SELECT statements are allowed",
suggestion="Use SELECT queries for data retrieval only"
))
# Check for incomplete statements
query_text = str(parsed).strip()
if not query_text.endswith(';') and len(query_text) > 10:
results.append(ValidationResult(
level=ValidationLevel.INFO,
category="syntax",
message="Query should end with semicolon",
suggestion="Add ';' at the end of your query"
))
return results
async def _validate_security(self, parsed: sql.Statement) -> List[ValidationResult]:
"""Validate query for security issues"""
results = []
query_text = str(parsed).upper()
# Check for SQL injection patterns
injection_patterns = [
(r";\s*(DROP|DELETE|UPDATE|INSERT|ALTER|CREATE)", "Potential SQL injection with dangerous commands"),
(r"UNION\s+SELECT", "UNION SELECT statements can be used for SQL injection"),
(r"--", "SQL comments can be used to bypass security"),
(r"/\*.*\*/", "SQL block comments can hide malicious code"),
(r"'\s*OR\s*'", "Potential SQL injection with OR condition"),
(r"'\s*=\s*'", "Potential tautology-based SQL injection"),
]
for pattern, message in injection_patterns:
if re.search(pattern, query_text, re.IGNORECASE):
results.append(ValidationResult(
level=ValidationLevel.WARNING,
category="sql_injection",
message=message,
suggestion="Use parameterized queries instead"
))
# Check for dangerous functions
dangerous_functions = [
"pg_read_file", "pg_write_file", "pg_execute", "copy",
"pg_stat_file", "pg_ls_dir"
]
for func in dangerous_functions:
if func.upper() in query_text:
results.append(ValidationResult(
level=ValidationLevel.CRITICAL,
category="dangerous_operation",
message=f"Dangerous function '{func}' detected",
suggestion="Remove dangerous system functions"
))
return results
async def _validate_performance(self, parsed: sql.Statement, schema: str) -> List[ValidationResult]:
"""Validate query for performance issues"""
results = []
query_text = str(parsed).upper()
# Check for SELECT *
if "SELECT *" in query_text:
results.append(ValidationResult(
level=ValidationLevel.WARNING,
category="performance",
message="SELECT * can be inefficient",
suggestion="Specify only needed columns instead of using SELECT *"
))
# Check for missing LIMIT
if "LIMIT" not in query_text and "COUNT(" not in query_text:
results.append(ValidationResult(
level=ValidationLevel.INFO,
category="performance",
message="Consider adding LIMIT clause",
suggestion="Add LIMIT clause to prevent large result sets"
))
# Check for cartesian products (JOIN without ON)
if "JOIN" in query_text and " ON " not in query_text and " USING" not in query_text:
results.append(ValidationResult(
level=ValidationLevel.WARNING,
category="performance",
message="Potential cartesian product detected",
suggestion="Ensure JOIN clauses have proper ON conditions"
))
# Check for functions in WHERE clause
where_functions = re.findall(r'WHERE.*?(\w+)\s*\([^)]*\)\s*=', query_text)
if where_functions:
results.append(ValidationResult(
level=ValidationLevel.WARNING,
category="performance",
message="Functions in WHERE clause can prevent index usage",
suggestion="Consider restructuring query to avoid functions on indexed columns"
))
# Check for LIKE with leading wildcard
if re.search(r"LIKE\s+['\"]%", query_text):
results.append(ValidationResult(
level=ValidationLevel.WARNING,
category="indexing",
message="LIKE with leading wildcard prevents index usage",
suggestion="Avoid leading wildcards in LIKE patterns or consider full-text search"
))
return results
async def _generate_optimizations(self, parsed: sql.Statement, schema: str) -> List[str]:
"""Generate optimization suggestions"""
suggestions = []
query_text = str(parsed).upper()
# Suggest EXPLAIN ANALYZE
suggestions.append("Run EXPLAIN ANALYZE to see the actual execution plan")
# Suggest specific indexes based on WHERE clauses
where_columns = self._extract_where_columns(parsed)
if where_columns:
for table, columns in where_columns.items():
for column in columns:
suggestions.append(
f"Consider adding an index on {table}.{column} if queries are slow"
)
# Suggest query restructuring
if "ORDER BY" in query_text and "LIMIT" in query_text:
suggestions.append(
"For ORDER BY with LIMIT, ensure there's an index on the ORDER BY columns"
)
# Suggest JOIN order optimization
if query_text.count("JOIN") > 2:
suggestions.append(
"For complex JOINs, consider the join order - start with the most selective table"
)
# Suggest EXISTS over IN for subqueries
if " IN (" in query_text and "SELECT" in query_text:
suggestions.append(
"Consider using EXISTS instead of IN with subqueries for better performance"
)
# Suggest proper data types
if "::TEXT" in query_text or "CAST(" in query_text:
suggestions.append(
"Avoid unnecessary type conversions by ensuring proper column data types"
)
return suggestions
def _extract_where_columns(self, parsed: sql.Statement) -> Dict[str, List[str]]:
"""Extract table.column references from WHERE clauses"""
# This is a simplified extraction - in practice, you'd want more sophisticated parsing
where_columns = {}
query_text = str(parsed)
# Look for WHERE clauses and extract column references
where_match = re.search(r'WHERE\s+(.*?)(?:GROUP BY|ORDER BY|LIMIT|$)', query_text, re.IGNORECASE | re.DOTALL)
if where_match:
where_clause = where_match.group(1)
# Extract table.column or just column references
column_refs = re.findall(r'(\w+)\.(\w+)', where_clause)
for table, column in column_refs:
if table not in where_columns:
where_columns[table] = []
where_columns[table].append(column)
return where_columns
def _calculate_complexity(self, parsed: sql.Statement) -> int:
"""Calculate query complexity on a scale of 1-10"""
complexity = 1
query_text = str(parsed).upper()
# Add complexity for various factors
complexity += query_text.count("JOIN") * 2
complexity += query_text.count("SUBQUERY") * 2
complexity += query_text.count("UNION") * 2
complexity += query_text.count("CASE") * 1
complexity += query_text.count("GROUP BY") * 1
complexity += query_text.count("ORDER BY") * 1
complexity += query_text.count("HAVING") * 2
complexity += query_text.count("WINDOW") * 3
complexity += query_text.count("RECURSIVE") * 4
# Count number of tables
from_match = re.search(r'FROM\s+([\w\s,\.]+?)(?:WHERE|GROUP|ORDER|LIMIT|$)', query_text)
if from_match:
tables = from_match.group(1).split(',')
complexity += len(tables)
return min(complexity, 10) # Cap at 10
async def get_table_columns(self, table_name: str, schema: str = "public") -> List[str]:
"""Get column names for a table (cached)"""
cache_key = f"{schema}.{table_name}"
if cache_key not in self._table_cache:
try:
columns = await self.db_manager.get_column_info(table_name, schema)
self._table_cache[cache_key] = [col['column_name'] for col in columns]
except Exception as e:
logger.warning(f"Could not fetch columns for {cache_key}: {e}")
self._table_cache[cache_key] = []
return self._table_cache[cache_key]
async def get_table_indexes(self, table_name: str, schema: str = "public") -> List[Dict[str, Any]]:
"""Get index information for a table (cached)"""
cache_key = f"{schema}.{table_name}"
if cache_key not in self._index_cache:
try:
query = """
SELECT
i.relname as index_name,
array_agg(a.attname ORDER BY c.ordinality) as column_names,
ix.indisunique as is_unique,
ix.indisprimary as is_primary
FROM pg_class t
JOIN pg_index ix ON t.oid = ix.indrelid
JOIN pg_class i ON i.oid = ix.indexrelid
JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN unnest(ix.indkey) WITH ORDINALITY AS c(attnum, ordinality) ON true
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = c.attnum
WHERE t.relname = $1 AND n.nspname = $2
GROUP BY i.relname, ix.indisunique, ix.indisprimary
ORDER BY i.relname
"""
indexes = await self.db_manager.execute_query(query, [table_name, schema])
self._index_cache[cache_key] = indexes or []
except Exception as e:
logger.warning(f"Could not fetch indexes for {cache_key}: {e}")
self._index_cache[cache_key] = []
return self._index_cache[cache_key]
def format_analysis_report(self, analysis: QueryAnalysis) -> str:
"""Format analysis results into a readable report"""
report = []
report.append(f"Query Analysis Report")
report.append("=" * 50)
report.append(f"Valid: {'✅ Yes' if analysis.is_valid else '❌ No'}")
report.append(f"Complexity: {analysis.estimated_complexity}/10")
report.append("")
if analysis.security_issues:
report.append("🔒 Security Issues:")
for issue in analysis.security_issues:
report.append(f" {issue.level.value.upper()}: {issue.message}")
if issue.suggestion:
report.append(f" 💡 {issue.suggestion}")
report.append("")
if analysis.performance_warnings:
report.append("⚡ Performance Warnings:")
for warning in analysis.performance_warnings:
report.append(f" {warning.level.value.upper()}: {warning.message}")
if warning.suggestion:
report.append(f" 💡 {warning.suggestion}")
report.append("")
if analysis.optimization_suggestions:
report.append("💡 Optimization Suggestions:")
for i, suggestion in enumerate(analysis.optimization_suggestions, 1):
report.append(f" {i}. {suggestion}")
report.append("")
if analysis.validation_results:
other_issues = [r for r in analysis.validation_results
if r not in analysis.security_issues + analysis.performance_warnings]
if other_issues:
report.append("ℹ️ Other Issues:")
for issue in other_issues:
report.append(f" {issue.level.value.upper()}: {issue.message}")
if issue.suggestion:
report.append(f" 💡 {issue.suggestion}")
return "\n".join(report)