Module agents.inference

VideoSDK Inference Gateway Plugins

Lightweight STT, TTS, and Realtime clients that connect to VideoSDK's Inference Gateway. All heavy lifting (API connections, resampling, etc.) is handled server-side.

Usage

from videosdk.inference import STT, TTS, LLM, Realtime

Quick start with factory methods

stt = STT.google() tts = TTS.sarvam(speaker="anushka") realtime = Realtime.gemini(model="gemini-2.0-flash-exp") denoise = Denoise.sanas()

Use with CascadingPipeline

pipeline = CascadingPipeline(stt=stt, llm=llm, tts=tts,denoise=denoise)

Use with RealTimePipeline

pipeline = RealTimePipeline(model=realtime)

Sub-modules

agents.inference.denoise
agents.inference.llm

VideoSDK Inference Gateway LLM Plugin …

agents.inference.realtime
agents.inference.stt
agents.inference.tts

Classes

class Denoise (*,
provider: str,
model_id: str,
sample_rate: int = 48000,
channels: int = 1,
chunk_ms: int = 10,
config: Dict[str, Any] | None = None,
base_url: str | None = None)
Expand source code
class Denoise(BaseDenoise):
    """
    VideoSDK Inference Gateway Denoise Plugin.

    A lightweight noise cancellation client that connects to VideoSDK's Inference Gateway.
    Supports SANAS and AI-Coustics noise cancellation through a unified interface.

    Example:
        # Using factory methods (recommended)
        denoise = Denoise.aicoustics(model_id="sparrow-xxs-48khz")

        # Using generic constructor
        denoise = Denoise(
            provider="aicoustics",
            model_id="sparrow-xxs-48khz",
            config={"sample_rate": 48000}
        )

        # Use in pipeline
        pipeline = CascadingPipeline(
            stt=DeepgramSTT(sample_rate=48000),
            llm=GoogleLLM(),
            tts=ElevenLabsTTS(),
            vad=SileroVAD(input_sample_rate=48000),
            turn_detector=TurnDetector(),
            denoise=denoise
        )
    """

    def __init__(
        self,
        *,
        provider: str,
        model_id: str,
        sample_rate: int = 48000,
        channels: int = 1,
        chunk_ms: int = 10,
        config: Dict[str, Any] | None = None,
        base_url: str | None = None,
    ) -> None:
        """
        Initialize the VideoSDK Inference Denoise plugin.

        Args:
            provider: Denoise provider name (e.g., "aicoustics")
            model_id: Model identifier for the provider
            sample_rate: Audio sample rate in Hz (default: 48000)
            channels: Number of audio channels (default: 1 for mono)
            config: Provider-specific configuration dictionary
            base_url: Custom inference gateway URL (default: production gateway)
        """
        super().__init__()

        self._videosdk_token = os.getenv("VIDEOSDK_AUTH_TOKEN")
        if not self._videosdk_token:
            raise ValueError("VIDEOSDK_AUTH_TOKEN environment variable must be set")

        self.provider = provider
        self.model_id = model_id
        self.sample_rate = sample_rate
        self.channels = channels
        self.chunk_ms = chunk_ms
        self.config = config or {}
        self.base_url = base_url or VIDEOSDK_INFERENCE_URL

        # WebSocket state
        self._session: Optional[aiohttp.ClientSession] = None
        self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
        self._ws_task: Optional[asyncio.Task] = None
        self._config_sent: bool = False
        self.connected: bool = False
        self._shutting_down: bool = False

        # Audio
        self._send_buffer: bytearray = bytearray()
        self._audio_buffer: asyncio.Queue = asyncio.Queue(maxsize=20)
        self._connect_lock: asyncio.Lock | None = None

        # Latency tracking
        # Maps send sequence number → send timestamp (monotonic)
        self._pending_chunks: dict[int, float] = {}
        self._send_seq: int = 0
        self._recv_seq: int = 0

        # Rolling window of round-trip latencies (ms)
        self._latency_window: deque[float] = deque(maxlen=LATENCY_WINDOW)

        # Stats
        self._stats = {
            "chunks_sent": 0,
            "bytes_sent": 0,
            "chunks_received": 0,
            "bytes_received": 0,
            "errors": 0,
            "reconnections": 0,
            "buffer_drops": 0,
            # Latency stats (ms)
            "latency_last_ms": 0.0,
            "latency_avg_ms": 0.0,
            "latency_min_ms": float("inf"),
            "latency_max_ms": 0.0,
            "latency_p95_ms": 0.0,
        }

        logger.info(
            f"[InferenceDenoise] Initialized: provider={provider}, "
            f"model={model_id}, sample_rate={sample_rate}Hz, channels={channels}"
        )

    # ==================== Factory Methods ====================

    @staticmethod
    def aicoustics(
        *,
        model_id: str = "sparrow-xxs-48khz",
        sample_rate: int = 48000,
        channels: int = 1,
        base_url: str | None = None,
    ) -> "Denoise":
        """
        Create a Denoise instance configured for AI-Coustics.

        Args:
            model_id: AI-Coustics model (default: "sparrow-xxs-48khz")
                Sparrow family (human-to-human, 48kHz):
                - "sparrow-xxs-48khz": Ultra-fast, 10ms latency, 1MB
                - "sparrow-s-48khz": Small, 30ms latency, 8.96MB
                - "sparrow-l-48khz": Large, best quality, 30ms latency, 35.1MB

                Quail family (human-to-machine, voice AI, 16kHz):
                - "quail-vf-l-16khz": Voice focus + STT optimization, 35MB
                - "quail-l-16khz": General purpose, 35MB
                - "quail-s-16khz": Faster, 8.88MB

            sample_rate: Audio sample rate in Hz
                - Sparrow models: 48000 Hz (default)
                - Quail models: 16000 Hz
            channels: Number of audio channels (default: 1 for mono)
            base_url: Custom inference gateway URL

        Returns:
            Configured Denoise instance for AI-Coustics

        Example:
            >>> # Ultra-fast for real-time calls
            >>> denoise = Denoise.aicoustics(model_id="sparrow-xxs-48khz")
            >>>
            >>> # Best quality for recordings
            >>> denoise = Denoise.aicoustics(model_id="sparrow-l-48khz")
            >>>
            >>> # Voice AI / STT optimization (16kHz)
            >>> denoise = Denoise.aicoustics(
            ...     model_id="quail-vf-l-16khz",
            ...     sample_rate=16000
            ... )
        """

        return Denoise(
            provider="aicoustics",
            model_id=model_id,
            sample_rate=sample_rate,
            channels=channels,
            chunk_ms=10,
            config={},
            base_url=base_url or VIDEOSDK_INFERENCE_URL,
        )

    @staticmethod
    def sanas(
        *,
        model_id: str = "VI_G_NC3.0",
        sample_rate: int = 16000,
        channels: int = 1,
        base_url: str | None = None,
    ) -> "Denoise":
        """
        Create a Denoise instance configured for Sanas.

        Args:
            model_id: Sanas model (default: "VI_G_NC3.0")

            sample_rate: Audio sample rate in Hz
                - VI_G_NC3.0 - 16000 for noise cancellation
            channels: Number of audio channels (default: 1 for mono)
            base_url: Custom inference gateway URL

        Returns:
            Configured Denoise instance for Sanas

        Example:
            >>> # Ultra-fast for real-time calls
            >>> denoise = Denoise.aicoustics(model_id="VI_G_NC3.0")
            >>>
            >>> # Best quality for recordings
            >>> denoise = Denoise.aicoustics(model_id="VI_G_NC3.0")
            >>>
            >>> # Voice AI / STT optimization (16kHz)
            >>> denoise = Denoise.sanas(
            ...     model_id="VI_G_NC3.0",
            ...     sample_rate=16000
            ... )
        """

        return Denoise(
            provider="sanas",
            model_id=model_id,
            sample_rate=sample_rate,
            channels=channels,
            chunk_ms=20,
            config={},
            base_url=base_url or VIDEOSDK_INFERENCE_URL,
        )

    # ==================== Latency Helpers ====================

    def _record_latency(self, latency_ms: float) -> None:
        """Update all latency stats with a new measurement."""
        self._latency_window.append(latency_ms)

        self._stats["latency_last_ms"] = round(latency_ms, 2)
        self._stats["latency_min_ms"] = round(
            min(self._stats["latency_min_ms"], latency_ms), 2
        )
        self._stats["latency_max_ms"] = round(
            max(self._stats["latency_max_ms"], latency_ms), 2
        )
        self._stats["latency_avg_ms"] = round(
            sum(self._latency_window) / len(self._latency_window), 2
        )

        # p95 over rolling window
        if len(self._latency_window) >= 2:
            sorted_w = sorted(self._latency_window)
            p95_idx = int(len(sorted_w) * 0.95)
            self._stats["latency_p95_ms"] = round(sorted_w[p95_idx], 2)

        logger.debug(
            f"[InferenceDenoise] Latency: last={latency_ms:.1f}ms  "
            f"avg={self._stats['latency_avg_ms']}ms  "
            f"min={self._stats['latency_min_ms']}ms  "
            f"max={self._stats['latency_max_ms']}ms  "
            f"p95={self._stats['latency_p95_ms']}ms"
        )

    def _reset_latency_state(self) -> None:
        """Clear pending chunk map on reconnect so stale timestamps don't pollute stats."""
        self._pending_chunks.clear()
        self._send_seq = 0
        self._recv_seq = 0

    # ==================== Core Denoise ====================

    async def denoise(self, audio_frames: bytes, **kwargs: Any) -> bytes:
        # logger.info(f"Using Sanas secret: {self._secret}")
        # print("enter in denoise")
        try:
            if self._connect_lock is None:
                self._connect_lock = asyncio.Lock()

            frame_size = len(audio_frames)

            if self._shutting_down:
                return audio_frames

            if not self._ws or self._ws.closed:
                if self._connect_lock.locked():
                    return audio_frames

                async with self._connect_lock:
                    if not self._ws or self._ws.closed:
                        try:
                            await self._connect_ws()
                            self.connected = True
                            self._stats["errors"] = 0
                            await self._send_config()

                            chunk_size = (
                                (self.chunk_ms * self.sample_rate // 1000)
                                * self.channels
                                * 2
                            )
                            self._send_buffer.extend(audio_frames)
                            if len(self._send_buffer) >= chunk_size:
                                first_chunk = bytes(self._send_buffer[:chunk_size])
                                del self._send_buffer[:chunk_size]
                                await self._send_audio(first_chunk)

                            if not self._ws_task or self._ws_task.done():
                                self._ws_task = asyncio.create_task(
                                    self._listen_for_responses()
                                )
                            logger.info(
                                f"[InferenceDenoise] Ready (provider={self.provider})"
                            )
                        except Exception as e:
                            logger.error(f"[InferenceDenoise] Setup failed: {e}")
                            self._ws = None
                            self._config_sent = False
                            self._send_buffer.clear()
                            return audio_frames

            if not self._config_sent:
                return audio_frames

            chunk_size = (self.chunk_ms * self.sample_rate // 1000) * self.channels * 2
            self._send_buffer.extend(audio_frames)

            while len(self._send_buffer) >= chunk_size:
                chunk = bytes(self._send_buffer[:chunk_size])
                del self._send_buffer[:chunk_size]
                try:
                    await self._send_audio(chunk)
                except Exception as e:
                    logger.error(f"[InferenceDenoise] Send failed: {e} — resetting")
                    await asyncio.sleep(0.5)
                    self._ws = None
                    self._config_sent = False
                    self._send_buffer.clear()
                    self._reset_latency_state()
                    return audio_frames

            denoised_chunks = []
            while not self._audio_buffer.empty():
                try:
                    denoised_chunks.append(self._audio_buffer.get_nowait())
                except asyncio.QueueEmpty:
                    break

            if denoised_chunks:
                all_denoised = b"".join(denoised_chunks)
                total = len(all_denoised)
                self._stats["chunks_received"] += len(denoised_chunks)
                self._stats["bytes_received"] += total

                if total > frame_size:
                    excess = all_denoised[frame_size:]
                    for i in range(0, len(excess), frame_size):
                        piece = excess[i : i + frame_size]
                        if self._audio_buffer.full():
                            try:
                                self._audio_buffer.get_nowait()
                                self._stats["buffer_drops"] += 1
                            except asyncio.QueueEmpty:
                                pass
                        try:
                            self._audio_buffer.put_nowait(piece)
                        except asyncio.QueueFull:
                            pass
                    return all_denoised[:frame_size]

                return all_denoised

            return audio_frames

        except Exception as e:
            logger.error(f"[InferenceDenoise] Error in denoise: {e}", exc_info=True)
            self._stats["errors"] += 1
            return audio_frames

    # ==================== WebSocket ====================

    async def _connect_ws(self) -> None:
        try:
            if self._shutting_down:
                return

            if not self._session or self._session.closed:
                self._session = aiohttp.ClientSession()

            ws_url = (
                f"{self.base_url}/v1/denoise"
                f"?provider={self.provider}"
                f"&secret={self._videosdk_token}"
                f"&modelId={self.model_id}"
            )

            logger.info(
                f"[InferenceDenoise] Connecting to {self.base_url} "
                f"(provider={self.provider}, model={self.model_id})"
            )

            self._ws = await self._session.ws_connect(
                ws_url, timeout=aiohttp.ClientTimeout(total=10)
            )
            if self._shutting_down:
                await self._ws.close()
                return
            self._config_sent = False
            self._send_buffer.clear()
            self._reset_latency_state()
            logger.info("[InferenceDenoise] Connected successfully")

        except Exception as e:
            logger.error(f"[InferenceDenoise] Connection failed: {e}", exc_info=True)
            raise

    async def _send_config(self) -> None:
        """Send configuration message to the inference server."""
        if not self._ws or self._ws.closed:
            raise ConnectionError("WebSocket not connected")

        config_message = {
            "type": "config",
            "data": {
                "model": self.model_id,
                "sample_rate": self.sample_rate,
                "channels": self.channels,
                **self.config,
            },
        }
        await self._ws.send_str(json.dumps(config_message))
        self._config_sent = True
        logger.info(
            f"[InferenceDenoise] Config sent: "
            f"model={self.model_id}, sample_rate={self.sample_rate}Hz, channels={self.channels}"
        )

    async def _send_audio(self, audio_bytes: bytes) -> None:
        """Send one audio chunk, stamping it with a sequence number for latency tracking."""
        if not self._ws or self._ws.closed:
            raise ConnectionError("WebSocket not connected")

        seq = self._send_seq
        self._send_seq += 1

        # Record send timestamp BEFORE the await so network time is included
        self._pending_chunks[seq] = time.monotonic()

        await self._ws.send_str(
            json.dumps(
                {
                    "type": "audio",
                    "data": base64.b64encode(audio_bytes).decode("utf-8"),
                    "seq": seq,  # server will echo this back if it supports it
                }
            )
        )
        self._stats["chunks_sent"] += 1
        self._stats["bytes_sent"] += len(audio_bytes)

    def _resolve_latency(self, recv_seq: int | None = None) -> None:
        """
        Match a received chunk to a sent chunk and record the round-trip latency.

        If the server echoes 'seq', we match exactly.
        Otherwise we consume the oldest pending timestamp (FIFO approximation).
        """
        now = time.monotonic()

        if recv_seq is not None and recv_seq in self._pending_chunks:
            sent_at = self._pending_chunks.pop(recv_seq)
        elif self._pending_chunks:
            # FIFO: oldest sent chunk corresponds to oldest received chunk
            oldest_seq = min(self._pending_chunks)
            sent_at = self._pending_chunks.pop(oldest_seq)
        else:
            return  # no pending chunk to match

        latency_ms = (now - sent_at) * 1000
        self._record_latency(latency_ms)

    async def _listen_for_responses(self) -> None:
        """Background task to listen for WebSocket responses from the server."""
        if not self._ws:
            return

        try:
            async for msg in self._ws:
                if msg.type == aiohttp.WSMsgType.TEXT:
                    await self._handle_message(msg.data)
                elif msg.type == aiohttp.WSMsgType.BINARY:
                    # Binary frames carry raw denoised PCM — measure latency (FIFO)
                    self._resolve_latency()
                    if self._audio_buffer.full():
                        try:
                            self._audio_buffer.get_nowait()
                            self._stats["buffer_drops"] += 1
                        except asyncio.QueueEmpty:
                            pass
                    try:
                        self._audio_buffer.put_nowait(msg.data)
                    except asyncio.QueueFull:
                        pass
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    logger.error(
                        f"[InferenceDenoise] WebSocket error: {self._ws.exception()}"
                    )
                    break
                elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
                    logger.info("[InferenceDenoise] WebSocket closed by server")
                    break

        except asyncio.CancelledError:
            logger.debug("[InferenceDenoise] Listener cancelled")
        except Exception as e:
            logger.error(f"[InferenceDenoise] Listener error: {e}", exc_info=True)
        finally:
            self._ws = None
            self._config_sent = False
            logger.info("[InferenceDenoise] Listener exited — connection marked dead")

    async def _handle_message(self, raw_message: str) -> None:
        """
        Handle incoming messages from the inference server.

        Args:
            raw_message: Raw JSON message string from server
        """
        # logger.info(f"[STT DEBUG] raw server msg: {raw_message}")
        try:
            data = json.loads(raw_message)
            msg_type = data.get("type")

            if msg_type == "event":
                event_data = data.get("data", {})
                event_type = event_data.get("eventType")

                if event_type == "DENOISE_AUDIO":
                    audio_data = event_data.get("audio", "")
                    # Echo'd seq from server (optional — works without it too)
                    recv_seq = event_data.get("seq", None)
                    if audio_data:
                        self._resolve_latency(recv_seq)
                        denoised = base64.b64decode(audio_data)
                        if self._audio_buffer.full():
                            try:
                                self._audio_buffer.get_nowait()
                                self._stats["buffer_drops"] += 1
                            except asyncio.QueueEmpty:
                                pass
                        try:
                            self._audio_buffer.put_nowait(denoised)
                        except asyncio.QueueFull:
                            pass
                # START_SPEECH, END_SPEECH, TRANSCRIPT silently ignored

            elif msg_type == "audio":
                audio_data = data.get("data", "")
                recv_seq = data.get("seq", None)
                if audio_data:
                    self._resolve_latency(recv_seq)
                    denoised = base64.b64decode(audio_data)
                    if self._audio_buffer.full():
                        try:
                            self._audio_buffer.get_nowait()
                            self._stats["buffer_drops"] += 1
                        except asyncio.QueueEmpty:
                            pass
                    try:
                        self._audio_buffer.put_nowait(denoised)
                    except asyncio.QueueFull:
                        pass

            elif msg_type == "error":
                # logger.error(f"[InferenceDenoise] FULL ERROR MESSAGE: {raw_message}")
                error_data = data.get("data", {})

                # Safely extract error message
                error_msg = (
                    error_data.get("error")
                    or error_data.get("message")
                    or json.dumps(error_data)
                    or "Unknown error"
                )

                self._stats["errors"] += 1

                logger.error(
                    f"[InferenceDenoise] Server error: {error_msg} "
                    f"(total: {self._stats['errors']})"
                )

                # Force reset connection on first error
                if self._stats["errors"] == 1:
                    self._send_buffer.clear()
                    self._config_sent = False

                    if self._ws and not self._ws.closed:
                        try:
                            await self._ws.close()
                        except Exception:
                            pass

                    self._ws = None

        except json.JSONDecodeError as e:
            logger.error(f"[InferenceDenoise] Failed to parse message: {e}")
        except Exception as e:
            logger.error(
                f"[InferenceDenoise] Message handling error: {e}", exc_info=True
            )

    async def _cleanup_connection(self) -> None:
        if self._ws and not self._ws.closed:
            try:
                await asyncio.wait_for(
                    self._ws.send_str(json.dumps({"type": "stop"})), timeout=1.0
                )
                await asyncio.sleep(0.1)
            except Exception:
                pass
            try:
                await self._ws.close()
            except Exception:
                pass

        self._ws = None
        self._config_sent = False
        self._send_buffer.clear()

    # ==================== Utilities ====================

    def get_stats(self) -> Dict[str, Any]:
        """
        Get processing statistics.

        Returns:
            Dictionary containing processing statistics
        """
        return {
            **self._stats,
            "buffer_size": self._audio_buffer.qsize(),
            "pending_chunks": len(self._pending_chunks),
            "provider": self.provider,
            "model": self.model_id,
            "sample_rate": self.sample_rate,
            "channels": self.channels,
            "connected": self._ws is not None and not self._ws.closed,
        }

    def get_latency_stats(self) -> Dict[str, Any]:
        """Return only the latency-related stats — handy for logging/monitoring."""
        return {
            "last_ms": self._stats["latency_last_ms"],
            "avg_ms": self._stats["latency_avg_ms"],
            "min_ms": self._stats["latency_min_ms"],
            "max_ms": self._stats["latency_max_ms"],
            "p95_ms": self._stats["latency_p95_ms"],
            "samples": len(self._latency_window),
        }

    async def aclose(self) -> None:
        logger.info(
            f"[InferenceDenoise] Closing (provider={self.provider}). "
            f"Final stats: {self.get_stats()}"
        )
        self._shutting_down = True
        # Log final latency summary on close
        lat = self.get_latency_stats()
        if lat["samples"] > 0:
            logger.info(
                f"[InferenceDenoise] Latency summary — "
                f"avg={lat['avg_ms']}ms  p95={lat['p95_ms']}ms  "
                f"min={lat['min_ms']}ms  max={lat['max_ms']}ms  "
                f"over {lat['samples']} samples"
            )

        if self._ws_task and not self._ws_task.done():
            self._ws_task.cancel()
            try:
                await asyncio.wait_for(self._ws_task, timeout=2.0)
            except (asyncio.CancelledError, asyncio.TimeoutError):
                pass
            self._ws_task = None

        await self._cleanup_connection()

        if self._session and not self._session.closed:
            await self._session.close()
            self._session = None

        while not self._audio_buffer.empty():
            try:
                self._audio_buffer.get_nowait()
            except asyncio.QueueEmpty:
                break

        await super().aclose()
        logger.info("[InferenceDenoise] Closed successfully")

    @property
    def label(self) -> str:
        return f"videosdk.inference.Denoise.{self.provider}.{self.model_id}"

VideoSDK Inference Gateway Denoise Plugin.

A lightweight noise cancellation client that connects to VideoSDK's Inference Gateway. Supports SANAS and AI-Coustics noise cancellation through a unified interface.

Example

Using factory methods (recommended)

denoise = Denoise.aicoustics(model_id="sparrow-xxs-48khz")

Using generic constructor

denoise = Denoise( provider="aicoustics", model_id="sparrow-xxs-48khz", config={"sample_rate": 48000} )

Use in pipeline

pipeline = CascadingPipeline( stt=DeepgramSTT(sample_rate=48000), llm=GoogleLLM(), tts=ElevenLabsTTS(), vad=SileroVAD(input_sample_rate=48000), turn_detector=TurnDetector(), denoise=denoise )

Initialize the VideoSDK Inference Denoise plugin.

Args

provider
Denoise provider name (e.g., "aicoustics")
model_id
Model identifier for the provider
sample_rate
Audio sample rate in Hz (default: 48000)
channels
Number of audio channels (default: 1 for mono)
config
Provider-specific configuration dictionary
base_url
Custom inference gateway URL (default: production gateway)

Ancestors

  • videosdk.agents.denoise.Denoise
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Static methods

def aicoustics(*,
model_id: str = 'sparrow-xxs-48khz',
sample_rate: int = 48000,
channels: int = 1,
base_url: str | None = None) ‑> Denoise
Expand source code
@staticmethod
def aicoustics(
    *,
    model_id: str = "sparrow-xxs-48khz",
    sample_rate: int = 48000,
    channels: int = 1,
    base_url: str | None = None,
) -> "Denoise":
    """
    Create a Denoise instance configured for AI-Coustics.

    Args:
        model_id: AI-Coustics model (default: "sparrow-xxs-48khz")
            Sparrow family (human-to-human, 48kHz):
            - "sparrow-xxs-48khz": Ultra-fast, 10ms latency, 1MB
            - "sparrow-s-48khz": Small, 30ms latency, 8.96MB
            - "sparrow-l-48khz": Large, best quality, 30ms latency, 35.1MB

            Quail family (human-to-machine, voice AI, 16kHz):
            - "quail-vf-l-16khz": Voice focus + STT optimization, 35MB
            - "quail-l-16khz": General purpose, 35MB
            - "quail-s-16khz": Faster, 8.88MB

        sample_rate: Audio sample rate in Hz
            - Sparrow models: 48000 Hz (default)
            - Quail models: 16000 Hz
        channels: Number of audio channels (default: 1 for mono)
        base_url: Custom inference gateway URL

    Returns:
        Configured Denoise instance for AI-Coustics

    Example:
        >>> # Ultra-fast for real-time calls
        >>> denoise = Denoise.aicoustics(model_id="sparrow-xxs-48khz")
        >>>
        >>> # Best quality for recordings
        >>> denoise = Denoise.aicoustics(model_id="sparrow-l-48khz")
        >>>
        >>> # Voice AI / STT optimization (16kHz)
        >>> denoise = Denoise.aicoustics(
        ...     model_id="quail-vf-l-16khz",
        ...     sample_rate=16000
        ... )
    """

    return Denoise(
        provider="aicoustics",
        model_id=model_id,
        sample_rate=sample_rate,
        channels=channels,
        chunk_ms=10,
        config={},
        base_url=base_url or VIDEOSDK_INFERENCE_URL,
    )

Create a Denoise instance configured for AI-Coustics.

Args

model_id

AI-Coustics model (default: "sparrow-xxs-48khz") Sparrow family (human-to-human, 48kHz): - "sparrow-xxs-48khz": Ultra-fast, 10ms latency, 1MB - "sparrow-s-48khz": Small, 30ms latency, 8.96MB - "sparrow-l-48khz": Large, best quality, 30ms latency, 35.1MB

Quail family (human-to-machine, voice AI, 16kHz): - "quail-vf-l-16khz": Voice focus + STT optimization, 35MB - "quail-l-16khz": General purpose, 35MB - "quail-s-16khz": Faster, 8.88MB

sample_rate
Audio sample rate in Hz - Sparrow models: 48000 Hz (default) - Quail models: 16000 Hz
channels
Number of audio channels (default: 1 for mono)
base_url
Custom inference gateway URL

Returns

Configured Denoise instance for AI-Coustics

Example

>>> # Ultra-fast for real-time calls
>>> denoise = Denoise.aicoustics(model_id="sparrow-xxs-48khz")
>>>
>>> # Best quality for recordings
>>> denoise = Denoise.aicoustics(model_id="sparrow-l-48khz")
>>>
>>> # Voice AI / STT optimization (16kHz)
>>> denoise = Denoise.aicoustics(
...     model_id="quail-vf-l-16khz",
...     sample_rate=16000
... )
def sanas(*,
model_id: str = 'VI_G_NC3.0',
sample_rate: int = 16000,
channels: int = 1,
base_url: str | None = None) ‑> Denoise
Expand source code
@staticmethod
def sanas(
    *,
    model_id: str = "VI_G_NC3.0",
    sample_rate: int = 16000,
    channels: int = 1,
    base_url: str | None = None,
) -> "Denoise":
    """
    Create a Denoise instance configured for Sanas.

    Args:
        model_id: Sanas model (default: "VI_G_NC3.0")

        sample_rate: Audio sample rate in Hz
            - VI_G_NC3.0 - 16000 for noise cancellation
        channels: Number of audio channels (default: 1 for mono)
        base_url: Custom inference gateway URL

    Returns:
        Configured Denoise instance for Sanas

    Example:
        >>> # Ultra-fast for real-time calls
        >>> denoise = Denoise.aicoustics(model_id="VI_G_NC3.0")
        >>>
        >>> # Best quality for recordings
        >>> denoise = Denoise.aicoustics(model_id="VI_G_NC3.0")
        >>>
        >>> # Voice AI / STT optimization (16kHz)
        >>> denoise = Denoise.sanas(
        ...     model_id="VI_G_NC3.0",
        ...     sample_rate=16000
        ... )
    """

    return Denoise(
        provider="sanas",
        model_id=model_id,
        sample_rate=sample_rate,
        channels=channels,
        chunk_ms=20,
        config={},
        base_url=base_url or VIDEOSDK_INFERENCE_URL,
    )

Create a Denoise instance configured for Sanas.

Args

model_id
Sanas model (default: "VI_G_NC3.0")
sample_rate
Audio sample rate in Hz - VI_G_NC3.0 - 16000 for noise cancellation
channels
Number of audio channels (default: 1 for mono)
base_url
Custom inference gateway URL

Returns

Configured Denoise instance for Sanas

Example

>>> # Ultra-fast for real-time calls
>>> denoise = Denoise.aicoustics(model_id="VI_G_NC3.0")
>>>
>>> # Best quality for recordings
>>> denoise = Denoise.aicoustics(model_id="VI_G_NC3.0")
>>>
>>> # Voice AI / STT optimization (16kHz)
>>> denoise = Denoise.sanas(
...     model_id="VI_G_NC3.0",
...     sample_rate=16000
... )

Instance variables

prop label : str
Expand source code
@property
def label(self) -> str:
    return f"videosdk.inference.Denoise.{self.provider}.{self.model_id}"

Get the Denoise provider label

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    logger.info(
        f"[InferenceDenoise] Closing (provider={self.provider}). "
        f"Final stats: {self.get_stats()}"
    )
    self._shutting_down = True
    # Log final latency summary on close
    lat = self.get_latency_stats()
    if lat["samples"] > 0:
        logger.info(
            f"[InferenceDenoise] Latency summary — "
            f"avg={lat['avg_ms']}ms  p95={lat['p95_ms']}ms  "
            f"min={lat['min_ms']}ms  max={lat['max_ms']}ms  "
            f"over {lat['samples']} samples"
        )

    if self._ws_task and not self._ws_task.done():
        self._ws_task.cancel()
        try:
            await asyncio.wait_for(self._ws_task, timeout=2.0)
        except (asyncio.CancelledError, asyncio.TimeoutError):
            pass
        self._ws_task = None

    await self._cleanup_connection()

    if self._session and not self._session.closed:
        await self._session.close()
        self._session = None

    while not self._audio_buffer.empty():
        try:
            self._audio_buffer.get_nowait()
        except asyncio.QueueEmpty:
            break

    await super().aclose()
    logger.info("[InferenceDenoise] Closed successfully")

Cleanup resources

async def denoise(self, audio_frames: bytes, **kwargs: Any) ‑> bytes
Expand source code
async def denoise(self, audio_frames: bytes, **kwargs: Any) -> bytes:
    # logger.info(f"Using Sanas secret: {self._secret}")
    # print("enter in denoise")
    try:
        if self._connect_lock is None:
            self._connect_lock = asyncio.Lock()

        frame_size = len(audio_frames)

        if self._shutting_down:
            return audio_frames

        if not self._ws or self._ws.closed:
            if self._connect_lock.locked():
                return audio_frames

            async with self._connect_lock:
                if not self._ws or self._ws.closed:
                    try:
                        await self._connect_ws()
                        self.connected = True
                        self._stats["errors"] = 0
                        await self._send_config()

                        chunk_size = (
                            (self.chunk_ms * self.sample_rate // 1000)
                            * self.channels
                            * 2
                        )
                        self._send_buffer.extend(audio_frames)
                        if len(self._send_buffer) >= chunk_size:
                            first_chunk = bytes(self._send_buffer[:chunk_size])
                            del self._send_buffer[:chunk_size]
                            await self._send_audio(first_chunk)

                        if not self._ws_task or self._ws_task.done():
                            self._ws_task = asyncio.create_task(
                                self._listen_for_responses()
                            )
                        logger.info(
                            f"[InferenceDenoise] Ready (provider={self.provider})"
                        )
                    except Exception as e:
                        logger.error(f"[InferenceDenoise] Setup failed: {e}")
                        self._ws = None
                        self._config_sent = False
                        self._send_buffer.clear()
                        return audio_frames

        if not self._config_sent:
            return audio_frames

        chunk_size = (self.chunk_ms * self.sample_rate // 1000) * self.channels * 2
        self._send_buffer.extend(audio_frames)

        while len(self._send_buffer) >= chunk_size:
            chunk = bytes(self._send_buffer[:chunk_size])
            del self._send_buffer[:chunk_size]
            try:
                await self._send_audio(chunk)
            except Exception as e:
                logger.error(f"[InferenceDenoise] Send failed: {e} — resetting")
                await asyncio.sleep(0.5)
                self._ws = None
                self._config_sent = False
                self._send_buffer.clear()
                self._reset_latency_state()
                return audio_frames

        denoised_chunks = []
        while not self._audio_buffer.empty():
            try:
                denoised_chunks.append(self._audio_buffer.get_nowait())
            except asyncio.QueueEmpty:
                break

        if denoised_chunks:
            all_denoised = b"".join(denoised_chunks)
            total = len(all_denoised)
            self._stats["chunks_received"] += len(denoised_chunks)
            self._stats["bytes_received"] += total

            if total > frame_size:
                excess = all_denoised[frame_size:]
                for i in range(0, len(excess), frame_size):
                    piece = excess[i : i + frame_size]
                    if self._audio_buffer.full():
                        try:
                            self._audio_buffer.get_nowait()
                            self._stats["buffer_drops"] += 1
                        except asyncio.QueueEmpty:
                            pass
                    try:
                        self._audio_buffer.put_nowait(piece)
                    except asyncio.QueueFull:
                        pass
                return all_denoised[:frame_size]

            return all_denoised

        return audio_frames

    except Exception as e:
        logger.error(f"[InferenceDenoise] Error in denoise: {e}", exc_info=True)
        self._stats["errors"] += 1
        return audio_frames

Process audio frames to denoise them. Denoised audio frames should be sent via the on_denoised_audio callback.

Args

audio_frames
bytes of audio to process
**kwargs
Additional provider-specific arguments
def get_latency_stats(self) ‑> Dict[str, Any]
Expand source code
def get_latency_stats(self) -> Dict[str, Any]:
    """Return only the latency-related stats — handy for logging/monitoring."""
    return {
        "last_ms": self._stats["latency_last_ms"],
        "avg_ms": self._stats["latency_avg_ms"],
        "min_ms": self._stats["latency_min_ms"],
        "max_ms": self._stats["latency_max_ms"],
        "p95_ms": self._stats["latency_p95_ms"],
        "samples": len(self._latency_window),
    }

Return only the latency-related stats — handy for logging/monitoring.

def get_stats(self) ‑> Dict[str, Any]
Expand source code
def get_stats(self) -> Dict[str, Any]:
    """
    Get processing statistics.

    Returns:
        Dictionary containing processing statistics
    """
    return {
        **self._stats,
        "buffer_size": self._audio_buffer.qsize(),
        "pending_chunks": len(self._pending_chunks),
        "provider": self.provider,
        "model": self.model_id,
        "sample_rate": self.sample_rate,
        "channels": self.channels,
        "connected": self._ws is not None and not self._ws.closed,
    }

Get processing statistics.

Returns

Dictionary containing processing statistics

class LLM (*,
provider: str,
model_id: str,
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
base_url: Optional[str] = None,
config: Dict[str, Any] | None = None)
Expand source code
class LLM(BaseLLM):
    """
    VideoSDK Inference Gateway LLM Plugin.

    A lightweight LLM client that connects to VideoSDK's Inference Gateway via HTTP.
    Supports Google Gemini models through a unified interface.

    Example:
        # model_id is consistent with the STT/TTS plugin convention
        llm = LLM.google(model_id="gemini-2.0-flash")
        llm = LLM.google(model_id="gemini-2.5-pro")

        # Use with CascadingPipeline
        pipeline = CascadingPipeline(stt=stt, llm=llm, tts=tts)
    """

    def __init__(
        self,
        *,
        provider: str,
        model_id: str,
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_output_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        base_url: Optional[str] = None,
        config: Dict[str, Any] | None = None,
    ) -> None:
        """
        Initialize the VideoSDK Inference LLM plugin.

        Args:
            provider: LLM provider name (e.g., "google")
            model_id: Model identifier (e.g., "gemini-2.0-flash") — consistent with STT/TTS
            temperature: Controls randomness in responses (0.0 to 1.0)
            tool_choice: Tool calling mode ("auto", "required", "none")
            max_output_tokens: Maximum tokens in model responses
            top_p: Nucleus sampling parameter (0.0 to 1.0)
            top_k: Limits tokens considered for each generation step
            presence_penalty: Penalizes token presence (-2.0 to 2.0)
            frequency_penalty: Penalizes token frequency (-2.0 to 2.0)
            base_url: Custom inference gateway URL
        """
        super().__init__()

        self._videosdk_token = os.getenv("VIDEOSDK_AUTH_TOKEN")
        if not self._videosdk_token:
            raise ValueError(
                "VIDEOSDK_AUTH_TOKEN environment variable must be set for authentication"
            )

        self.provider = provider
        self.model_id = model_id
        self.model = model_id  # OpenAI-compat alias used in request payload
        self.temperature = temperature
        self.tool_choice = tool_choice
        self.max_output_tokens = max_output_tokens
        self.top_p = top_p
        self.top_k = top_k
        self.presence_penalty = presence_penalty
        self.frequency_penalty = frequency_penalty
        self.base_url = base_url or DEFAULT_LLM_HTTP_URL

        # HTTP session state
        self._session: Optional[aiohttp.ClientSession] = None
        self._cancelled: bool = False
        self.config = config or {}

    # ==================== Factory Methods ====================

    @staticmethod
    def google(
        *,
        model_id: str = "gemini-2.0-flash",
        config: Optional[Dict] = None,
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_output_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        base_url: Optional[str] = None,
    ) -> "LLM":
        """
        Create an LLM instance configured for Google Gemini.

        Args:
            model_id: Gemini model identifier (default: "gemini-2.0-flash")
                Options: "gemini-2.0-flash", "gemini-2.0-flash-lite",
                         "gemini-2.5-flash-lite", "gemini-2.5-pro", etc.
            config: Optional extra config dict (merged on top of defaults)
            temperature: Controls randomness in responses (0.0 to 1.0)
            tool_choice: Tool calling mode ("auto", "required", "none")
            max_output_tokens: Maximum tokens in model responses
            top_p: Nucleus sampling parameter (0.0 to 1.0)
            top_k: Limits tokens considered for each generation step
            presence_penalty: Penalizes token presence (-2.0 to 2.0)
            frequency_penalty: Penalizes token frequency (-2.0 to 2.0)
            base_url: Custom inference gateway URL

        Returns:
            Configured LLM instance for Google Gemini
        """
        resolved_config: Dict[str, Any] = {"model_id": model_id}
        if config:
            resolved_config.update(config)

        return LLM(
            provider="google",
            model_id=model_id,
            temperature=temperature,
            tool_choice=tool_choice,
            max_output_tokens=max_output_tokens,
            top_p=top_p,
            top_k=top_k,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
            base_url=base_url,
            config=resolved_config,
        )

    # ==================== Core Methods ====================

    async def chat(
        self,
        messages: ChatContext,
        tools: List[FunctionTool] | None = None,
        conversational_graph: Any | None = None,
        **kwargs: Any,
    ) -> AsyncIterator[LLMResponse]:
        """
        Implement chat functionality using VideoSDK Inference Gateway.

        Args:
            messages: ChatContext containing conversation history
            tools: Optional list of function tools available to the model
            conversational_graph: Optional conversational graph for structured responses
            **kwargs: Additional arguments passed to the inference gateway

        Yields:
            LLMResponse objects containing the model's responses
        """
        self._cancelled = False

        try:
            # Convert messages to OpenAI-compatible format
            formatted_messages = await self._convert_messages_to_dict(messages)

            # Build request payload
            payload: Dict[str, Any] = {
                "model": self.model_id,
                "messages": formatted_messages,
                "stream": True,
                "temperature": self.temperature,
            }

            # Add optional parameters
            if self.max_output_tokens:
                payload["max_tokens"] = self.max_output_tokens
            if self.top_p is not None:
                payload["top_p"] = self.top_p
            if self.top_k is not None:
                payload["top_k"] = self.top_k
            if self.presence_penalty is not None:
                payload["presence_penalty"] = self.presence_penalty
            if self.frequency_penalty is not None:
                payload["frequency_penalty"] = self.frequency_penalty

            # Add conversational graph response format
            if conversational_graph:
                payload["response_format"] = {
                    "type": "json_object",
                    "schema": ConversationalGraphResponse.model_json_schema(),
                }

            # Add tools if provided
            if tools:
                formatted_tools = self._format_tools(tools)
                if formatted_tools:
                    payload["tools"] = formatted_tools
                    payload["tool_choice"] = self.tool_choice

            # Make streaming HTTP request
            async for response in self._stream_request(payload, conversational_graph):
                if self._cancelled:
                    break
                yield response

        except Exception as e:
            traceback.print_exc()
            if not self._cancelled:
                logger.error(f"[InferenceLLM] Error in chat: {e}")
                self.emit("error", e)
            raise

    async def _stream_request(
        self,
        payload: Dict[str, Any],
        conversational_graph: Any | None = None,
    ) -> AsyncIterator[LLMResponse]:
        """
        Make streaming HTTP request to the inference gateway.

        Args:
            payload: Request payload
            conversational_graph: Optional conversational graph for structured responses

        Yields:
            LLMResponse objects
        """
        if not self._session:
            self._session = aiohttp.ClientSession()

        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self._videosdk_token}",
        }

        url = f"{self.base_url}/v1/chat/completions?provider={self.provider}"

        current_content = ""
        streaming_state = {
            "in_response": False,
            "response_start_index": -1,
            "yielded_content_length": 0,
        }

        try:
            logger.debug(
                f"[InferenceLLM] Making request to {self.base_url} "
                f"(provider={self.provider}, model_id={self.model_id})"
            )

            async with self._session.post(
                url,
                json=payload,
                headers=headers,
                timeout=aiohttp.ClientTimeout(total=120, connect=30),
            ) as response:
                if response.status != 200:
                    error_text = await response.text()
                    raise Exception(f"HTTP {response.status}: {error_text}")

                # Process SSE stream
                async for line in response.content:
                    if self._cancelled:
                        break

                    line_str = line.decode("utf-8").strip()
                    if not line_str:
                        continue

                    # Handle SSE format
                    if line_str.startswith("data:"):
                        data_str = line_str[5:].strip()
                        if data_str == "[DONE]":
                            break

                        try:
                            chunk = json.loads(data_str)
                            async for llm_response in self._process_chunk(
                                chunk,
                                current_content,
                                streaming_state,
                                conversational_graph,
                            ):
                                if llm_response.content:
                                    current_content += llm_response.content
                                yield llm_response
                        except json.JSONDecodeError as e:
                            logger.warning(f"[InferenceLLM] Failed to parse chunk: {e}")
                            continue

            # Handle conversational graph final response
            if current_content and conversational_graph and not self._cancelled:
                try:
                    parsed_json = json.loads(current_content.strip())
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata=parsed_json,
                    )
                except json.JSONDecodeError:
                    pass

        except aiohttp.ClientError as e:
            logger.error(f"[InferenceLLM] HTTP request failed: {e}")
            raise

    async def _process_chunk(
        self,
        chunk: Dict[str, Any],
        current_content: str,
        streaming_state: Dict[str, Any],
        conversational_graph: Any,
    ) -> AsyncIterator[LLMResponse]:
        """
        Process a single SSE chunk from the response stream.

        Args:
            chunk: Parsed JSON chunk
            current_content: Accumulated content so far
            streaming_state: State for conversational graph streaming
            conversational_graph: Optional conversational graph

        Yields:
            LLMResponse objects
        """
        choices = chunk.get("choices", [])
        if not choices:
            return

        choice = choices[0]
        delta = choice.get("delta", {})

        # Check for tool calls
        if "tool_calls" in delta:
            for tool_call in delta.get("tool_calls") or []:
                function_data = tool_call.get("function", {})
                function_name = function_data.get("name", "")
                function_args = function_data.get("arguments", "")

                if function_name:
                    try:
                        args_dict = json.loads(function_args) if function_args else {}
                    except json.JSONDecodeError:
                        args_dict = function_args

                    function_call = {
                        "name": function_name,
                        "arguments": args_dict,
                    }
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata={"function_call": function_call},
                    )

        # Check for content
        content = delta.get("content", "")
        if content:
            if conversational_graph:
                full_content = current_content + content
                for (
                    content_chunk
                ) in conversational_graph.stream_conversational_graph_response(
                    full_content, streaming_state
                ):
                    yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT)
            else:
                yield LLMResponse(content=content, role=ChatRole.ASSISTANT)

    async def cancel_current_generation(self) -> None:
        """Cancel the current LLM generation."""
        self._cancelled = True
        logger.debug("[InferenceLLM] Generation cancelled")

    # ==================== Message Conversion ====================

    async def _convert_messages_to_dict(
        self, messages: ChatContext
    ) -> List[Dict[str, Any]]:
        """
        Convert ChatContext to OpenAI-compatible message format.

        Args:
            messages: ChatContext containing conversation history

        Returns:
            List of message dictionaries
        """
        formatted_messages = []

        for item in messages.items:
            if isinstance(item, ChatMessage):
                role = self._map_role(item.role)
                content = await self._format_content(item.content)
                formatted_messages.append({"role": role, "content": content})

            elif isinstance(item, FunctionCall):
                formatted_messages.append(
                    {
                        "role": "assistant",
                        "content": None,
                        "tool_calls": [
                            {
                                "id": f"call_{item.name}",
                                "type": "function",
                                "function": {
                                    "name": item.name,
                                    "arguments": (
                                        item.arguments
                                        if isinstance(item.arguments, str)
                                        else json.dumps(item.arguments)
                                    ),
                                },
                            }
                        ],
                    }
                )

            elif isinstance(item, FunctionCallOutput):
                formatted_messages.append(
                    {
                        "role": "tool",
                        "tool_call_id": f"call_{item.name}",
                        "content": str(item.output),
                    }
                )

        return formatted_messages

    def _map_role(self, role: ChatRole) -> str:
        """Map ChatRole to OpenAI role string."""
        role_map = {
            ChatRole.SYSTEM: "system",
            ChatRole.USER: "user",
            ChatRole.ASSISTANT: "assistant",
        }
        return role_map.get(role, "user")

    async def _format_content(
        self, content: Union[str, List[ChatContent]]
    ) -> Union[str, List[Dict[str, Any]]]:
        """
        Format message content to OpenAI-compatible format.

        Args:
            content: String or list of ChatContent

        Returns:
            Formatted content
        """
        if isinstance(content, str):
            return content

        if len(content) == 1 and isinstance(content[0], str):
            return content[0]

        formatted_parts = []
        for part in content:
            if isinstance(part, str):
                formatted_parts.append({"type": "text", "text": part})
            elif isinstance(part, ImageContent):
                image_url = part.to_data_url()
                image_part = {
                    "type": "image_url",
                    "image_url": {"url": image_url},
                }
                if part.inference_detail != "auto":
                    image_part["image_url"]["detail"] = part.inference_detail
                formatted_parts.append(image_part)

        return formatted_parts if formatted_parts else ""

    # ==================== Tool Formatting ====================

    def _format_tools(self, tools: List[FunctionTool]) -> List[Dict[str, Any]]:
        """
        Format function tools to OpenAI-compatible format.

        Args:
            tools: List of FunctionTool objects

        Returns:
            List of formatted tool dictionaries
        """
        formatted_tools = []

        for tool in tools:
            if not is_function_tool(tool):
                continue

            try:
                gemini_schema = build_gemini_schema(tool)
                formatted_tools.append(
                    {
                        "type": "function",
                        "function": {
                            "name": gemini_schema.get("name", ""),
                            "description": gemini_schema.get("description", ""),
                            "parameters": gemini_schema.get("parameters", {}),
                        },
                    }
                )
            except Exception as e:
                logger.error(f"[InferenceLLM] Failed to format tool: {e}")
                continue

        return formatted_tools

    # ==================== Cleanup ====================

    async def aclose(self) -> None:
        """Clean up all resources."""
        logger.info(f"[InferenceLLM] Closing LLM (provider={self.provider})")

        self._cancelled = True

        if self._session and not self._session.closed:
            await self._session.close()
            self._session = None

        await super().aclose()

        logger.info("[InferenceLLM] Closed successfully")

    # ==================== Properties ====================

    @property
    def label(self) -> str:
        """Get a descriptive label for this LLM instance."""
        return f"videosdk.inference.LLM.{self.provider}.{self.model_id}"

