Skip to main content
Glama
test_oauth_discovery.pyโ€ข30.4 kB
#!/usr/bin/env python3 """ OAuth Discovery Endpoints Testing for Kafka Schema Registry MCP Server Tests the OAuth 2.0 discovery endpoints that enable MCP client auto-configuration: - /.well-known/oauth-authorization-server (RFC 8414) - /.well-known/oauth-protected-resource (RFC 8692) - /.well-known/jwks.json (RFC 7517) """ import json import os import subprocess import sys import time import requests # Add parent directory to path to import modules sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class OAuthDiscoveryTest: """OAuth discovery endpoints test class.""" def __init__(self): self.test_results = [] self.server_process = None self.server_url = "http://localhost:8899" # Use different port to avoid conflicts def run_test(self, test_name: str, test_func): """Run a single test and track results.""" try: print(f"\n๐Ÿงช Running: {test_name}") result = test_func() if result: print(f"โœ… {test_name} PASSED") self.test_results.append((test_name, True, None)) return True else: print(f"โŒ {test_name} FAILED") self.test_results.append((test_name, False, "Test returned False")) return False except Exception as e: print(f"โŒ {test_name} FAILED with exception: {e}") self.test_results.append((test_name, False, str(e))) return False def setup_test_server(self, enable_auth: bool = False) -> bool: """Start a test remote MCP server with specified OAuth configuration.""" try: print(f"๐Ÿš€ Starting test server (OAuth enabled: {enable_auth})...") # Set environment variables for the test env = os.environ.copy() env.update( { "MCP_TRANSPORT": "streamable-http", "MCP_HOST": "localhost", "MCP_PORT": "8899", "ENABLE_AUTH": "true" if enable_auth else "false", "AUTH_PROVIDER": "azure", "AUTH_AUDIENCE": "test-audience", "AZURE_TENANT_ID": "test-tenant-123", "OKTA_DOMAIN": "test-domain.okta.com", "AUTH_GITHUB_CLIENT_ID": "test-github-client", "SCHEMA_REGISTRY_URL": "http://localhost:38081", # Use test registry } ) # Start the remote server cmd = [ sys.executable, os.path.join(os.path.dirname(os.path.dirname(__file__)), "remote-mcp-server.py"), ] self.server_process = subprocess.Popen( cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True ) # Wait for server to start and detect which port it's actually using detected_port = None for i in range(30): # Wait up to 30 seconds # Try the intended port first for test_port in [8899, 8000]: # Try both ports try: response = requests.get(f"http://localhost:{test_port}/health", timeout=2) if response.status_code in [ 200, 503, ]: # 503 is OK if registries not available detected_port = test_port self.server_url = f"http://localhost:{test_port}" print(f"โœ… Test server started on {self.server_url}") return True except requests.exceptions.RequestException: pass time.sleep(1) print("โŒ Test server failed to start within 30 seconds") return False except Exception as e: print(f"โŒ Failed to start test server: {e}") return False def teardown_test_server(self): """Stop the test server.""" if self.server_process: try: print("๐Ÿ›‘ Stopping test server...") self.server_process.terminate() self.server_process.wait(timeout=10) print("โœ… Test server stopped") except subprocess.TimeoutExpired: print("โš ๏ธ Force killing test server...") self.server_process.kill() self.server_process.wait() except Exception as e: print(f"โš ๏ธ Error stopping server: {e}") finally: self.server_process = None def test_oauth_authorization_server_endpoint(self) -> bool: """Test /.well-known/oauth-authorization-server endpoint.""" print("๐Ÿ” Testing OAuth Authorization Server discovery endpoint...") try: response = requests.get(f"{self.server_url}/.well-known/oauth-authorization-server", timeout=10) print(f" Status Code: {response.status_code}") print(f" Content-Type: {response.headers.get('Content-Type', 'N/A')}") if response.status_code == 404: print(" โ„น๏ธ Got 404 - this is expected when OAuth is disabled") return True if response.status_code != 200: print(f" โŒ Expected 200 or 404, got {response.status_code}") return False # Validate JSON response try: data = response.json() except json.JSONDecodeError: print(" โŒ Response is not valid JSON") return False # Check required fields per RFC 8414 required_fields = ["issuer", "scopes_supported"] for field in required_fields: if field not in data: print(f" โŒ Missing required field: {field}") return False print(f" โœ… Found required field: {field}") # Check MCP-specific extensions mcp_fields = ["mcp_server_version", "mcp_transport", "mcp_endpoints"] for field in mcp_fields: if field in data: print(f" โœ… Found MCP extension: {field}") else: print(f" โš ๏ธ Missing MCP extension: {field}") # Validate scopes scopes = data.get("scopes_supported", []) expected_scopes = ["read", "write", "admin"] for scope in expected_scopes: if scope in scopes: print(f" โœ… Found expected scope: {scope}") else: print(f" โš ๏ธ Missing expected scope: {scope}") # Check PKCE requirements (mandatory per MCP spec) if "code_challenge_methods_supported" in data: pkce_methods = data["code_challenge_methods_supported"] print(f" โœ… PKCE methods supported: {pkce_methods}") # Verify S256 is supported (mandatory) if "S256" in pkce_methods: print(" โœ… S256 method supported (required)") else: print(" โŒ S256 method not supported (should be mandatory)") return False # Verify plain is NOT supported (less secure) if "plain" in pkce_methods: print(" โš ๏ธ Plain method supported (not recommended for security)") else: print(" โœ… Plain method not supported (secure configuration)") # Check if PKCE is marked as required if data.get("require_pkce") is True: print(" โœ… PKCE marked as required (MCP compliant)") else: print(" โš ๏ธ PKCE not explicitly marked as required") else: print(" โŒ No PKCE methods advertised") return False # Check CORS headers cors_headers = ["Access-Control-Allow-Origin", "Cache-Control"] for header in cors_headers: if header in response.headers: print(f" โœ… Found header: {header}: {response.headers[header]}") else: print(f" โš ๏ธ Missing header: {header}") print(" โœ… OAuth authorization server endpoint validation passed") return True except requests.exceptions.RequestException as e: print(f" โŒ Request failed: {e}") return False def test_oauth_protected_resource_endpoint(self) -> bool: """Test /.well-known/oauth-protected-resource endpoint.""" print("๐Ÿ›ก๏ธ Testing OAuth Protected Resource discovery endpoint...") try: response = requests.get(f"{self.server_url}/.well-known/oauth-protected-resource", timeout=10) print(f" Status Code: {response.status_code}") print(f" Content-Type: {response.headers.get('Content-Type', 'N/A')}") if response.status_code == 404: print(" โ„น๏ธ Got 404 - this is expected when OAuth is disabled") return True if response.status_code != 200: print(f" โŒ Expected 200 or 404, got {response.status_code}") return False # Validate JSON response try: data = response.json() except json.JSONDecodeError: print(" โŒ Response is not valid JSON") return False # Check required fields per RFC 8692 required_fields = ["resource", "authorization_servers", "scopes_supported"] for field in required_fields: if field not in data: print(f" โŒ Missing required field: {field}") return False print(f" โœ… Found required field: {field}") # Check MCP-specific fields mcp_fields = [ "mcp_server_info", "scope_descriptions", "protected_endpoints", ] for field in mcp_fields: if field in data: print(f" โœ… Found MCP extension: {field}") else: print(f" โš ๏ธ Missing MCP extension: {field}") # Validate server info if "mcp_server_info" in data: server_info = data["mcp_server_info"] info_fields = ["name", "version", "transport", "tools_count"] for field in info_fields: if field in server_info: print(f" โœ… Server info contains: {field}: {server_info[field]}") else: print(f" โš ๏ธ Server info missing: {field}") # Validate scope descriptions if "scope_descriptions" in data: scope_desc = data["scope_descriptions"] expected_scopes = ["read", "write", "admin"] for scope in expected_scopes: if scope in scope_desc: print(f" โœ… Scope description for '{scope}': {scope_desc[scope][:50]}...") else: print(f" โš ๏ธ Missing scope description: {scope}") # Check PKCE requirements (should also be in protected resource metadata) if data.get("require_pkce") is True: print(" โœ… PKCE marked as required in protected resource") else: print(" โš ๏ธ PKCE not marked as required in protected resource") if "pkce_code_challenge_methods" in data: pkce_methods = data["pkce_code_challenge_methods"] print(f" โœ… PKCE methods in protected resource: {pkce_methods}") if "S256" in pkce_methods: print(" โœ… S256 method in protected resource (secure)") else: print(" โš ๏ธ S256 method missing from protected resource") if data.get("pkce_note"): print(f" โœ… PKCE note: {data['pkce_note']}") print(" โœ… OAuth protected resource endpoint validation passed") return True except requests.exceptions.RequestException as e: print(f" โŒ Request failed: {e}") return False def test_jwks_endpoint(self) -> bool: """Test /.well-known/jwks.json endpoint.""" print("๐Ÿ”‘ Testing JWKS discovery endpoint...") try: response = requests.get(f"{self.server_url}/.well-known/jwks.json", timeout=10) print(f" Status Code: {response.status_code}") print(f" Content-Type: {response.headers.get('Content-Type', 'N/A')}") if response.status_code == 404: print(" โ„น๏ธ Got 404 - this is expected when OAuth is disabled") return True if response.status_code != 200: print(f" โŒ Expected 200 or 404, got {response.status_code}") return False # Validate JSON response try: data = response.json() except json.JSONDecodeError: print(" โŒ Response is not valid JSON") return False # Check JWKS structure per RFC 7517 if "keys" not in data: print(" โŒ Missing required 'keys' field") return False keys = data["keys"] print(f" โœ… Found 'keys' field with {len(keys)} keys") # For our implementation, keys might be empty (proxy mode) or contain a note if len(keys) == 0 and "note" in data: print(f" โœ… Empty keys with note: {data['note']}") elif len(keys) > 0: print(f" โœ… Found {len(keys)} key(s) in JWKS") # Validate first key structure if present first_key = keys[0] key_fields = ["kty", "kid"] # Basic required fields for field in key_fields: if field in first_key: print(f" โœ… Key contains: {field}") else: print(f" โš ๏ธ Key missing: {field}") # Check caching headers cache_header = response.headers.get("Cache-Control") if cache_header: print(f" โœ… Found Cache-Control header: {cache_header}") else: print(" โš ๏ธ Missing Cache-Control header") print(" โœ… JWKS endpoint validation passed") return True except requests.exceptions.RequestException as e: print(f" โŒ Request failed: {e}") return False def test_discovery_consistency(self) -> bool: """Test consistency between discovery endpoints.""" print("๐Ÿ”„ Testing discovery endpoint consistency...") try: # Get data from both endpoints auth_server_resp = requests.get(f"{self.server_url}/.well-known/oauth-authorization-server", timeout=10) protected_resource_resp = requests.get( f"{self.server_url}/.well-known/oauth-protected-resource", timeout=10 ) # Both should have same success status if auth_server_resp.status_code != protected_resource_resp.status_code: print( f" โš ๏ธ Status code mismatch: auth_server={auth_server_resp.status_code}, protected_resource={protected_resource_resp.status_code}" ) # This might be OK in some cases, so we continue if auth_server_resp.status_code != 200 or protected_resource_resp.status_code != 200: print(" โ„น๏ธ One or both endpoints returned non-200, skipping consistency check") return True auth_data = auth_server_resp.json() resource_data = protected_resource_resp.json() # Check scope consistency auth_scopes = set(auth_data.get("scopes_supported", [])) resource_scopes = set(resource_data.get("scopes_supported", [])) if auth_scopes != resource_scopes: print(f" โš ๏ธ Scope mismatch - Auth server: {auth_scopes}, Protected resource: {resource_scopes}") else: print(f" โœ… Scopes consistent across endpoints: {auth_scopes}") # Check issuer consistency auth_issuer = auth_data.get("issuer") auth_servers = resource_data.get("authorization_servers", []) if auth_issuer and auth_issuer in auth_servers: print(f" โœ… Issuer consistency: {auth_issuer}") elif auth_issuer: print(f" โš ๏ธ Issuer '{auth_issuer}' not found in authorization_servers: {auth_servers}") # Check MCP version consistency auth_version = auth_data.get("mcp_server_version") resource_version = resource_data.get("mcp_server_info", {}).get("version") if auth_version and resource_version and auth_version == resource_version: print(f" โœ… MCP version consistent: {auth_version}") elif auth_version and resource_version: print(f" โš ๏ธ MCP version mismatch: auth={auth_version}, resource={resource_version}") print(" โœ… Discovery endpoint consistency check completed") return True except Exception as e: print(f" โŒ Consistency check failed: {e}") return False def test_pkce_mandatory_requirements(self) -> bool: """Test that PKCE is properly marked as mandatory per MCP specification. Note: FastMCP may override the authorization server endpoint, but the protected resource endpoint is more important for MCP clients to discover PKCE requirements. """ print("๐Ÿ›ก๏ธ Testing PKCE mandatory requirements...") try: # Test authorization server metadata auth_server_resp = requests.get(f"{self.server_url}/.well-known/oauth-authorization-server", timeout=10) if auth_server_resp.status_code == 404: print(" โ„น๏ธ OAuth disabled, skipping PKCE validation") return True if auth_server_resp.status_code != 200: print(f" โŒ Authorization server endpoint failed: {auth_server_resp.status_code}") return False auth_data = auth_server_resp.json() # Test protected resource metadata protected_resp = requests.get(f"{self.server_url}/.well-known/oauth-protected-resource", timeout=10) if protected_resp.status_code != 200: print(f" โŒ Protected resource endpoint failed: {protected_resp.status_code}") return False protected_data = protected_resp.json() # PKCE validation for authorization server print(" ๐Ÿ” Validating PKCE in authorization server metadata...") # Check code challenge methods pkce_methods = auth_data.get("code_challenge_methods_supported", []) if "S256" not in pkce_methods: print(" โŒ S256 not in code_challenge_methods_supported") return False print(" โœ… S256 method supported") # Check that plain method is NOT supported (security best practice) if "plain" in pkce_methods: print(" โš ๏ธ Plain method supported (not recommended, but allowed)") else: print(" โœ… Plain method not supported (secure configuration)") # Check require_pkce flag (may not be present in FastMCP's built-in endpoint) if auth_data.get("require_pkce") is True: print(" โœ… require_pkce set to true") else: print(" โš ๏ธ require_pkce not set in authorization server (FastMCP limitation)") print(" โ„น๏ธ Will verify PKCE requirements in protected resource endpoint") # PKCE validation for protected resource print(" ๐Ÿ” Validating PKCE in protected resource metadata...") if protected_data.get("require_pkce") is not True: print(" โŒ require_pkce not set in protected resource") return False print(" โœ… require_pkce set in protected resource") # Check PKCE methods in protected resource resource_pkce_methods = protected_data.get("pkce_code_challenge_methods", []) if "S256" not in resource_pkce_methods: print(" โŒ S256 not in protected resource pkce_code_challenge_methods") return False print(" โœ… S256 method in protected resource") # Check PKCE note pkce_note = protected_data.get("pkce_note", "") if "mandatory" not in pkce_note.lower(): print(f" โš ๏ธ PKCE note doesn't mention 'mandatory': {pkce_note}") else: print(" โœ… PKCE note mentions mandatory requirement") # Consistency check between endpoints if set(pkce_methods) != set(resource_pkce_methods): print( f" โš ๏ธ PKCE method mismatch between endpoints: auth={pkce_methods}, resource={resource_pkce_methods}" ) else: print(" โœ… PKCE methods consistent between endpoints") # Final validation: ensure at least the protected resource properly advertises PKCE pkce_compliant = ( protected_data.get("require_pkce") is True and "S256" in resource_pkce_methods and "mandatory" in protected_data.get("pkce_note", "").lower() ) if pkce_compliant: print(" โœ… PKCE mandatory requirements validation passed") print(" โ„น๏ธ Protected resource endpoint properly advertises PKCE requirements") return True else: print(" โŒ PKCE requirements not properly advertised in protected resource") return False except requests.exceptions.RequestException as e: print(f" โŒ Request failed: {e}") return False except Exception as e: print(f" โŒ PKCE validation failed: {e}") return False def test_discovery_with_different_providers(self) -> bool: """Test generic OAuth 2.1 discovery approach for different providers.""" print("๐ŸŒ Testing generic OAuth 2.1 discovery approach...") # Test that generic discovery works for different provider URL patterns provider_examples = { "Azure AD": "https://login.microsoftonline.com/tenant-id/v2.0", "Google": "https://accounts.google.com", "Okta": "https://domain.okta.com/oauth2/default", "Keycloak": "https://keycloak.example.com/realms/realm-name", "GitHub": "https://github.com", # Special case - uses fallback } for provider_name, issuer_url in provider_examples.items(): print(f" Testing {provider_name}: {issuer_url}") try: # Validate URL pattern from urllib.parse import urlparse parsed = urlparse(issuer_url) if not parsed.scheme or not parsed.netloc: print(f" โŒ Invalid URL format for {provider_name}") return False # Check for OAuth 2.1 discovery endpoint construction discovery_endpoint = f"{issuer_url}/.well-known/oauth-authorization-server" oidc_endpoint = f"{issuer_url}/.well-known/openid_configuration" print(f" โœ… {provider_name} discovery endpoints:") print(f" - OAuth 2.1: {discovery_endpoint}") print(f" - OIDC: {oidc_endpoint}") # Special validation for GitHub (should use fallback) if provider_name == "GitHub": print(f" โ„น๏ธ {provider_name} will use fallback configuration (no standard discovery)") else: print(f" โœ… {provider_name} should work with standard OAuth 2.1 discovery") except Exception as e: print(f" โŒ {provider_name} validation failed: {e}") return False print(" โœ… Generic OAuth 2.1 discovery approach validated") print(" ๐Ÿš€ No provider-specific configuration needed!") return True def test_discovery_error_handling(self) -> bool: """Test discovery endpoint error handling.""" print("โš ๏ธ Testing discovery endpoint error handling...") # Test invalid endpoints invalid_endpoints = [ "/.well-known/invalid-endpoint", "/.well-known/oauth-invalid", "/.well-known/jwks-invalid", ] for endpoint in invalid_endpoints: try: response = requests.get(f"{self.server_url}{endpoint}", timeout=5) if response.status_code == 404: print(f" โœ… Invalid endpoint '{endpoint}' correctly returns 404") else: print(f" โš ๏ธ Invalid endpoint '{endpoint}' returned {response.status_code}") except requests.exceptions.RequestException as e: print(f" โš ๏ธ Request to invalid endpoint '{endpoint}' failed: {e}") print(" โœ… Error handling validation completed") return True def run_all_tests(self) -> bool: """Run all OAuth discovery tests.""" print("๐Ÿš€ Starting OAuth Discovery Endpoints Test Suite") print("=" * 60) # First test with OAuth disabled to ensure basic functionality if not self.setup_test_server(enable_auth=False): print("โŒ Failed to setup basic test server") return False self.teardown_test_server() # Now test with OAuth enabled oauth_enabled = self.setup_test_server(enable_auth=True) if not oauth_enabled: print("โš ๏ธ Failed to setup test server with OAuth enabled - running basic tests only") if not self.setup_test_server(enable_auth=False): print("โŒ Failed to setup even basic test server") return False try: # Run all tests if oauth_enabled: # Full OAuth tests tests = [ ( "OAuth Authorization Server Endpoint", self.test_oauth_authorization_server_endpoint, ), ( "OAuth Protected Resource Endpoint", self.test_oauth_protected_resource_endpoint, ), ("JWKS Endpoint", self.test_jwks_endpoint), ( "PKCE Mandatory Requirements", self.test_pkce_mandatory_requirements, ), ("Discovery Consistency", self.test_discovery_consistency), ( "Generic OAuth 2.1 Discovery", self.test_discovery_with_different_providers, ), ("Error Handling", self.test_discovery_error_handling), ] else: # Basic tests without OAuth tests = [ ( "OAuth Disabled - Authorization Server 404", self.test_oauth_authorization_server_endpoint, ), ( "OAuth Disabled - Protected Resource 404", self.test_oauth_protected_resource_endpoint, ), ("OAuth Disabled - JWKS 404", self.test_jwks_endpoint), ("Error Handling", self.test_discovery_error_handling), ] for test_name, test_func in tests: self.run_test(test_name, test_func) finally: self.teardown_test_server() # Generate summary self.print_summary() # Return overall success return all(result[1] for result in self.test_results) def print_summary(self): """Print test execution summary.""" print("\n" + "=" * 60) print("๐Ÿ OAuth Discovery Test Summary") print("=" * 60) total_tests = len(self.test_results) passed_tests = sum(1 for result in self.test_results if result[1]) failed_tests = total_tests - passed_tests print(f"Total Tests: {total_tests}") print(f"Passed: {passed_tests}") print(f"Failed: {failed_tests}") print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%" if total_tests > 0 else "Success Rate: 0%") if failed_tests > 0: print("\nโŒ Failed Tests:") for test_name, passed, error in self.test_results: if not passed: print(f" - {test_name}: {error}") print(f"\n{'๐ŸŽ‰ All tests passed!' if failed_tests == 0 else 'โš ๏ธ Some tests failed'}") def main(): """Run the OAuth discovery tests.""" tester = OAuthDiscoveryTest() try: success = tester.run_all_tests() return 0 if success else 1 except KeyboardInterrupt: print("\n๐Ÿ›‘ Tests interrupted by user") tester.teardown_test_server() return 1 except Exception as e: print(f"\nโŒ Test suite failed with unexpected error: {e}") tester.teardown_test_server() return 1 if __name__ == "__main__": exit(main())

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/aywengo/kafka-schema-reg-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server