"""Tests for PostgreSQL MCP Server"""
import pytest
import asyncio
import os
from unittest.mock import patch, AsyncMock, MagicMock
# Import the modules to test
from src.config import DatabaseConfig, ServerConfig, load_config
from src.database import PostgreSQLManager, DatabaseError
from src.tools import DatabaseTools
from src.mcp_server import PostgreSQLMCPServer
from src.query_optimizer import QueryValidator, ValidationLevel, QueryAnalysis
class TestDatabaseConfig:
"""Test DatabaseConfig"""
def test_connection_string(self):
"""Test connection string generation"""
config = DatabaseConfig(
host="localhost",
port=5432,
database="testdb",
username="testuser",
password="testpass",
ssl_mode="require"
)
expected = "postgresql://testuser:testpass@localhost:5432/testdb?sslmode=require"
assert config.connection_string == expected
def test_port_validation(self):
"""Test port validation"""
with pytest.raises(ValueError, match="Port must be between 1 and 65535"):
DatabaseConfig(
host="localhost",
port=70000, # Invalid port
database="testdb",
username="testuser",
password="testpass"
)
class TestServerConfig:
"""Test ServerConfig"""
def test_log_level_validation(self):
"""Test log level validation"""
config = ServerConfig(log_level="debug")
assert config.log_level == "DEBUG"
with pytest.raises(ValueError, match="Log level must be one of"):
ServerConfig(log_level="INVALID")
class TestDatabaseManager:
"""Test PostgreSQLManager"""
@pytest.fixture
def db_config(self):
return DatabaseConfig(
host="localhost",
port=5432,
database="testdb",
username="testuser",
password="testpass"
)
@pytest.fixture
def db_manager(self, db_config):
return PostgreSQLManager(db_config)
@pytest.mark.asyncio
async def test_initialization(self, db_manager):
"""Test database manager initialization"""
with patch('asyncpg.create_pool') as mock_create_pool:
mock_pool = AsyncMock()
mock_create_pool.return_value = mock_pool
await db_manager.initialize()
assert db_manager.pool == mock_pool
mock_create_pool.assert_called_once()
@pytest.mark.asyncio
async def test_connection_error(self, db_manager):
"""Test database connection error handling"""
with patch('asyncpg.create_pool', side_effect=Exception("Connection failed")):
with pytest.raises(DatabaseError, match="Database connection failed"):
await db_manager.initialize()
class TestDatabaseTools:
"""Test DatabaseTools"""
@pytest.fixture
def db_manager(self):
manager = MagicMock()
manager.execute_query = AsyncMock()
manager.get_table_info = AsyncMock()
manager.get_column_info = AsyncMock()
manager.get_schemas = AsyncMock()
manager.test_connection = AsyncMock()
return manager
@pytest.fixture
def server_config(self):
return ServerConfig()
@pytest.fixture
def tools(self, db_manager, server_config):
return DatabaseTools(db_manager, server_config)
def test_get_tools(self, tools):
"""Test that tools are properly defined"""
tool_list = tools.get_tools()
tool_names = [tool.name for tool in tool_list]
expected_tools = ["query", "list_tables", "describe_table", "list_schemas", "test_connection"]
assert all(tool in tool_names for tool in expected_tools)
@pytest.mark.asyncio
async def test_query_tool_security(self, tools):
"""Test that only SELECT queries are allowed"""
# Test dangerous query
result = await tools._handle_query({"sql": "DROP TABLE users;"})
assert len(result) == 1
assert "Only SELECT queries are allowed" in result[0].text
# Test valid query
tools.db_manager.execute_query.return_value = [{"id": 1, "name": "test"}]
result = await tools._handle_query({"sql": "SELECT * FROM users"})
assert len(result) == 1
assert "Query Results" in result[0].text
@pytest.mark.asyncio
async def test_list_tables_tool(self, tools):
"""Test list tables tool"""
tools.db_manager.get_table_info.return_value = [
{"table_name": "users", "table_type": "BASE TABLE"},
{"table_name": "posts", "table_type": "BASE TABLE"}
]
result = await tools._handle_list_tables({"schema": "public"})
assert len(result) == 1
assert "users" in result[0].text
assert "posts" in result[0].text
@pytest.mark.asyncio
async def test_describe_table_tool(self, tools):
"""Test describe table tool"""
tools.db_manager.get_column_info.return_value = [
{
"column_name": "id",
"data_type": "integer",
"is_nullable": "NO",
"column_default": "nextval('users_id_seq'::regclass)"
},
{
"column_name": "name",
"data_type": "character varying",
"is_nullable": "YES",
"column_default": None
}
]
result = await tools._handle_describe_table({"table_name": "users", "schema": "public"})
assert len(result) == 1
assert "id: integer NOT NULL" in result[0].text
assert "name: character varying NULL" in result[0].text
class TestMCPServer:
"""Test PostgreSQL MCP Server"""
@pytest.fixture
def server(self):
return PostgreSQLMCPServer()
@pytest.mark.asyncio
async def test_server_initialization(self, server):
"""Test MCP server initialization"""
with patch('src.config.load_config') as mock_load_config, \
patch('src.database.PostgreSQLManager') as mock_db_manager, \
patch('src.tools.DatabaseTools') as mock_tools, \
patch('mcp.server.Server') as mock_mcp_server:
# Setup mocks
mock_load_config.return_value = (MagicMock(), MagicMock())
mock_db_instance = AsyncMock()
mock_db_manager.return_value = mock_db_instance
await server.initialize()
# Verify initialization calls
mock_load_config.assert_called_once()
mock_db_instance.initialize.assert_called_once()
mock_tools.assert_called_once()
mock_mcp_server.assert_called_once()
# Test running the main CLI
def test_main_help():
"""Test that main CLI shows help"""
with patch('sys.argv', ['main.py', '--help']):
with pytest.raises(SystemExit) as exc_info:
from main import main
main()
assert exc_info.value.code == 0
def test_main_version():
"""Test that main CLI shows version"""
with patch('sys.argv', ['main.py', '--version']):
with pytest.raises(SystemExit) as exc_info:
from main import main
main()
assert exc_info.value.code == 0
class TestQueryValidator:
"""Test QueryValidator"""
@pytest.fixture
def db_manager(self):
manager = MagicMock()
manager.execute_query = AsyncMock()
manager.get_column_info = AsyncMock()
return manager
@pytest.fixture
def validator(self, db_manager):
return QueryValidator(db_manager)
@pytest.mark.asyncio
async def test_valid_select_query(self, validator):
"""Test validation of a valid SELECT query"""
sql = "SELECT id, name FROM users WHERE active = true LIMIT 10;"
analysis = await validator.validate_query(sql)
assert analysis.is_valid is True
assert analysis.query == sql
assert analysis.estimated_complexity <= 5
assert len(analysis.validation_results) >= 0
@pytest.mark.asyncio
async def test_dangerous_query_blocked(self, validator):
"""Test that dangerous queries are blocked"""
sql = "DROP TABLE users;"
analysis = await validator.validate_query(sql)
assert analysis.is_valid is False
assert any(result.level == ValidationLevel.CRITICAL for result in analysis.validation_results)
assert any("SELECT" in result.message for result in analysis.validation_results)
@pytest.mark.asyncio
async def test_sql_injection_detection(self, validator):
"""Test SQL injection pattern detection"""
sql = "SELECT * FROM users WHERE id = 1; DROP TABLE users; --"
analysis = await validator.validate_query(sql)
security_issues = [r for r in analysis.validation_results if r.category == "sql_injection"]
assert len(security_issues) > 0
assert any("SQL injection" in issue.message for issue in security_issues)
@pytest.mark.asyncio
async def test_performance_warnings(self, validator):
"""Test performance warning detection"""
sql = "SELECT * FROM users u JOIN posts p JOIN comments c"
analysis = await validator.validate_query(sql)
# Should warn about SELECT *
performance_warnings = [r for r in analysis.validation_results if r.category == "performance"]
assert any("SELECT *" in warning.message for warning in performance_warnings)
# Should warn about missing JOIN conditions
assert any("cartesian product" in warning.message for warning in performance_warnings)
@pytest.mark.asyncio
async def test_optimization_suggestions(self, validator):
"""Test optimization suggestion generation"""
sql = "SELECT name FROM users WHERE email = 'test@example.com' ORDER BY created_at LIMIT 5"
analysis = await validator.validate_query(sql)
assert len(analysis.optimization_suggestions) > 0
assert any("EXPLAIN ANALYZE" in suggestion for suggestion in analysis.optimization_suggestions)
@pytest.mark.asyncio
async def test_complexity_calculation(self, validator):
"""Test query complexity calculation"""
simple_sql = "SELECT id FROM users LIMIT 10"
complex_sql = """
SELECT u.name, COUNT(p.id) as post_count
FROM users u
LEFT JOIN posts p ON u.id = p.user_id
LEFT JOIN comments c ON p.id = c.post_id
WHERE u.created_at > '2023-01-01'
GROUP BY u.id, u.name
HAVING COUNT(p.id) > 5
ORDER BY post_count DESC
LIMIT 20
"""
simple_analysis = await validator.validate_query(simple_sql)
complex_analysis = await validator.validate_query(complex_sql)
assert simple_analysis.estimated_complexity < complex_analysis.estimated_complexity
assert complex_analysis.estimated_complexity > 5
def test_format_analysis_report(self, validator):
"""Test analysis report formatting"""
from src.query_optimizer import ValidationResult
analysis = QueryAnalysis(
query="SELECT * FROM test",
is_valid=True,
estimated_complexity=3,
validation_results=[
ValidationResult(
level=ValidationLevel.WARNING,
category="performance",
message="SELECT * can be inefficient",
suggestion="Specify only needed columns"
)
],
optimization_suggestions=["Add an index on frequently queried columns"],
security_issues=[],
performance_warnings=[
ValidationResult(
level=ValidationLevel.WARNING,
category="performance",
message="SELECT * can be inefficient",
suggestion="Specify only needed columns"
)
]
)
report = validator.format_analysis_report(analysis)
assert "Query Analysis Report" in report
assert "✅ Yes" in report # Valid query
assert "3/10" in report # Complexity
assert "Performance Warnings" in report
assert "Optimization Suggestions" in report
class TestDatabaseToolsWithValidation:
"""Test DatabaseTools with query validation integration"""
@pytest.fixture
def db_manager(self):
manager = MagicMock()
manager.execute_query = AsyncMock()
return manager
@pytest.fixture
def server_config(self):
return ServerConfig()
@pytest.fixture
def tools(self, db_manager, server_config):
return DatabaseTools(db_manager, server_config)
def test_validate_query_tool_exists(self, tools):
"""Test that validate_query tool is available"""
tool_list = tools.get_tools()
tool_names = [tool.name for tool in tool_list]
assert "validate_query" in tool_names
@pytest.mark.asyncio
async def test_validate_query_tool(self, tools):
"""Test the validate_query tool functionality"""
with patch.object(tools.query_validator, 'validate_query') as mock_validate:
mock_validate.return_value = QueryAnalysis(
query="SELECT * FROM test",
is_valid=True,
estimated_complexity=2,
validation_results=[],
optimization_suggestions=["Use specific columns"],
security_issues=[],
performance_warnings=[]
)
result = await tools._handle_validate_query({
"sql": "SELECT * FROM test",
"schema": "public"
})
assert len(result) == 1
assert "Query Analysis Report" in result[0].text
mock_validate.assert_called_once_with("SELECT * FROM test", "public")
@pytest.mark.asyncio
async def test_query_tool_with_validation_integration(self, tools):
"""Test that the query tool integrates validation feedback"""
# Mock the validator to return some warnings
with patch.object(tools.query_validator, 'validate_query') as mock_validate:
mock_validate.return_value = QueryAnalysis(
query="SELECT * FROM test",
is_valid=True,
estimated_complexity=8, # High complexity
validation_results=[],
optimization_suggestions=["Add LIMIT clause", "Use specific columns"],
security_issues=[],
performance_warnings=[]
)
# Mock successful query execution
tools.db_manager.execute_query.return_value = [{"id": 1, "name": "test"}]
result = await tools._handle_query({
"sql": "SELECT * FROM test",
"schema": "public"
})
assert len(result) == 1
response_text = result[0].text
# Should include query results
assert "Query Results" in response_text
# Should include validation feedback for high complexity
assert "High complexity query" in response_text
# Should include optimization tips
assert "Quick optimization tips" in response_text