VideoSDK Inference Gateway LLM Plugin.

A lightweight LLM client that connects to VideoSDK's Inference Gateway via HTTP. Supports Google Gemini models through a unified interface.

Example

model_id is consistent with the STT/TTS plugin convention

llm = LLM.google(model_id="gemini-2.0-flash") llm = LLM.google(model_id="gemini-2.5-pro")

Use with CascadingPipeline

pipeline = CascadingPipeline(stt=stt, llm=llm, tts=tts)

Initialize the VideoSDK Inference LLM plugin.

Args

provider
LLM provider name (e.g., "google")
model_id
Model identifier (e.g., "gemini-2.0-flash") — consistent with STT/TTS
temperature
Controls randomness in responses (0.0 to 1.0)
tool_choice
Tool calling mode ("auto", "required", "none")
max_output_tokens
Maximum tokens in model responses
top_p
Nucleus sampling parameter (0.0 to 1.0)
top_k
Limits tokens considered for each generation step
presence_penalty
Penalizes token presence (-2.0 to 2.0)
frequency_penalty
Penalizes token frequency (-2.0 to 2.0)
base_url
Custom inference gateway URL

Ancestors

  • videosdk.agents.llm.llm.LLM
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Static methods

def google(*,
model_id: str = 'gemini-2.0-flash',
config: Optional[Dict] = None,
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
base_url: Optional[str] = None) ‑> LLM
Expand source code
@staticmethod
def google(
    *,
    model_id: str = "gemini-2.0-flash",
    config: Optional[Dict] = None,
    temperature: float = 0.7,
    tool_choice: ToolChoice = "auto",
    max_output_tokens: Optional[int] = None,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    presence_penalty: Optional[float] = None,
    frequency_penalty: Optional[float] = None,
    base_url: Optional[str] = None,
) -> "LLM":
    """
    Create an LLM instance configured for Google Gemini.

    Args:
        model_id: Gemini model identifier (default: "gemini-2.0-flash")
            Options: "gemini-2.0-flash", "gemini-2.0-flash-lite",
                     "gemini-2.5-flash-lite", "gemini-2.5-pro", etc.
        config: Optional extra config dict (merged on top of defaults)
        temperature: Controls randomness in responses (0.0 to 1.0)
        tool_choice: Tool calling mode ("auto", "required", "none")
        max_output_tokens: Maximum tokens in model responses
        top_p: Nucleus sampling parameter (0.0 to 1.0)
        top_k: Limits tokens considered for each generation step
        presence_penalty: Penalizes token presence (-2.0 to 2.0)
        frequency_penalty: Penalizes token frequency (-2.0 to 2.0)
        base_url: Custom inference gateway URL

    Returns:
        Configured LLM instance for Google Gemini
    """
    resolved_config: Dict[str, Any] = {"model_id": model_id}
    if config:
        resolved_config.update(config)

    return LLM(
        provider="google",
        model_id=model_id,
        temperature=temperature,
        tool_choice=tool_choice,
        max_output_tokens=max_output_tokens,
        top_p=top_p,
        top_k=top_k,
        presence_penalty=presence_penalty,
        frequency_penalty=frequency_penalty,
        base_url=base_url,
        config=resolved_config,
    )

