Skip to main content
Glama
test_registry_fallback.py4.79 kB
import json import os from pathlib import Path import types import pytest from providers.registry import ModelProviderRegistry as R from providers.base import ModelProvider, ProviderType class DummyProvider(ModelProvider): SUPPORTED_MODELS = {"modelA": None, "modelB": None} def __init__(self, api_key: str = "", base_url: str | None = None): self._api_key = api_key def get_provider_type(self) -> ProviderType: return ProviderType.CUSTOM def validate_model_name(self, model_name: str) -> bool: return model_name in self.SUPPORTED_MODELS def list_models(self, respect_restrictions: bool = True) -> list[str]: return list(self.SUPPORTED_MODELS.keys()) def get_capabilities(self, model_name: str): return types.SimpleNamespace(model_name=model_name) def get_preferred_model(self, category, allowed_models): return None def generate_content(self, prompt, model_name, **kwargs): raise NotImplementedError def count_tokens(self, text: str, model_name: str) -> int: return len(text) def test_fallback_chain_order_and_retry(monkeypatch): # Reset registry R.reset_for_testing() # Register two providers with overlapping models R.register_provider(ProviderType.CUSTOM, DummyProvider) # Monkeypatch provider instance to simulate one failure then success prov = R.get_provider(ProviderType.CUSTOM) calls = {"modelA": 0, "modelB": 0} def call_fn(model): calls[model] += 1 if model == "modelA": raise RuntimeError("fail A") return types.SimpleNamespace(content="ok", usage={"input_tokens": 10, "output_tokens": 5}) # Force fallback order to [modelA, modelB] by patching helper monkeypatch.setattr(R, "_auggie_fallback_chain", lambda c: ["modelA", "modelB"]) resp = R.call_with_fallback(None, call_fn) assert resp.content == "ok" assert calls["modelA"] == 1 and calls["modelB"] == 1 tel = R.get_telemetry() assert tel["modelA"]["failure"] >= 1 assert tel["modelB"]["success"] >= 1 assert tel["modelB"]["input_tokens"] == 10 assert tel["modelB"]["output_tokens"] == 5 def test_fallback_chain_from_config(tmp_path, monkeypatch): # Prepare auggie-config.json with fallback order cfg = {"auggie": {"fallback": {"chat": ["m1", "m2"], "reasoning": ["m3"]}}} p = tmp_path / "auggie-config.json" p.write_text(json.dumps(cfg), encoding="utf-8") # Load config from auggie.config import load_auggie_config load_auggie_config(p) # Patch registry to believe providers support those models R.reset_for_testing() class Prov(ModelProvider): SUPPORTED_MODELS = {"m1": None, "m2": None, "m3": None} def __init__(self, api_key: str = "", base_url: str | None = None): pass def get_provider_type(self): return ProviderType.CUSTOM def validate_model_name(self, model_name): return model_name in self.SUPPORTED_MODELS def list_models(self, respect_restrictions=True): return list(self.SUPPORTED_MODELS.keys()) def get_capabilities(self, model_name): return types.SimpleNamespace(model_name=model_name) def get_preferred_model(self, category, allowed_models): return None def generate_content(self, *a, **k): raise NotImplementedError def count_tokens(self, text, model_name): return len(text) R.register_provider(ProviderType.CUSTOM, Prov) # Request chain for FAST_RESPONSE (maps to "chat") from tools.models import ToolModelCategory as Cat chain = R._auggie_fallback_chain(Cat.FAST_RESPONSE) assert chain[:2] == ["m1", "m2"] def test_noop_when_no_config(monkeypatch): # No config loaded from auggie.config import load_auggie_config load_auggie_config(path=None) R.reset_for_testing() class Prov(ModelProvider): SUPPORTED_MODELS = {"x": None} def __init__(self, api_key: str = "", base_url: str | None = None): pass def get_provider_type(self): return ProviderType.CUSTOM def validate_model_name(self, model_name): return model_name in self.SUPPORTED_MODELS def list_models(self, respect_restrictions=True): return list(self.SUPPORTED_MODELS.keys()) def get_capabilities(self, model_name): return types.SimpleNamespace(model_name=model_name) def get_preferred_model(self, category, allowed_models): return None def generate_content(self, *a, **k): raise NotImplementedError def count_tokens(self, text, model_name): return len(text) R.register_provider(ProviderType.CUSTOM, Prov) # Chain should at least include a valid model name (fallback to heuristics) chain = R._auggie_fallback_chain(None) assert isinstance(chain, list) and len(chain) >= 1

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Zazzles2908/EX_AI-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server