Skip to main content
Glama
test_auth_providers.py21.5 kB
"""Tests for the authentication provider architecture.""" import os from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from amazon_ads_mcp.auth.base import BaseAmazonAdsProvider, ProviderConfig from amazon_ads_mcp.auth.providers.direct import DirectAmazonAdsProvider from amazon_ads_mcp.auth.providers.openbridge import OpenBridgeProvider from amazon_ads_mcp.auth.registry import ProviderRegistry, register_provider from amazon_ads_mcp.models import Token class TestProviderRegistry: """Tests for the provider registry system.""" def test_registry_lists_builtin_providers(self): """Test that built-in providers are registered.""" providers = ProviderRegistry.list_providers() assert "direct" in providers assert "openbridge" in providers def test_registry_creates_direct_provider(self): """Test creating a direct provider through registry.""" config = ProviderConfig( client_id="test_client", client_secret="test_secret", refresh_token="test_refresh", region="na" ) provider = ProviderRegistry.create_provider("direct", config) assert isinstance(provider, DirectAmazonAdsProvider) assert provider.client_id == "test_client" assert provider.region == "na" def test_registry_creates_openbridge_provider(self): """Test creating an OpenBridge provider through registry.""" config = ProviderConfig(refresh_token="test_openbridge_token") provider = ProviderRegistry.create_provider("openbridge", config) assert isinstance(provider, OpenBridgeProvider) assert provider.refresh_token == "test_openbridge_token" def test_registry_raises_for_unknown_provider(self): """Test that registry raises error for unknown provider.""" config = ProviderConfig() with pytest.raises(ValueError, match="Unknown provider type"): ProviderRegistry.create_provider("unknown", config) def test_custom_provider_registration(self): """Test registering a custom provider.""" @register_provider("test_custom") class TestCustomProvider(BaseAmazonAdsProvider): def __init__(self, config: ProviderConfig): self.config = config @property def provider_type(self) -> str: return "test_custom" @property def region(self) -> str: return "na" async def initialize(self): pass async def get_token(self) -> Token: return Token( value="test", expires_at=datetime.now() + timedelta(hours=1), token_type="Bearer" ) async def validate_token(self, token: Token) -> bool: return True async def get_headers(self) -> dict: return {"test": "header"} async def close(self): pass # Check it's registered providers = ProviderRegistry.list_providers() assert "test_custom" in providers # Create instance config = ProviderConfig() provider = ProviderRegistry.create_provider("test_custom", config) assert isinstance(provider, TestCustomProvider) class TestDirectProvider: """Tests for the Direct Amazon Ads provider.""" @pytest.fixture def direct_provider(self): """Create a direct provider instance.""" config = ProviderConfig( client_id="test_client", client_secret="test_secret", refresh_token="test_refresh", profile_id="test_profile", region="na" ) return DirectAmazonAdsProvider(config) def test_direct_provider_initialization(self, direct_provider): """Test direct provider initializes correctly.""" assert direct_provider.client_id == "test_client" assert direct_provider.client_secret == "test_secret" assert direct_provider.refresh_token == "test_refresh" assert direct_provider.profile_id == "test_profile" assert direct_provider.region == "na" assert direct_provider.get_oauth_endpoint() == "https://api.amazon.com/auth/o2/token" def test_direct_provider_region_endpoints(self): """Test that regions map to correct auth endpoints.""" na_config = ProviderConfig( client_id="c", client_secret="s", refresh_token="r", region="na" ) na_provider = DirectAmazonAdsProvider(na_config) assert na_provider.get_oauth_endpoint() == "https://api.amazon.com/auth/o2/token" eu_config = ProviderConfig( client_id="c", client_secret="s", refresh_token="r", region="eu" ) eu_provider = DirectAmazonAdsProvider(eu_config) assert eu_provider.get_oauth_endpoint() == "https://api.amazon.co.uk/auth/o2/token" fe_config = ProviderConfig( client_id="c", client_secret="s", refresh_token="r", region="fe" ) fe_provider = DirectAmazonAdsProvider(fe_config) assert fe_provider.get_oauth_endpoint() == "https://api.amazon.co.jp/auth/o2/token" @pytest.mark.asyncio async def test_direct_provider_token_refresh(self, direct_provider): """Test token refresh for direct provider.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "access_token": "new_access_token", "expires_in": 3600 } with patch("httpx.AsyncClient.post", return_value=mock_response): with patch.object(direct_provider, "_get_client", new_callable=AsyncMock) as mock_get_client: mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_get_client.return_value = mock_client token = await direct_provider._refresh_access_token() assert token.value == "new_access_token" assert token.token_type == "Bearer" # Check expiration is roughly 1 hour from now expected_expiry = datetime.now(timezone.utc) + timedelta(seconds=3600) assert abs((token.expires_at - expected_expiry).total_seconds()) < 5 # Clean up the token to avoid affecting other tests direct_provider._access_token = None @pytest.mark.asyncio async def test_direct_provider_token_caching(self, direct_provider): """Test that tokens are cached and reused.""" # Reset AuthManager to ensure clean state from amazon_ads_mcp.auth.manager import AuthManager AuthManager.reset() # Set a cached token direct_provider._access_token = Token( value="cached_token", expires_at=datetime.now(timezone.utc) + timedelta(hours=1), token_type="Bearer" ) # Mock _refresh_access_token to prevent it from being called with patch.object(direct_provider, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh: # Get token should return cached one without making request token = await direct_provider.get_token() assert token.value == "cached_token" # Ensure _refresh_access_token was not called mock_refresh.assert_not_called() @pytest.mark.asyncio async def test_direct_provider_expired_token_refresh(self, direct_provider): """Test that expired tokens trigger refresh.""" # Reset AuthManager to ensure clean state from amazon_ads_mcp.auth.manager import AuthManager AuthManager.reset() # Set an expired cached token direct_provider._access_token = Token( value="expired_token", expires_at=datetime.now(timezone.utc) - timedelta(hours=1), token_type="Bearer" ) mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "access_token": "fresh_token", "expires_in": 3600 } with patch.object(direct_provider, "_get_client", new_callable=AsyncMock) as mock_get_client: mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_get_client.return_value = mock_client # Mock _refresh_access_token to return our expected token with patch.object(direct_provider, "_refresh_access_token", return_value=Token( value="fresh_token", expires_at=datetime.now(timezone.utc) + timedelta(hours=1), token_type="Bearer" )): token = await direct_provider.get_token() assert token.value == "fresh_token" @pytest.mark.asyncio async def test_direct_provider_list_identities(self, direct_provider): """Test that direct provider returns single synthetic identity.""" identities = await direct_provider.list_identities() assert len(identities) == 1 identity = identities[0] assert identity.id == "direct-auth" assert identity.type == "amazon_ads_direct" assert identity.attributes["auth_method"] == "direct" assert identity.attributes["client_id"] == "test_client" assert identity.attributes["region"] == "na" @pytest.mark.asyncio async def test_direct_provider_get_credentials(self, direct_provider): """Test getting credentials from direct provider.""" # Reset AuthManager to ensure clean state from amazon_ads_mcp.auth.manager import AuthManager AuthManager.reset() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "access_token": "test_token", "expires_in": 3600 } with patch.object(direct_provider, "_get_client", new_callable=AsyncMock) as mock_get_client: mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_get_client.return_value = mock_client # Mock _refresh_access_token to return our expected token with patch.object(direct_provider, "_refresh_access_token", return_value=Token( value="test_token", expires_at=datetime.now(timezone.utc) + timedelta(hours=1), token_type="Bearer" )): credentials = await direct_provider.get_identity_credentials("direct-auth") assert credentials.identity_id == "direct-auth" assert credentials.access_token == "test_token" assert "Authorization" in credentials.headers assert credentials.headers["Authorization"] == "Bearer test_token" assert credentials.headers["Amazon-Advertising-API-ClientId"] == "test_client" assert credentials.headers["Amazon-Advertising-API-Scope"] == "test_profile" class TestOpenBridgeProvider: """Tests for the OpenBridge provider.""" @pytest.fixture def openbridge_provider(self): """Create an OpenBridge provider instance.""" config = ProviderConfig(refresh_token="test_openbridge_refresh") return OpenBridgeProvider(config) def test_openbridge_provider_initialization(self, openbridge_provider): """Test OpenBridge provider initializes correctly.""" assert openbridge_provider.refresh_token == "test_openbridge_refresh" assert openbridge_provider.auth_base_url == "https://authentication.api.openbridge.io" assert openbridge_provider.identity_base_url == "https://remote-identity.api.openbridge.io" assert openbridge_provider.service_base_url == "https://service.api.openbridge.io" def test_openbridge_provider_custom_endpoints(self, monkeypatch): """Test OpenBridge provider with custom endpoints.""" config = ProviderConfig( refresh_token="test", auth_base_url="https://custom.auth.url", identity_base_url="https://custom.identity.url", service_base_url="https://custom.service.url" ) provider = OpenBridgeProvider(config) assert provider.auth_base_url == "https://custom.auth.url" assert provider.identity_base_url == "https://custom.identity.url" assert provider.service_base_url == "https://custom.service.url" @pytest.mark.asyncio async def test_openbridge_jwt_conversion(self, openbridge_provider): """Test converting refresh token to JWT.""" mock_response = MagicMock() mock_response.status_code = 202 mock_response.json.return_value = { "data": { "attributes": { "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHBpcmVzX2F0IjoxNzAwMDAwMDAwLCJ1c2VyX2lkIjoiMTIzIiwiYWNjb3VudF9pZCI6IjQ1NiJ9.test" } } } with patch.object(openbridge_provider, "_get_client", new_callable=AsyncMock) as mock_get_client: mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_get_client.return_value = mock_client with patch("jwt.decode") as mock_decode: mock_decode.return_value = { "expires_at": 1700000000, "user_id": "123", "account_id": "456" } token = await openbridge_provider._refresh_jwt_token() assert token.value.startswith("eyJ") assert token.token_type == "Bearer" assert token.metadata["user_id"] == "123" assert token.metadata["account_id"] == "456" @pytest.mark.asyncio async def test_openbridge_list_identities(self, openbridge_provider): """Test listing identities from OpenBridge.""" # Mock JWT token openbridge_provider._jwt_token = Token( value="test_jwt", expires_at=datetime.now(timezone.utc) + timedelta(hours=1), token_type="Bearer" ) mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "data": [ { "id": "identity-1", "type": "remote_identity", "attributes": {"name": "Test Identity 1", "region": "na"} }, { "id": "identity-2", "type": "remote_identity", "attributes": {"name": "Test Identity 2", "region": "eu"} } ], "links": {} } mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) with patch.object(openbridge_provider, "_get_client", return_value=mock_client): identities = await openbridge_provider.list_identities() assert len(identities) == 2 assert identities[0].id == "identity-1" assert identities[0].attributes["name"] == "Test Identity 1" assert identities[1].id == "identity-2" assert identities[1].attributes["region"] == "eu" @pytest.mark.asyncio async def test_openbridge_identity_caching(self, openbridge_provider): """Test that identities are cached.""" # This test verifies the caching logic exists # The actual OpenBridge provider doesn't have a simple _identities_cache # but uses a more complex caching mechanism # Check that the provider has caching attributes assert hasattr(openbridge_provider, '_identities_cache') # Mock the fetch to avoid real API calls from amazon_ads_mcp.models import Identity test_identities = [Identity(id="test-1", type="test", attributes={})] with patch.object(openbridge_provider, '_fetch_identities', return_value=test_identities): # First call should fetch identities1 = await openbridge_provider.list_identities() assert len(identities1) == 1 assert identities1[0].id == "test-1" class TestAuthManagerIntegration: """Integration tests for auth manager with providers.""" @pytest.mark.asyncio async def test_auth_manager_auto_detect_direct(self, monkeypatch): """Test auth manager auto-detects direct credentials.""" # Explicitly set AUTH_METHOD to override auto-detection from .env file monkeypatch.setenv("AUTH_METHOD", "direct") monkeypatch.setenv("AMAZON_AD_API_CLIENT_ID", "test_client") monkeypatch.setenv("AMAZON_AD_API_CLIENT_SECRET", "test_secret") monkeypatch.setenv("AMAZON_AD_API_REFRESH_TOKEN", "test_refresh") from amazon_ads_mcp.auth.manager import AuthManager AuthManager.reset() # Reset singleton manager = AuthManager() assert isinstance(manager.provider, DirectAmazonAdsProvider) @pytest.mark.asyncio async def test_auth_manager_explicit_openbridge(self, monkeypatch): """Test auth manager uses OpenBridge when explicitly configured.""" # Clear direct credentials (both new and legacy environment variables) # Remove them entirely to override .env file values monkeypatch.delenv("AMAZON_AD_API_CLIENT_ID", raising=False) monkeypatch.delenv("AMAZON_AD_API_CLIENT_SECRET", raising=False) monkeypatch.delenv("AMAZON_AD_API_REFRESH_TOKEN", raising=False) monkeypatch.delenv("AMAZON_ADS_CLIENT_ID", raising=False) monkeypatch.delenv("AMAZON_ADS_CLIENT_SECRET", raising=False) monkeypatch.delenv("AMAZON_ADS_REFRESH_TOKEN", raising=False) # Set OpenBridge credentials monkeypatch.setenv("OPENBRIDGE_REFRESH_TOKEN", "test_ob_token") # Explicitly set AUTH_METHOD to openbridge to override auto-detection monkeypatch.setenv("AUTH_METHOD", "openbridge") from amazon_ads_mcp.auth.manager import AuthManager AuthManager.reset() # Reset singleton manager = AuthManager() assert isinstance(manager.provider, OpenBridgeProvider) @pytest.mark.asyncio async def test_auth_manager_explicit_method(self, monkeypatch): """Test auth manager uses explicit AUTH_METHOD.""" monkeypatch.setenv("AUTH_METHOD", "direct") monkeypatch.setenv("AMAZON_AD_API_CLIENT_ID", "test_client") monkeypatch.setenv("AMAZON_AD_API_CLIENT_SECRET", "test_secret") monkeypatch.setenv("AMAZON_AD_API_REFRESH_TOKEN", "test_refresh") from amazon_ads_mcp.auth.manager import AuthManager AuthManager.reset() # Reset singleton manager = AuthManager() assert isinstance(manager.provider, DirectAmazonAdsProvider) def test_auth_manager_no_credentials_error(self, monkeypatch): """Test auth manager raises error when no credentials.""" # Note: This test may not work properly if there's a .env file with credentials. # In a real test environment, we'd mock the settings loading to avoid this. # For now, we skip this test if a .env file is present. if os.path.exists(".env"): pytest.skip(".env file present - would interfere with test") # Clear all auth environment variables for key in ["AUTH_METHOD", "AMAZON_AD_API_CLIENT_ID", "AMAZON_AD_API_CLIENT_SECRET", "AMAZON_AD_API_REFRESH_TOKEN", "OPENBRIDGE_REFRESH_TOKEN", "AMAZON_ADS_CLIENT_ID", "AMAZON_ADS_CLIENT_SECRET", "AMAZON_ADS_REFRESH_TOKEN"]: monkeypatch.delenv(key, raising=False) from amazon_ads_mcp.auth.manager import AuthManager AuthManager.reset() # Reset singleton with pytest.raises(ValueError, match="No authentication method configured"): AuthManager() class TestProviderConfig: """Tests for ProviderConfig class.""" def test_provider_config_dict_access(self): """Test ProviderConfig works like a dict.""" config = ProviderConfig( client_id="test", custom_field="value" ) assert config.get("client_id") == "test" assert config.get("custom_field") == "value" assert config.get("missing", "default") == "default" def test_provider_config_attribute_access(self): """Test ProviderConfig allows attribute access.""" config = ProviderConfig( client_id="test", api_key="key123" ) assert config.client_id == "test" assert config.api_key == "key123" assert config.get("client_id") == config.client_id

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/KuudoAI/amazon_ads_mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server