Create an LLM instance configured for Google Gemini.

Args

model_id
Gemini model identifier (default: "gemini-2.0-flash") Options: "gemini-2.0-flash", "gemini-2.0-flash-lite", "gemini-2.5-flash-lite", "gemini-2.5-pro", etc.
config
Optional extra config dict (merged on top of defaults)
temperature
Controls randomness in responses (0.0 to 1.0)
tool_choice
Tool calling mode ("auto", "required", "none")
max_output_tokens
Maximum tokens in model responses
top_p
Nucleus sampling parameter (0.0 to 1.0)
top_k
Limits tokens considered for each generation step
presence_penalty
Penalizes token presence (-2.0 to 2.0)
frequency_penalty
Penalizes token frequency (-2.0 to 2.0)
base_url
Custom inference gateway URL

Returns

Configured LLM instance for Google Gemini

Instance variables

prop label : str
Expand source code
@property
def label(self) -> str:
    """Get a descriptive label for this LLM instance."""
    return f"videosdk.inference.LLM.{self.provider}.{self.model_id}"

Get a descriptive label for this LLM instance.

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Clean up all resources."""
    logger.info(f"[InferenceLLM] Closing LLM (provider={self.provider})")

    self._cancelled = True

    if self._session and not self._session.closed:
        await self._session.close()
        self._session = None

    await super().aclose()

    logger.info("[InferenceLLM] Closed successfully")

Clean up all resources.

async def cancel_current_generation(self) ‑> None
Expand source code
async def cancel_current_generation(self) -> None:
    """Cancel the current LLM generation."""
    self._cancelled = True
    logger.debug("[InferenceLLM] Generation cancelled")

Cancel the current LLM generation.

async def chat(self,
messages: ChatContext,
tools: List[FunctionTool] | None = None,
conversational_graph: Any | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]
Expand source code
async def chat(
    self,
    messages: ChatContext,
    tools: List[FunctionTool] | None = None,
    conversational_graph: Any | None = None,
    **kwargs: Any,
) -> AsyncIterator[LLMResponse]:
    """
    Implement chat functionality using VideoSDK Inference Gateway.

    Args:
        messages: ChatContext containing conversation history
        tools: Optional list of function tools available to the model
        conversational_graph: Optional conversational graph for structured responses
        **kwargs: Additional arguments passed to the inference gateway

    Yields:
        LLMResponse objects containing the model's responses
    """
    self._cancelled = False

    try:
        # Convert messages to OpenAI-compatible format
        formatted_messages = await self._convert_messages_to_dict(messages)

        # Build request payload
        payload: Dict[str, Any] = {
            "model": self.model_id,
            "messages": formatted_messages,
            "stream": True,
            "temperature": self.temperature,
        }

        # Add optional parameters
        if self.max_output_tokens:
            payload["max_tokens"] = self.max_output_tokens
        if self.top_p is not None:
            payload["top_p"] = self.top_p
        if self.top_k is not None:
            payload["top_k"] = self.top_k
        if self.presence_penalty is not None:
            payload["presence_penalty"] = self.presence_penalty
        if self.frequency_penalty is not None:
            payload["frequency_penalty"] = self.frequency_penalty

        # Add conversational graph response format
        if conversational_graph:
            payload["response_format"] = {
                "type": "json_object",
                "schema": ConversationalGraphResponse.model_json_schema(),
            }

        # Add tools if provided
        if tools:
            formatted_tools = self._format_tools(tools)
            if formatted_tools:
                payload["tools"] = formatted_tools
                payload["tool_choice"] = self.tool_choice

        # Make streaming HTTP request
        async for response in self._stream_request(payload, conversational_graph):
            if self._cancelled:
                break
            yield response

    except Exception as e:
        traceback.print_exc()
        if not self._cancelled:
            logger.error(f"[InferenceLLM] Error in chat: {e}")
            self.emit("error", e)
        raise

Implement chat functionality using VideoSDK Inference Gateway.

Args

messages
ChatContext containing conversation history
tools
Optional list of function tools available to the model
conversational_graph
Optional conversational graph for structured responses
**kwargs
Additional arguments passed to the inference gateway

Yields

LLMResponse objects containing the model's responses

class Realtime (*,
provider: str,
model: str,
config: GeminiRealtimeConfig | None = None,
base_url: str | None = None)
Expand source code
class Realtime(RealtimeBaseModel[RealtimeEventTypes]):
    """
    VideoSDK Inference Gateway Realtime Plugin.

    A lightweight multimodal realtime client that connects to VideoSDK's Inference Gateway.
    Supports Gemini's realtime model for audio-first communication.

    Example:
        # Using factory method (recommended)
        model = Realtime.gemini(
            model="gemini-2.0-flash-exp",
            voice="Puck",
            language_code="en-US",
        )

        # Use with RealTimePipeline
        pipeline = RealTimePipeline(model=model)
    """

    def __init__(
        self,
        *,
        provider: str,
        model: str,
        config: GeminiRealtimeConfig | None = None,
        base_url: str | None = None,
    ) -> None:
        """
        Initialize the VideoSDK Inference Realtime plugin.

        Args:
            provider: Realtime provider name (currently only "gemini" supported)
            model: Model identifier (e.g., "gemini-2.0-flash-exp")
            config: Provider-specific configuration
            base_url: Custom inference gateway URL (default: production gateway)
        """
        super().__init__()

        self._videosdk_token = os.getenv("VIDEOSDK_AUTH_TOKEN")
        if not self._videosdk_token:
            raise ValueError(
                "VIDEOSDK_AUTH_TOKEN environment variable must be set for authentication"
            )

        self.provider = provider
        self.model = model
        self.model_id = model.split("/")[-1] if "/" in model else model
        self.config = config or GeminiRealtimeConfig()
        self.base_url = base_url or VIDEOSDK_INFERENCE_URL

        # WebSocket state
        self._session: Optional[aiohttp.ClientSession] = None
        self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
        self._ws_task: Optional[asyncio.Task] = None
        self._config_sent: bool = False
        self._closing: bool = False

        # Audio state
        self.target_sample_rate = 24000
        self.input_sample_rate = 48000

        # Speaking state tracking
        self._user_speaking: bool = False
        self._agent_speaking: bool = False

        # Agent configuration
        self._instructions: str = (
            "You are a helpful voice assistant that can answer questions and help with tasks."
        )
        self.tools: List[FunctionTool] = []
        self.tools_formatted: List[Dict] = []

    # ==================== Factory Methods ====================

    @staticmethod
    def gemini(
        *,
        model: str = "gemini-2.5-flash-native-audio-preview-12-2025",
        voice: Voice = "Puck",
        language_code: str = "en-US",
        temperature: float | None = None,
        top_p: float | None = None,
        top_k: float | None = None,
        candidate_count: int = 1,
        max_output_tokens: int | None = None,
        presence_penalty: float | None = None,
        frequency_penalty: float | None = None,
        response_modalities: List[str] | None = None,
        base_url: str | None = None,
    ) -> "Realtime":
        """
        Create a Realtime instance configured for Google Gemini.

        Args:
            model: Gemini model identifier (default: "gemini-2.0-flash-exp")
            voice: Voice ID for audio output. Options: 'Puck', 'Charon', 'Kore', 'Fenrir', 'Aoede'
            language_code: Language code for speech synthesis (default: "en-US")
            temperature: Controls randomness in responses (0.0 to 1.0)
            top_p: Nucleus sampling parameter (0.0 to 1.0)
            top_k: Limits tokens considered for each generation step
            candidate_count: Number of response candidates (default: 1)
            max_output_tokens: Maximum tokens in model responses
            presence_penalty: Penalizes token presence (-2.0 to 2.0)
            frequency_penalty: Penalizes token frequency (-2.0 to 2.0)
            response_modalities: Response types ["TEXT", "AUDIO"] (default: ["AUDIO"])
            base_url: Custom inference gateway URL

        Returns:
            Configured Realtime instance for Gemini
        """
        config = GeminiRealtimeConfig(
            voice=voice,
            language_code=language_code,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            candidate_count=candidate_count,
            max_output_tokens=max_output_tokens,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
            response_modalities=response_modalities or ["AUDIO"],
        )

        return Realtime(
            provider="google",
            model=model,
            config=config,
            base_url=base_url,
        )

    # ==================== Agent Setup ====================

    def set_agent(self, agent: Agent) -> None:
        """Set agent instructions and tools."""
        self._instructions = agent.instructions
        self.tools = agent.tools or []
        self.tools_formatted = self._convert_tools_to_format(self.tools)

    def _convert_tools_to_format(self, tools: List[FunctionTool]) -> List[Dict]:
        """Convert tool definitions to Gemini format."""
        function_declarations = []

        for tool in tools:
            if not is_function_tool(tool):
                continue

            try:
                function_declaration = build_gemini_schema(tool)
                function_declarations.append(function_declaration)
            except Exception as e:
                logger.error(f"[InferenceRealtime] Failed to format tool {tool}: {e}")
                continue

        return function_declarations

    # ==================== Connection Management ====================

    async def connect(self) -> None:
        """Connect to the inference gateway."""
        if self._ws and not self._ws.closed:
            await self._cleanup_connection()

        self._closing = False

        try:
            # Create audio track if needed
            if not self.audio_track and "AUDIO" in self.config.response_modalities:
                logger.warning("[InferenceRealtime] audio_track not set — it should be assigned externally by the pipeline before connect().")

            # Connect to WebSocket
            await self._connect_ws()

            # Start listening for responses
            if not self._ws_task or self._ws_task.done():
                self._ws_task = asyncio.create_task(
                    self._listen_for_responses(), name="inference-realtime-listener"
                )

            logger.info(
                f"[InferenceRealtime] Connected to inference gateway (provider={self.provider})"
            )

        except Exception as e:
            self.emit("error", f"Error connecting to inference gateway: {e}")
            logger.error(f"[InferenceRealtime] Connection error: {e}")
            raise

    async def _connect_ws(self) -> None:
        """Establish WebSocket connection to the inference gateway."""
        if not self._session:
            self._session = aiohttp.ClientSession()

        # Build WebSocket URL with query parameters
        ws_url = (
            f"{self.base_url}/v1/llm"
            f"?provider={self.provider}"
            f"&secret={self._videosdk_token}"
            f"&modelId={self.model_id}"
        )

        try:
            logger.info(f"[InferenceRealtime] Connecting to {self.base_url}")
            self._ws = await self._session.ws_connect(ws_url)
            self._config_sent = False
            logger.info("[InferenceRealtime] WebSocket connected successfully")
        except Exception as e:
            logger.error(f"[InferenceRealtime] WebSocket connection failed: {e}")
            raise

    async def _send_config(self) -> None:
        """Send configuration to the inference server."""
        if not self._ws:
            return

        config_data = {
            "model": (
                f"models/{self.model}"
                if not self.model.startswith("models/")
                else self.model
            ),
            "instructions": self._instructions,
            "voice": self.config.voice,
            "language_code": self.config.language_code,
            "temperature": self.config.temperature,
            "top_p": self.config.top_p,
            "top_k": self.config.top_k,
            "candidate_count": self.config.candidate_count,
            "max_output_tokens": self.config.max_output_tokens,
            "presence_penalty": self.config.presence_penalty,
            "frequency_penalty": self.config.frequency_penalty,
            "response_modalities": self.config.response_modalities,
        }

        # Add tools if available
        if self.tools_formatted:
            config_data["tools"] = self.tools_formatted

        config_message = {
            "type": "config",
            "data": config_data,
        }

        try:
            await self._ws.send_str(json.dumps(config_message))
            self._config_sent = True
            logger.info(
                f"[InferenceRealtime] Config sent: voice={self.config.voice}, modalities={self.config.response_modalities}"
            )

        except Exception as e:
            logger.error(f"[InferenceRealtime] Failed to send config: {e}")
            raise

    # ==================== Audio Input ====================

    async def handle_audio_input(self, audio_data: bytes) -> None:
        """Handle incoming audio data from the user."""
        if not self._ws or self._closing:
            return

        if self.current_utterance and not self.current_utterance.is_interruptible:
            return

        if "AUDIO" not in self.config.response_modalities:
            return

        try:
            # Ensure connection and config
            if self._ws.closed:
                await self._connect_ws()
                if not self._ws_task or self._ws_task.done():
                    self._ws_task = asyncio.create_task(self._listen_for_responses())

            if not self._config_sent:
                await self._send_config()

            # Resample audio from 48kHz to 24kHz (expected by Gemini)
            audio_array = np.frombuffer(audio_data, dtype=np.int16)
            audio_array = signal.resample(
                audio_array,
                int(
                    len(audio_array) * self.target_sample_rate / self.input_sample_rate
                ),
            )
            audio_data = audio_array.astype(np.int16).tobytes()

            # Send audio as base64
            audio_message = {
                "type": "audio",
                "data": base64.b64encode(audio_data).decode("utf-8"),
            }

            await self._ws.send_str(json.dumps(audio_message))

        except Exception as e:
            logger.error(f"[InferenceRealtime] Error sending audio: {e}")
            self.emit("error", str(e))

    async def handle_video_input(self, video_data: av.VideoFrame) -> None:
        """Handle incoming video data from the user (vision mode)."""
        if not self._ws or self._closing:
            return

        try:
            if not video_data or not video_data.planes:
                return

            # Rate limit video frames
            now = time.monotonic()

            if (
                hasattr(self, "_last_video_frame")
                and (now - self._last_video_frame) < 0.5
            ):

                return
            self._last_video_frame = now

            # Encode frame as JPEG
            processed_jpeg = encode_image(video_data, DEFAULT_IMAGE_ENCODE_OPTIONS)
            if not processed_jpeg or len(processed_jpeg) < 100:
                logger.warning("[InferenceRealtime] Invalid JPEG data generated")
                return

            # Send video as base64
            video_message = {
                "type": "video",
                "data": base64.b64encode(processed_jpeg).decode("utf-8"),
            }

            await self._ws.send_str(json.dumps(video_message))

        except Exception as e:
            logger.error(f"[InferenceRealtime] Error sending video: {e}")

    # ==================== Text Messages ====================

    async def send_message(self, message: str) -> None:
        """Send a text message to get audio response."""
        if not self._ws or self._closing:
            logger.warning("[InferenceRealtime] Cannot send message: not connected")
            return

        try:
            if not self._config_sent:
                await self._send_config()

            text_message = {
                "type": "text",
                "data": f"Please start the conversation by saying exactly this, without any additional text: '{message}'",
            }

            await self._ws.send_str(json.dumps(text_message))
            logger.debug(f"[InferenceRealtime] Sent message: {message[:50]}...")

        except Exception as e:
            logger.error(f"[InferenceRealtime] Error sending message: {e}")
            self.emit("error", str(e))

    async def send_text_message(self, message: str) -> None:
        """Send a text message for text-only communication."""
        if not self._ws or self._closing:
            logger.warning("[InferenceRealtime] Cannot send text: not connected")
            return

        try:
            if not self._config_sent:
                await self._send_config()

            text_message = {
                "type": "text",
                "data": message,
            }

            await self._ws.send_str(json.dumps(text_message))

        except Exception as e:
            logger.error(f"[InferenceRealtime] Error sending text message: {e}")
            self.emit("error", str(e))

    async def send_message_with_frames(
        self, message: str, frames: List[av.VideoFrame]
    ) -> None:
        """Send a text message with video frames for vision-enabled communication."""
        if not self._ws or self._closing:
            logger.warning(
                "[InferenceRealtime] Cannot send message with frames: not connected"
            )

        try:
            if not self._config_sent:
                await self._send_config()

            # Encode frames as base64
            encoded_frames = []
            for frame in frames:
                try:
                    processed_jpeg = encode_image(frame, DEFAULT_IMAGE_ENCODE_OPTIONS)
                    if processed_jpeg and len(processed_jpeg) >= 100:
                        encoded_frames.append(
                            base64.b64encode(processed_jpeg).decode("utf-8")
                        )

                except Exception as e:
                    logger.error(f"[InferenceRealtime] Error encoding frame: {e}")

            # Send message with frames
            message_with_frames = {
                "type": "text_with_frames",
                "data": {
                    "text": message,
                    "frames": encoded_frames,
                },
            }

            await self._ws.send_str(json.dumps(message_with_frames))

        except Exception as e:
            logger.error(f"[InferenceRealtime] Error sending message with frames: {e}")
            self.emit("error", str(e))

    # ==================== Interruption ====================

    async def interrupt(self) -> None:
        """Interrupt current response."""
        if not self._ws or self._closing:
            return

        if self.current_utterance and not self.current_utterance.is_interruptible:
            logger.info(
                "[InferenceRealtime] Utterance not interruptible, skipping interrupt"
            )
            return

        try:
            interrupt_message = {"type": "interrupt"}
            await self._ws.send_str(json.dumps(interrupt_message))

            self.emit("agent_speech_ended", {})
            metrics_collector.on_interrupted()

            if self.audio_track and "AUDIO" in self.config.response_modalities:
                self.audio_track.interrupt()

            logger.debug("[InferenceRealtime] Sent interrupt signal")

        except Exception as e:
            logger.error(f"[InferenceRealtime] Interrupt error: {e}")
            self.emit("error", str(e))

    # ==================== Response Handling ====================

    async def _listen_for_responses(self) -> None:
        """Background task to listen for WebSocket responses from the server."""
        if not self._ws:
            return

        accumulated_input_text = ""
        accumulated_output_text = ""

        try:
            async for msg in self._ws:
                if self._closing:
                    break

                if msg.type == aiohttp.WSMsgType.TEXT:

                    accumulated_input_text, accumulated_output_text = (
                        await self._handle_message(
                            msg.data, accumulated_input_text, accumulated_output_text
                        )
                    )
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    logger.error(
                        f"[InferenceRealtime] WebSocket error: {self._ws.exception()}"
                    )
                    self.emit("error", f"WebSocket error: {self._ws.exception()}")
                    break
                elif msg.type == aiohttp.WSMsgType.CLOSED:
                    logger.info("[InferenceRealtime] WebSocket closed by server")
                    break

        except asyncio.CancelledError:
            logger.debug("[InferenceRealtime] WebSocket listener cancelled")
        except Exception as e:
            logger.error(f"[InferenceRealtime] Error in WebSocket listener: {e}")
            self.emit("error", str(e))
        finally:
            await self._cleanup_connection()

    async def _handle_message(
        self,
        raw_message: str,
        accumulated_input_text: str,
        accumulated_output_text: str,
    ) -> tuple[str, str]:
        """Handle incoming messages from the inference server."""
        try:
            data = json.loads(raw_message)
            msg_type = data.get("type")

            if msg_type == "audio":
                await self._handle_audio_data(data.get("data", {}))

            elif msg_type == "event":
                event_data = data.get("data", {})

                accumulated_input_text, accumulated_output_text = (
                    await self._handle_event(
                        event_data, accumulated_input_text, accumulated_output_text
                    )
                )

            elif msg_type == "error":
                error_msg = data.get("data", {}).get("error") or data.get(
                    "message", "Unknown error"
                )

                logger.error(f"[InferenceRealtime] Server error: {error_msg}")
                self.emit("error", error_msg)

        except json.JSONDecodeError as e:
            logger.error(f"[InferenceRealtime] Failed to parse message: {e}")
        except Exception as e:
            logger.error(f"[InferenceRealtime] Error handling message: {e}")

        return accumulated_input_text, accumulated_output_text

    async def _handle_audio_data(self, audio_data: Dict[str, Any]) -> None:
        """Handle audio data received from the inference server."""
        audio_b64 = audio_data.get("audio")
        if not audio_b64:
            return

        if not self.audio_track:
            logger.warning("[InferenceRealtime] Audio track not available")
            return

        try:
            audio_bytes = base64.b64decode(audio_b64)

            # Ensure even number of bytes
            if len(audio_bytes) % 2 != 0:
                audio_bytes += b"\x00"

            if not self._agent_speaking:
                self._agent_speaking = True
                self.emit("agent_speech_started", {})
                metrics_collector.on_agent_speech_start()

            await self.audio_track.add_new_bytes(audio_bytes)

        except Exception as e:
            logger.error(f"[InferenceRealtime] Error processing audio data: {e}")

    async def _handle_event(
        self,
        event_data: Dict[str, Any],
        accumulated_input_text: str,
        accumulated_output_text: str,
    ) -> tuple[str, str]:
        """Handle event messages from the inference server."""
        event_type = event_data.get("eventType")

        if event_type == "user_speech_started":
            if not self._user_speaking:
                self._user_speaking = True
                metrics_collector.on_user_speech_start()
                metrics_collector.start_turn()
                self.emit("user_speech_started", {"type": "done"})

        elif event_type == "user_speech_ended":
            if self._user_speaking:
                self._user_speaking = False
                metrics_collector.on_user_speech_end()
                self.emit("user_speech_ended", {})

        elif event_type == "agent_speech_started":
            if not self._agent_speaking:
                self._agent_speaking = True
                self.emit("agent_speech_started", {})
                metrics_collector.on_agent_speech_start()

        elif event_type == "agent_speech_ended":
            if self._agent_speaking:
                self._agent_speaking = False
                self.emit("agent_speech_ended", {})
                metrics_collector.on_agent_speech_end()
                metrics_collector.schedule_turn_complete(timeout=1.0)

        elif event_type == "input_transcription":
            text = event_data.get("text", "")
            if text:
                accumulated_input_text = text
                global_event_emitter.emit(
                    "input_transcription",
                    {"text": accumulated_input_text, "is_final": False},
                )

        elif event_type == "output_transcription":
            text = event_data.get("text", "")
            if text:
                accumulated_output_text += text
                global_event_emitter.emit(
                    "output_transcription",
                    {"text": accumulated_output_text, "is_final": False},
                )

        elif event_type == "user_transcript":
            text = event_data.get("text", "")
            if text:
                metrics_collector.set_user_transcript(text)
                self.emit(
                    "realtime_model_transcription",
                    {"role": "user", "text": text, "is_final": True},
                )
                accumulated_input_text = ""

        elif event_type == "text_response":
            text = event_data.get("text", "")
            is_final = event_data.get("is_final", False)
            if text and is_final:
                metrics_collector.set_agent_response(text)
                self.emit(
                    "realtime_model_transcription",
                    {"role": "agent", "text": text, "is_final": True},
                )

                global_event_emitter.emit(
                    "text_response", {"type": "done", "text": text}
                )

                accumulated_output_text = ""

        elif event_type == "response_interrupted":
            if self.audio_track and "AUDIO" in self.config.response_modalities:
                self.audio_track.interrupt()

        return accumulated_input_text, accumulated_output_text

    # ==================== Cleanup ====================

    async def _cleanup_connection(self) -> None:
        """Clean up WebSocket connection resources."""
        if self._ws and not self._ws.closed:
            try:
                # Send stop message before closing
                stop_message = {"type": "stop"}
                await self._ws.send_str(json.dumps(stop_message))
            except Exception:
                pass

            try:
                await self._ws.close()
            except Exception:
                pass

        self._ws = None
        self._config_sent = False

    async def aclose(self) -> None:
        """Clean up all resources."""
        logger.info(f"[InferenceRealtime] Closing (provider={self.provider})")

        self._closing = True

        # Cancel listener task
        if self._ws_task:
            self._ws_task.cancel()
            try:
                await self._ws_task
            except asyncio.CancelledError:
                pass
            self._ws_task = None

        # Close WebSocket
        await self._cleanup_connection()

        # Close HTTP session
        if self._session and not self._session.closed:
            await self._session.close()
            self._session = None

        await super().aclose()  # Nulls audio_track and loop

        logger.info("[InferenceRealtime] Closed successfully")

    # ==================== Properties ====================

    @property
    def label(self) -> str:
        """Get a descriptive label for this Realtime instance."""
        return f"videosdk.inference.Realtime.{self.provider}.{self.model_id}"

VideoSDK Inference Gateway Realtime Plugin.

A lightweight multimodal realtime client that connects to VideoSDK's Inference Gateway. Supports Gemini's realtime model for audio-first communication.

Example

Using factory method (recommended)

model = Realtime.gemini( model="gemini-2.0-flash-exp", voice="Puck", language_code="en-US", )

Use with RealTimePipeline

pipeline = RealTimePipeline(model=model)

Initialize the VideoSDK Inference Realtime plugin.

Args

provider
Realtime provider name (currently only "gemini" supported)
model
Model identifier (e.g., "gemini-2.0-flash-exp")
config
Provider-specific configuration
base_url
Custom inference gateway URL (default: production gateway)

Ancestors

  • videosdk.agents.realtime_base_model.RealtimeBaseModel
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic
  • abc.ABC

Static methods

def gemini(*,
model: str = 'gemini-2.5-flash-native-audio-preview-12-2025',
voice: Voice = 'Puck',
language_code: str = 'en-US',
temperature: float | None = None,
top_p: float | None = None,
top_k: float | None = None,
candidate_count: int = 1,
max_output_tokens: int | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
response_modalities: List[str] | None = None,
base_url: str | None = None) ‑> Realtime
Expand source code
@staticmethod
def gemini(
    *,
    model: str = "gemini-2.5-flash-native-audio-preview-12-2025",
    voice: Voice = "Puck",
    language_code: str = "en-US",
    temperature: float | None = None,
    top_p: float | None = None,
    top_k: float | None = None,
    candidate_count: int = 1,
    max_output_tokens: int | None = None,
    presence_penalty: float | None = None,
    frequency_penalty: float | None = None,
    response_modalities: List[str] | None = None,
    base_url: str | None = None,
) -> "Realtime":
    """
    Create a Realtime instance configured for Google Gemini.

    Args:
        model: Gemini model identifier (default: "gemini-2.0-flash-exp")
        voice: Voice ID for audio output. Options: 'Puck', 'Charon', 'Kore', 'Fenrir', 'Aoede'
        language_code: Language code for speech synthesis (default: "en-US")
        temperature: Controls randomness in responses (0.0 to 1.0)
        top_p: Nucleus sampling parameter (0.0 to 1.0)
        top_k: Limits tokens considered for each generation step
        candidate_count: Number of response candidates (default: 1)
        max_output_tokens: Maximum tokens in model responses
        presence_penalty: Penalizes token presence (-2.0 to 2.0)
        frequency_penalty: Penalizes token frequency (-2.0 to 2.0)
        response_modalities: Response types ["TEXT", "AUDIO"] (default: ["AUDIO"])
        base_url: Custom inference gateway URL

    Returns:
        Configured Realtime instance for Gemini
    """
    config = GeminiRealtimeConfig(
        voice=voice,
        language_code=language_code,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        candidate_count=candidate_count,
        max_output_tokens=max_output_tokens,
        presence_penalty=presence_penalty,
        frequency_penalty=frequency_penalty,
        response_modalities=response_modalities or ["AUDIO"],
    )

    return Realtime(
        provider="google",
        model=model,
        config=config,
        base_url=base_url,
    )

Create a Realtime instance configured for Google Gemini.

Args

model
Gemini model identifier (default: "gemini-2.0-flash-exp")
voice
Voice ID for audio output. Options: 'Puck', 'Charon', 'Kore', 'Fenrir', 'Aoede'
language_code
Language code for speech synthesis (default: "en-US")
temperature
Controls randomness in responses (0.0 to 1.0)
top_p
Nucleus sampling parameter (0.0 to 1.0)
top_k
Limits tokens considered for each generation step
candidate_count
Number of response candidates (default: 1)
max_output_tokens
Maximum tokens in model responses
presence_penalty
Penalizes token presence (-2.0 to 2.0)
frequency_penalty
Penalizes token frequency (-2.0 to 2.0)
response_modalities
Response types ["TEXT", "AUDIO"] (default: ["AUDIO"])
base_url
Custom inference gateway URL

Returns

Configured Realtime instance for Gemini

Instance variables

prop label : str
Expand source code
@property
def label(self) -> str:
    """Get a descriptive label for this Realtime instance."""
    return f"videosdk.inference.Realtime.{self.provider}.{self.model_id}"

Get a descriptive label for this Realtime instance.

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Clean up all resources."""
    logger.info(f"[InferenceRealtime] Closing (provider={self.provider})")

    self._closing = True

    # Cancel listener task
    if self._ws_task:
        self._ws_task.cancel()
        try:
            await self._ws_task
        except asyncio.CancelledError:
            pass
        self._ws_task = None

    # Close WebSocket
    await self._cleanup_connection()

    # Close HTTP session
    if self._session and not self._session.closed:
        await self._session.close()
        self._session = None

    await super().aclose()  # Nulls audio_track and loop

    logger.info("[InferenceRealtime] Closed successfully")

