test_threat_intel.py•7.57 kB
"""
Tests for threat intelligence functionality
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cyber_sentinel.threat_intel import (
AbuseIPDBClient,
ThreatIntelAggregator,
ThreatIntelCache,
URLhausClient,
VirusTotalClient,
)
class TestThreatIntelCache:
"""Test the threat intelligence cache"""
def test_cache_key_generation(self):
cache = ThreatIntelCache()
key1 = cache._get_key("8.8.8.8", "virustotal")
key2 = cache._get_key("8.8.8.8", "abuseipdb")
key3 = cache._get_key("8.8.8.8", "virustotal")
assert key1 != key2 # Different sources should have different keys
assert key1 == key3 # Same indicator and source should have same key
def test_cache_set_and_get(self):
cache = ThreatIntelCache(ttl=3600)
test_data = {"result": "test"}
# Test cache miss
assert cache.get("8.8.8.8", "virustotal") is None
# Test cache set and hit
cache.set("8.8.8.8", "virustotal", test_data)
cached_result = cache.get("8.8.8.8", "virustotal")
assert cached_result == test_data
def test_cache_expiry(self):
cache = ThreatIntelCache(ttl=0) # Immediate expiry
test_data = {"result": "test"}
cache.set("8.8.8.8", "virustotal", test_data)
# Should be expired immediately
assert cache.get("8.8.8.8", "virustotal") is None
class TestThreatIntelAggregator:
"""Test the main aggregator functionality"""
def test_indicator_type_detection(self):
aggregator = ThreatIntelAggregator()
# Test IP detection
assert aggregator._detect_indicator_type("8.8.8.8") == "ip"
assert aggregator._detect_indicator_type("192.168.1.1") == "ip"
# Test domain detection
assert aggregator._detect_indicator_type("example.com") == "domain"
assert aggregator._detect_indicator_type("sub.example.com") == "domain"
# Test URL detection
assert aggregator._detect_indicator_type("https://example.com") == "url"
assert aggregator._detect_indicator_type("http://example.com/path") == "url"
# Test hash detection
assert (
aggregator._detect_indicator_type("d41d8cd98f00b204e9800998ecf8427e")
== "hash"
) # MD5
assert (
aggregator._detect_indicator_type(
"da39a3ee5e6b4b0d3255bfef95601890afd80709"
)
== "hash"
) # SHA1
assert (
aggregator._detect_indicator_type(
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
)
== "hash"
) # SHA256
# Test unknown
assert aggregator._detect_indicator_type("invalid") == "unknown"
def test_ip_validation(self):
aggregator = ThreatIntelAggregator()
# Valid IPs
assert aggregator._is_ip("8.8.8.8") is True
assert aggregator._is_ip("192.168.1.1") is True
assert aggregator._is_ip("0.0.0.0") is True
assert aggregator._is_ip("255.255.255.255") is True
# Invalid IPs
assert aggregator._is_ip("256.1.1.1") is False
assert aggregator._is_ip("1.1.1") is False
assert aggregator._is_ip("example.com") is False
assert aggregator._is_ip("not.an.ip.address") is False
def test_result_aggregation(self):
aggregator = ThreatIntelAggregator()
# Test with mixed results
results = [
{"reputation": "malicious", "source": "source1"},
{"reputation": "clean", "source": "source2"},
{"reputation": "malicious", "source": "source3"},
]
aggregated = aggregator._aggregate_results("8.8.8.8", "ip", results, [])
assert aggregated["overall_reputation"] == "malicious"
assert aggregated["malicious_sources"] == 2
assert aggregated["clean_sources"] == 1
assert aggregated["confidence"] == 66.67 # 2/3 * 100, rounded
# Test with all clean results
clean_results = [
{"reputation": "clean", "source": "source1"},
{"reputation": "clean", "source": "source2"},
]
aggregated_clean = aggregator._aggregate_results(
"8.8.8.8", "ip", clean_results, []
)
assert aggregated_clean["overall_reputation"] == "clean"
assert aggregated_clean["confidence"] == 100.0
@pytest.mark.asyncio
async def test_analyze_unknown_indicator(self):
aggregator = ThreatIntelAggregator()
result = await aggregator.analyze_indicator("invalid_indicator")
assert result["type"] == "unknown"
assert "error" in result
assert "Unable to determine indicator type" in result["error"]
@pytest.mark.asyncio
class TestAPIClients:
"""Test individual API clients"""
async def test_virustotal_client_ip_check(self):
"""Test VirusTotal IP checking"""
throttler = AsyncMock()
client = VirusTotalClient("test_api_key", throttler)
mock_response = MagicMock()
mock_response.json.return_value = {"detected_urls": [], "response_code": 1}
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
result = await client.check_ip("8.8.8.8")
assert result["source"] == "VirusTotal"
assert result["indicator"] == "8.8.8.8"
assert result["type"] == "ip"
assert result["reputation"] == "clean"
async def test_abuseipdb_client_ip_check(self):
"""Test AbuseIPDB IP checking"""
throttler = AsyncMock()
client = AbuseIPDBClient("test_api_key", throttler)
mock_response = MagicMock()
mock_response.json.return_value = {
"data": {
"abuseConfidencePercentage": 0,
"countryCode": "US",
"isp": "Google LLC",
}
}
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
result = await client.check_ip("8.8.8.8")
assert result["source"] == "AbuseIPDB"
assert result["indicator"] == "8.8.8.8"
assert result["type"] == "ip"
assert result["reputation"] == "clean"
assert result["country"] == "US"
assert result["isp"] == "Google LLC"
async def test_urlhaus_client_url_check(self):
"""Test URLhaus URL checking"""
throttler = AsyncMock()
client = URLhausClient(throttler)
mock_response = MagicMock()
mock_response.json.return_value = {"query_status": "no_results"}
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response
)
result = await client.check_url("https://example.com")
assert result["source"] == "URLhaus"
assert result["indicator"] == "https://example.com"
assert result["type"] == "url"
assert result["reputation"] == "clean"
if __name__ == "__main__":
pytest.main([__file__])