Skip to main content
Glama
test_security_refactor.py13.9 kB
"""Tests for security refactoring.""" import asyncio from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock import pytest from amazon_ads_mcp.auth.oauth_state_store import OAuthStateStore from amazon_ads_mcp.auth.secure_token_store import SecureTokenStore from amazon_ads_mcp.exceptions import ( OAuthStateError, TimeoutError, APIError, ToolExecutionError, ) from amazon_ads_mcp.utils.async_compat import ( CompatibleEventLoopPolicy, ensure_event_loop, run_async_in_sync, AsyncContextManager, ) from amazon_ads_mcp.utils.response_wrapper import ResponseWrapper from amazon_ads_mcp.utils.sampling_wrapper import SamplingHandlerWrapper class TestOAuthStateStore: """Test OAuth state store functionality.""" def test_generate_state(self): """Test state generation with HMAC signature.""" store = OAuthStateStore(secret_key="test_secret") state = store.generate_state( auth_url="https://example.com/auth", user_agent="TestAgent/1.0", ip_address="192.168.1.1" ) assert state is not None assert "." in state # Should have signature separator assert len(state) > 40 # Should be reasonably long def test_validate_state_success(self): """Test successful state validation.""" store = OAuthStateStore(secret_key="test_secret") state = store.generate_state( auth_url="https://example.com/auth", user_agent="TestAgent/1.0" ) is_valid, error = store.validate_state(state, user_agent="TestAgent/1.0") assert is_valid is True assert error is None def test_validate_state_invalid(self): """Test invalid state validation.""" store = OAuthStateStore(secret_key="test_secret") # Test with completely invalid state is_valid, error = store.validate_state("invalid_state") assert is_valid is False assert error == "Invalid or expired state" def test_validate_state_tampered(self): """Test tampered state detection.""" store = OAuthStateStore(secret_key="test_secret") state = store.generate_state(auth_url="https://example.com/auth") # Tamper with the signature base, sig = state.rsplit(".", 1) tampered_state = f"{base}.tampered_signature" is_valid, error = store.validate_state(tampered_state) assert is_valid is False # Tampering the signature changes the state token; store lookup fails first assert error in ("Invalid or expired state", "Invalid state signature") def test_validate_state_reuse_prevention(self): """Test that states cannot be reused.""" store = OAuthStateStore(secret_key="test_secret") state = store.generate_state(auth_url="https://example.com/auth") # First validation should succeed is_valid, error = store.validate_state(state) assert is_valid is True # Second validation should fail is_valid, error = store.validate_state(state) assert is_valid is False assert "already used" in error def test_state_expiration(self): """Test state expiration.""" store = OAuthStateStore(secret_key="test_secret") state = store.generate_state( auth_url="https://example.com/auth", ttl_minutes=0 # Expire immediately ) # Force expiration entry = store._memory_store[state] entry.expires_at = datetime.now(timezone.utc) - timedelta(minutes=1) is_valid, error = store.validate_state(state) assert is_valid is False assert "expired" in error.lower() def test_persistence(self, tmp_path): """Test state persistence to file.""" store_path = tmp_path / "oauth_states.json" store1 = OAuthStateStore(secret_key="test_secret", store_path=store_path) state = store1.generate_state(auth_url="https://example.com/auth") # Create new store instance store2 = OAuthStateStore(secret_key="test_secret", store_path=store_path) # Should be able to validate state from first store is_valid, error = store2.validate_state(state) assert is_valid is True class TestSecureTokenStore: """Test secure token storage.""" def test_store_and_retrieve(self, tmp_path): """Test storing and retrieving tokens.""" store = SecureTokenStore( storage_path=tmp_path / "tokens.enc", encryption_key="test_password" ) store.store_token( token_id="test_token", token_value="secret_value_123", token_type="refresh", expires_at=datetime.now(timezone.utc) + timedelta(hours=1) ) token = store.get_token("test_token") assert token is not None assert token["value"] == "secret_value_123" assert token["type"] == "refresh" def test_encryption(self, tmp_path): """Test that tokens are encrypted on disk.""" storage_path = tmp_path / "tokens.enc" store = SecureTokenStore( storage_path=storage_path, encryption_key="test_password" ) store.store_token( token_id="sensitive_token", token_value="super_secret_value", token_type="access" ) # Read raw file content with open(storage_path, "rb") as f: raw_content = f.read() # Should not contain the plaintext token assert b"super_secret_value" not in raw_content assert b"sensitive_token" not in raw_content # ID should also be encrypted def test_expiration(self, tmp_path): """Test token expiration.""" store = SecureTokenStore( storage_path=tmp_path / "tokens.enc", encryption_key="test_password" ) # Store expired token store.store_token( token_id="expired_token", token_value="old_value", expires_at=datetime.now(timezone.utc) - timedelta(hours=1) ) # Should not retrieve expired token token = store.get_token("expired_token") assert token is None def test_persistence_across_instances(self, tmp_path): """Test token persistence across store instances.""" storage_path = tmp_path / "tokens.enc" # Store token with first instance store1 = SecureTokenStore( storage_path=storage_path, encryption_key="test_password" ) store1.store_token( token_id="persistent_token", token_value="persistent_value" ) # Retrieve with second instance store2 = SecureTokenStore( storage_path=storage_path, encryption_key="test_password" ) token = store2.get_token("persistent_token") assert token is not None assert token["value"] == "persistent_value" def test_wrong_key_fails(self, tmp_path): """Test that wrong encryption key fails gracefully.""" storage_path = tmp_path / "tokens.enc" # Store with one key store1 = SecureTokenStore( storage_path=storage_path, encryption_key="correct_password" ) store1.store_token(token_id="test", token_value="value") # Try to load with wrong key store2 = SecureTokenStore( storage_path=storage_path, encryption_key="wrong_password" ) # Should start fresh, not crash token = store2.get_token("test") assert token is None class TestAsyncCompatibility: """Test async compatibility utilities.""" def test_compatible_event_loop_policy(self): """Test compatible event loop policy.""" policy = CompatibleEventLoopPolicy() asyncio.set_event_loop_policy(policy) # Should create loop when needed loop = asyncio.get_event_loop() assert loop is not None assert not loop.is_closed() # Clean up loop.close() asyncio.set_event_loop(None) asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) def test_ensure_event_loop(self): """Test ensure_event_loop function.""" # Clear any existing loop try: loop = asyncio.get_event_loop() loop.close() except RuntimeError: pass asyncio.set_event_loop(None) # Should create new loop loop = ensure_event_loop() assert loop is not None assert not loop.is_closed() # Clean up loop.close() asyncio.set_event_loop(None) def test_run_async_in_sync(self): """Test running async function from sync context.""" async def async_func(value): await asyncio.sleep(0.01) return value * 2 result = run_async_in_sync(async_func, 21) assert result == 42 def test_async_context_manager(self): """Test AsyncContextManager.""" async def async_task(): await asyncio.sleep(0.01) return "completed" with AsyncContextManager() as ctx: result = ctx.run(async_task()) assert result == "completed" class TestResponseWrapper: """Test response wrapper functionality.""" def test_response_wrapper_basic(self): """Test basic response wrapper functionality.""" import httpx # Create mock response response = httpx.Response( 200, headers={"content-type": "application/json"}, content=b'{"key": "value"}' ) wrapper = ResponseWrapper(response) assert wrapper.status_code == 200 assert wrapper.json() == {"key": "value"} def test_response_wrapper_modification(self): """Test response content modification.""" import httpx response = httpx.Response( 200, headers={"content-type": "application/json"}, content=b'{"old": "value"}' ) wrapper = ResponseWrapper(response) wrapper.set_json({"new": "value"}) assert wrapper.json() == {"new": "value"} assert wrapper.content == b'{"new": "value"}' def test_response_wrapper_modify_json(self): """Test JSON modification with function.""" import httpx response = httpx.Response( 200, headers={"content-type": "application/json"}, content=b'{"count": 10}' ) wrapper = ResponseWrapper(response) wrapper.modify_json(lambda data: {**data, "count": data["count"] * 2}) assert wrapper.json() == {"count": 20} class TestStructuredExceptions: """Test structured exception classes.""" def test_oauth_state_error(self): """Test OAuthStateError.""" error = OAuthStateError("Invalid state") assert error.code == "OAUTH_STATE_ERROR" assert error.message == "Invalid state" error_dict = error.to_dict() assert error_dict["error"] == "OAUTH_STATE_ERROR" assert error_dict["message"] == "Invalid state" def test_timeout_error(self): """Test TimeoutError.""" error = TimeoutError("Request timed out", operation="list_campaigns") assert error.code == "TIMEOUT_ERROR" assert error.details["operation"] == "list_campaigns" def test_api_error(self): """Test APIError.""" error = APIError( "API request failed", status_code=404, response_body="Not found" ) assert error.code == "API_ERROR" assert error.status_code == 404 assert error.details["status_code"] == 404 assert error.details["response_body"] == "Not found" def test_tool_execution_error(self): """Test ToolExecutionError.""" original = ValueError("Original error") error = ToolExecutionError( "Tool failed", tool_name="test_tool", original_error=original ) assert error.code == "TOOL_EXECUTION_ERROR" assert error.tool_name == "test_tool" assert error.details["tool"] == "test_tool" assert "ValueError" in error.details["error_type"] class TestSamplingWrapper: """Test sampling wrapper functionality.""" @pytest.mark.asyncio async def test_sampling_wrapper_with_handler(self): """Test sampling wrapper with configured handler.""" # Create mock handler async def mock_handler(messages, params, context): return MagicMock(content="sampled response") wrapper = SamplingHandlerWrapper() wrapper.set_handler(mock_handler) assert wrapper.has_handler() is True # Mock context that doesn't support sampling mock_ctx = MagicMock() mock_ctx.sample = AsyncMock(side_effect=Exception("does not support sampling")) mock_ctx.request_context = {} result = await wrapper.sample( messages="test message", ctx=mock_ctx ) assert result == "sampled response" @pytest.mark.asyncio async def test_sampling_wrapper_no_handler(self): """Test sampling wrapper without handler.""" wrapper = SamplingHandlerWrapper() assert wrapper.has_handler() is False mock_ctx = MagicMock() mock_ctx.sample = AsyncMock(side_effect=Exception("does not support sampling")) with pytest.raises(Exception) as exc_info: await wrapper.sample( messages="test message", ctx=mock_ctx ) assert "no server-side fallback is configured" in str(exc_info.value)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/KuudoAI/amazon_ads_mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server