Clean up all resources.

async def connect(self) ‑> None
Expand source code
async def connect(self) -> None:
    """Connect to the inference gateway."""
    if self._ws and not self._ws.closed:
        await self._cleanup_connection()

    self._closing = False

    try:
        # Create audio track if needed
        if not self.audio_track and "AUDIO" in self.config.response_modalities:
            logger.warning("[InferenceRealtime] audio_track not set — it should be assigned externally by the pipeline before connect().")

        # Connect to WebSocket
        await self._connect_ws()

        # Start listening for responses
        if not self._ws_task or self._ws_task.done():
            self._ws_task = asyncio.create_task(
                self._listen_for_responses(), name="inference-realtime-listener"
            )

        logger.info(
            f"[InferenceRealtime] Connected to inference gateway (provider={self.provider})"
        )

    except Exception as e:
        self.emit("error", f"Error connecting to inference gateway: {e}")
        logger.error(f"[InferenceRealtime] Connection error: {e}")
        raise

Connect to the inference gateway.

async def handle_audio_input(self, audio_data: bytes) ‑> None
Expand source code
async def handle_audio_input(self, audio_data: bytes) -> None:
    """Handle incoming audio data from the user."""
    if not self._ws or self._closing:
        return

    if self.current_utterance and not self.current_utterance.is_interruptible:
        return

    if "AUDIO" not in self.config.response_modalities:
        return

    try:
        # Ensure connection and config
        if self._ws.closed:
            await self._connect_ws()
            if not self._ws_task or self._ws_task.done():
                self._ws_task = asyncio.create_task(self._listen_for_responses())

        if not self._config_sent:
            await self._send_config()

        # Resample audio from 48kHz to 24kHz (expected by Gemini)
        audio_array = np.frombuffer(audio_data, dtype=np.int16)
        audio_array = signal.resample(
            audio_array,
            int(
                len(audio_array) * self.target_sample_rate / self.input_sample_rate
            ),
        )
        audio_data = audio_array.astype(np.int16).tobytes()

        # Send audio as base64
        audio_message = {
            "type": "audio",
            "data": base64.b64encode(audio_data).decode("utf-8"),
        }

        await self._ws.send_str(json.dumps(audio_message))

    except Exception as e:
        logger.error(f"[InferenceRealtime] Error sending audio: {e}")
        self.emit("error", str(e))

