#!/usr/bin/env python3
"""
Integration tests for SSRF (Server-Side Request Forgery) vulnerability protection.
This test suite validates that the MCP server properly protects against URL injection
and SSRF attacks by:
1. Blocking access to internal/private networks
2. Preventing protocol smuggling
3. Properly encoding context parameters to prevent injection
4. Allowing localhost only in test/development mode
"""
import os
import sys
from unittest.mock import patch
from urllib.parse import quote
import pytest
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from schema_registry_common import (
MultiRegistryManager,
RegistryClient,
RegistryConfig,
SingleRegistryManager,
validate_url,
)
class TestSSRFProtection:
"""Test suite for SSRF vulnerability protection."""
def test_validate_url_blocks_non_http_protocols(self):
"""Test that only http and https protocols are allowed."""
# Should block dangerous protocols
assert not validate_url("file:///etc/passwd")
assert not validate_url("ftp://example.com")
assert not validate_url("gopher://example.com")
assert not validate_url("dict://example.com")
assert not validate_url("sftp://example.com")
assert not validate_url("ldap://example.com")
assert not validate_url("jar:file:///tmp/test.jar!/")
# Should allow http and https
assert validate_url("http://example.com")
assert validate_url("https://example.com")
assert validate_url("http://schema-registry.example.com:8081")
assert validate_url("https://schema-registry.example.com:8081")
def test_validate_url_test_mode_detection(self):
"""Test that the function correctly detects test vs production mode."""
# In test mode (current environment), localhost should be allowed
assert validate_url("http://localhost:8081")
# Verify that ALLOW_LOCALHOST environment variable works
with patch.dict(os.environ, {"ALLOW_LOCALHOST": "true"}):
assert validate_url("http://localhost:8081")
assert validate_url("http://127.0.0.1:8081")
# Verify that TESTING environment variable works
with patch.dict(os.environ, {"TESTING": "true"}):
assert validate_url("http://localhost:8081")
assert validate_url("http://127.0.0.1:8081")
def test_registry_client_validates_url_on_init(self):
"""Test that RegistryClient validates URL during initialization."""
# Should raise ValueError for invalid protocols
with pytest.raises(ValueError, match="Invalid or unsafe registry URL"):
config = RegistryConfig(name="test", url="file:///etc/passwd")
RegistryClient(config)
with pytest.raises(ValueError, match="Invalid or unsafe registry URL"):
config = RegistryConfig(name="test", url="gopher://example.com")
RegistryClient(config)
with pytest.raises(ValueError, match="Invalid or unsafe registry URL"):
config = RegistryConfig(name="test", url="dict://example.com")
RegistryClient(config)
def test_build_context_url_prevents_injection(self):
"""Test that build_context_url properly encodes context to prevent injection."""
# Create a valid client (localhost allowed in test mode)
config = RegistryConfig(name="test", url="http://localhost:8081")
client = RegistryClient(config)
# Test normal context
url = client.build_context_url("/subjects", "my-context")
assert url == "http://localhost:8081/contexts/my-context/subjects"
# Test that special characters are properly encoded
test_cases = [
# (input_context, expected_encoded_part)
("../../../etc/passwd", "..%2F..%2F..%2Fetc%2Fpasswd"),
("context/admin", "context%2Fadmin"),
("context#fragment", "context%23fragment"),
("context?query=value", "context%3Fquery%3Dvalue"),
("context¶m=value", "context%26param%3Dvalue"),
("context;param=value", "context%3Bparam%3Dvalue"),
("context\nHost: evil.com", "context%0AHost%3A%20evil.com"),
("context\r\nHost: evil.com", "context%0D%0AHost%3A%20evil.com"),
(".\\admin", ".%5Cadmin"),
("..\\..\\", "..%5C..%5C"),
]
for input_context, expected_encoded in test_cases:
url = client.build_context_url("/subjects", input_context)
# Verify the context is properly encoded
assert f"/contexts/{expected_encoded}/subjects" in url, f"Failed for input: {input_context}"
def test_single_registry_manager_validates_url(self):
"""Test that SingleRegistryManager validates registry URLs."""
# Test with invalid protocol URL
with patch.dict(
os.environ,
{
"SCHEMA_REGISTRY_URL": "file:///etc/passwd",
},
clear=False,
):
manager = SingleRegistryManager()
# Should not load the registry due to invalid URL
# Note: In test mode, the registry count might vary based on other env vars
# The key is that the file:// URL should not be loaded
if "default" in manager.registries:
assert manager.registries["default"].config.url != "file:///etc/passwd"
def test_multi_registry_manager_validates_urls(self):
"""Test that MultiRegistryManager validates all registry URLs."""
# Test with mix of valid and invalid URLs
with patch.dict(
os.environ,
{
"SCHEMA_REGISTRY_NAME_1": "valid",
"SCHEMA_REGISTRY_URL_1": "https://schema-registry.example.com",
"SCHEMA_REGISTRY_NAME_2": "invalid-protocol",
"SCHEMA_REGISTRY_URL_2": "ftp://schema-registry.example.com",
"SCHEMA_REGISTRY_NAME_3": "another-invalid",
"SCHEMA_REGISTRY_URL_3": "dict://schema-registry.example.com",
},
):
manager = MultiRegistryManager()
# Should load the valid registry
assert "valid" in manager.registries
# Should not load invalid protocol registries
assert "invalid-protocol" not in manager.registries
assert "another-invalid" not in manager.registries
def test_url_validation_handles_edge_cases(self):
"""Test URL validation with edge cases."""
# Empty URL
assert not validate_url("")
# None URL
assert not validate_url(None)
# Malformed URLs
assert not validate_url("not-a-url")
assert not validate_url("://example.com")
# URLs with authentication (should be allowed)
assert validate_url("http://user:pass@example.com:8081")
assert validate_url("https://user:pass@example.com:8081")
def test_context_encoding_matches_url_standard(self):
"""Test that context encoding follows URL encoding standards."""
config = RegistryConfig(name="test", url="http://localhost:8081")
client = RegistryClient(config)
# Test that the encoding matches Python's quote function
test_contexts = [
"simple-context",
"context with spaces",
"context/with/slashes",
"context?with=query",
"context#with#hash",
"../../../etc/passwd",
"context\nwith\nnewlines",
"context\rwith\rreturns",
"special!@#$%^&*()chars",
]
for context in test_contexts:
url = client.build_context_url("/subjects", context)
expected_encoded = quote(context, safe="")
expected_url = f"http://localhost:8081/contexts/{expected_encoded}/subjects"
assert url == expected_url, f"Encoding mismatch for context: {context}"
def test_ssrf_protection_documentation(self):
"""Test that demonstrates the SSRF protection features."""
# This test serves as documentation of the security features
# 1. Protocol whitelisting - only http/https allowed
dangerous_protocols = [
"file:///etc/passwd",
"ftp://internal.server",
"gopher://internal.server",
"dict://internal.server:2628",
"sftp://internal.server",
"ldap://internal.server",
"jar:file:///app.jar!/",
]
for url in dangerous_protocols:
assert not validate_url(url), f"Should block dangerous protocol: {url}"
# 2. Context injection prevention through URL encoding
config = RegistryConfig(name="test", url="http://localhost:8081")
client = RegistryClient(config)
injection_attempts = [
"../../../admin", # Path traversal
"context\r\nHost: evil.com", # Header injection
"context%0d%0aHost:%20evil.com", # Encoded header injection
"';DROP TABLE subjects;--", # SQL injection attempt
"<script>alert('xss')</script>", # XSS attempt
]
for attempt in injection_attempts:
url = client.build_context_url("/subjects", attempt)
# Verify the dangerous characters are encoded
assert attempt not in url, f"Dangerous input not encoded: {attempt}"
# Verify URL structure is maintained
assert "/contexts/" in url
assert "/subjects" in url
def test_production_mode_simulation(self):
"""Test to demonstrate how production mode would work."""
# This test documents the expected behavior in production
# In actual production, these assertions would pass:
# Production mode indicators:
# - TESTING not set or set to false
# - CI not set or set to false
# - ALLOW_LOCALHOST not set or set to false
# - PYTEST_CURRENT_TEST not set
# - Working directory doesn't contain 'test'
# - Script name doesn't contain 'test'
# In production, these would be blocked:
localhost_urls = [
"http://localhost:8081",
"http://127.0.0.1:8081",
"http://0.0.0.0:8081",
"http://[::1]:8081",
"http://[::]:8081",
]
private_ip_urls = [
"http://10.0.0.1:8081",
"http://172.16.0.1:8081",
"http://192.168.1.1:8081",
]
# Document that in production these would return False
# (In test environment they return True)
for url in localhost_urls + private_ip_urls:
# In production: assert not validate_url(url)
# In test: validate_url(url) returns True
pass
if __name__ == "__main__":
pytest.main([__file__, "-v"])