Skip to main content
Glama

Alibaba Cloud DMS MCP Server

Official
by aliyun
test_server.py19.5 kB
import asyncio import pytest import os from unittest.mock import AsyncMock, MagicMock, patch # Import the functions and classes to be tested from alibabacloud_dms_mcp_server.server import ( create_client, add_instance, get_instance, search_database, get_database, list_tables, get_meta_table_detail_info, execute_script, nl2sql, ToolRegistry, lifespan, InstanceInfo, InstanceDetail, DatabaseInfo, DatabaseDetail, TableDetail, ExecuteScriptResult, ResultSet, SqlResult, MyBaseModel ) from mcp.server.fastmcp import FastMCP from alibabacloud_dms_enterprise20181101 import models as dms_models from alibabacloud_tea_openapi import models as open_api_models # --- Fixtures --- @pytest.fixture def mock_dms_client(): """Fixture to mock the DMS client.""" with patch('alibabacloud_dms_mcp_server.server.create_client') as mock_create_client: client_instance = MagicMock() # Mock specific client methods as needed for tests client_instance.simply_add_instance = MagicMock() client_instance.get_instance = MagicMock() client_instance.search_database = MagicMock() client_instance.get_database = MagicMock() client_instance.list_tables = MagicMock() client_instance.get_meta_table_detail_info = MagicMock() client_instance.execute_script = MagicMock() client_instance.generate_sql_from_nl = MagicMock() mock_create_client.return_value = client_instance yield client_instance @pytest.fixture def mcp_app(): """Fixture to create a FastMCP app instance.""" app = FastMCP("TestApp") # Mock app.state if necessary for ToolRegistry tests class AppState: pass app.state = AppState() app.state.default_database_id = None return app # --- Helper Functions for Mock Responses --- def create_mock_openapi_response(body_data: dict, status_code: int = 200): """Creates a mock OpenAPI response object.""" response = MagicMock() response.status_code = status_code response.body = MagicMock() response.body.to_map = MagicMock(return_value=body_data) return response # --- Tests for Core Logic Functions --- @pytest.mark.asyncio async def test_add_instance_success(mock_dms_client): mock_response_body = { "instance_id": "dms-instance-123", "host": "test-host.com", "port": "3306" } mock_dms_client.simply_add_instance.return_value = create_mock_openapi_response(mock_response_body) result = await add_instance( db_user="test_user", db_password="test_password", host="test-host.com", port="3306" ) assert isinstance(result, InstanceInfo) assert result.instance_id == "dms-instance-123" assert result.host == "test-host.com" mock_dms_client.simply_add_instance.assert_called_once() call_args = mock_dms_client.simply_add_instance.call_args[0][0] assert call_args.database_user == "test_user" assert call_args.host == "test-host.com" @pytest.mark.asyncio async def test_add_instance_missing_user_raises_error(): with pytest.raises(ValueError, match="db_user must be a non-empty string"): await add_instance(db_user="", db_password="password") @pytest.mark.asyncio async def test_get_instance_success(mock_dms_client): mock_response_body = { "Instance": { "InstanceId": "rm-123", "State": "NORMAL", "InstanceType": "MySQL", "InstanceAlias": "My Test DB" } } mock_dms_client.get_instance.return_value = create_mock_openapi_response(mock_response_body) result = await get_instance(host="test-host.com", port="3306") assert isinstance(result, InstanceDetail) assert result.InstanceId == "rm-123" assert result.InstanceType == "MySQL" @pytest.mark.asyncio async def test_search_database_success(mock_dms_client): mock_response_body = { "SearchDatabaseList": { "SearchDatabase": [ {"DatabaseId": "db1", "Host": "host1", "Port": "3306", "DbType": "MySQL", "SchemaName": "schema1", "CatalogName": "def"}, {"DatabaseId": "db2", "Host": "host2", "Port": "5432", "DbType": "PostgreSQL", "SchemaName": "public", "CatalogName": "pg_catalog"} ] }, "TotalCount": 2 } mock_dms_client.search_database.return_value = create_mock_openapi_response(mock_response_body) results = await search_database(search_key="test_db") assert len(results) == 2 assert isinstance(results[0], DatabaseInfo) assert results[0].DatabaseId == "db1" assert results[0].SchemaName == "schema1" assert results[1].DatabaseId == "db2" assert results[1].SchemaName == "pg_catalog.public" @pytest.mark.asyncio async def test_get_database_success(mock_dms_client): mock_response_body = { "Database": { "DatabaseId": "db-guid-123", "SchemaName": "my_schema", "DbType": "MySQL", "InstanceId": "inst-id-456" } } mock_dms_client.get_database.return_value = create_mock_openapi_response(mock_response_body) result = await get_database(host="test-host.com", port="3306", schema_name="my_schema") assert isinstance(result, DatabaseDetail) assert result.DatabaseId == "db-guid-123" assert result.SchemaName == "my_schema" @pytest.mark.asyncio async def test_list_tables_success(mock_dms_client): mock_response_body = { "TableList": { "Table": [{"TableName": "users", "TableGuid": "guid1"}, {"TableName": "products", "TableGuid": "guid2"}]}, "TotalCount": 2 } mock_dms_client.list_tables.return_value = create_mock_openapi_response(mock_response_body) result = await list_tables(database_id="db-guid-123", search_name="user") assert "TableList" in result assert len(result["TableList"]["Table"]) == 2 @pytest.mark.asyncio async def test_get_meta_table_detail_info_success(mock_dms_client): mock_response_body = { "DetailInfo": { "ColumnList": [{"ColumnName": "id", "ColumnType": "int"}, {"ColumnName": "name", "ColumnType": "varchar"}], "IndexList": [{"IndexName": "PRIMARY", "IndexColumns": ["id"]}] } } mock_dms_client.get_meta_table_detail_info.return_value = create_mock_openapi_response(mock_response_body) result = await get_meta_table_detail_info(table_guid="guid.schema.table") assert isinstance(result, TableDetail) assert len(result.ColumnList) == 2 assert result.ColumnList[0]['ColumnName'] == "id" assert len(result.IndexList) == 1 @pytest.mark.asyncio async def test_execute_script_success(mock_dms_client): mock_response_body = { "RequestId": "req-123", "Success": True, "Results": [ { "Success": True, "ColumnNames": ["id", "name"], "RowCount": 1, "Rows": [{"id": 1, "name": "Alice"}] } ] } mock_dms_client.execute_script.return_value = create_mock_openapi_response(mock_response_body) result = await execute_script(database_id="db-guid-123", script="SELECT * FROM users") assert isinstance(result, ExecuteScriptResult) assert result.Success is True assert len(result.Results) == 1 assert result.Results[0].Success is True assert result.Results[0].ColumnNames == ["id", "name"] assert result.Results[0].MarkdownTable is not None @pytest.mark.asyncio async def test_execute_script_failure_in_results(mock_dms_client): mock_response_body = { "RequestId": "req-456", "Success": True, # Overall success can be true even if one script part fails "Results": [ { "Success": False, # Individual result failed "ErrorMessage": "Syntax error" } ] } mock_dms_client.execute_script.return_value = create_mock_openapi_response(mock_response_body) result = await execute_script(database_id="db-guid-123", script="INVALID SQL") assert result.Success is True assert len(result.Results) == 1 assert result.Results[0].Success is False assert result.Results[0].MarkdownTable is None @pytest.mark.asyncio async def test_nl2sql_success(mock_dms_client): mock_response_body = { "Data": {"Sql": "SELECT id, name FROM users WHERE age > 30"} } mock_dms_client.generate_sql_from_nl.return_value = create_mock_openapi_response(mock_response_body) result = await nl2sql(database_id="db-guid-123", question="show users older than 30") assert isinstance(result, SqlResult) assert result.sql == "SELECT id, name FROM users WHERE age > 30" # --- Tests for ToolRegistry --- @pytest.mark.asyncio async def test_tool_registry_full_toolset(mcp_app): registry = ToolRegistry(mcp=mcp_app) registry.register_tools() # default_database_id is None # Check if all expected tools are registered expected_tool_names = [ "addInstance", "getInstance", "searchDatabase", "getDatabase", "listTables", "getTableDetailInfo", "executeScript", "generateSql" ] tools = await mcp_app.list_tools() registered_tool_names = [tool.name for tool in tools] for name in expected_tool_names: assert name in registered_tool_names @pytest.mark.asyncio async def test_tool_registry_configured_toolset(mcp_app, mock_dms_client): mcp_app.state.default_database_id = "configured_db_id_123" registry = ToolRegistry(mcp=mcp_app) registry.register_tools() expected_tool_names = ["listTables", "getTableDetailInfo", "executeScript", "askDatabase"] tools = await mcp_app.list_tools() registered_tool_names = [tool.name for tool in tools] for name in expected_tool_names: assert name in registered_tool_names assert name not in ["addInstance", "getInstance", "searchDatabase", "getDatabase", "generateSql"] # Ensure full set not registered @pytest.mark.asyncio async def test_tool_registry_ask_database_configured_success(mcp_app, mock_dms_client): mcp_app.state.default_database_id = "configured_db_id_ask" registry = ToolRegistry(mcp=mcp_app) registry.register_tools() tools = await mcp_app.list_tools() ask_database_tool = next(tool for tool in tools if tool.name == "askDatabase") # Mock nl2sql response nl_response_body = {"Data": {"Sql": "SELECT * FROM test_table"}} mock_dms_client.generate_sql_from_nl.return_value = create_mock_openapi_response(nl_response_body) # Mock execute_script response exec_response_body = { "RequestId": "req-ask", "Success": True, "Results": [{"Success": True, "ColumnNames": ["col1"], "RowCount": 1, "Rows": [{"col1": "val1"}]}] } mock_dms_client.execute_script.return_value = create_mock_openapi_response(exec_response_body) result_str = await mcp_app.call_tool(ask_database_tool.name, arguments={"question":"show me the data"}) assert "val1" in str(result_str) # Check if markdown table string contains the value mock_dms_client.generate_sql_from_nl.assert_called_once() mock_dms_client.execute_script.assert_called_once() assert mock_dms_client.execute_script.call_args[0][0].script == "SELECT * FROM test_table" @pytest.mark.asyncio async def test_tool_registry_ask_database_nl_fails(mcp_app, mock_dms_client): mcp_app.state.default_database_id = "configured_db_id_ask_fail_nl" registry = ToolRegistry(mcp=mcp_app) registry.register_tools() tools = await mcp_app.list_tools() ask_database_tool = next(tool for tool in tools if tool.name == "askDatabase") # Mock nl2sql to return no SQL mock_dms_client.generate_sql_from_nl.return_value = create_mock_openapi_response({"Data": {"Sql": None}}) result_str = await mcp_app.call_tool(ask_database_tool.name, arguments={"question": "bad question"}) assert "Error: Could not generate an SQL query" in str(result_str) mock_dms_client.generate_sql_from_nl.assert_called_once() mock_dms_client.execute_script.assert_not_called() @pytest.mark.asyncio async def test_tool_registry_ask_database_exec_fails(mcp_app, mock_dms_client): mcp_app.state.default_database_id = "configured_db_id_ask_fail_exec" registry = ToolRegistry(mcp=mcp_app) registry.register_tools() tools = await mcp_app.list_tools() ask_database_tool = next(tool for tool in tools if tool.name == "askDatabase") mock_dms_client.generate_sql_from_nl.return_value = create_mock_openapi_response({"Data": {"Sql": "SELECT 1"}}) # Mock execute_script to raise an exception mock_dms_client.execute_script.side_effect = Exception("DB execution error") result_str = await mcp_app.call_tool(ask_database_tool.name, arguments={"question": "show me the data"}) assert "Error: An issue occurred while executing the query: DB execution error" in str(result_str) mock_dms_client.generate_sql_from_nl.assert_called_once() mock_dms_client.execute_script.assert_called_once() # --- Tests for Lifespan --- @pytest.mark.asyncio async def test_lifespan_with_connection_string(mcp_app, mock_dms_client): # 模拟 get_instance 的成功响应 instance_response = { "Instance": { "InstanceId": "rm-test123", "State": "NORMAL", "InstanceType": "MySQL" } } mock_dms_client.get_instance.return_value = create_mock_openapi_response(instance_response) # 模拟 get_database 的成功响应,包含DatabaseId db_response = { "Database": { "DatabaseId": "db-test-123", "SchemaName": "test_db", "DbType": "MySQL" } } mock_dms_client.get_database.return_value = create_mock_openapi_response(db_response) # 使用 CONNECTION_STRING 替代原来的 DATABASE_ID with patch.dict(os.environ, {"CONNECTION_STRING": "test_db@localhost:3306"}): with patch('alibabacloud_dms_mcp_server.server.ToolRegistry.register_tools') as mock_register: async with lifespan(mcp_app): assert hasattr(mcp_app.state, 'default_database_id') assert mcp_app.state.default_database_id == "db-test-123" mock_register.assert_called_once() assert not hasattr(mcp_app.state, 'default_database_id') # 检查清理 @pytest.mark.asyncio async def test_lifespan_without_connection_string(mcp_app): # 确保环境变量为空 with patch.dict(os.environ, {"CONNECTION_STRING": ""}): with patch('alibabacloud_dms_mcp_server.server.ToolRegistry.register_tools') as mock_register: async with lifespan(mcp_app): assert hasattr(mcp_app.state, 'default_database_id') assert mcp_app.state.default_database_id is None mock_register.assert_called_once() assert not hasattr(mcp_app.state, 'default_database_id') @pytest.mark.asyncio async def test_lifespan_with_pg_connection_string(mcp_app, mock_dms_client): # 模拟 get_instance 的成功响应 instance_response = { "Instance": { "InstanceId": "pg-test123", "State": "NORMAL", "InstanceType": "PostgreSQL" } } mock_dms_client.get_instance.return_value = create_mock_openapi_response(instance_response) # 模拟 get_database 的成功响应,包含DatabaseId db_response = { "Database": { "DatabaseId": "pg-db-test-456", "SchemaName": "pg_schema", "DbType": "PostgreSQL" } } mock_dms_client.get_database.return_value = create_mock_openapi_response(db_response) # 使用PostgreSQL格式的CONNECTION_STRING (catalog@host:port:schema) with patch.dict(os.environ, {"CONNECTION_STRING": "test_db@localhost:5432:pg_schema"}): with patch('alibabacloud_dms_mcp_server.server.ToolRegistry.register_tools') as mock_register: async with lifespan(mcp_app): assert hasattr(mcp_app.state, 'default_database_id') assert mcp_app.state.default_database_id == "pg-db-test-456" mock_register.assert_called_once() # 验证调用get_database时使用了正确的参数 call_args = mock_dms_client.get_database.call_args[0][0] assert call_args.host == "localhost" assert call_args.port == "5432" assert call_args.schema_name == "test_db" # catalog名称用作search_key assert call_args.sid == "pg_schema" # schema名称作为sid参数传递 assert not hasattr(mcp_app.state, 'default_database_id') # 检查清理 # --- Test ExecuteScriptResult __str__ --- def test_execute_script_result_str_success_with_markdown(): result = ExecuteScriptResult( RequestId="req1", Success=True, Results=[ ResultSet(ColumnNames=["colA"], RowCount=1, Rows=[{"colA": "valA"}], MarkdownTable="## Markdown Table", Success=True) ] ) assert str(result) == "## Markdown Table" def test_execute_script_result_str_success_no_markdown(): result = ExecuteScriptResult( RequestId="req2", Success=True, Results=[ ResultSet(ColumnNames=["colB"], RowCount=0, Rows=[], MarkdownTable=None, Success=True) ] ) assert str(result) == "Result data is not available in Markdown format." def test_execute_script_result_str_first_result_not_success(): result = ExecuteScriptResult( RequestId="req3", Success=True, # Overall success Results=[ ResultSet(ColumnNames=[], RowCount=0, Rows=[], MarkdownTable=None, Success=False) # First result failed ] ) assert str(result) == "The first result set was not successful." def test_execute_script_result_str_overall_failure(): result = ExecuteScriptResult( RequestId="req4", Success=False, # Overall failure Results=[] ) assert str(result) == "Script execution failed." def test_execute_script_result_str_success_no_results(): result = ExecuteScriptResult( RequestId="req5", Success=True, Results=[] # No results ) assert str(result) == "Script executed successfully, but no results were returned." # --- Test _format_as_markdown_table --- from alibabacloud_dms_mcp_server.server import _format_as_markdown_table def test_format_as_markdown_table_basic(): cols = ["ID", "Name"] rows = [{"ID": 1, "Name": "Alice"}, {"ID": 2, "Name": "Bob"}] expected_md = """| ID | Name | | --- | --- | | 1 | Alice | | 2 | Bob |""" assert _format_as_markdown_table(cols, rows) == expected_md def test_format_as_markdown_table_empty(): assert _format_as_markdown_table([], []) == "" assert _format_as_markdown_table(["ID"], []) == "" assert _format_as_markdown_table([], [{"ID": 1}]) == "" def test_format_as_markdown_table_missing_keys(): cols = ["ID", "Name", "Age"] rows = [{"ID": 1, "Name": "Alice"}, {"ID": 2}] # Bob is missing Name and Age expected_md = """| ID | Name | Age | | --- | --- | --- | | 1 | Alice | | | 2 | | |""" # Missing values should be empty strings assert _format_as_markdown_table(cols, rows) == expected_md

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/aliyun/alibabacloud-dms-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server