Handle incoming audio data from the user.

async def handle_video_input(self, video_data: av.VideoFrame) ‑> None
Expand source code
async def handle_video_input(self, video_data: av.VideoFrame) -> None:
    """Handle incoming video data from the user (vision mode)."""
    if not self._ws or self._closing:
        return

    try:
        if not video_data or not video_data.planes:
            return

        # Rate limit video frames
        now = time.monotonic()

        if (
            hasattr(self, "_last_video_frame")
            and (now - self._last_video_frame) < 0.5
        ):

            return
        self._last_video_frame = now

        # Encode frame as JPEG
        processed_jpeg = encode_image(video_data, DEFAULT_IMAGE_ENCODE_OPTIONS)
        if not processed_jpeg or len(processed_jpeg) < 100:
            logger.warning("[InferenceRealtime] Invalid JPEG data generated")
            return

        # Send video as base64
        video_message = {
            "type": "video",
            "data": base64.b64encode(processed_jpeg).decode("utf-8"),
        }

        await self._ws.send_str(json.dumps(video_message))

    except Exception as e:
        logger.error(f"[InferenceRealtime] Error sending video: {e}")

Handle incoming video data from the user (vision mode).

async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    """Interrupt current response."""
    if not self._ws or self._closing:
        return

    if self.current_utterance and not self.current_utterance.is_interruptible:
        logger.info(
            "[InferenceRealtime] Utterance not interruptible, skipping interrupt"
        )
        return

    try:
        interrupt_message = {"type": "interrupt"}
        await self._ws.send_str(json.dumps(interrupt_message))

        self.emit("agent_speech_ended", {})
        metrics_collector.on_interrupted()

        if self.audio_track and "AUDIO" in self.config.response_modalities:
            self.audio_track.interrupt()

        logger.debug("[InferenceRealtime] Sent interrupt signal")

    except Exception as e:
        logger.error(f"[InferenceRealtime] Interrupt error: {e}")
        self.emit("error", str(e))

Interrupt current response.

async def send_message(self, message: str) ‑> None
Expand source code
async def send_message(self, message: str) -> None:
    """Send a text message to get audio response."""
    if not self._ws or self._closing:
        logger.warning("[InferenceRealtime] Cannot send message: not connected")
        return

    try:
        if not self._config_sent:
            await self._send_config()

        text_message = {
            "type": "text",
            "data": f"Please start the conversation by saying exactly this, without any additional text: '{message}'",
        }

        await self._ws.send_str(json.dumps(text_message))
        logger.debug(f"[InferenceRealtime] Sent message: {message[:50]}...")

    except Exception as e:
        logger.error(f"[InferenceRealtime] Error sending message: {e}")
        self.emit("error", str(e))

Send a text message to get audio response.

async def send_message_with_frames(self, message: str, frames: List[av.VideoFrame]) ‑> None
Expand source code
async def send_message_with_frames(
    self, message: str, frames: List[av.VideoFrame]
) -> None:
    """Send a text message with video frames for vision-enabled communication."""
    if not self._ws or self._closing:
        logger.warning(
            "[InferenceRealtime] Cannot send message with frames: not connected"
        )

    try:
        if not self._config_sent:
            await self._send_config()

        # Encode frames as base64
        encoded_frames = []
        for frame in frames:
            try:
                processed_jpeg = encode_image(frame, DEFAULT_IMAGE_ENCODE_OPTIONS)
                if processed_jpeg and len(processed_jpeg) >= 100:
                    encoded_frames.append(
                        base64.b64encode(processed_jpeg).decode("utf-8")
                    )

            except Exception as e:
                logger.error(f"[InferenceRealtime] Error encoding frame: {e}")

        # Send message with frames
        message_with_frames = {
            "type": "text_with_frames",
            "data": {
                "text": message,
                "frames": encoded_frames,
            },
        }

        await self._ws.send_str(json.dumps(message_with_frames))

    except Exception as e:
        logger.error(f"[InferenceRealtime] Error sending message with frames: {e}")
        self.emit("error", str(e))

Send a text message with video frames for vision-enabled communication.

async def send_text_message(self, message: str) ‑> None
Expand source code
async def send_text_message(self, message: str) -> None:
    """Send a text message for text-only communication."""
    if not self._ws or self._closing:
        logger.warning("[InferenceRealtime] Cannot send text: not connected")
        return

    try:
        if not self._config_sent:
            await self._send_config()

        text_message = {
            "type": "text",
            "data": message,
        }

        await self._ws.send_str(json.dumps(text_message))

    except Exception as e:
        logger.error(f"[InferenceRealtime] Error sending text message: {e}")
        self.emit("error", str(e))

Send a text message for text-only communication.

def set_agent(self, agent: Agent) ‑> None
Expand source code
def set_agent(self, agent: Agent) -> None:
    """Set agent instructions and tools."""
    self._instructions = agent.instructions
    self.tools = agent.tools or []
    self.tools_formatted = self._convert_tools_to_format(self.tools)

Set agent instructions and tools.

