Skip to main content
Glama
query_optimizer.py17 kB
"""Query validation and optimization for PostgreSQL MCP Server""" import re import logging from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from enum import Enum import sqlparse from sqlparse import sql, tokens as T from .database import PostgreSQLManager logger = logging.getLogger(__name__) class ValidationLevel(Enum): """Validation severity levels""" INFO = "info" WARNING = "warning" ERROR = "error" CRITICAL = "critical" @dataclass class ValidationResult: """Result of query validation""" level: ValidationLevel category: str message: str suggestion: Optional[str] = None line_number: Optional[int] = None column: Optional[int] = None @dataclass class QueryAnalysis: """Complete query analysis result""" query: str is_valid: bool estimated_complexity: int # 1-10 scale validation_results: List[ValidationResult] optimization_suggestions: List[str] security_issues: List[ValidationResult] performance_warnings: List[ValidationResult] class QueryValidator: """SQL query validator and optimizer""" def __init__(self, db_manager: PostgreSQLManager): self.db_manager = db_manager self._table_cache: Dict[str, List[str]] = {} self._index_cache: Dict[str, List[Dict[str, Any]]] = {} async def validate_query(self, query: str, schema: str = "public") -> QueryAnalysis: """Validate and analyze a SQL query""" logger.debug(f"Validating query: {query[:100]}...") # Parse the query try: parsed = sqlparse.parse(query)[0] except Exception as e: return QueryAnalysis( query=query, is_valid=False, estimated_complexity=10, validation_results=[ ValidationResult( level=ValidationLevel.ERROR, category="syntax", message=f"SQL parsing failed: {e}" ) ], optimization_suggestions=[], security_issues=[], performance_warnings=[] ) # Initialize result lists validation_results = [] optimization_suggestions = [] security_issues = [] performance_warnings = [] # Perform various validations validation_results.extend(await self._validate_syntax(parsed)) validation_results.extend(await self._validate_security(parsed)) validation_results.extend(await self._validate_performance(parsed, schema)) # Categorize results for result in validation_results: if result.category in ["sql_injection", "dangerous_operation"]: security_issues.append(result) elif result.category in ["performance", "indexing", "optimization"]: performance_warnings.append(result) # Generate optimization suggestions optimization_suggestions.extend(await self._generate_optimizations(parsed, schema)) # Calculate complexity complexity = self._calculate_complexity(parsed) # Determine if query is valid (no critical errors) is_valid = not any(r.level == ValidationLevel.CRITICAL for r in validation_results) return QueryAnalysis( query=query, is_valid=is_valid, estimated_complexity=complexity, validation_results=validation_results, optimization_suggestions=optimization_suggestions, security_issues=security_issues, performance_warnings=performance_warnings ) async def _validate_syntax(self, parsed: sql.Statement) -> List[ValidationResult]: """Validate SQL syntax and structure""" results = [] # Check for basic SQL structure if not parsed.tokens: results.append(ValidationResult( level=ValidationLevel.ERROR, category="syntax", message="Empty or invalid SQL statement" )) return results # Check for SELECT statement (security requirement) first_token = None for token in parsed.flatten(): if token.ttype is T.Keyword and token.value.upper().strip(): first_token = token.value.upper().strip() break if first_token != "SELECT": results.append(ValidationResult( level=ValidationLevel.CRITICAL, category="security", message="Only SELECT statements are allowed", suggestion="Use SELECT queries for data retrieval only" )) # Check for incomplete statements query_text = str(parsed).strip() if not query_text.endswith(';') and len(query_text) > 10: results.append(ValidationResult( level=ValidationLevel.INFO, category="syntax", message="Query should end with semicolon", suggestion="Add ';' at the end of your query" )) return results async def _validate_security(self, parsed: sql.Statement) -> List[ValidationResult]: """Validate query for security issues""" results = [] query_text = str(parsed).upper() # Check for SQL injection patterns injection_patterns = [ (r";\s*(DROP|DELETE|UPDATE|INSERT|ALTER|CREATE)", "Potential SQL injection with dangerous commands"), (r"UNION\s+SELECT", "UNION SELECT statements can be used for SQL injection"), (r"--", "SQL comments can be used to bypass security"), (r"/\*.*\*/", "SQL block comments can hide malicious code"), (r"'\s*OR\s*'", "Potential SQL injection with OR condition"), (r"'\s*=\s*'", "Potential tautology-based SQL injection"), ] for pattern, message in injection_patterns: if re.search(pattern, query_text, re.IGNORECASE): results.append(ValidationResult( level=ValidationLevel.WARNING, category="sql_injection", message=message, suggestion="Use parameterized queries instead" )) # Check for dangerous functions dangerous_functions = [ "pg_read_file", "pg_write_file", "pg_execute", "copy", "pg_stat_file", "pg_ls_dir" ] for func in dangerous_functions: if func.upper() in query_text: results.append(ValidationResult( level=ValidationLevel.CRITICAL, category="dangerous_operation", message=f"Dangerous function '{func}' detected", suggestion="Remove dangerous system functions" )) return results async def _validate_performance(self, parsed: sql.Statement, schema: str) -> List[ValidationResult]: """Validate query for performance issues""" results = [] query_text = str(parsed).upper() # Check for SELECT * if "SELECT *" in query_text: results.append(ValidationResult( level=ValidationLevel.WARNING, category="performance", message="SELECT * can be inefficient", suggestion="Specify only needed columns instead of using SELECT *" )) # Check for missing LIMIT if "LIMIT" not in query_text and "COUNT(" not in query_text: results.append(ValidationResult( level=ValidationLevel.INFO, category="performance", message="Consider adding LIMIT clause", suggestion="Add LIMIT clause to prevent large result sets" )) # Check for cartesian products (JOIN without ON) if "JOIN" in query_text and " ON " not in query_text and " USING" not in query_text: results.append(ValidationResult( level=ValidationLevel.WARNING, category="performance", message="Potential cartesian product detected", suggestion="Ensure JOIN clauses have proper ON conditions" )) # Check for functions in WHERE clause where_functions = re.findall(r'WHERE.*?(\w+)\s*\([^)]*\)\s*=', query_text) if where_functions: results.append(ValidationResult( level=ValidationLevel.WARNING, category="performance", message="Functions in WHERE clause can prevent index usage", suggestion="Consider restructuring query to avoid functions on indexed columns" )) # Check for LIKE with leading wildcard if re.search(r"LIKE\s+['\"]%", query_text): results.append(ValidationResult( level=ValidationLevel.WARNING, category="indexing", message="LIKE with leading wildcard prevents index usage", suggestion="Avoid leading wildcards in LIKE patterns or consider full-text search" )) return results async def _generate_optimizations(self, parsed: sql.Statement, schema: str) -> List[str]: """Generate optimization suggestions""" suggestions = [] query_text = str(parsed).upper() # Suggest EXPLAIN ANALYZE suggestions.append("Run EXPLAIN ANALYZE to see the actual execution plan") # Suggest specific indexes based on WHERE clauses where_columns = self._extract_where_columns(parsed) if where_columns: for table, columns in where_columns.items(): for column in columns: suggestions.append( f"Consider adding an index on {table}.{column} if queries are slow" ) # Suggest query restructuring if "ORDER BY" in query_text and "LIMIT" in query_text: suggestions.append( "For ORDER BY with LIMIT, ensure there's an index on the ORDER BY columns" ) # Suggest JOIN order optimization if query_text.count("JOIN") > 2: suggestions.append( "For complex JOINs, consider the join order - start with the most selective table" ) # Suggest EXISTS over IN for subqueries if " IN (" in query_text and "SELECT" in query_text: suggestions.append( "Consider using EXISTS instead of IN with subqueries for better performance" ) # Suggest proper data types if "::TEXT" in query_text or "CAST(" in query_text: suggestions.append( "Avoid unnecessary type conversions by ensuring proper column data types" ) return suggestions def _extract_where_columns(self, parsed: sql.Statement) -> Dict[str, List[str]]: """Extract table.column references from WHERE clauses""" # This is a simplified extraction - in practice, you'd want more sophisticated parsing where_columns = {} query_text = str(parsed) # Look for WHERE clauses and extract column references where_match = re.search(r'WHERE\s+(.*?)(?:GROUP BY|ORDER BY|LIMIT|$)', query_text, re.IGNORECASE | re.DOTALL) if where_match: where_clause = where_match.group(1) # Extract table.column or just column references column_refs = re.findall(r'(\w+)\.(\w+)', where_clause) for table, column in column_refs: if table not in where_columns: where_columns[table] = [] where_columns[table].append(column) return where_columns def _calculate_complexity(self, parsed: sql.Statement) -> int: """Calculate query complexity on a scale of 1-10""" complexity = 1 query_text = str(parsed).upper() # Add complexity for various factors complexity += query_text.count("JOIN") * 2 complexity += query_text.count("SUBQUERY") * 2 complexity += query_text.count("UNION") * 2 complexity += query_text.count("CASE") * 1 complexity += query_text.count("GROUP BY") * 1 complexity += query_text.count("ORDER BY") * 1 complexity += query_text.count("HAVING") * 2 complexity += query_text.count("WINDOW") * 3 complexity += query_text.count("RECURSIVE") * 4 # Count number of tables from_match = re.search(r'FROM\s+([\w\s,\.]+?)(?:WHERE|GROUP|ORDER|LIMIT|$)', query_text) if from_match: tables = from_match.group(1).split(',') complexity += len(tables) return min(complexity, 10) # Cap at 10 async def get_table_columns(self, table_name: str, schema: str = "public") -> List[str]: """Get column names for a table (cached)""" cache_key = f"{schema}.{table_name}" if cache_key not in self._table_cache: try: columns = await self.db_manager.get_column_info(table_name, schema) self._table_cache[cache_key] = [col['column_name'] for col in columns] except Exception as e: logger.warning(f"Could not fetch columns for {cache_key}: {e}") self._table_cache[cache_key] = [] return self._table_cache[cache_key] async def get_table_indexes(self, table_name: str, schema: str = "public") -> List[Dict[str, Any]]: """Get index information for a table (cached)""" cache_key = f"{schema}.{table_name}" if cache_key not in self._index_cache: try: query = """ SELECT i.relname as index_name, array_agg(a.attname ORDER BY c.ordinality) as column_names, ix.indisunique as is_unique, ix.indisprimary as is_primary FROM pg_class t JOIN pg_index ix ON t.oid = ix.indrelid JOIN pg_class i ON i.oid = ix.indexrelid JOIN pg_namespace n ON n.oid = t.relnamespace JOIN unnest(ix.indkey) WITH ORDINALITY AS c(attnum, ordinality) ON true JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = c.attnum WHERE t.relname = $1 AND n.nspname = $2 GROUP BY i.relname, ix.indisunique, ix.indisprimary ORDER BY i.relname """ indexes = await self.db_manager.execute_query(query, [table_name, schema]) self._index_cache[cache_key] = indexes or [] except Exception as e: logger.warning(f"Could not fetch indexes for {cache_key}: {e}") self._index_cache[cache_key] = [] return self._index_cache[cache_key] def format_analysis_report(self, analysis: QueryAnalysis) -> str: """Format analysis results into a readable report""" report = [] report.append(f"Query Analysis Report") report.append("=" * 50) report.append(f"Valid: {'✅ Yes' if analysis.is_valid else '❌ No'}") report.append(f"Complexity: {analysis.estimated_complexity}/10") report.append("") if analysis.security_issues: report.append("🔒 Security Issues:") for issue in analysis.security_issues: report.append(f" {issue.level.value.upper()}: {issue.message}") if issue.suggestion: report.append(f" 💡 {issue.suggestion}") report.append("") if analysis.performance_warnings: report.append("⚡ Performance Warnings:") for warning in analysis.performance_warnings: report.append(f" {warning.level.value.upper()}: {warning.message}") if warning.suggestion: report.append(f" 💡 {warning.suggestion}") report.append("") if analysis.optimization_suggestions: report.append("💡 Optimization Suggestions:") for i, suggestion in enumerate(analysis.optimization_suggestions, 1): report.append(f" {i}. {suggestion}") report.append("") if analysis.validation_results: other_issues = [r for r in analysis.validation_results if r not in analysis.security_issues + analysis.performance_warnings] if other_issues: report.append("ℹ️ Other Issues:") for issue in other_issues: report.append(f" {issue.level.value.upper()}: {issue.message}") if issue.suggestion: report.append(f" 💡 {issue.suggestion}") return "\n".join(report)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/abdou-ghonim/mcp-postgres'

If you have feedback or need assistance with the MCP directory API, please join our Discord server