Skip to main content
Glama
test_main.py15.4 kB
"""Tests for PostgreSQL MCP Server""" import pytest import asyncio import os from unittest.mock import patch, AsyncMock, MagicMock # Import the modules to test from src.config import DatabaseConfig, ServerConfig, load_config from src.database import PostgreSQLManager, DatabaseError from src.tools import DatabaseTools from src.mcp_server import PostgreSQLMCPServer from src.query_optimizer import QueryValidator, ValidationLevel, QueryAnalysis class TestDatabaseConfig: """Test DatabaseConfig""" def test_connection_string(self): """Test connection string generation""" config = DatabaseConfig( host="localhost", port=5432, database="testdb", username="testuser", password="testpass", ssl_mode="require" ) expected = "postgresql://testuser:testpass@localhost:5432/testdb?sslmode=require" assert config.connection_string == expected def test_port_validation(self): """Test port validation""" with pytest.raises(ValueError, match="Port must be between 1 and 65535"): DatabaseConfig( host="localhost", port=70000, # Invalid port database="testdb", username="testuser", password="testpass" ) class TestServerConfig: """Test ServerConfig""" def test_log_level_validation(self): """Test log level validation""" config = ServerConfig(log_level="debug") assert config.log_level == "DEBUG" with pytest.raises(ValueError, match="Log level must be one of"): ServerConfig(log_level="INVALID") class TestDatabaseManager: """Test PostgreSQLManager""" @pytest.fixture def db_config(self): return DatabaseConfig( host="localhost", port=5432, database="testdb", username="testuser", password="testpass" ) @pytest.fixture def db_manager(self, db_config): return PostgreSQLManager(db_config) @pytest.mark.asyncio async def test_initialization(self, db_manager): """Test database manager initialization""" with patch('asyncpg.create_pool') as mock_create_pool: mock_pool = AsyncMock() mock_create_pool.return_value = mock_pool await db_manager.initialize() assert db_manager.pool == mock_pool mock_create_pool.assert_called_once() @pytest.mark.asyncio async def test_connection_error(self, db_manager): """Test database connection error handling""" with patch('asyncpg.create_pool', side_effect=Exception("Connection failed")): with pytest.raises(DatabaseError, match="Database connection failed"): await db_manager.initialize() class TestDatabaseTools: """Test DatabaseTools""" @pytest.fixture def db_manager(self): manager = MagicMock() manager.execute_query = AsyncMock() manager.get_table_info = AsyncMock() manager.get_column_info = AsyncMock() manager.get_schemas = AsyncMock() manager.test_connection = AsyncMock() return manager @pytest.fixture def server_config(self): return ServerConfig() @pytest.fixture def tools(self, db_manager, server_config): return DatabaseTools(db_manager, server_config) def test_get_tools(self, tools): """Test that tools are properly defined""" tool_list = tools.get_tools() tool_names = [tool.name for tool in tool_list] expected_tools = ["query", "list_tables", "describe_table", "list_schemas", "test_connection"] assert all(tool in tool_names for tool in expected_tools) @pytest.mark.asyncio async def test_query_tool_security(self, tools): """Test that only SELECT queries are allowed""" # Test dangerous query result = await tools._handle_query({"sql": "DROP TABLE users;"}) assert len(result) == 1 assert "Only SELECT queries are allowed" in result[0].text # Test valid query tools.db_manager.execute_query.return_value = [{"id": 1, "name": "test"}] result = await tools._handle_query({"sql": "SELECT * FROM users"}) assert len(result) == 1 assert "Query Results" in result[0].text @pytest.mark.asyncio async def test_list_tables_tool(self, tools): """Test list tables tool""" tools.db_manager.get_table_info.return_value = [ {"table_name": "users", "table_type": "BASE TABLE"}, {"table_name": "posts", "table_type": "BASE TABLE"} ] result = await tools._handle_list_tables({"schema": "public"}) assert len(result) == 1 assert "users" in result[0].text assert "posts" in result[0].text @pytest.mark.asyncio async def test_describe_table_tool(self, tools): """Test describe table tool""" tools.db_manager.get_column_info.return_value = [ { "column_name": "id", "data_type": "integer", "is_nullable": "NO", "column_default": "nextval('users_id_seq'::regclass)" }, { "column_name": "name", "data_type": "character varying", "is_nullable": "YES", "column_default": None } ] result = await tools._handle_describe_table({"table_name": "users", "schema": "public"}) assert len(result) == 1 assert "id: integer NOT NULL" in result[0].text assert "name: character varying NULL" in result[0].text class TestMCPServer: """Test PostgreSQL MCP Server""" @pytest.fixture def server(self): return PostgreSQLMCPServer() @pytest.mark.asyncio async def test_server_initialization(self, server): """Test MCP server initialization""" with patch('src.config.load_config') as mock_load_config, \ patch('src.database.PostgreSQLManager') as mock_db_manager, \ patch('src.tools.DatabaseTools') as mock_tools, \ patch('mcp.server.Server') as mock_mcp_server: # Setup mocks mock_load_config.return_value = (MagicMock(), MagicMock()) mock_db_instance = AsyncMock() mock_db_manager.return_value = mock_db_instance await server.initialize() # Verify initialization calls mock_load_config.assert_called_once() mock_db_instance.initialize.assert_called_once() mock_tools.assert_called_once() mock_mcp_server.assert_called_once() # Test running the main CLI def test_main_help(): """Test that main CLI shows help""" with patch('sys.argv', ['main.py', '--help']): with pytest.raises(SystemExit) as exc_info: from main import main main() assert exc_info.value.code == 0 def test_main_version(): """Test that main CLI shows version""" with patch('sys.argv', ['main.py', '--version']): with pytest.raises(SystemExit) as exc_info: from main import main main() assert exc_info.value.code == 0 class TestQueryValidator: """Test QueryValidator""" @pytest.fixture def db_manager(self): manager = MagicMock() manager.execute_query = AsyncMock() manager.get_column_info = AsyncMock() return manager @pytest.fixture def validator(self, db_manager): return QueryValidator(db_manager) @pytest.mark.asyncio async def test_valid_select_query(self, validator): """Test validation of a valid SELECT query""" sql = "SELECT id, name FROM users WHERE active = true LIMIT 10;" analysis = await validator.validate_query(sql) assert analysis.is_valid is True assert analysis.query == sql assert analysis.estimated_complexity <= 5 assert len(analysis.validation_results) >= 0 @pytest.mark.asyncio async def test_dangerous_query_blocked(self, validator): """Test that dangerous queries are blocked""" sql = "DROP TABLE users;" analysis = await validator.validate_query(sql) assert analysis.is_valid is False assert any(result.level == ValidationLevel.CRITICAL for result in analysis.validation_results) assert any("SELECT" in result.message for result in analysis.validation_results) @pytest.mark.asyncio async def test_sql_injection_detection(self, validator): """Test SQL injection pattern detection""" sql = "SELECT * FROM users WHERE id = 1; DROP TABLE users; --" analysis = await validator.validate_query(sql) security_issues = [r for r in analysis.validation_results if r.category == "sql_injection"] assert len(security_issues) > 0 assert any("SQL injection" in issue.message for issue in security_issues) @pytest.mark.asyncio async def test_performance_warnings(self, validator): """Test performance warning detection""" sql = "SELECT * FROM users u JOIN posts p JOIN comments c" analysis = await validator.validate_query(sql) # Should warn about SELECT * performance_warnings = [r for r in analysis.validation_results if r.category == "performance"] assert any("SELECT *" in warning.message for warning in performance_warnings) # Should warn about missing JOIN conditions assert any("cartesian product" in warning.message for warning in performance_warnings) @pytest.mark.asyncio async def test_optimization_suggestions(self, validator): """Test optimization suggestion generation""" sql = "SELECT name FROM users WHERE email = 'test@example.com' ORDER BY created_at LIMIT 5" analysis = await validator.validate_query(sql) assert len(analysis.optimization_suggestions) > 0 assert any("EXPLAIN ANALYZE" in suggestion for suggestion in analysis.optimization_suggestions) @pytest.mark.asyncio async def test_complexity_calculation(self, validator): """Test query complexity calculation""" simple_sql = "SELECT id FROM users LIMIT 10" complex_sql = """ SELECT u.name, COUNT(p.id) as post_count FROM users u LEFT JOIN posts p ON u.id = p.user_id LEFT JOIN comments c ON p.id = c.post_id WHERE u.created_at > '2023-01-01' GROUP BY u.id, u.name HAVING COUNT(p.id) > 5 ORDER BY post_count DESC LIMIT 20 """ simple_analysis = await validator.validate_query(simple_sql) complex_analysis = await validator.validate_query(complex_sql) assert simple_analysis.estimated_complexity < complex_analysis.estimated_complexity assert complex_analysis.estimated_complexity > 5 def test_format_analysis_report(self, validator): """Test analysis report formatting""" from src.query_optimizer import ValidationResult analysis = QueryAnalysis( query="SELECT * FROM test", is_valid=True, estimated_complexity=3, validation_results=[ ValidationResult( level=ValidationLevel.WARNING, category="performance", message="SELECT * can be inefficient", suggestion="Specify only needed columns" ) ], optimization_suggestions=["Add an index on frequently queried columns"], security_issues=[], performance_warnings=[ ValidationResult( level=ValidationLevel.WARNING, category="performance", message="SELECT * can be inefficient", suggestion="Specify only needed columns" ) ] ) report = validator.format_analysis_report(analysis) assert "Query Analysis Report" in report assert "✅ Yes" in report # Valid query assert "3/10" in report # Complexity assert "Performance Warnings" in report assert "Optimization Suggestions" in report class TestDatabaseToolsWithValidation: """Test DatabaseTools with query validation integration""" @pytest.fixture def db_manager(self): manager = MagicMock() manager.execute_query = AsyncMock() return manager @pytest.fixture def server_config(self): return ServerConfig() @pytest.fixture def tools(self, db_manager, server_config): return DatabaseTools(db_manager, server_config) def test_validate_query_tool_exists(self, tools): """Test that validate_query tool is available""" tool_list = tools.get_tools() tool_names = [tool.name for tool in tool_list] assert "validate_query" in tool_names @pytest.mark.asyncio async def test_validate_query_tool(self, tools): """Test the validate_query tool functionality""" with patch.object(tools.query_validator, 'validate_query') as mock_validate: mock_validate.return_value = QueryAnalysis( query="SELECT * FROM test", is_valid=True, estimated_complexity=2, validation_results=[], optimization_suggestions=["Use specific columns"], security_issues=[], performance_warnings=[] ) result = await tools._handle_validate_query({ "sql": "SELECT * FROM test", "schema": "public" }) assert len(result) == 1 assert "Query Analysis Report" in result[0].text mock_validate.assert_called_once_with("SELECT * FROM test", "public") @pytest.mark.asyncio async def test_query_tool_with_validation_integration(self, tools): """Test that the query tool integrates validation feedback""" # Mock the validator to return some warnings with patch.object(tools.query_validator, 'validate_query') as mock_validate: mock_validate.return_value = QueryAnalysis( query="SELECT * FROM test", is_valid=True, estimated_complexity=8, # High complexity validation_results=[], optimization_suggestions=["Add LIMIT clause", "Use specific columns"], security_issues=[], performance_warnings=[] ) # Mock successful query execution tools.db_manager.execute_query.return_value = [{"id": 1, "name": "test"}] result = await tools._handle_query({ "sql": "SELECT * FROM test", "schema": "public" }) assert len(result) == 1 response_text = result[0].text # Should include query results assert "Query Results" in response_text # Should include validation feedback for high complexity assert "High complexity query" in response_text # Should include optimization tips assert "Quick optimization tips" in response_text

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/abdou-ghonim/mcp-postgres'

If you have feedback or need assistance with the MCP directory API, please join our Discord server