class STT (*,
provider: str,
model_id: str,
language: str = 'en-US',
config: Dict[str, Any] | None = None,
enable_streaming: bool = True,
base_url: str | None = None)
Expand source code
class STT(BaseSTT):
    """
    VideoSDK Inference Gateway STT Plugin.

    A lightweight Speech-to-Text client that connects to VideoSDK's Inference Gateway.
    Supports multiple providers (Google, Sarvam, Deepgram) through a unified interface.

    Example:
        # Using factory methods (recommended)
        stt = STT.google(language="en-US")
        stt = STT.sarvam(language="en-IN")

        # Using generic constructor
        stt = STT(provider="google", model_id="chirp_3", config={"language": "en-US"})
    """

    def __init__(
        self,
        *,
        # provider: InferenceProvider,
        provider: str,
        model_id: str,
        language: str = "en-US",
        config: Dict[str, Any] | None = None,
        enable_streaming: bool = True,
        base_url: str | None = None,
    ) -> None:
        """
        Initialize the VideoSDK Inference STT plugin.

        Args:
            provider: STT provider name (e.g., "google", "sarvamai", "deepgram")
            model_id: Model identifier for the provider
            language: Language code (default: "en-US")
            config: Provider-specific configuration dictionary
            enable_streaming: Enable streaming transcription (default: True)
            base_url: Custom inference gateway URL (default: production gateway)
        """
        super().__init__()

        self._videosdk_token = os.getenv("VIDEOSDK_AUTH_TOKEN")
        if not self._videosdk_token:
            raise ValueError(
                "VIDEOSDK_AUTH_TOKEN environment variable must be set for authentication"
            )

        self.provider = provider
        self.model_id = model_id
        self.language = language
        self.config = config or {}
        self.enable_streaming = enable_streaming
        self.base_url = base_url or VIDEOSDK_INFERENCE_URL

        # WebSocket state
        self._session: Optional[aiohttp.ClientSession] = None
        self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
        self._ws_task: Optional[asyncio.Task] = None
        self._config_sent: bool = False

        # Speech state tracking
        self._is_speaking: bool = False
        self._last_transcript: str = ""

        # Metrics tracking
        self._stt_start_time: Optional[float] = None
        self._connecting: bool = False

    # ==================== Factory Methods ====================

    @staticmethod
    def google(
        *,
        model_id: str = "chirp_3",
        language: str = "en-US",
        languages: list[str] | None = None,
        interim_results: bool = True,
        punctuate: bool = True,
        location: str = "asia-south1",
        input_sample_rate: int = 48000,
        output_sample_rate: int = 16000,
        enable_streaming: bool = True,
        base_url: str | None = None,
        config: Optional[Dict] = None,
    ) -> "STT":
        """
        Create an STT instance configured for Google Cloud Speech-to-Text.

        Args:
            model_id: Google STT model (default: "chirp_3"). Options: "chirp_3", "latest_long", "latest_short"
            language: Primary language code (default: "en-US")
            languages: List of languages for auto-detection (default: [language])
            interim_results: Return interim transcription results (default: True)
            punctuate: Add punctuation to transcripts (default: True)
            location: Google Cloud region (default: "asia-south1")
            input_sample_rate: Input audio sample rate (default: 48000)
            output_sample_rate: Output sample rate for processing (default: 16000)
            enable_streaming: Enable streaming mode (default: True)
            base_url: Custom inference gateway URL

        Returns:
            Configured STT instance for Google
        """
        config = {
            "model": model_id,
            "language": language,
            "languages": languages or [language],
            "input_sample_rate": input_sample_rate,
            "output_sample_rate": output_sample_rate,
            "interim_results": interim_results,
            "punctuate": punctuate,
            "location": location,
            **(config or {}),
        }
        return STT(
            provider="google",
            model_id=model_id,
            language=language,
            config=config,
            enable_streaming=enable_streaming,
            base_url=base_url,
        )

    @staticmethod
    def sarvam(
        *,
        model_id: str = "saarika:v2.5",
        language: str = "en-IN",
        input_sample_rate: int = 48000,
        output_sample_rate: int = 16000,
        enable_streaming: bool = True,
        base_url: str | None = None,
        config: Optional[Dict] = None,
    ) -> "STT":
        """
        Create an STT instance configured for Sarvam AI.

        Args:
            model_id: Sarvam model (default: "saarika:v2.5")
            language: Language code (default: "en-IN"). Supports Indian languages.
            input_sample_rate: Input audio sample rate (default: 48000)
            output_sample_rate: Output sample rate for processing (default: 16000)
            enable_streaming: Enable streaming mode (default: True)
            base_url: Custom inference gateway URL

        Returns:
            Configured STT instance for Sarvam AI
        """
        config = {
            "model": model_id,
            "language": language,
            "input_sample_rate": input_sample_rate,
            "output_sample_rate": output_sample_rate,
            **(config or {}),
        }
        return STT(
            provider="sarvam",
            model_id=model_id,
            language=language,
            config=config,
            enable_streaming=enable_streaming,
            base_url=base_url,
        )

    @staticmethod
    def deepgram(
        *,
        model_id: str = "nova-2",
        language: str = "en-US",
        input_sample_rate: int = 48000,
        interim_results: bool = True,
        punctuate: bool = True,
        smart_format: bool = True,
        endpointing: int = 50,
        enable_streaming: bool = True,
        base_url: str | None = None,
        config: Optional[Dict] = None,
        eager_eot_threshold: float = 0.6,
        eot_threshold: float = 0.8,
        eot_timeout_ms: int = 7000,
        keyterm: list[str] | None = None,
    ) -> "STT":
        """
        Create an STT instance configured for Deepgram.

        Args:
            model_id: Deepgram model (default: "nova-2")
            language: Language code (default: "en-US")
            input_sample_rate: Input audio sample rate (default: 48000)
            interim_results: Return interim transcription results (default: True)
            punctuate: Add punctuation to transcripts (default: True)
            smart_format: Enable smart formatting (default: True)
            endpointing: Endpointing threshold in ms (default: 50)
            enable_streaming: Enable streaming mode (default: True)
            base_url: Custom inference gateway URL

        Returns:
            Configured STT instance for Deepgram
        """
        # config = config or DeepgramSTTInferenceConfig()

        # default_config = config.model_dump(exclude_none=True)
        # _config = {
        #     "model": model_id,
        #     "language": language,
        #     **default_config,
        # }
        config = {
            "model": model_id,
            "language": language,
            "input_sample_rate": input_sample_rate,
            "interim_results": interim_results,
            "punctuate": punctuate,
            "smart_format": smart_format,
            "endpointing": endpointing,
            "eager_eot_threshold": eager_eot_threshold,
            "eot_threshold": eot_threshold,
            "eot_timeout_ms": eot_timeout_ms,
            "keyterm": keyterm,
            **(config or {}),
        }
        return STT(
            provider="deepgram",
            model_id=model_id,
            language=language,
            enable_streaming=enable_streaming,
            config=config,
            base_url=base_url,
        )

    # ==================== Core Methods ====================

    async def flush(self) -> None:
        """Signal end-of-speech to the inference server."""
        if not self._ws or self._ws.closed:
            return
        try:
            await self._ws.send_str(json.dumps({"type": "flush"}))
        except Exception as e:
            logger.debug(f"[InferenceSTT] Flush error: {e}")

    async def process_audio(
        self,
        audio_frames: bytes,
        language: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """
        Process audio frames and send to the inference server.

        Args:
            audio_frames: Raw PCM audio bytes (16-bit, typically 48kHz stereo)
            language: Optional language override
            **kwargs: Additional arguments (unused)
        """
        # logger.info(f"[STT DEBUG] process_audio called | bytes={len(audio_frames)}")
        if not self.enable_streaming:
            logger.warning("Non-streaming mode not yet supported for inference STT")
            return

        try:
            if self._connecting:
                return

            if not self._ws or self._ws.closed:
                self._connecting = True
                try:
                    await self._connect_ws()
                    if not self._ws_task or self._ws_task.done():
                        self._ws_task = asyncio.create_task(
                            self._listen_for_responses()
                        )
                finally:
                    self._connecting = False

            if not self._config_sent:
                await self._send_config()

            await self._send_audio(audio_frames)
            await asyncio.sleep(0)

        except Exception as e:
            logger.error(f"[InferenceSTT] Error in process_audio: {e}")
            self.emit("error", str(e))
            await self._cleanup_connection()

    async def _connect_ws(self) -> None:
        """Establish WebSocket connection to the inference gateway."""
        # logger.info(f"[STT DEBUG] WS URL: {ws_url}")

        if not self._session:
            self._session = aiohttp.ClientSession()

        # Build WebSocket URL with query parameters
        ws_url = (
            f"{self.base_url}/v1/stt"
            f"?provider={self.provider}"
            f"&secret={self._videosdk_token}"
            f"&modelId={self.model_id}"
        )
        # logger.info(f"[STT DEBUG] Connecting to: {ws_url}")
        try:
            logger.info(
                f"[InferenceSTT] Connecting to {self.base_url} (provider={self.provider})"
            )
            self._ws = await self._session.ws_connect(ws_url, heartbeat=20)
            self._config_sent = False
            logger.info(f"[InferenceSTT] Connected successfully")
        except Exception as e:
            logger.error(f"[InferenceSTT] Connection failed: {e}")
            raise

    async def _send_config(self) -> None:
        """Send configuration message to the inference server."""
        if not self._ws:
            return

        config_message = {"type": "config", "data": self.config}
        try:
            await self._ws.send_str(json.dumps(config_message))
            self._config_sent = True
            logger.info(f"[InferenceSTT] Config sent: {self.config}")
        except Exception as e:
            logger.error(f"[InferenceSTT] Failed to send config: {e}")
            raise

    async def _send_audio(self, audio_bytes: bytes) -> None:
        """Send audio data to the inference server."""
        # logger.info(f"[STT DEBUG] sending audio chunk | bytes={len(audio_bytes)}")
        if not self._ws:
            return

        # Track STT start time for metrics
        if self._stt_start_time is None:
            self._stt_start_time = time.perf_counter()

        # Encode audio as base64 for JSON transmission
        audio_message = {
            "type": "audio",
            "data": base64.b64encode(audio_bytes).decode("utf-8"),
        }
        try:
            await self._ws.send_str(json.dumps(audio_message))
        except Exception as e:
            logger.error(f"[InferenceSTT] Failed to send audio: {e}")
            raise

    async def _listen_for_responses(self) -> None:
        """Background task to listen for WebSocket responses from the server."""
        if not self._ws:
            return

        try:
            async for msg in self._ws:
                if msg.type == aiohttp.WSMsgType.TEXT:
                    await self._handle_message(msg.data)
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    logger.error(
                        f"[InferenceSTT] WebSocket error: {self._ws.exception()}"
                    )
                    self.emit("error", f"WebSocket error: {self._ws.exception()}")
                    break
                elif msg.type == aiohttp.WSMsgType.CLOSED:
                    logger.info("[InferenceSTT] WebSocket closed by server")
                    break

        except asyncio.CancelledError:
            logger.debug("[InferenceSTT] WebSocket listener cancelled")
        except Exception as e:
            logger.error(f"[InferenceSTT] Error in WebSocket listener: {e}")
            self.emit("error", str(e))
        finally:
            await self._cleanup_connection()

    async def _handle_message(self, raw_message: str) -> None:
        """
        Handle incoming messages from the inference server.

        Args:
            raw_message: Raw JSON message string from server
        """
        # logger.info(f"[STT DEBUG] raw message from server: {raw_message}")
        try:
            data = json.loads(raw_message)
            msg_type = data.get("type")

            if msg_type == "event":
                await self._handle_event(data.get("data", {}))
            elif msg_type == "error":
                error_msg = data.get("data", {}).get("error") or data.get(
                    "message", "Unknown error"
                )
                logger.error(f"[InferenceSTT] Server error: {error_msg}")
                self.emit("error", error_msg)

        except json.JSONDecodeError as e:
            logger.error(f"[InferenceSTT] Failed to parse message: {e}")
        except Exception as e:
            logger.error(f"[InferenceSTT] Error handling message: {e}")

    async def _handle_event(self, event_data: Dict[str, Any]) -> None:
        """
        Handle event messages from the inference server.

        Args:
            event_data: Event data dictionary
        """
        event_type = event_data.get("eventType")

        if event_type == "TRANSCRIPT":
            text = event_data.get("text", "")
            language = event_data.get("language", self.language)
            is_final = event_data.get("is_final", True)
            confidence = event_data.get("confidence", 1.0)

            # before we send the text if there is an small intruption like khasi and it trigger the endspeech and send the blank text to llm so sarvam tts cant processed that blank space
            # if not self._has_enough_content(text):
            #     logger.debug(f"[InferenceSTT] Invalid transcript skipped: '{text}'")
            #     return
            if text.strip():
                logger.info(f"[STT] {text} | Final: {is_final}")
            self._last_transcript = text.strip()

            response = STTResponse(
                event_type=(
                    SpeechEventType.FINAL if is_final else SpeechEventType.INTERIM
                ),
                data=SpeechData(
                    text=text.strip(),
                    language=language,
                    confidence=confidence,
                ),
                metadata={
                    "provider": self.provider,
                    "model": self.model_id,
                },
            )

            if self._transcript_callback:
                await self._transcript_callback(response)

            if is_final:
                self._stt_start_time = None

        elif event_type == "START_SPEECH":
            if not self._is_speaking:
                self._is_speaking = True
                global_event_emitter.emit("speech_started")
                logger.debug("[InferenceSTT] Speech started")

        elif event_type == "END_SPEECH":
            if self._is_speaking:
                self._is_speaking = False
                if self._last_transcript:
                    global_event_emitter.emit("speech_stopped")
                    logger.debug("[InferenceSTT] Speech ended (transcript present)")
                else:
                    logger.debug(
                        "[InferenceSTT] Speech ended but no transcript — suppressing speech_stopped"
                    )
        elif event_type == "DENOISE_AUDIO":
            logger.warning("[STT] Received DENOISE_AUDIO event — ignoring")
            return

    async def _cleanup_connection(self) -> None:
        """Clean up WebSocket connection resources."""
        if self._ws and not self._ws.closed:
            try:
                await self._ws.close()
            except Exception:
                pass
        self._ws = None
        self._config_sent = False

    async def aclose(self) -> None:
        """Clean up all resources."""
        logger.info(f"[InferenceSTT] Closing STT (provider={self.provider})")

        # Cancel listener task
        if self._ws_task:
            self._ws_task.cancel()
            try:
                await self._ws_task
            except asyncio.CancelledError:
                pass
            self._ws_task = None

        # Close WebSocket
        await self._cleanup_connection()

        # Close HTTP session
        if self._session and not self._session.closed:
            await self._session.close()
            self._session = None

        # Call parent cleanup
        await super().aclose()
        logger.info(f"[InferenceSTT] Closed successfully")

    # ==================== Properties ====================

    @property
    def label(self) -> str:
        """Get a descriptive label for this STT instance."""
        return f"videosdk.inference.STT.{self.provider}.{self.model_id}"

VideoSDK Inference Gateway STT Plugin.

A lightweight Speech-to-Text client that connects to VideoSDK's Inference Gateway. Supports multiple providers (Google, Sarvam, Deepgram) through a unified interface.

Example

Using factory methods (recommended)

stt = STT.google(language="en-US") stt = STT.sarvam(language="en-IN")

Using generic constructor

stt = STT(provider="google", model_id="chirp_3", config={"language": "en-US"})

Initialize the VideoSDK Inference STT plugin.

Args

provider
STT provider name (e.g., "google", "sarvamai", "deepgram")
model_id
Model identifier for the provider
language
Language code (default: "en-US")
config
Provider-specific configuration dictionary
enable_streaming
Enable streaming transcription (default: True)
base_url
Custom inference gateway URL (default: production gateway)

Ancestors

  • videosdk.agents.stt.stt.STT
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Static methods

def deepgram(*,
model_id: str = 'nova-2',
language: str = 'en-US',
input_sample_rate: int = 48000,
interim_results: bool = True,
punctuate: bool = True,
smart_format: bool = True,
endpointing: int = 50,
enable_streaming: bool = True,
base_url: str | None = None,
config: Optional[Dict] = None,
eager_eot_threshold: float = 0.6,
eot_threshold: float = 0.8,
eot_timeout_ms: int = 7000,
keyterm: list[str] | None = None) ‑> STT
Expand source code
@staticmethod
def deepgram(
    *,
    model_id: str = "nova-2",
    language: str = "en-US",
    input_sample_rate: int = 48000,
    interim_results: bool = True,
    punctuate: bool = True,
    smart_format: bool = True,
    endpointing: int = 50,
    enable_streaming: bool = True,
    base_url: str | None = None,
    config: Optional[Dict] = None,
    eager_eot_threshold: float = 0.6,
    eot_threshold: float = 0.8,
    eot_timeout_ms: int = 7000,
    keyterm: list[str] | None = None,
) -> "STT":
    """
    Create an STT instance configured for Deepgram.

    Args:
        model_id: Deepgram model (default: "nova-2")
        language: Language code (default: "en-US")
        input_sample_rate: Input audio sample rate (default: 48000)
        interim_results: Return interim transcription results (default: True)
        punctuate: Add punctuation to transcripts (default: True)
        smart_format: Enable smart formatting (default: True)
        endpointing: Endpointing threshold in ms (default: 50)
        enable_streaming: Enable streaming mode (default: True)
        base_url: Custom inference gateway URL

    Returns:
        Configured STT instance for Deepgram
    """
    # config = config or DeepgramSTTInferenceConfig()

    # default_config = config.model_dump(exclude_none=True)
    # _config = {
    #     "model": model_id,
    #     "language": language,
    #     **default_config,
    # }
    config = {
        "model": model_id,
        "language": language,
        "input_sample_rate": input_sample_rate,
        "interim_results": interim_results,
        "punctuate": punctuate,
        "smart_format": smart_format,
        "endpointing": endpointing,
        "eager_eot_threshold": eager_eot_threshold,
        "eot_threshold": eot_threshold,
        "eot_timeout_ms": eot_timeout_ms,
        "keyterm": keyterm,
        **(config or {}),
    }
    return STT(
        provider="deepgram",
        model_id=model_id,
        language=language,
        enable_streaming=enable_streaming,
        config=config,
        base_url=base_url,
    )

Create an STT instance configured for Deepgram.

Args

model_id
Deepgram model (default: "nova-2")
language
Language code (default: "en-US")
input_sample_rate
Input audio sample rate (default: 48000)
interim_results
Return interim transcription results (default: True)
punctuate
Add punctuation to transcripts (default: True)
smart_format
Enable smart formatting (default: True)
endpointing
Endpointing threshold in ms (default: 50)
enable_streaming
Enable streaming mode (default: True)
base_url
Custom inference gateway URL

Returns

Configured STT instance for Deepgram

def google(*,
model_id: str = 'chirp_3',
language: str = 'en-US',
languages: list[str] | None = None,
interim_results: bool = True,
punctuate: bool = True,
location: str = 'asia-south1',
input_sample_rate: int = 48000,
output_sample_rate: int = 16000,
enable_streaming: bool = True,
base_url: str | None = None,
config: Optional[Dict] = None) ‑> STT
Expand source code
@staticmethod
def google(
    *,
    model_id: str = "chirp_3",
    language: str = "en-US",
    languages: list[str] | None = None,
    interim_results: bool = True,
    punctuate: bool = True,
    location: str = "asia-south1",
    input_sample_rate: int = 48000,
    output_sample_rate: int = 16000,
    enable_streaming: bool = True,
    base_url: str | None = None,
    config: Optional[Dict] = None,
) -> "STT":
    """
    Create an STT instance configured for Google Cloud Speech-to-Text.

    Args:
        model_id: Google STT model (default: "chirp_3"). Options: "chirp_3", "latest_long", "latest_short"
        language: Primary language code (default: "en-US")
        languages: List of languages for auto-detection (default: [language])
        interim_results: Return interim transcription results (default: True)
        punctuate: Add punctuation to transcripts (default: True)
        location: Google Cloud region (default: "asia-south1")
        input_sample_rate: Input audio sample rate (default: 48000)
        output_sample_rate: Output sample rate for processing (default: 16000)
        enable_streaming: Enable streaming mode (default: True)
        base_url: Custom inference gateway URL

    Returns:
        Configured STT instance for Google
    """
    config = {
        "model": model_id,
        "language": language,
        "languages": languages or [language],
        "input_sample_rate": input_sample_rate,
        "output_sample_rate": output_sample_rate,
        "interim_results": interim_results,
        "punctuate": punctuate,
        "location": location,
        **(config or {}),
    }
    return STT(
        provider="google",
        model_id=model_id,
        language=language,
        config=config,
        enable_streaming=enable_streaming,
        base_url=base_url,
    )

Create an STT instance configured for Google Cloud Speech-to-Text.

Args

model_id
Google STT model (default: "chirp_3"). Options: "chirp_3", "latest_long", "latest_short"
language
Primary language code (default: "en-US")
languages
List of languages for auto-detection (default: [language])
interim_results
Return interim transcription results (default: True)
punctuate
Add punctuation to transcripts (default: True)
location
Google Cloud region (default: "asia-south1")
input_sample_rate
Input audio sample rate (default: 48000)
output_sample_rate
Output sample rate for processing (default: 16000)
enable_streaming
Enable streaming mode (default: True)
base_url
Custom inference gateway URL

Returns

Configured STT instance for Google

def sarvam(*,
model_id: str = 'saarika:v2.5',
language: str = 'en-IN',
input_sample_rate: int = 48000,
output_sample_rate: int = 16000,
enable_streaming: bool = True,
base_url: str | None = None,
config: Optional[Dict] = None) ‑> STT
Expand source code
@staticmethod
def sarvam(
    *,
    model_id: str = "saarika:v2.5",
    language: str = "en-IN",
    input_sample_rate: int = 48000,
    output_sample_rate: int = 16000,
    enable_streaming: bool = True,
    base_url: str | None = None,
    config: Optional[Dict] = None,
) -> "STT":
    """
    Create an STT instance configured for Sarvam AI.

    Args:
        model_id: Sarvam model (default: "saarika:v2.5")
        language: Language code (default: "en-IN"). Supports Indian languages.
        input_sample_rate: Input audio sample rate (default: 48000)
        output_sample_rate: Output sample rate for processing (default: 16000)
        enable_streaming: Enable streaming mode (default: True)
        base_url: Custom inference gateway URL

    Returns:
        Configured STT instance for Sarvam AI
    """
    config = {
        "model": model_id,
        "language": language,
        "input_sample_rate": input_sample_rate,
        "output_sample_rate": output_sample_rate,
        **(config or {}),
    }
    return STT(
        provider="sarvam",
        model_id=model_id,
        language=language,
        config=config,
        enable_streaming=enable_streaming,
        base_url=base_url,
    )

Create an STT instance configured for Sarvam AI.

Args

model_id
Sarvam model (default: "saarika:v2.5")
language
Language code (default: "en-IN"). Supports Indian languages.
input_sample_rate
Input audio sample rate (default: 48000)
output_sample_rate
Output sample rate for processing (default: 16000)
enable_streaming
Enable streaming mode (default: True)
base_url
Custom inference gateway URL

Returns

Configured STT instance for Sarvam AI

Instance variables

prop label : str
Expand source code
@property
def label(self) -> str:
    """Get a descriptive label for this STT instance."""
    return f"videosdk.inference.STT.{self.provider}.{self.model_id}"

Get a descriptive label for this STT instance.

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Clean up all resources."""
    logger.info(f"[InferenceSTT] Closing STT (provider={self.provider})")

    # Cancel listener task
    if self._ws_task:
        self._ws_task.cancel()
        try:
            await self._ws_task
        except asyncio.CancelledError:
            pass
        self._ws_task = None

    # Close WebSocket
    await self._cleanup_connection()

    # Close HTTP session
    if self._session and not self._session.closed:
        await self._session.close()
        self._session = None

    # Call parent cleanup
    await super().aclose()
    logger.info(f"[InferenceSTT] Closed successfully")

Clean up all resources.

async def flush(self) ‑> None
Expand source code
async def flush(self) -> None:
    """Signal end-of-speech to the inference server."""
    if not self._ws or self._ws.closed:
        return
    try:
        await self._ws.send_str(json.dumps({"type": "flush"}))
    except Exception as e:
        logger.debug(f"[InferenceSTT] Flush error: {e}")

Signal end-of-speech to the inference server.

async def process_audio(self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any) ‑> None
Expand source code
async def process_audio(
    self,
    audio_frames: bytes,
    language: Optional[str] = None,
    **kwargs: Any,
) -> None:
    """
    Process audio frames and send to the inference server.

    Args:
        audio_frames: Raw PCM audio bytes (16-bit, typically 48kHz stereo)
        language: Optional language override
        **kwargs: Additional arguments (unused)
    """
    # logger.info(f"[STT DEBUG] process_audio called | bytes={len(audio_frames)}")
    if not self.enable_streaming:
        logger.warning("Non-streaming mode not yet supported for inference STT")
        return

    try:
        if self._connecting:
            return

        if not self._ws or self._ws.closed:
            self._connecting = True
            try:
                await self._connect_ws()
                if not self._ws_task or self._ws_task.done():
                    self._ws_task = asyncio.create_task(
                        self._listen_for_responses()
                    )
            finally:
                self._connecting = False

        if not self._config_sent:
            await self._send_config()

        await self._send_audio(audio_frames)
        await asyncio.sleep(0)

    except Exception as e:
        logger.error(f"[InferenceSTT] Error in process_audio: {e}")
        self.emit("error", str(e))
        await self._cleanup_connection()

Process audio frames and send to the inference server.

Args

audio_frames
Raw PCM audio bytes (16-bit, typically 48kHz stereo)
language
Optional language override
**kwargs
Additional arguments (unused)
class TTS (*,
provider: str,
model_id: str,
voice_id: str | None = None,
language: str = 'en-US',
config: Dict[str, Any] | None = None,
enable_streaming: bool = True,
sample_rate: int = 24000,
base_url: str | None = None)
Expand source code
class TTS(BaseTTS):
    def __init__(
        self,
        *,
        provider: str,
        model_id: str,
        voice_id: str | None = None,
        language: str = "en-US",
        config: Dict[str, Any] | None = None,
        enable_streaming: bool = True,
        sample_rate: int = DEFAULT_SAMPLE_RATE,
        base_url: str | None = None,
    ) -> None:
        """
        Initialize the VideoSDK Inference TTS plugin.

        Args:
            provider: TTS provider name (e.g., "google", "sarvamai", "deepgram")
            model_id: Model identifier for the provider
            voice_id: Voice identifier
            language: Language code (default: "en-US")
            config: Provider-specific configuration dictionary
            enable_streaming: Enable streaming synthesis (default: True)
            sample_rate: Audio sample rate (default: 24000)
            base_url: Custom inference gateway URL (default: production gateway)
        """
        super().__init__(sample_rate=sample_rate, num_channels=DEFAULT_CHANNELS)

        self._videosdk_token = os.getenv("VIDEOSDK_AUTH_TOKEN")
        if not self._videosdk_token:
            raise ValueError("VIDEOSDK_AUTH_TOKEN environment variable must be set")

        self.provider = provider
        self.model_id = model_id
        self.voice_id = voice_id
        self.language = language
        self.config = config or {}
        self.enable_streaming = enable_streaming
        self.base_url = base_url or VIDEOSDK_INFERENCE_URL

        # WebSocket state
        self._session: Optional[aiohttp.ClientSession] = None
        self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
        self._recv_task: Optional[asyncio.Task] = None
        self._keepalive_task: Optional[asyncio.Task] = None
        self._config_sent: bool = False
        self._connection_lock = asyncio.Lock()

        # Synthesis state
        self._interrupted: bool = False
        self._first_chunk_sent: bool = False
        self._has_error: bool = False

        self._synthesis_id: int = 0
        self._interrupted_at_id: int = -1

    # ==================== Factory Methods ====================

    @staticmethod
    def google(
        *,
        model_id="Chirp3-HD",
        voice_id="Achernar",
        language="en-US",
        speed=1.0,
        pitch=0.0,
        sample_rate=24000,
        enable_streaming=True,
        base_url=None,
        config=None,
    ) -> "TTS":
        """
        Create a TTS instance configured for Google Cloud Text-to-Speech.

        Args:
            model_id: Google TTS model (default: "Chirp3-HD")
            voice_id: Voice name (default: "Achernar")
            language: Language code (default: "en-US")
            speed: Speech speed (default: 1.0)
            pitch: Voice pitch (default: 0.0)
            sample_rate: Audio sample rate (default: 24000)
            enable_streaming: Enable streaming mode (default: True)
            base_url: Custom inference gateway URL

        Returns:
            Configured TTS instance for Google
        """
        # Build voice_name like: en-US-Chirp3-HD-Achernar
        voice_name = f"{language}-{model_id}-{voice_id}"

        config = {
            "voice_name": voice_name,
            "language_code": language,
            "speed": speed,
            "pitch": pitch,
            "sample_rate": sample_rate,
            "model_id": model_id,
            **(config or {}),
        }
        return TTS(
            provider="google",
            model_id=model_id,
            voice_id=voice_id,
            language=language,
            config=config,
            enable_streaming=enable_streaming,
            sample_rate=sample_rate,
            base_url=base_url,
        )

    @staticmethod
    def sarvam(
        *,
        model_id="bulbul:v2",
        speaker="anushka",
        language="en-IN",
        sample_rate=24000,
        enable_streaming=True,
        base_url=None,
        config=None,
    ) -> "TTS":
        """
        Create a TTS instance configured for Sarvam AI.

        Args:
            model_id: Sarvam model (default: "bulbul:v2")
            speaker: Speaker voice (default: "anushka")
            language: Language code (default: "en-IN")
            sample_rate: Audio sample rate (default: 24000)
            enable_streaming: Enable streaming mode (default: True)
            base_url: Custom inference gateway URL

        Returns:
            Configured TTS instance for Sarvam AI
        """
        config = {
            "model": model_id,
            "language": language,
            "speaker": speaker,
            "sample_rate": sample_rate,
            **(config or {}),
        }
        return TTS(
            provider="sarvamai",
            model_id=model_id,
            voice_id=speaker,
            language=language,
            config=config,
            enable_streaming=enable_streaming,
            sample_rate=sample_rate,
            base_url=base_url,
        )

    @staticmethod
    def cartesia(
        *,
        model_id="sonic-2",
        voice_id="faf0731e-dfb9-4cfc-8119-259a79b27e12",
        language="en",
        sample_rate=24000,
        enable_streaming=True,
        base_url=None,
        config=None,
    ) -> "TTS":
        """
        Create a TTS instance configured for Cartesia.

        Args:
            model_id: Cartesia model (default: "sonic-2")
            voice_id: Voice ID (string) or voice embedding (list of floats)
                     (default: "f786b574-daa5-4673-aa0c-cbe3e8534c02")
            language: Language code (default: "en")
            sample_rate: Audio sample rate (default: 24000)
            enable_streaming: Enable streaming mode (default: True)
            base_url: Custom inference gateway URL

        Returns:
            Configured TTS instance for Cartesia
        """
        config = {
            "model": model_id,
            "language": language,
            "voice": voice_id,
            "sample_rate": sample_rate,
            **(config or {}),
        }
        return TTS(
            provider="cartesia",
            model_id=model_id,
            voice_id=str(voice_id),
            language=language,
            config=config,
            enable_streaming=enable_streaming,
            sample_rate=sample_rate,
            base_url=base_url,
        )

    @staticmethod
    def deepgram(
        *,
        model_id="aura-2",
        voice_id="amalthea",
        language="en",
        encoding="linear16",
        sample_rate=24000,
        container="none",
        bit_rate=None,
        enable_streaming=True,
        base_url=None,
        config=None,
    ) -> "TTS":
        """
        Create a TTS instance configured for Deepgram Aura.

        Args:
            model_id: Deepgram Aura model (default: "aura-2")
            encoding: Audio encoding format (default: "linear16")
            sample_rate: Audio sample rate in Hz (default: 24000)
            container: Container format (default: "none" for raw audio)
            bit_rate: Bitrate in bps for compressed formats (optional)
            enable_streaming: Enable streaming mode (default: True)
            base_url: Custom inference gateway URL (optional)

        Returns:
            Configured TTS instance for Deepgram
        """
        config = {
            "model": model_id,
            "encoding": encoding,
            "sample_rate": sample_rate,
            "container": container,
            "voice_id": voice_id,
            "language": language,
            **(config or {}),
        }
        if bit_rate is not None:
            config["bit_rate"] = bit_rate
        return TTS(
            provider="deepgram",
            model_id=model_id,
            voice_id=voice_id,
            language="en",
            config=config,
            enable_streaming=enable_streaming,
            sample_rate=sample_rate,
            base_url=base_url,
        )

    # ==================== Core ====================

    def reset_first_audio_tracking(self) -> None:
        self._first_chunk_sent = False

    async def warmup(self) -> None:
        """
        Pre-warm the WebSocket connection before the first synthesis request.
        Call this right after session start to eliminate cold-start latency
        (~3-4s) on the first user turn.
        """
        logger.info(f"[InferenceTTS] Warming up connection (provider={self.provider})")
        try:
            await self._ensure_connection()
            logger.info("[InferenceTTS] Warmup complete — connection ready")
        except Exception as e:
            logger.warning(f"[InferenceTTS] Warmup failed (non-fatal): {e}")

    async def synthesize(
        self, text: AsyncIterator[str] | str, voice_id=None, **kwargs
    ) -> None:
        """
        Synthesize speech from text.

        Args:
            text: Text to synthesize (string or async iterator of strings)
            voice_id: Optional voice override
            **kwargs: Additional arguments
        """
        if not self.audio_track or not self.loop:
            logger.error("[InferenceTTS] Audio track or event loop not initialized")
            return

        self._synthesis_id += 1
        current_id = self._synthesis_id

        self._interrupted = False
        self.reset_first_audio_tracking()
        logger.debug(f"[InferenceTTS] New synthesis started, id={current_id}")

        if isinstance(text, str):
            if not text.strip():
                logger.debug("[InferenceTTS] Skipping synthesis — empty text")
                return
            if not _has_enough_content(text, self.provider):
                logger.warning(
                    f"[InferenceTTS] Skipping — text too short for {self.provider}: '{text}'"
                )
                return

        text_for_retry = text if isinstance(text, str) else None

        for attempt in range(2):
            # Abort if a newer synthesis has already started
            if self._synthesis_id != current_id:
                logger.debug("[InferenceTTS] Synthesis superseded — aborting")
                return

            try:
                await self._ensure_connection()

                if self._synthesis_id != current_id:
                    logger.debug(
                        "[InferenceTTS] Synthesis superseded after connect — aborting"
                    )
                    return

                if isinstance(text, str):
                    await self._send_text(text, current_id)
                else:
                    await self._send_text_stream(text, current_id)
                return

            except ConnectionError as e:
                if attempt == 0 and text_for_retry is not None:
                    logger.warning(
                        f"[InferenceTTS] Connection lost mid-synthesis, retrying... ({e})"
                    )
                    self._has_error = True
                    await asyncio.sleep(0.05)
                    continue
                logger.error(f"[InferenceTTS] Synthesis failed after retry: {e}")
                self.emit("error", str(e))
                return

            except Exception as e:
                logger.error(f"[InferenceTTS] Synthesis error: {e}")
                self.emit("error", str(e))
                return

    # ==================== Connection ====================

    def _is_connection_alive(self) -> bool:
        return (
            self._ws is not None
            and not self._ws.closed
            and self._recv_task is not None
            and not self._recv_task.done()
            and not self._has_error
        )

    async def _ensure_connection(self) -> None:
        """Ensure WebSocket connection is established."""
        async with self._connection_lock:
            if self._is_connection_alive():
                logger.info("[InferenceTTS] Connection alive, reusing")
                return
            logger.info("[InferenceTTS] Connection not alive — reconnecting...")
            await self._teardown_connection()
            await self._connect_ws()
            self._recv_task = asyncio.create_task(self._recv_loop())
            await self._send_config()
            # FIX 1: Reduced from 100ms to 10ms — just enough for the server
            # to process the config frame before the first text arrives.
            await asyncio.sleep(0.01)
            self._has_error = False

    async def _connect_ws(self) -> None:
        """Establish WebSocket connection to the inference gateway."""
        if not self._session:
            self._session = aiohttp.ClientSession()

        ws_url = (
            f"{self.base_url}/v1/tts"
            f"?provider={self.provider}"
            f"&secret={self._videosdk_token}"
            f"&modelId={self.model_id}"
        )
        if self.voice_id:
            ws_url += f"&voiceId={self.voice_id}"

        logger.info(
            f"[InferenceTTS] Connecting to {self.base_url} (provider={self.provider})"
        )
        self._ws = await self._session.ws_connect(ws_url, heartbeat=20)
        self._config_sent = False
        logger.info("[InferenceTTS] Connected successfully")

        if self._keepalive_task and not self._keepalive_task.done():
            self._keepalive_task.cancel()
        self._keepalive_task = asyncio.create_task(self._keepalive())

    async def _keepalive(self) -> None:
        """Ping every 15s — Sarvam closes idle connections at ~60s."""
        while True:
            await asyncio.sleep(15)
            if self._ws and not self._ws.closed:
                try:
                    await self._ws.ping()
                except Exception:
                    logger.warning("[InferenceTTS] Keepalive ping failed")
                    break
            else:
                break

    async def _teardown_connection(self) -> None:
        if self._keepalive_task and not self._keepalive_task.done():
            self._keepalive_task.cancel()
            try:
                await self._keepalive_task
            except asyncio.CancelledError:
                pass
            self._keepalive_task = None

        recv_task = self._recv_task
        self._recv_task = None
        if recv_task and not recv_task.done():
            recv_task.cancel()
            try:
                await recv_task
            except asyncio.CancelledError:
                pass

        ws = self._ws
        self._ws = None
        self._config_sent = False
        if ws and not ws.closed:
            try:
                await ws.close()
            except Exception:
                pass

    # ==================== Send ====================

    async def _send_config(self) -> None:
        if not self._ws or self._ws.closed:
            raise ConnectionError("WebSocket not connected")
        await self._ws.send_str(json.dumps({"type": "config", "data": self.config}))
        self._config_sent = True
        logger.info(f"[InferenceTTS] Config sent: {self.config}")

    async def _send_text(self, text: str, synthesis_id: int) -> None:
        text = text.strip()
        if not text:
            return
        # Guard: don't send if superseded or interrupted
        if self._synthesis_id != synthesis_id or self._interrupted:
            return
        if not _has_enough_content(text, self.provider):
            logger.warning(
                f"[InferenceTTS] Dropping short chunk for {self.provider}: '{text}'"
            )
            return
        if not self._ws or self._ws.closed:
            raise ConnectionError("WebSocket closed before text could be sent")
        await self._ws.send_str(json.dumps({"type": "text", "data": text}))
        await self._ws.send_str(json.dumps({"type": "flush"}))
        logger.debug(f"[InferenceTTS] Sent text + flush: '{text[:80]}'")

    async def _send_text_stream(
        self, text_iterator: AsyncIterator[str], synthesis_id: int
    ) -> None:
        buffer = []
        # FIX 2: Reduced from 3 words / 300ms to 2 words / 100ms.
        # This gets the first chunk to TTS ~200ms sooner on every streamed turn.
        MIN_WORDS = 2
        MAX_DELAY = 0.1
        last_send_time = asyncio.get_event_loop().time()
        first_chunk_sent = False

        try:
            async for chunk in text_iterator:
                # Stop immediately if interrupted or superseded
                if self._interrupted or self._synthesis_id != synthesis_id:
                    logger.debug(
                        "[InferenceTTS] Stream aborted — interrupted or superseded"
                    )
                    return

                if not chunk or not chunk.strip():
                    continue
                buffer.extend(chunk.split())
                now = asyncio.get_event_loop().time()

                combined = " ".join(buffer).strip()

                if not first_chunk_sent:
                    has_content = _has_enough_content(combined, self.provider)
                    should_send = has_content and (
                        len(buffer) >= MIN_WORDS or (now - last_send_time > MAX_DELAY)
                    )
                else:
                    should_send = len(buffer) >= MIN_WORDS or (
                        now - last_send_time > MAX_DELAY
                    )

                if should_send and combined:
                    await self._send_text(combined, synthesis_id)
                    first_chunk_sent = True
                    buffer.clear()
                    last_send_time = now

            # Flush remaining buffer only if still active
            if buffer and not self._interrupted and self._synthesis_id == synthesis_id:
                combined = " ".join(buffer).strip()
                if combined:
                    await self._send_text(combined, synthesis_id)
            # FIX 3: Removed the redundant trailing flush here.
            # _send_text already sends a flush after every chunk, so this
            # was causing a duplicate flush on the final chunk every time.

        except Exception as e:
            logger.error(f"[InferenceTTS] Stream send error: {e}")
            raise

    # ==================== Receive ====================

    async def _recv_loop(self) -> None:
        logger.debug("[InferenceTTS] Receive loop started")
        try:
            while self._ws and not self._ws.closed:
                msg = await self._ws.receive()
                if msg.type == aiohttp.WSMsgType.TEXT:
                    await self._handle_message(msg.data)
                elif msg.type == aiohttp.WSMsgType.BINARY:
                    if (
                        not self._interrupted
                        and self._synthesis_id > self._interrupted_at_id
                        and self.audio_track
                    ):
                        await self.audio_track.add_new_bytes(msg.data)
                    else:
                        logger.debug(
                            "[InferenceTTS] Discarding stale binary audio chunk"
                        )
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    logger.error(
                        f"[InferenceTTS] WebSocket error: {self._ws.exception()}"
                    )
                    break
                elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
                    logger.info("[InferenceTTS] WebSocket closed by server")
                    break
        except asyncio.CancelledError:
            pass
        except Exception as e:
            logger.error(f"[InferenceTTS] Receive loop error: {e}")
        finally:
            logger.info("[InferenceTTS] Receive loop terminated")
            self._ws = None
            self._config_sent = False

    async def _handle_message(self, raw_message: str) -> None:
        """Handle incoming messages from the inference server."""
        try:
            if not raw_message or not raw_message.strip():
                return

            try:
                data = json.loads(raw_message)
            except json.JSONDecodeError:
                logger.debug(
                    f"[InferenceTTS] Received non-JSON message: {raw_message[:100]}"
                )
                if "success" in raw_message.lower() or "ok" in raw_message.lower():
                    logger.debug("[InferenceTTS] Received acknowledgment")
                    return
                logger.warning(
                    f"[InferenceTTS] Unexpected non-JSON message: {raw_message[:200]}"
                )
                return

            msg_type = data.get("type")

            if msg_type == "audio":
                await self._handle_audio(data.get("data", {}))

            elif msg_type == "event":
                if data.get("data", {}).get("eventType") == "TTS_COMPLETE":
                    logger.debug("[InferenceTTS] Synthesis completed")

            elif msg_type == "error":
                error_msg = data.get("data", {}).get("error") or data.get(
                    "message", "Unknown error"
                )
                logger.error(f"[InferenceTTS] Server error: {error_msg}")
                self.emit("error", error_msg)
                logger.warning(
                    "[InferenceTTS] Forcing full reconnect due to provider error"
                )
                self._has_error = True
                if self._ws and not self._ws.closed:
                    try:
                        await self._ws.close()
                    except Exception:
                        pass

        except json.JSONDecodeError:
            pass
        except Exception as e:
            logger.error(f"[InferenceTTS] Message handling error: {e}")

    async def _handle_audio(self, audio_data: Dict[str, Any]) -> None:
        if self._interrupted or self._synthesis_id <= self._interrupted_at_id:
            logger.debug("[InferenceTTS] Discarding stale/interrupted audio")
            return
        if not audio_data:
            return
        audio_b64 = audio_data.get("audio")
        if not audio_b64:
            return
        try:
            audio_bytes = base64.b64decode(audio_b64)
            audio_bytes = self._remove_wav_header(audio_bytes)
            if not self._first_chunk_sent and self._first_audio_callback:
                self._first_chunk_sent = True
                await self._first_audio_callback()
            if self.audio_track and not self._interrupted:
                await self.audio_track.add_new_bytes(audio_bytes)
        except Exception as e:
            logger.error(f"[InferenceTTS] Audio processing error: {e}")

    def _remove_wav_header(self, audio_bytes: bytes) -> bytes:
        """Remove WAV header if present."""
        if audio_bytes.startswith(b"RIFF"):
            data_pos = audio_bytes.find(b"data")
            if data_pos != -1:
                return audio_bytes[data_pos + 8 :]
        return audio_bytes

    # ==================== Control ====================

    async def interrupt(self) -> None:
        """Interrupt ongoing synthesis."""
        self._interrupted = True
        self._interrupted_at_id = self._synthesis_id
        logger.debug(
            f"[InferenceTTS] Stamped interrupted_at_id={self._interrupted_at_id}"
        )

        if self.audio_track:
            self.audio_track.interrupt()

    async def aclose(self) -> None:
        logger.info(f"[InferenceTTS] Closing TTS (provider={self.provider})")
        self._interrupted = True
        await self._teardown_connection()
        if self._session and not self._session.closed:
            await self._session.close()
            self._session = None
        await super().aclose()
        logger.info("[InferenceTTS] Closed successfully")

    @property
    def label(self) -> str:
        return f"videosdk.inference.TTS.{self.provider}.{self.model_id}"

Base class for Text-to-Speech implementations

Initialize the VideoSDK Inference TTS plugin.

Args

provider
TTS provider name (e.g., "google", "sarvamai", "deepgram")
model_id
Model identifier for the provider
voice_id
Voice identifier
language
Language code (default: "en-US")
config
Provider-specific configuration dictionary
enable_streaming
Enable streaming synthesis (default: True)
sample_rate
Audio sample rate (default: 24000)
base_url
Custom inference gateway URL (default: production gateway)

Ancestors

  • videosdk.agents.tts.tts.TTS
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Static methods

def cartesia(*,
model_id='sonic-2',
voice_id='faf0731e-dfb9-4cfc-8119-259a79b27e12',
language='en',
sample_rate=24000,
enable_streaming=True,
base_url=None,
config=None) ‑> TTS
Expand source code
@staticmethod
def cartesia(
    *,
    model_id="sonic-2",
    voice_id="faf0731e-dfb9-4cfc-8119-259a79b27e12",
    language="en",
    sample_rate=24000,
    enable_streaming=True,
    base_url=None,
    config=None,
) -> "TTS":
    """
    Create a TTS instance configured for Cartesia.

    Args:
        model_id: Cartesia model (default: "sonic-2")
        voice_id: Voice ID (string) or voice embedding (list of floats)
                 (default: "f786b574-daa5-4673-aa0c-cbe3e8534c02")
        language: Language code (default: "en")
        sample_rate: Audio sample rate (default: 24000)
        enable_streaming: Enable streaming mode (default: True)
        base_url: Custom inference gateway URL

    Returns:
        Configured TTS instance for Cartesia
    """
    config = {
        "model": model_id,
        "language": language,
        "voice": voice_id,
        "sample_rate": sample_rate,
        **(config or {}),
    }
    return TTS(
        provider="cartesia",
        model_id=model_id,
        voice_id=str(voice_id),
        language=language,
        config=config,
        enable_streaming=enable_streaming,
        sample_rate=sample_rate,
        base_url=base_url,
    )

Create a TTS instance configured for Cartesia.

Args

model_id
Cartesia model (default: "sonic-2")
voice_id
Voice ID (string) or voice embedding (list of floats) (default: "f786b574-daa5-4673-aa0c-cbe3e8534c02")
language
Language code (default: "en")
sample_rate
Audio sample rate (default: 24000)
enable_streaming
Enable streaming mode (default: True)
base_url
Custom inference gateway URL

Returns

Configured TTS instance for Cartesia

def deepgram(*,
model_id='aura-2',
voice_id='amalthea',
language='en',
encoding='linear16',
sample_rate=24000,
container='none',
bit_rate=None,
enable_streaming=True,
base_url=None,
config=None) ‑> TTS
Expand source code
@staticmethod
def deepgram(
    *,
    model_id="aura-2",
    voice_id="amalthea",
    language="en",
    encoding="linear16",
    sample_rate=24000,
    container="none",
    bit_rate=None,
    enable_streaming=True,
    base_url=None,
    config=None,
) -> "TTS":
    """
    Create a TTS instance configured for Deepgram Aura.

    Args:
        model_id: Deepgram Aura model (default: "aura-2")
        encoding: Audio encoding format (default: "linear16")
        sample_rate: Audio sample rate in Hz (default: 24000)
        container: Container format (default: "none" for raw audio)
        bit_rate: Bitrate in bps for compressed formats (optional)
        enable_streaming: Enable streaming mode (default: True)
        base_url: Custom inference gateway URL (optional)

    Returns:
        Configured TTS instance for Deepgram
    """
    config = {
        "model": model_id,
        "encoding": encoding,
        "sample_rate": sample_rate,
        "container": container,
        "voice_id": voice_id,
        "language": language,
        **(config or {}),
    }
    if bit_rate is not None:
        config["bit_rate"] = bit_rate
    return TTS(
        provider="deepgram",
        model_id=model_id,
        voice_id=voice_id,
        language="en",
        config=config,
        enable_streaming=enable_streaming,
        sample_rate=sample_rate,
        base_url=base_url,
    )

Create a TTS instance configured for Deepgram Aura.

Args

model_id
Deepgram Aura model (default: "aura-2")
encoding
Audio encoding format (default: "linear16")
sample_rate
Audio sample rate in Hz (default: 24000)
container
Container format (default: "none" for raw audio)
bit_rate
Bitrate in bps for compressed formats (optional)
enable_streaming
Enable streaming mode (default: True)
base_url
Custom inference gateway URL (optional)

Returns

Configured TTS instance for Deepgram

def google(*,
model_id='Chirp3-HD',
voice_id='Achernar',
language='en-US',
speed=1.0,
pitch=0.0,
sample_rate=24000,
enable_streaming=True,
base_url=None,
config=None) ‑> TTS
Expand source code
@staticmethod
def google(
    *,
    model_id="Chirp3-HD",
    voice_id="Achernar",
    language="en-US",
    speed=1.0,
    pitch=0.0,
    sample_rate=24000,
    enable_streaming=True,
    base_url=None,
    config=None,
) -> "TTS":
    """
    Create a TTS instance configured for Google Cloud Text-to-Speech.

    Args:
        model_id: Google TTS model (default: "Chirp3-HD")
        voice_id: Voice name (default: "Achernar")
        language: Language code (default: "en-US")
        speed: Speech speed (default: 1.0)
        pitch: Voice pitch (default: 0.0)
        sample_rate: Audio sample rate (default: 24000)
        enable_streaming: Enable streaming mode (default: True)
        base_url: Custom inference gateway URL

    Returns:
        Configured TTS instance for Google
    """
    # Build voice_name like: en-US-Chirp3-HD-Achernar
    voice_name = f"{language}-{model_id}-{voice_id}"

    config = {
        "voice_name": voice_name,
        "language_code": language,
        "speed": speed,
        "pitch": pitch,
        "sample_rate": sample_rate,
        "model_id": model_id,
        **(config or {}),
    }
    return TTS(
        provider="google",
        model_id=model_id,
        voice_id=voice_id,
        language=language,
        config=config,
        enable_streaming=enable_streaming,
        sample_rate=sample_rate,
        base_url=base_url,
    )

Create a TTS instance configured for Google Cloud Text-to-Speech.

Args

model_id
Google TTS model (default: "Chirp3-HD")
voice_id
Voice name (default: "Achernar")
language
Language code (default: "en-US")
speed
Speech speed (default: 1.0)
pitch
Voice pitch (default: 0.0)
sample_rate
Audio sample rate (default: 24000)
enable_streaming
Enable streaming mode (default: True)
base_url
Custom inference gateway URL

Returns

Configured TTS instance for Google

def sarvam(*,
model_id='bulbul:v2',
speaker='anushka',
language='en-IN',
sample_rate=24000,
enable_streaming=True,
base_url=None,
config=None) ‑> TTS
Expand source code
@staticmethod
def sarvam(
    *,
    model_id="bulbul:v2",
    speaker="anushka",
    language="en-IN",
    sample_rate=24000,
    enable_streaming=True,
    base_url=None,
    config=None,
) -> "TTS":
    """
    Create a TTS instance configured for Sarvam AI.

    Args:
        model_id: Sarvam model (default: "bulbul:v2")
        speaker: Speaker voice (default: "anushka")
        language: Language code (default: "en-IN")
        sample_rate: Audio sample rate (default: 24000)
        enable_streaming: Enable streaming mode (default: True)
        base_url: Custom inference gateway URL

    Returns:
        Configured TTS instance for Sarvam AI
    """
    config = {
        "model": model_id,
        "language": language,
        "speaker": speaker,
        "sample_rate": sample_rate,
        **(config or {}),
    }
    return TTS(
        provider="sarvamai",
        model_id=model_id,
        voice_id=speaker,
        language=language,
        config=config,
        enable_streaming=enable_streaming,
        sample_rate=sample_rate,
        base_url=base_url,
    )

Create a TTS instance configured for Sarvam AI.

Args

model_id
Sarvam model (default: "bulbul:v2")
speaker
Speaker voice (default: "anushka")
language
Language code (default: "en-IN")
sample_rate
Audio sample rate (default: 24000)
enable_streaming
Enable streaming mode (default: True)
base_url
Custom inference gateway URL

Returns

Configured TTS instance for Sarvam AI

Instance variables

prop label : str
Expand source code
@property
def label(self) -> str:
    return f"videosdk.inference.TTS.{self.provider}.{self.model_id}"

Get the TTS provider label

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    logger.info(f"[InferenceTTS] Closing TTS (provider={self.provider})")
    self._interrupted = True
    await self._teardown_connection()
    if self._session and not self._session.closed:
        await self._session.close()
        self._session = None
    await super().aclose()
    logger.info("[InferenceTTS] Closed successfully")

Cleanup resources

async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    """Interrupt ongoing synthesis."""
    self._interrupted = True
    self._interrupted_at_id = self._synthesis_id
    logger.debug(
        f"[InferenceTTS] Stamped interrupted_at_id={self._interrupted_at_id}"
    )

    if self.audio_track:
        self.audio_track.interrupt()

Interrupt ongoing synthesis.

def reset_first_audio_tracking(self) ‑> None
Expand source code
def reset_first_audio_tracking(self) -> None:
    self._first_chunk_sent = False

Reset the first audio tracking state for next TTS task

async def synthesize(self, text: AsyncIterator[str] | str, voice_id=None, **kwargs) ‑> None
Expand source code
async def synthesize(
    self, text: AsyncIterator[str] | str, voice_id=None, **kwargs
) -> None:
    """
    Synthesize speech from text.

    Args:
        text: Text to synthesize (string or async iterator of strings)
        voice_id: Optional voice override
        **kwargs: Additional arguments
    """
    if not self.audio_track or not self.loop:
        logger.error("[InferenceTTS] Audio track or event loop not initialized")
        return

    self._synthesis_id += 1
    current_id = self._synthesis_id

    self._interrupted = False
    self.reset_first_audio_tracking()
    logger.debug(f"[InferenceTTS] New synthesis started, id={current_id}")

    if isinstance(text, str):
        if not text.strip():
            logger.debug("[InferenceTTS] Skipping synthesis — empty text")
            return
        if not _has_enough_content(text, self.provider):
            logger.warning(
                f"[InferenceTTS] Skipping — text too short for {self.provider}: '{text}'"
            )
            return

    text_for_retry = text if isinstance(text, str) else None

    for attempt in range(2):
        # Abort if a newer synthesis has already started
        if self._synthesis_id != current_id:
            logger.debug("[InferenceTTS] Synthesis superseded — aborting")
            return

        try:
            await self._ensure_connection()

            if self._synthesis_id != current_id:
                logger.debug(
                    "[InferenceTTS] Synthesis superseded after connect — aborting"
                )
                return

            if isinstance(text, str):
                await self._send_text(text, current_id)
            else:
                await self._send_text_stream(text, current_id)
            return

        except ConnectionError as e:
            if attempt == 0 and text_for_retry is not None:
                logger.warning(
                    f"[InferenceTTS] Connection lost mid-synthesis, retrying... ({e})"
                )
                self._has_error = True
                await asyncio.sleep(0.05)
                continue
            logger.error(f"[InferenceTTS] Synthesis failed after retry: {e}")
            self.emit("error", str(e))
            return

        except Exception as e:
            logger.error(f"[InferenceTTS] Synthesis error: {e}")
            self.emit("error", str(e))
            return

Synthesize speech from text.

Args

text
Text to synthesize (string or async iterator of strings)
voice_id
Optional voice override
**kwargs
Additional arguments
async def warmup(self) ‑> None
Expand source code
async def warmup(self) -> None:
    """
    Pre-warm the WebSocket connection before the first synthesis request.
    Call this right after session start to eliminate cold-start latency
    (~3-4s) on the first user turn.
    """
    logger.info(f"[InferenceTTS] Warming up connection (provider={self.provider})")
    try:
        await self._ensure_connection()
        logger.info("[InferenceTTS] Warmup complete — connection ready")
    except Exception as e:
        logger.warning(f"[InferenceTTS] Warmup failed (non-fatal): {e}")

Pre-warm the WebSocket connection before the first synthesis request. Call this right after session start to eliminate cold-start latency (~3-4s) on the first user turn.