llm_manager.py•19.5 kB
"""
Production-Grade LLM Manager with Multi-Provider Fallback
Supports: Euri -> Deepseek -> Gemini -> Claude
Features:
- Automatic fallback on failure
- Circuit breaker pattern
- Rate limiting
- Health checks
- Cost tracking
- Retry logic with exponential backoff
- Provider caching
"""
import asyncio
import time
from typing import Optional, Dict, List, Any, Callable
from enum import Enum
from datetime import datetime, timedelta
import structlog
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type
)
# LLM Client imports
import openai # For Euri and Deepseek (OpenAI-compatible)
import anthropic
import google.generativeai as genai
from ..models.llm_models import (
LLMRequest,
LLMResponse,
LLMProvider,
ProviderHealth,
LLMUsageMetrics
)
from .circuit_breaker import CircuitBreaker
from .rate_limiter import RateLimiter
logger = structlog.get_logger(__name__)
class ProviderStatus(Enum):
"""Provider availability status"""
HEALTHY = "healthy"
DEGRADED = "degraded"
FAILED = "failed"
CIRCUIT_OPEN = "circuit_open"
class LLMProviderClient:
"""Individual LLM provider client with health tracking"""
def __init__(
self,
name: str,
config: Dict[str, Any],
circuit_breaker: CircuitBreaker,
rate_limiter: RateLimiter
):
self.name = name
self.config = config
self.circuit_breaker = circuit_breaker
self.rate_limiter = rate_limiter
self.status = ProviderStatus.HEALTHY
self.last_success = datetime.now()
self.last_failure: Optional[datetime] = None
self.total_calls = 0
self.successful_calls = 0
self.failed_calls = 0
self.total_cost = 0.0
# Initialize provider-specific client
self._init_client()
def _init_client(self):
"""Initialize provider-specific API client"""
try:
if self.name == "euron":
self.client = openai.AsyncOpenAI(
api_key=self.config["api_key"],
base_url=self.config.get("api_base", "https://api.euron.one/api/v1/euri"),
timeout=self.config.get("timeout", 30)
)
elif self.name == "deepseek":
self.client = openai.AsyncOpenAI(
api_key=self.config["api_key"],
base_url=self.config.get("api_base", "https://api.deepseek.com/v1"),
timeout=self.config.get("timeout", 30)
)
elif self.name == "gemini":
genai.configure(api_key=self.config["api_key"])
self.client = genai.GenerativeModel(
self.config.get("model", "gemini-pro")
)
elif self.name == "claude":
self.client = anthropic.AsyncAnthropic(
api_key=self.config["api_key"],
timeout=self.config.get("timeout", 30)
)
logger.info(f"Initialized {self.name} client", provider=self.name)
except Exception as e:
logger.error(
f"Failed to initialize {self.name} client",
provider=self.name,
error=str(e)
)
self.status = ProviderStatus.FAILED
raise
async def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
**kwargs
) -> LLMResponse:
"""
Generate completion from this provider
Args:
prompt: User prompt
system_prompt: System prompt (optional)
**kwargs: Additional generation parameters
Returns:
LLMResponse with generated content
"""
# Check circuit breaker
if not self.circuit_breaker.can_proceed():
logger.warning(
f"Circuit breaker open for {self.name}",
provider=self.name
)
raise Exception(f"Circuit breaker open for {self.name}")
# Check rate limit
if not await self.rate_limiter.acquire():
logger.warning(
f"Rate limit exceeded for {self.name}",
provider=self.name
)
raise Exception(f"Rate limit exceeded for {self.name}")
start_time = time.time()
self.total_calls += 1
try:
# Call provider-specific API
if self.name in ["euron", "deepseek"]:
response = await self._generate_openai_compatible(
prompt, system_prompt, **kwargs
)
elif self.name == "gemini":
response = await self._generate_gemini(
prompt, system_prompt, **kwargs
)
elif self.name == "claude":
response = await self._generate_claude(
prompt, system_prompt, **kwargs
)
else:
raise ValueError(f"Unknown provider: {self.name}")
# Track success
self.successful_calls += 1
self.last_success = datetime.now()
self.circuit_breaker.record_success()
self.status = ProviderStatus.HEALTHY
# Calculate latency
latency = time.time() - start_time
logger.info(
f"Successfully generated from {self.name}",
provider=self.name,
latency=f"{latency:.2f}s",
tokens=response.usage.total_tokens
)
return response
except Exception as e:
# Track failure
self.failed_calls += 1
self.last_failure = datetime.now()
self.circuit_breaker.record_failure()
if self.circuit_breaker.is_open():
self.status = ProviderStatus.CIRCUIT_OPEN
else:
self.status = ProviderStatus.FAILED
logger.error(
f"Failed to generate from {self.name}",
provider=self.name,
error=str(e),
status=self.status.value
)
raise
async def _generate_openai_compatible(
self,
prompt: str,
system_prompt: Optional[str],
**kwargs
) -> LLMResponse:
"""Generate using OpenAI-compatible API (Euron, Deepseek)"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = await self.client.chat.completions.create(
model=self.config.get("model", "default"),
messages=messages,
temperature=kwargs.get("temperature", 0.7),
max_tokens=kwargs.get("max_tokens", 2000),
**{k: v for k, v in kwargs.items() if k not in ["temperature", "max_tokens"]}
)
content = response.choices[0].message.content
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
# Estimate cost (placeholder - adjust based on actual pricing)
cost = self._estimate_cost(usage["total_tokens"])
self.total_cost += cost
return LLMResponse(
provider=self.name,
content=content,
usage=usage,
cost=cost,
latency=0, # Will be set by caller
model=self.config.get("model", "default")
)
async def _generate_gemini(
self,
prompt: str,
system_prompt: Optional[str],
**kwargs
) -> LLMResponse:
"""Generate using Google Gemini API"""
full_prompt = prompt
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
# Gemini doesn't have async API yet, so we run in executor
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self.client.generate_content(
full_prompt,
generation_config={
"temperature": kwargs.get("temperature", 0.7),
"max_output_tokens": kwargs.get("max_tokens", 2000),
}
)
)
content = response.text
usage = {
"prompt_tokens": 0, # Gemini doesn't provide token counts
"completion_tokens": 0,
"total_tokens": len(content.split()) # Rough estimate
}
cost = self._estimate_cost(usage["total_tokens"])
self.total_cost += cost
return LLMResponse(
provider=self.name,
content=content,
usage=usage,
cost=cost,
latency=0,
model=self.config.get("model", "gemini-pro")
)
async def _generate_claude(
self,
prompt: str,
system_prompt: Optional[str],
**kwargs
) -> LLMResponse:
"""Generate using Anthropic Claude API"""
response = await self.client.messages.create(
model=self.config.get("model", "claude-3-5-sonnet-20241022"),
system=system_prompt or "",
messages=[{"role": "user", "content": prompt}],
temperature=kwargs.get("temperature", 0.7),
max_tokens=kwargs.get("max_tokens", 2000)
)
content = response.content[0].text
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
}
cost = self._estimate_cost(usage["total_tokens"])
self.total_cost += cost
return LLMResponse(
provider=self.name,
content=content,
usage=usage,
cost=cost,
latency=0,
model=self.config.get("model", "claude-3-5-sonnet-20241022")
)
def _estimate_cost(self, total_tokens: int) -> float:
"""
Estimate cost based on token count
This is a placeholder - update with actual pricing for each provider
"""
# Example pricing (per 1M tokens):
pricing = {
"euron": 0.50,
"deepseek": 0.10,
"gemini": 0.25,
"claude": 3.00
}
rate = pricing.get(self.name, 1.0)
return (total_tokens / 1_000_000) * rate
def get_health(self) -> ProviderHealth:
"""Get provider health metrics"""
success_rate = (
self.successful_calls / self.total_calls
if self.total_calls > 0
else 0.0
)
return ProviderHealth(
provider=self.name,
status=self.status.value,
success_rate=success_rate,
total_calls=self.total_calls,
successful_calls=self.successful_calls,
failed_calls=self.failed_calls,
total_cost=self.total_cost,
last_success=self.last_success,
last_failure=self.last_failure,
circuit_breaker_state=self.circuit_breaker.state.value
)
class LLMManager:
"""
Production-grade LLM manager with automatic fallback
Features:
- Multi-provider fallback (Euri -> Deepseek -> Gemini -> Claude)
- Circuit breaker per provider
- Rate limiting
- Health checks
- Cost tracking
- Provider caching
"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.providers: List[LLMProviderClient] = []
self.provider_map: Dict[str, LLMProviderClient] = {}
self.last_successful_provider: Optional[str] = None
self.total_requests = 0
self.fallback_count = 0
# Initialize providers in priority order
self._init_providers()
logger.info(
"LLM Manager initialized",
providers=[p.name for p in self.providers]
)
def _init_providers(self):
"""Initialize all configured providers"""
provider_configs = self.config.get("llm", {}).get("providers", [])
# Sort by priority
provider_configs.sort(key=lambda x: x.get("priority", 999))
for provider_config in provider_configs:
if not provider_config.get("enabled", True):
continue
name = provider_config["name"]
try:
# Create circuit breaker for this provider
cb_config = self.config.get("llm", {}).get("circuit_breaker", {})
circuit_breaker = CircuitBreaker(
failure_threshold=cb_config.get("failure_threshold", 5),
timeout=cb_config.get("timeout", 60),
recovery_threshold=cb_config.get("recovery_threshold", 3),
name=f"{name}_circuit"
)
# Create rate limiter for this provider
rl_config = self.config.get("llm", {}).get("rate_limit", {})
rate_limiter = RateLimiter(
calls_per_minute=rl_config.get("calls_per_minute", 100),
burst_limit=rl_config.get("burst_limit", 10)
)
# Create provider client
provider = LLMProviderClient(
name=name,
config=provider_config,
circuit_breaker=circuit_breaker,
rate_limiter=rate_limiter
)
self.providers.append(provider)
self.provider_map[name] = provider
logger.info(
f"Initialized provider {name}",
priority=provider_config.get("priority")
)
except Exception as e:
logger.error(
f"Failed to initialize provider {name}",
error=str(e)
)
async def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
force_provider: Optional[str] = None,
**kwargs
) -> LLMResponse:
"""
Generate completion with automatic fallback
Args:
prompt: User prompt
system_prompt: System prompt (optional)
force_provider: Force specific provider (optional)
**kwargs: Additional generation parameters
Returns:
LLMResponse from successful provider
Raises:
Exception if all providers fail
"""
self.total_requests += 1
# If specific provider requested, try only that one
if force_provider:
provider = self.provider_map.get(force_provider)
if not provider:
raise ValueError(f"Provider {force_provider} not found")
logger.info(
f"Using forced provider: {force_provider}",
provider=force_provider
)
return await provider.generate(prompt, system_prompt, **kwargs)
# Try providers in order, with fallback
providers_to_try = self.providers.copy()
# Optimization: Try last successful provider first
if self.last_successful_provider:
last_provider = self.provider_map.get(self.last_successful_provider)
if last_provider and last_provider.status == ProviderStatus.HEALTHY:
providers_to_try.remove(last_provider)
providers_to_try.insert(0, last_provider)
errors = []
for i, provider in enumerate(providers_to_try):
# Skip if circuit breaker is open
if provider.status == ProviderStatus.CIRCUIT_OPEN:
logger.warning(
f"Skipping {provider.name} - circuit breaker open",
provider=provider.name
)
continue
try:
logger.info(
f"Attempting provider {provider.name}",
provider=provider.name,
attempt=i + 1,
total=len(providers_to_try)
)
response = await provider.generate(prompt, system_prompt, **kwargs)
# Success!
self.last_successful_provider = provider.name
if i > 0:
self.fallback_count += 1
logger.info(
f"Fallback successful to {provider.name}",
provider=provider.name,
fallback_position=i + 1
)
return response
except Exception as e:
error_msg = f"{provider.name}: {str(e)}"
errors.append(error_msg)
logger.warning(
f"Provider {provider.name} failed, trying next",
provider=provider.name,
error=str(e)
)
# Continue to next provider
continue
# All providers failed
logger.error(
"All LLM providers failed",
errors=errors,
total_providers=len(providers_to_try)
)
raise Exception(
f"All LLM providers failed. Errors: {'; '.join(errors)}"
)
async def health_check(self) -> Dict[str, Any]:
"""Get health status of all providers"""
health_statuses = {}
for provider in self.providers:
health_statuses[provider.name] = provider.get_health()
return {
"total_requests": self.total_requests,
"fallback_count": self.fallback_count,
"fallback_rate": (
self.fallback_count / self.total_requests
if self.total_requests > 0
else 0.0
),
"providers": health_statuses,
"last_successful_provider": self.last_successful_provider
}
def get_usage_metrics(self) -> LLMUsageMetrics:
"""Get aggregated usage metrics across all providers"""
total_calls = sum(p.total_calls for p in self.providers)
successful_calls = sum(p.successful_calls for p in self.providers)
failed_calls = sum(p.failed_calls for p in self.providers)
total_cost = sum(p.total_cost for p in self.providers)
per_provider_metrics = {
p.name: {
"calls": p.total_calls,
"success_rate": (
p.successful_calls / p.total_calls
if p.total_calls > 0
else 0.0
),
"cost": p.total_cost
}
for p in self.providers
}
return LLMUsageMetrics(
total_requests=self.total_requests,
total_calls=total_calls,
successful_calls=successful_calls,
failed_calls=failed_calls,
total_cost=total_cost,
fallback_count=self.fallback_count,
per_provider=per_provider_metrics
)
async def close(self):
"""Cleanup resources"""
logger.info("Closing LLM Manager")
# Close any open connections if needed