Module agents.llm.fallback_llm

Classes

class FallbackLLM (providers: List[LLM],
temporary_disable_sec: float = 60.0,
permanent_disable_after_attempts: int = 3,
latency_threshold_ms: float | None = None,
consecutive_latency_hits: int = 3)
Expand source code
class FallbackLLM(LLM, FallbackBase):
    """LLM wrapper that automatically fails over to backup providers on errors, latency degradation, and attempts recovery of higher-priority ones."""
    def __init__(
        self,
        providers: List[LLM],
        temporary_disable_sec: float = 60.0,
        permanent_disable_after_attempts: int = 3,
        latency_threshold_ms: Optional[float] = None,
        consecutive_latency_hits: int = 3,
    ):
        LLM.__init__(self)
        FallbackBase.__init__(
            self,
            providers,
            "LLM",
            temporary_disable_sec=temporary_disable_sec,
            permanent_disable_after_attempts=permanent_disable_after_attempts,
            latency_threshold_ms=latency_threshold_ms,
            consecutive_latency_hits=consecutive_latency_hits,
        )
        self._setup_event_listeners()
        self._setup_latency_listener()

    def _setup_event_listeners(self):
        self.active_provider.on("error", self._on_provider_error)

    def _setup_latency_listener(self):
        if self.latency_threshold_ms is None:
            return
        global_event_emitter.on("TURN_METRICS_ADDED", self._on_turn_metrics)

    def _on_turn_metrics(self, event: dict):
        metrics = event.get("metrics") or {}
        function_tools = metrics.get("functionToolsCalled") or []
        mcp_tools = metrics.get("mcpToolMetrics") or []
        if function_tools or mcp_tools:
            return

        ttft = metrics.get("ttft")
        if ttft is None:
            return
        asyncio.create_task(self._record_latency(float(ttft)))

    def _on_provider_error(self, error_msg):
        failed_p = self.active_provider
        asyncio.create_task(self._handle_async_error(str(error_msg), failed_p))

    async def _handle_async_error(self, error_msg: str, failed_provider: Any):
        switched = await self._switch_provider(f"Async Error: {error_msg}", failed_provider=failed_provider)
        self.emit("error", error_msg)

    async def _switch_provider(self, reason: str, failed_provider: Any = None):
        provider_to_cleanup = failed_provider if failed_provider else self.active_provider
        try:
            provider_to_cleanup.off("error", self._on_provider_error)
        except: pass

        active_before = self.active_provider
        switched = await super()._switch_provider(reason, failed_provider)
        active_after = self.active_provider
        
        if switched:
            if active_before != active_after:
                self.active_provider.on("error", self._on_provider_error)
            return True
        return False

    async def chat(self, messages: ChatContext, **kwargs) -> AsyncIterator[LLMResponse]:
        """
        Attempts to chat with current provider. 
        Loops until one succeeds or all fail.
        Checks for recovery of primary providers before starting.
        """
        self.check_recovery()

        while True:
            current_provider = self.active_provider
            try:
                async for chunk in current_provider.chat(messages, **kwargs):
                    yield chunk
                return 
            except Exception as e:
                switched = await self._switch_provider(str(e), failed_provider=current_provider)
                self.emit("error", str(e))
                if not switched:
                    raise e

    async def cancel_current_generation(self) -> None:
        await self.active_provider.cancel_current_generation()

    async def aclose(self) -> None:
        if self.latency_threshold_ms is not None:
            try:
                global_event_emitter.off("TURN_METRICS_ADDED", self._on_turn_metrics)
            except Exception:
                pass
        for p in self.providers:
            await p.aclose()
        await super().aclose()

LLM wrapper that automatically fails over to backup providers on errors, latency degradation, and attempts recovery of higher-priority ones.

Initialize the LLM base class.

Ancestors

Methods

async def chat(self,
messages: ChatContext,
**kwargs) ‑> AsyncIterator[LLMResponse]
Expand source code
async def chat(self, messages: ChatContext, **kwargs) -> AsyncIterator[LLMResponse]:
    """
    Attempts to chat with current provider. 
    Loops until one succeeds or all fail.
    Checks for recovery of primary providers before starting.
    """
    self.check_recovery()

    while True:
        current_provider = self.active_provider
        try:
            async for chunk in current_provider.chat(messages, **kwargs):
                yield chunk
            return 
        except Exception as e:
            switched = await self._switch_provider(str(e), failed_provider=current_provider)
            self.emit("error", str(e))
            if not switched:
                raise e

Attempts to chat with current provider. Loops until one succeeds or all fail. Checks for recovery of primary providers before starting.

Inherited members