#!/usr/bin/env python3
"""
OAuth Discovery Endpoints Testing for Kafka Schema Registry MCP Server
Tests the OAuth 2.0 discovery endpoints that enable MCP client auto-configuration:
- /.well-known/oauth-authorization-server (RFC 8414)
- /.well-known/oauth-protected-resource (RFC 8692)
- /.well-known/jwks.json (RFC 7517)
"""
import json
import os
import subprocess
import sys
import time
import requests
# Add parent directory to path to import modules
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class OAuthDiscoveryTest:
"""OAuth discovery endpoints test class."""
def __init__(self):
self.test_results = []
self.server_process = None
self.server_url = "http://localhost:8899" # Use different port to avoid conflicts
def run_test(self, test_name: str, test_func):
"""Run a single test and track results."""
try:
print(f"\n๐งช Running: {test_name}")
result = test_func()
if result:
print(f"โ
{test_name} PASSED")
self.test_results.append((test_name, True, None))
return True
else:
print(f"โ {test_name} FAILED")
self.test_results.append((test_name, False, "Test returned False"))
return False
except Exception as e:
print(f"โ {test_name} FAILED with exception: {e}")
self.test_results.append((test_name, False, str(e)))
return False
def setup_test_server(self, enable_auth: bool = False) -> bool:
"""Start a test remote MCP server with specified OAuth configuration."""
try:
print(f"๐ Starting test server (OAuth enabled: {enable_auth})...")
# Set environment variables for the test
env = os.environ.copy()
env.update(
{
"MCP_TRANSPORT": "streamable-http",
"MCP_HOST": "localhost",
"MCP_PORT": "8899",
"ENABLE_AUTH": "true" if enable_auth else "false",
"AUTH_PROVIDER": "azure",
"AUTH_AUDIENCE": "test-audience",
"AZURE_TENANT_ID": "test-tenant-123",
"OKTA_DOMAIN": "test-domain.okta.com",
"AUTH_GITHUB_CLIENT_ID": "test-github-client",
"SCHEMA_REGISTRY_URL": "http://localhost:38081", # Use test registry
}
)
# Start the remote server
cmd = [
sys.executable,
os.path.join(os.path.dirname(os.path.dirname(__file__)), "remote-mcp-server.py"),
]
self.server_process = subprocess.Popen(
cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
# Wait for server to start and detect which port it's actually using
detected_port = None
for i in range(30): # Wait up to 30 seconds
# Try the intended port first
for test_port in [8899, 8000]: # Try both ports
try:
response = requests.get(f"http://localhost:{test_port}/health", timeout=2)
if response.status_code in [
200,
503,
]: # 503 is OK if registries not available
detected_port = test_port
self.server_url = f"http://localhost:{test_port}"
print(f"โ
Test server started on {self.server_url}")
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
print("โ Test server failed to start within 30 seconds")
return False
except Exception as e:
print(f"โ Failed to start test server: {e}")
return False
def teardown_test_server(self):
"""Stop the test server."""
if self.server_process:
try:
print("๐ Stopping test server...")
self.server_process.terminate()
self.server_process.wait(timeout=10)
print("โ
Test server stopped")
except subprocess.TimeoutExpired:
print("โ ๏ธ Force killing test server...")
self.server_process.kill()
self.server_process.wait()
except Exception as e:
print(f"โ ๏ธ Error stopping server: {e}")
finally:
self.server_process = None
def test_oauth_authorization_server_endpoint(self) -> bool:
"""Test /.well-known/oauth-authorization-server endpoint."""
print("๐ Testing OAuth Authorization Server discovery endpoint...")
try:
response = requests.get(f"{self.server_url}/.well-known/oauth-authorization-server", timeout=10)
print(f" Status Code: {response.status_code}")
print(f" Content-Type: {response.headers.get('Content-Type', 'N/A')}")
if response.status_code == 404:
print(" โน๏ธ Got 404 - this is expected when OAuth is disabled")
return True
if response.status_code != 200:
print(f" โ Expected 200 or 404, got {response.status_code}")
return False
# Validate JSON response
try:
data = response.json()
except json.JSONDecodeError:
print(" โ Response is not valid JSON")
return False
# Check required fields per RFC 8414
required_fields = ["issuer", "scopes_supported"]
for field in required_fields:
if field not in data:
print(f" โ Missing required field: {field}")
return False
print(f" โ
Found required field: {field}")
# Check MCP-specific extensions
mcp_fields = ["mcp_server_version", "mcp_transport", "mcp_endpoints"]
for field in mcp_fields:
if field in data:
print(f" โ
Found MCP extension: {field}")
else:
print(f" โ ๏ธ Missing MCP extension: {field}")
# Validate scopes
scopes = data.get("scopes_supported", [])
expected_scopes = ["read", "write", "admin"]
for scope in expected_scopes:
if scope in scopes:
print(f" โ
Found expected scope: {scope}")
else:
print(f" โ ๏ธ Missing expected scope: {scope}")
# Check PKCE requirements (mandatory per MCP spec)
if "code_challenge_methods_supported" in data:
pkce_methods = data["code_challenge_methods_supported"]
print(f" โ
PKCE methods supported: {pkce_methods}")
# Verify S256 is supported (mandatory)
if "S256" in pkce_methods:
print(" โ
S256 method supported (required)")
else:
print(" โ S256 method not supported (should be mandatory)")
return False
# Verify plain is NOT supported (less secure)
if "plain" in pkce_methods:
print(" โ ๏ธ Plain method supported (not recommended for security)")
else:
print(" โ
Plain method not supported (secure configuration)")
# Check if PKCE is marked as required
if data.get("require_pkce") is True:
print(" โ
PKCE marked as required (MCP compliant)")
else:
print(" โ ๏ธ PKCE not explicitly marked as required")
else:
print(" โ No PKCE methods advertised")
return False
# Check CORS headers
cors_headers = ["Access-Control-Allow-Origin", "Cache-Control"]
for header in cors_headers:
if header in response.headers:
print(f" โ
Found header: {header}: {response.headers[header]}")
else:
print(f" โ ๏ธ Missing header: {header}")
print(" โ
OAuth authorization server endpoint validation passed")
return True
except requests.exceptions.RequestException as e:
print(f" โ Request failed: {e}")
return False
def test_oauth_protected_resource_endpoint(self) -> bool:
"""Test /.well-known/oauth-protected-resource endpoint."""
print("๐ก๏ธ Testing OAuth Protected Resource discovery endpoint...")
try:
response = requests.get(f"{self.server_url}/.well-known/oauth-protected-resource", timeout=10)
print(f" Status Code: {response.status_code}")
print(f" Content-Type: {response.headers.get('Content-Type', 'N/A')}")
if response.status_code == 404:
print(" โน๏ธ Got 404 - this is expected when OAuth is disabled")
return True
if response.status_code != 200:
print(f" โ Expected 200 or 404, got {response.status_code}")
return False
# Validate JSON response
try:
data = response.json()
except json.JSONDecodeError:
print(" โ Response is not valid JSON")
return False
# Check required fields per RFC 8692
required_fields = ["resource", "authorization_servers", "scopes_supported"]
for field in required_fields:
if field not in data:
print(f" โ Missing required field: {field}")
return False
print(f" โ
Found required field: {field}")
# Check MCP-specific fields
mcp_fields = [
"mcp_server_info",
"scope_descriptions",
"protected_endpoints",
]
for field in mcp_fields:
if field in data:
print(f" โ
Found MCP extension: {field}")
else:
print(f" โ ๏ธ Missing MCP extension: {field}")
# Validate server info
if "mcp_server_info" in data:
server_info = data["mcp_server_info"]
info_fields = ["name", "version", "transport", "tools_count"]
for field in info_fields:
if field in server_info:
print(f" โ
Server info contains: {field}: {server_info[field]}")
else:
print(f" โ ๏ธ Server info missing: {field}")
# Validate scope descriptions
if "scope_descriptions" in data:
scope_desc = data["scope_descriptions"]
expected_scopes = ["read", "write", "admin"]
for scope in expected_scopes:
if scope in scope_desc:
print(f" โ
Scope description for '{scope}': {scope_desc[scope][:50]}...")
else:
print(f" โ ๏ธ Missing scope description: {scope}")
# Check PKCE requirements (should also be in protected resource metadata)
if data.get("require_pkce") is True:
print(" โ
PKCE marked as required in protected resource")
else:
print(" โ ๏ธ PKCE not marked as required in protected resource")
if "pkce_code_challenge_methods" in data:
pkce_methods = data["pkce_code_challenge_methods"]
print(f" โ
PKCE methods in protected resource: {pkce_methods}")
if "S256" in pkce_methods:
print(" โ
S256 method in protected resource (secure)")
else:
print(" โ ๏ธ S256 method missing from protected resource")
if data.get("pkce_note"):
print(f" โ
PKCE note: {data['pkce_note']}")
print(" โ
OAuth protected resource endpoint validation passed")
return True
except requests.exceptions.RequestException as e:
print(f" โ Request failed: {e}")
return False
def test_jwks_endpoint(self) -> bool:
"""Test /.well-known/jwks.json endpoint."""
print("๐ Testing JWKS discovery endpoint...")
try:
response = requests.get(f"{self.server_url}/.well-known/jwks.json", timeout=10)
print(f" Status Code: {response.status_code}")
print(f" Content-Type: {response.headers.get('Content-Type', 'N/A')}")
if response.status_code == 404:
print(" โน๏ธ Got 404 - this is expected when OAuth is disabled")
return True
if response.status_code != 200:
print(f" โ Expected 200 or 404, got {response.status_code}")
return False
# Validate JSON response
try:
data = response.json()
except json.JSONDecodeError:
print(" โ Response is not valid JSON")
return False
# Check JWKS structure per RFC 7517
if "keys" not in data:
print(" โ Missing required 'keys' field")
return False
keys = data["keys"]
print(f" โ
Found 'keys' field with {len(keys)} keys")
# For our implementation, keys might be empty (proxy mode) or contain a note
if len(keys) == 0 and "note" in data:
print(f" โ
Empty keys with note: {data['note']}")
elif len(keys) > 0:
print(f" โ
Found {len(keys)} key(s) in JWKS")
# Validate first key structure if present
first_key = keys[0]
key_fields = ["kty", "kid"] # Basic required fields
for field in key_fields:
if field in first_key:
print(f" โ
Key contains: {field}")
else:
print(f" โ ๏ธ Key missing: {field}")
# Check caching headers
cache_header = response.headers.get("Cache-Control")
if cache_header:
print(f" โ
Found Cache-Control header: {cache_header}")
else:
print(" โ ๏ธ Missing Cache-Control header")
print(" โ
JWKS endpoint validation passed")
return True
except requests.exceptions.RequestException as e:
print(f" โ Request failed: {e}")
return False
def test_discovery_consistency(self) -> bool:
"""Test consistency between discovery endpoints."""
print("๐ Testing discovery endpoint consistency...")
try:
# Get data from both endpoints
auth_server_resp = requests.get(f"{self.server_url}/.well-known/oauth-authorization-server", timeout=10)
protected_resource_resp = requests.get(
f"{self.server_url}/.well-known/oauth-protected-resource", timeout=10
)
# Both should have same success status
if auth_server_resp.status_code != protected_resource_resp.status_code:
print(
f" โ ๏ธ Status code mismatch: auth_server={auth_server_resp.status_code}, protected_resource={protected_resource_resp.status_code}"
)
# This might be OK in some cases, so we continue
if auth_server_resp.status_code != 200 or protected_resource_resp.status_code != 200:
print(" โน๏ธ One or both endpoints returned non-200, skipping consistency check")
return True
auth_data = auth_server_resp.json()
resource_data = protected_resource_resp.json()
# Check scope consistency
auth_scopes = set(auth_data.get("scopes_supported", []))
resource_scopes = set(resource_data.get("scopes_supported", []))
if auth_scopes != resource_scopes:
print(f" โ ๏ธ Scope mismatch - Auth server: {auth_scopes}, Protected resource: {resource_scopes}")
else:
print(f" โ
Scopes consistent across endpoints: {auth_scopes}")
# Check issuer consistency
auth_issuer = auth_data.get("issuer")
auth_servers = resource_data.get("authorization_servers", [])
if auth_issuer and auth_issuer in auth_servers:
print(f" โ
Issuer consistency: {auth_issuer}")
elif auth_issuer:
print(f" โ ๏ธ Issuer '{auth_issuer}' not found in authorization_servers: {auth_servers}")
# Check MCP version consistency
auth_version = auth_data.get("mcp_server_version")
resource_version = resource_data.get("mcp_server_info", {}).get("version")
if auth_version and resource_version and auth_version == resource_version:
print(f" โ
MCP version consistent: {auth_version}")
elif auth_version and resource_version:
print(f" โ ๏ธ MCP version mismatch: auth={auth_version}, resource={resource_version}")
print(" โ
Discovery endpoint consistency check completed")
return True
except Exception as e:
print(f" โ Consistency check failed: {e}")
return False
def test_pkce_mandatory_requirements(self) -> bool:
"""Test that PKCE is properly marked as mandatory per MCP specification.
Note: FastMCP may override the authorization server endpoint, but the protected
resource endpoint is more important for MCP clients to discover PKCE requirements.
"""
print("๐ก๏ธ Testing PKCE mandatory requirements...")
try:
# Test authorization server metadata
auth_server_resp = requests.get(f"{self.server_url}/.well-known/oauth-authorization-server", timeout=10)
if auth_server_resp.status_code == 404:
print(" โน๏ธ OAuth disabled, skipping PKCE validation")
return True
if auth_server_resp.status_code != 200:
print(f" โ Authorization server endpoint failed: {auth_server_resp.status_code}")
return False
auth_data = auth_server_resp.json()
# Test protected resource metadata
protected_resp = requests.get(f"{self.server_url}/.well-known/oauth-protected-resource", timeout=10)
if protected_resp.status_code != 200:
print(f" โ Protected resource endpoint failed: {protected_resp.status_code}")
return False
protected_data = protected_resp.json()
# PKCE validation for authorization server
print(" ๐ Validating PKCE in authorization server metadata...")
# Check code challenge methods
pkce_methods = auth_data.get("code_challenge_methods_supported", [])
if "S256" not in pkce_methods:
print(" โ S256 not in code_challenge_methods_supported")
return False
print(" โ
S256 method supported")
# Check that plain method is NOT supported (security best practice)
if "plain" in pkce_methods:
print(" โ ๏ธ Plain method supported (not recommended, but allowed)")
else:
print(" โ
Plain method not supported (secure configuration)")
# Check require_pkce flag (may not be present in FastMCP's built-in endpoint)
if auth_data.get("require_pkce") is True:
print(" โ
require_pkce set to true")
else:
print(" โ ๏ธ require_pkce not set in authorization server (FastMCP limitation)")
print(" โน๏ธ Will verify PKCE requirements in protected resource endpoint")
# PKCE validation for protected resource
print(" ๐ Validating PKCE in protected resource metadata...")
if protected_data.get("require_pkce") is not True:
print(" โ require_pkce not set in protected resource")
return False
print(" โ
require_pkce set in protected resource")
# Check PKCE methods in protected resource
resource_pkce_methods = protected_data.get("pkce_code_challenge_methods", [])
if "S256" not in resource_pkce_methods:
print(" โ S256 not in protected resource pkce_code_challenge_methods")
return False
print(" โ
S256 method in protected resource")
# Check PKCE note
pkce_note = protected_data.get("pkce_note", "")
if "mandatory" not in pkce_note.lower():
print(f" โ ๏ธ PKCE note doesn't mention 'mandatory': {pkce_note}")
else:
print(" โ
PKCE note mentions mandatory requirement")
# Consistency check between endpoints
if set(pkce_methods) != set(resource_pkce_methods):
print(
f" โ ๏ธ PKCE method mismatch between endpoints: auth={pkce_methods}, resource={resource_pkce_methods}"
)
else:
print(" โ
PKCE methods consistent between endpoints")
# Final validation: ensure at least the protected resource properly advertises PKCE
pkce_compliant = (
protected_data.get("require_pkce") is True
and "S256" in resource_pkce_methods
and "mandatory" in protected_data.get("pkce_note", "").lower()
)
if pkce_compliant:
print(" โ
PKCE mandatory requirements validation passed")
print(" โน๏ธ Protected resource endpoint properly advertises PKCE requirements")
return True
else:
print(" โ PKCE requirements not properly advertised in protected resource")
return False
except requests.exceptions.RequestException as e:
print(f" โ Request failed: {e}")
return False
except Exception as e:
print(f" โ PKCE validation failed: {e}")
return False
def test_discovery_with_different_providers(self) -> bool:
"""Test generic OAuth 2.1 discovery approach for different providers."""
print("๐ Testing generic OAuth 2.1 discovery approach...")
# Test that generic discovery works for different provider URL patterns
provider_examples = {
"Azure AD": "https://login.microsoftonline.com/tenant-id/v2.0",
"Google": "https://accounts.google.com",
"Okta": "https://domain.okta.com/oauth2/default",
"Keycloak": "https://keycloak.example.com/realms/realm-name",
"GitHub": "https://github.com", # Special case - uses fallback
}
for provider_name, issuer_url in provider_examples.items():
print(f" Testing {provider_name}: {issuer_url}")
try:
# Validate URL pattern
from urllib.parse import urlparse
parsed = urlparse(issuer_url)
if not parsed.scheme or not parsed.netloc:
print(f" โ Invalid URL format for {provider_name}")
return False
# Check for OAuth 2.1 discovery endpoint construction
discovery_endpoint = f"{issuer_url}/.well-known/oauth-authorization-server"
oidc_endpoint = f"{issuer_url}/.well-known/openid_configuration"
print(f" โ
{provider_name} discovery endpoints:")
print(f" - OAuth 2.1: {discovery_endpoint}")
print(f" - OIDC: {oidc_endpoint}")
# Special validation for GitHub (should use fallback)
if provider_name == "GitHub":
print(f" โน๏ธ {provider_name} will use fallback configuration (no standard discovery)")
else:
print(f" โ
{provider_name} should work with standard OAuth 2.1 discovery")
except Exception as e:
print(f" โ {provider_name} validation failed: {e}")
return False
print(" โ
Generic OAuth 2.1 discovery approach validated")
print(" ๐ No provider-specific configuration needed!")
return True
def test_discovery_error_handling(self) -> bool:
"""Test discovery endpoint error handling."""
print("โ ๏ธ Testing discovery endpoint error handling...")
# Test invalid endpoints
invalid_endpoints = [
"/.well-known/invalid-endpoint",
"/.well-known/oauth-invalid",
"/.well-known/jwks-invalid",
]
for endpoint in invalid_endpoints:
try:
response = requests.get(f"{self.server_url}{endpoint}", timeout=5)
if response.status_code == 404:
print(f" โ
Invalid endpoint '{endpoint}' correctly returns 404")
else:
print(f" โ ๏ธ Invalid endpoint '{endpoint}' returned {response.status_code}")
except requests.exceptions.RequestException as e:
print(f" โ ๏ธ Request to invalid endpoint '{endpoint}' failed: {e}")
print(" โ
Error handling validation completed")
return True
def run_all_tests(self) -> bool:
"""Run all OAuth discovery tests."""
print("๐ Starting OAuth Discovery Endpoints Test Suite")
print("=" * 60)
# First test with OAuth disabled to ensure basic functionality
if not self.setup_test_server(enable_auth=False):
print("โ Failed to setup basic test server")
return False
self.teardown_test_server()
# Now test with OAuth enabled
oauth_enabled = self.setup_test_server(enable_auth=True)
if not oauth_enabled:
print("โ ๏ธ Failed to setup test server with OAuth enabled - running basic tests only")
if not self.setup_test_server(enable_auth=False):
print("โ Failed to setup even basic test server")
return False
try:
# Run all tests
if oauth_enabled:
# Full OAuth tests
tests = [
(
"OAuth Authorization Server Endpoint",
self.test_oauth_authorization_server_endpoint,
),
(
"OAuth Protected Resource Endpoint",
self.test_oauth_protected_resource_endpoint,
),
("JWKS Endpoint", self.test_jwks_endpoint),
(
"PKCE Mandatory Requirements",
self.test_pkce_mandatory_requirements,
),
("Discovery Consistency", self.test_discovery_consistency),
(
"Generic OAuth 2.1 Discovery",
self.test_discovery_with_different_providers,
),
("Error Handling", self.test_discovery_error_handling),
]
else:
# Basic tests without OAuth
tests = [
(
"OAuth Disabled - Authorization Server 404",
self.test_oauth_authorization_server_endpoint,
),
(
"OAuth Disabled - Protected Resource 404",
self.test_oauth_protected_resource_endpoint,
),
("OAuth Disabled - JWKS 404", self.test_jwks_endpoint),
("Error Handling", self.test_discovery_error_handling),
]
for test_name, test_func in tests:
self.run_test(test_name, test_func)
finally:
self.teardown_test_server()
# Generate summary
self.print_summary()
# Return overall success
return all(result[1] for result in self.test_results)
def print_summary(self):
"""Print test execution summary."""
print("\n" + "=" * 60)
print("๐ OAuth Discovery Test Summary")
print("=" * 60)
total_tests = len(self.test_results)
passed_tests = sum(1 for result in self.test_results if result[1])
failed_tests = total_tests - passed_tests
print(f"Total Tests: {total_tests}")
print(f"Passed: {passed_tests}")
print(f"Failed: {failed_tests}")
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%" if total_tests > 0 else "Success Rate: 0%")
if failed_tests > 0:
print("\nโ Failed Tests:")
for test_name, passed, error in self.test_results:
if not passed:
print(f" - {test_name}: {error}")
print(f"\n{'๐ All tests passed!' if failed_tests == 0 else 'โ ๏ธ Some tests failed'}")
def main():
"""Run the OAuth discovery tests."""
tester = OAuthDiscoveryTest()
try:
success = tester.run_all_tests()
return 0 if success else 1
except KeyboardInterrupt:
print("\n๐ Tests interrupted by user")
tester.teardown_test_server()
return 1
except Exception as e:
print(f"\nโ Test suite failed with unexpected error: {e}")
tester.teardown_test_server()
return 1
if __name__ == "__main__":
exit(main())