"""
Threat Intelligence API clients and aggregation logic
"""
import asyncio
import hashlib
import json
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import httpx
from asyncio_throttle import Throttler
from .config import get_settings
logger = logging.getLogger(__name__)
class ThreatIntelCache:
"""Simple in-memory cache for threat intelligence results"""
def __init__(self, ttl: int = 3600):
self.cache: Dict[str, Dict[str, Any]] = {}
self.ttl = ttl
def _get_key(self, indicator: str, source: str) -> str:
"""Generate cache key"""
return hashlib.md5(f"{indicator}:{source}".encode()).hexdigest()
def get(self, indicator: str, source: str) -> Optional[Dict[str, Any]]:
"""Get cached result"""
key = self._get_key(indicator, source)
if key in self.cache:
entry = self.cache[key]
if datetime.now() - entry["timestamp"] < timedelta(seconds=self.ttl):
return entry["data"]
else:
del self.cache[key]
return None
def set(self, indicator: str, source: str, data: Dict[str, Any]) -> None:
"""Cache result"""
key = self._get_key(indicator, source)
self.cache[key] = {"data": data, "timestamp": datetime.now()}
class VirusTotalClient:
"""VirusTotal API v3 client"""
def __init__(self, api_key: str, throttler: Throttler):
self.api_key = api_key
self.throttler = throttler
self.base_url = "https://www.virustotal.com/api/v3"
self.headers = {"x-apikey": self.api_key}
async def check_ip(self, ip: str) -> Dict[str, Any]:
"""Check IP reputation using API v3"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/ip_addresses/{ip}",
headers=self.headers,
timeout=30,
)
response.raise_for_status()
data = response.json()
# Extract analysis stats from v3 response
stats = (
data.get("data", {})
.get("attributes", {})
.get("last_analysis_stats", {})
)
malicious_count = stats.get("malicious", 0)
total_count = sum(stats.values()) if stats else 0
return {
"source": "VirusTotal",
"indicator": ip,
"type": "ip",
"malicious_count": malicious_count,
"total_engines": total_count,
"reputation": "malicious" if malicious_count > 0 else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
async def check_domain(self, domain: str) -> Dict[str, Any]:
"""Check domain reputation using API v3"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/domains/{domain}",
headers=self.headers,
timeout=30,
)
response.raise_for_status()
data = response.json()
# Extract analysis stats from v3 response
stats = (
data.get("data", {})
.get("attributes", {})
.get("last_analysis_stats", {})
)
malicious_count = stats.get("malicious", 0)
total_count = sum(stats.values()) if stats else 0
return {
"source": "VirusTotal",
"indicator": domain,
"type": "domain",
"malicious_count": malicious_count,
"total_engines": total_count,
"reputation": "malicious" if malicious_count > 0 else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
async def check_hash(self, file_hash: str) -> Dict[str, Any]:
"""Check file hash reputation using API v3"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/files/{file_hash}",
headers=self.headers,
timeout=30,
)
response.raise_for_status()
data = response.json()
# Extract analysis stats from v3 response
stats = (
data.get("data", {})
.get("attributes", {})
.get("last_analysis_stats", {})
)
malicious_count = stats.get("malicious", 0)
total_count = sum(stats.values()) if stats else 0
return {
"source": "VirusTotal",
"indicator": file_hash,
"type": "hash",
"malicious_count": malicious_count,
"total_engines": total_count,
"detections": f"{malicious_count}/{total_count}",
"reputation": "malicious" if malicious_count > 0 else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
async def check_url(self, url: str) -> Dict[str, Any]:
"""Check URL reputation using API v3"""
import base64
async with self.throttler:
async with httpx.AsyncClient() as client:
# URL needs to be base64 encoded for v3 API
url_id = base64.urlsafe_b64encode(url.encode()).decode().strip("=")
response = await client.get(
f"{self.base_url}/urls/{url_id}", headers=self.headers, timeout=30
)
response.raise_for_status()
data = response.json()
# Extract analysis stats from v3 response
stats = (
data.get("data", {})
.get("attributes", {})
.get("last_analysis_stats", {})
)
malicious_count = stats.get("malicious", 0)
total_count = sum(stats.values()) if stats else 0
return {
"source": "VirusTotal",
"indicator": url,
"type": "url",
"malicious_count": malicious_count,
"total_engines": total_count,
"reputation": "malicious" if malicious_count > 0 else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
class AbuseIPDBClient:
"""AbuseIPDB API client"""
def __init__(self, api_key: str, throttler: Throttler):
self.api_key = api_key
self.throttler = throttler
self.base_url = "https://api.abuseipdb.com/api/v2"
async def check_ip(self, ip: str) -> Dict[str, Any]:
"""Check IP reputation"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/check",
headers={"Key": self.api_key, "Accept": "application/json"},
params={"ipAddress": ip, "maxAgeInDays": 90, "verbose": ""},
timeout=30,
)
response.raise_for_status()
data = response.json()["data"]
confidence = data.get("abuseConfidencePercentage", 0)
return {
"source": "AbuseIPDB",
"indicator": ip,
"type": "ip",
"confidence": confidence,
"reputation": "malicious" if confidence > 25 else "clean",
"country": data.get("countryCode"),
"isp": data.get("isp"),
"details": data,
"timestamp": datetime.now().isoformat(),
}
class URLhausClient:
"""URLhaus API client"""
def __init__(self, throttler: Throttler):
self.throttler = throttler
self.base_url = "https://urlhaus-api.abuse.ch/v1"
async def check_url(self, url: str) -> Dict[str, Any]:
"""Check URL reputation"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/url/", data={"url": url}, timeout=30
)
response.raise_for_status()
data = response.json()
status = data.get("query_status")
return {
"source": "URLhaus",
"indicator": url,
"type": "url",
"status": status,
"reputation": "malicious" if status == "ok" else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
async def check_hash(self, file_hash: str) -> Dict[str, Any]:
"""Check file hash reputation"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/payload/",
data={"sha256_hash": file_hash},
timeout=30,
)
response.raise_for_status()
data = response.json()
status = data.get("query_status")
return {
"source": "URLhaus",
"indicator": file_hash,
"type": "hash",
"status": status,
"reputation": "malicious" if status == "ok" else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
class AlienVaultOTXClient:
"""AlienVault OTX API client"""
def __init__(self, api_key: str, throttler: Throttler):
self.api_key = api_key
self.throttler = throttler
self.base_url = "https://otx.alienvault.com/api/v1"
self.headers = {"X-OTX-API-KEY": self.api_key}
async def check_ip(self, ip: str) -> Dict[str, Any]:
"""Check IP reputation using OTX"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/indicators/IPv4/{ip}/general",
headers=self.headers,
timeout=30,
)
response.raise_for_status()
data = response.json()
pulse_count = data.get("pulse_info", {}).get("count", 0)
return {
"source": "AlienVault OTX",
"indicator": ip,
"type": "ip",
"pulse_count": pulse_count,
"reputation": "malicious" if pulse_count > 0 else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
async def check_domain(self, domain: str) -> Dict[str, Any]:
"""Check domain reputation using OTX"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/indicators/domain/{domain}/general",
headers=self.headers,
timeout=30,
)
response.raise_for_status()
data = response.json()
pulse_count = data.get("pulse_info", {}).get("count", 0)
return {
"source": "AlienVault OTX",
"indicator": domain,
"type": "domain",
"pulse_count": pulse_count,
"reputation": "malicious" if pulse_count > 0 else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
class ThreatFoxClient:
"""ThreatFox API client (abuse.ch)"""
def __init__(self, throttler: Throttler):
self.throttler = throttler
self.base_url = "https://threatfox-api.abuse.ch/api/v1"
async def check_hash(self, file_hash: str) -> Dict[str, Any]:
"""Check file hash in ThreatFox"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.post(
self.base_url,
json={"query": "search_hash", "hash": file_hash},
timeout=30,
)
response.raise_for_status()
data = response.json()
query_status = data.get("query_status")
return {
"source": "ThreatFox",
"indicator": file_hash,
"type": "hash",
"status": query_status,
"reputation": "malicious" if query_status == "ok" else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
async def check_ip(self, ip: str) -> Dict[str, Any]:
"""Check IP in ThreatFox"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.post(
self.base_url,
json={"query": "search_ioc", "search_term": ip},
timeout=30,
)
response.raise_for_status()
data = response.json()
query_status = data.get("query_status")
return {
"source": "ThreatFox",
"indicator": ip,
"type": "ip",
"status": query_status,
"reputation": "malicious" if query_status == "ok" else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
class MalwareBazaarClient:
"""MalwareBazaar API client (abuse.ch)"""
def __init__(self, throttler: Throttler):
self.throttler = throttler
self.base_url = "https://mb-api.abuse.ch/api/v1"
async def check_hash(self, file_hash: str) -> Dict[str, Any]:
"""Check file hash in MalwareBazaar"""
async with self.throttler:
async with httpx.AsyncClient() as client:
response = await client.post(
self.base_url,
data={"query": "get_info", "hash": file_hash},
timeout=30,
)
response.raise_for_status()
data = response.json()
query_status = data.get("query_status")
return {
"source": "MalwareBazaar",
"indicator": file_hash,
"type": "hash",
"status": query_status,
"reputation": "malicious" if query_status == "ok" else "clean",
"details": data,
"timestamp": datetime.now().isoformat(),
}
class ThreatIntelAggregator:
"""Main threat intelligence aggregator"""
def __init__(self):
self.settings = get_settings()
self.cache = ThreatIntelCache(ttl=self.settings.cache_ttl)
# Create throttler for rate limiting
self.throttler = Throttler(
rate_limit=self.settings.max_requests_per_minute, period=60
)
# Initialize clients based on available API keys
self.clients = {}
if self.settings.virustotal_api_key:
self.clients["virustotal"] = VirusTotalClient(
self.settings.virustotal_api_key, self.throttler
)
if self.settings.abuseipdb_api_key:
self.clients["abuseipdb"] = AbuseIPDBClient(
self.settings.abuseipdb_api_key, self.throttler
)
# URLhaus doesn't require API key
self.clients["urlhaus"] = URLhausClient(self.throttler)
# Add new threat intelligence sources
if self.settings.otx_api_key:
self.clients["alienvault_otx"] = AlienVaultOTXClient(
self.settings.otx_api_key, self.throttler
)
# Add free threat intelligence sources
self.clients["threatfox"] = ThreatFoxClient(self.throttler)
self.clients["malwarebazaar"] = MalwareBazaarClient(self.throttler)
def _detect_indicator_type(self, indicator: str) -> str:
"""Detect the type of indicator"""
# IP address pattern
if self._is_ip(indicator):
return "ip"
# Hash patterns
if len(indicator) == 32: # MD5
return "hash"
elif len(indicator) == 40: # SHA1
return "hash"
elif len(indicator) == 64: # SHA256
return "hash"
# URL pattern
if indicator.startswith(("http://", "https://")):
return "url"
# Domain pattern
if "." in indicator and not indicator.startswith(("http://", "https://")):
return "domain"
return "unknown"
def _is_ip(self, indicator: str) -> bool:
"""Check if indicator is an IP address"""
parts = indicator.split(".")
if len(parts) != 4:
return False
try:
return all(0 <= int(part) <= 255 for part in parts)
except ValueError:
return False
async def analyze_indicator(self, indicator: str) -> Dict[str, Any]:
"""Analyze an indicator across all available threat intelligence sources"""
indicator_type = self._detect_indicator_type(indicator)
if indicator_type == "unknown":
return {
"indicator": indicator,
"type": "unknown",
"error": "Unable to determine indicator type",
"timestamp": datetime.now().isoformat(),
}
# Collect results from all sources
results = []
errors = []
# Check cache first
for source_name in self.clients.keys():
cached_result = self.cache.get(indicator, source_name)
if cached_result:
results.append(cached_result)
continue
try:
client = self.clients[source_name]
result = None
# Route to appropriate method based on indicator type and client capabilities
if indicator_type == "ip":
if hasattr(client, "check_ip"):
result = await client.check_ip(indicator)
elif indicator_type == "domain":
if hasattr(client, "check_domain"):
result = await client.check_domain(indicator)
elif indicator_type == "hash":
if hasattr(client, "check_hash"):
result = await client.check_hash(indicator)
elif indicator_type == "url":
if hasattr(client, "check_url"):
result = await client.check_url(indicator)
if result:
results.append(result)
self.cache.set(indicator, source_name, result)
except Exception as e:
error_msg = f"Error querying {source_name}: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Aggregate and analyze results
return self._aggregate_results(indicator, indicator_type, results, errors)
def _aggregate_results(
self,
indicator: str,
indicator_type: str,
results: List[Dict[str, Any]],
errors: List[str],
) -> Dict[str, Any]:
"""Aggregate results from multiple sources into a unified report"""
if not results:
return {
"indicator": indicator,
"type": indicator_type,
"overall_reputation": "unknown",
"confidence": 0,
"sources_checked": len(self.clients),
"sources_responded": 0,
"errors": errors,
"timestamp": datetime.now().isoformat(),
}
# Calculate overall reputation
malicious_count = sum(1 for r in results if r.get("reputation") == "malicious")
clean_count = sum(1 for r in results if r.get("reputation") == "clean")
if malicious_count > 0:
overall_reputation = "malicious"
confidence = (malicious_count / len(results)) * 100
elif clean_count > 0:
overall_reputation = "clean"
confidence = (clean_count / len(results)) * 100
else:
overall_reputation = "unknown"
confidence = 0
# Extract key information
countries = list(set(r.get("country") for r in results if r.get("country")))
isps = list(set(r.get("isp") for r in results if r.get("isp")))
return {
"indicator": indicator,
"type": indicator_type,
"overall_reputation": overall_reputation,
"confidence": round(confidence, 2),
"sources_checked": len(self.clients),
"sources_responded": len(results),
"malicious_sources": malicious_count,
"clean_sources": clean_count,
"countries": countries,
"isps": isps,
"detailed_results": results,
"errors": errors,
"timestamp": datetime.now().isoformat(),
}