Package videosdk.plugins.openai

Sub-modules

videosdk.plugins.openai.llm
videosdk.plugins.openai.realtime_api
videosdk.plugins.openai.stt
videosdk.plugins.openai.tts

Classes

class OpenAILLM (*,
api_key: str | None = None,
model: str = 'gpt-4o-mini',
base_url: str | None = None,
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
organization: str | None = None,
project: str | None = None,
parallel_tool_calls: bool | None = None,
timeout: httpx.Timeout | None = None,
extra_headers: dict | None = None,
extra_query: dict | None = None,
extra_body: dict | None = None,
client: openai.AsyncOpenAI | None = None,
max_retries: int = 0,
reasoning_effort: "Literal['none', 'low', 'medium', 'high'] | None" = None,
verbosity: "Literal['low', 'medium', 'high'] | None" = None,
streaming: bool = False,
store: bool = False,
wss_url: str | None = None)
Expand source code
class OpenAILLM(LLM):

    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = "gpt-4o-mini",
        base_url: str | None = None,
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_completion_tokens: int | None = None,
        top_p: float | None = None,
        frequency_penalty: float | None = None,
        presence_penalty: float | None = None,
        seed: int | None = None,
        organization: str | None = None,
        project: str | None = None,
        parallel_tool_calls: bool | None = None,
        timeout: httpx.Timeout | None = None,
        extra_headers: dict | None = None,
        extra_query: dict | None = None,
        extra_body: dict | None = None,
        client: openai.AsyncOpenAI | None = None,
        max_retries: int = 0,
        reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
        verbosity: Literal["low", "medium", "high"] | None = None,
        streaming: bool = False,
        store: bool = False,
        wss_url: str | None = None,
    ) -> None:
        """Initialize the OpenAI LLM plugin.

        Args:
            api_key: OpenAI API key. Falls back to OPENAI_API_KEY env var.
            model: Chat model name. Defaults to "gpt-4o-mini".
            base_url: Override the default OpenAI API base URL.
            temperature: Sampling temperature. Defaults to 0.7.
            tool_choice: Controls which (if any) tool is called. Defaults to "auto".
            max_completion_tokens: Maximum tokens in the completion.
            top_p: Nucleus sampling probability mass.
            frequency_penalty: Penalise repeated tokens by frequency.
            presence_penalty: Penalise tokens that have already appeared.
            seed: Seed for deterministic sampling.
            organization: OpenAI organisation ID.
            project: OpenAI project ID.
            parallel_tool_calls: Allow the model to call multiple tools in one turn.
            timeout: Custom httpx.Timeout for the underlying HTTP client.
            extra_headers: Additional HTTP headers forwarded to every API call.
            extra_query: Additional query-string parameters forwarded to every API call.
            extra_body: Additional JSON body fields forwarded to every API call.
            client: Optional pre-built ``openai.AsyncOpenAI`` instance to use instead of
                creating a new one. Useful for sharing a client across instances or for
                testing. When provided, ``api_key``, ``base_url``, ``organization``,
                ``project``, ``timeout``, and ``max_retries`` are ignored.
            max_retries: Number of automatic retries on transient errors. Defaults to 0.
            reasoning_effort: Controls reasoning depth for reasoning models.
                Supported values: "none", "low", "medium", "high". Defaults to None
                (uses the model's default). Only applied for reasoning / GPT-5 models.
            verbosity: Controls output verbosity for reasoning / GPT-5 models.
                Supported values: "low", "medium", "high". Defaults to None.
            streaming: When True, use OpenAI's WebSocket Responses API
                (``wss://api.openai.com/v1/responses``) instead of the standard HTTP
                chat completions endpoint. The connection is reused across turns and
                continues with ``previous_response_id`` for lower per-turn latency.
                Defaults to False (HTTP mode).
            store: Only used when ``streaming=True``. Controls whether responses are
                persisted server-side. Defaults to False (ZDR-friendly). With
                ``store=False`` and an unrecoverable cache miss the connection
                resends the full context.
            wss_url: Override the WebSocket Responses URL. Defaults to OpenAI's
                public endpoint.
        """
        super().__init__()

        self.model = model
        self.temperature = temperature
        self.tool_choice = tool_choice
        self.max_completion_tokens = max_completion_tokens
        self.top_p = top_p
        self.frequency_penalty = frequency_penalty
        self.presence_penalty = presence_penalty
        self.seed = seed
        self.parallel_tool_calls = parallel_tool_calls
        self.extra_headers = extra_headers
        self.extra_query = extra_query
        self.extra_body = extra_body
        self.reasoning_effort = reasoning_effort
        self.verbosity = verbosity
        self._cancelled = False

        self.streaming = streaming
        self.store = store
        self._wss_url = wss_url or OPENAI_RESPONSES_WSS_URL

        # Always remember the API key for WSS use (even if a client was passed in).
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")

        # WSS state — created lazily on first use.
        self._ws_session: Optional[aiohttp.ClientSession] = None
        self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
        self._ws_lock = asyncio.Lock()
        self._drain_until_response_created: bool = False
        self._previous_response_id: Optional[str] = None
        self._last_seen_items_count: int = 0
        self._prewarm_task: Optional[asyncio.Task] = None

        self._owns_client = client is None
        if client is not None:
            self._client = client
        else:
            if not self.api_key:
                raise ValueError(
                    "OpenAI API key must be provided either through api_key parameter "
                    "or OPENAI_API_KEY environment variable"
                )
            _timeout = timeout or httpx.Timeout(connect=15.0, read=10.0, write=5.0, pool=5.0)
            self._client = openai.AsyncOpenAI(
                api_key=self.api_key,
                base_url=base_url or None,
                organization=organization or os.getenv("OPENAI_ORG_ID"),
                project=project or os.getenv("OPENAI_PROJECT_ID"),
                max_retries=max_retries,
                http_client=httpx.AsyncClient(
                    timeout=_timeout,
                    follow_redirects=True,
                    limits=httpx.Limits(
                        max_connections=50,
                        max_keepalive_connections=50,
                        keepalive_expiry=120,
                    ),
                ),
            )

        if self.streaming and not self.api_key:
            raise ValueError(
                "streaming=True requires an OpenAI API key (api_key parameter or "
                "OPENAI_API_KEY env var). The pre-built client cannot be introspected for it."
            )

        # Eagerly open the WSS connection if a loop is already running so the
        # first chat() call doesn't pay the TLS + WS handshake cost.
        if self.streaming:
            try:
                loop = asyncio.get_running_loop()
            except RuntimeError:
                loop = None
            if loop is not None:
                self._prewarm_task = loop.create_task(self._prewarm_safely())

    async def _prewarm_safely(self) -> None:
        try:
            await self._ensure_ws()
        except Exception as e:
            logger.warning("OpenAI WSS prewarm failed (will retry on first chat): %s", e)

    async def prewarm(
        self,
        *,
        instructions: str | None = None,
        tools: list[FunctionTool] | None = None,
    ) -> None:
        """Eagerly establish the WSS connection (and optionally prime request state).

        Call this once after constructing ``OpenAILLM(streaming=True)`` to avoid
        paying the TLS + WebSocket handshake on the first ``chat()`` call. When
        ``instructions`` and/or ``tools`` are provided, also sends a warmup
        ``response.create`` with ``generate: false`` so the server pre-builds
        request state for the first real turn (per OpenAI's WSS docs). The
        returned response id is used as ``previous_response_id`` on the first
        real turn for further latency reduction.

        No-op when ``streaming=False``.
        """
        if not self.streaming:
            return

        if self._prewarm_task is not None and not self._prewarm_task.done():
            try:
                await self._prewarm_task
            except Exception:
                pass

        ws = await self._ensure_ws()

        if instructions is None and not tools:
            return

        warmup_input: list[dict] = []
        if instructions:
            warmup_input.append(
                {
                    "type": "message",
                    "role": "system",
                    "content": [{"type": "input_text", "text": instructions}],
                }
            )

        payload = self._build_responses_payload(
            input_items=warmup_input,
            previous_response_id=None,
            tools=tools,
            conversational_graph=None,
            extra={"generate": False},
        )

        try:
            await ws.send_str(json.dumps(payload))
        except Exception as e:
            logger.warning("OpenAI WSS warmup send failed: %s", e)
            return

        try:
            async for msg in ws:
                if msg.type != aiohttp.WSMsgType.TEXT:
                    if msg.type in (
                        aiohttp.WSMsgType.CLOSED,
                        aiohttp.WSMsgType.CLOSE,
                        aiohttp.WSMsgType.CLOSING,
                        aiohttp.WSMsgType.ERROR,
                    ):
                        break
                    continue
                try:
                    event = json.loads(msg.data)
                except json.JSONDecodeError:
                    continue
                etype = event.get("type")
                if etype == "response.created":
                    self._previous_response_id = (
                        (event.get("response") or {}).get("id")
                    )
                elif etype == "response.completed":
                    resp = event.get("response") or {}
                    self._previous_response_id = (
                        resp.get("id") or self._previous_response_id
                    )
                    break
                elif etype == "error":
                    err = event.get("error") or {}
                    logger.warning("OpenAI WSS warmup error: %s", err.get("message") or err)
                    self._previous_response_id = None
                    break
        except Exception as e:
            logger.warning("OpenAI WSS warmup read failed: %s", e)
            self._previous_response_id = None

    def _is_reasoning_model(self) -> bool:
        """Return True if the configured model is a reasoning / GPT-5 family model
        that requires special parameter handling."""
        model_lower = self.model.lower()
        if model_lower.startswith(("o1", "o3", "o4")):
            return True
        if model_lower.startswith("gpt-5"):
            return True
        return False

    @staticmethod
    def azure(
        *,
        model: str = "gpt-4o-mini",
        azure_endpoint: str | None = None,
        azure_deployment: str | None = None,
        api_version: str | None = None,
        api_key: str | None = None,
        azure_ad_token: str | None = None,
        organization: str | None = None,
        project: str | None = None,
        base_url: str | None = None,
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_completion_tokens: int | None = None,
        top_p: float | None = None,
        frequency_penalty: float | None = None,
        presence_penalty: float | None = None,
        seed: int | None = None,
        parallel_tool_calls: bool | None = None,
        timeout: httpx.Timeout | None = None,
        extra_headers: dict | None = None,
        extra_query: dict | None = None,
        extra_body: dict | None = None,
        client: openai.AsyncAzureOpenAI | None = None,
        max_retries: int = 0,
        reasoning_effort: Literal["none", "low", "medium", "high"] | None = "none",
        verbosity: Literal["low", "medium", "high"] | None = "low",
    ) -> "OpenAILLM":
        """
        Create a new instance of Azure OpenAI LLM.

        Automatically infers the following from environment variables when not provided:
        - ``api_key`` from ``AZURE_OPENAI_API_KEY``
        - ``organization`` from ``OPENAI_ORG_ID``
        - ``project`` from ``OPENAI_PROJECT_ID``
        - ``azure_ad_token`` from ``AZURE_OPENAI_AD_TOKEN``
        - ``api_version`` from ``OPENAI_API_VERSION``
        - ``azure_endpoint`` from ``AZURE_OPENAI_ENDPOINT``
        - ``azure_deployment`` from ``AZURE_OPENAI_DEPLOYMENT`` (falls back to ``model``)

        Pass ``client`` to supply a pre-built ``openai.AsyncAzureOpenAI`` instance.
        When ``client`` is provided, connection/credential params are ignored.
        """
        if client is not None:
            instance = OpenAILLM(
                model=model,
                temperature=temperature,
                tool_choice=tool_choice,
                max_completion_tokens=max_completion_tokens,
                top_p=top_p,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                seed=seed,
                parallel_tool_calls=parallel_tool_calls,
                extra_headers=extra_headers,
                extra_query=extra_query,
                extra_body=extra_body,
                client=client,
                reasoning_effort=reasoning_effort,
                verbosity=verbosity,
            )
            return instance

        azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
        azure_deployment = azure_deployment or os.getenv("AZURE_OPENAI_DEPLOYMENT")
        api_version = api_version or os.getenv("OPENAI_API_VERSION")
        api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
        azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN")
        organization = organization or os.getenv("OPENAI_ORG_ID")
        project = project or os.getenv("OPENAI_PROJECT_ID")

        if not azure_deployment:
            azure_deployment = model

        if not azure_endpoint:
            raise ValueError(
                "Azure endpoint must be provided either through azure_endpoint parameter "
                "or AZURE_OPENAI_ENDPOINT environment variable"
            )

        if not api_key and not azure_ad_token:
            raise ValueError("Either API key or Azure AD token must be provided")

        _timeout = timeout or httpx.Timeout(connect=15.0, read=10.0, write=5.0, pool=5.0)
        azure_client = openai.AsyncAzureOpenAI(
            max_retries=max_retries,
            azure_endpoint=azure_endpoint,
            azure_deployment=azure_deployment,
            api_version=api_version,
            api_key=api_key,
            azure_ad_token=azure_ad_token,
            organization=organization,
            project=project,
            base_url=base_url,
            timeout=_timeout,
        )

        instance = OpenAILLM(
            model=model,
            temperature=temperature,
            tool_choice=tool_choice,
            max_completion_tokens=max_completion_tokens,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            seed=seed,
            parallel_tool_calls=parallel_tool_calls,
            extra_headers=extra_headers,
            extra_query=extra_query,
            extra_body=extra_body,
            client=azure_client,
            reasoning_effort=reasoning_effort,
            verbosity=verbosity,
        )
        return instance

    async def chat(
        self,
        messages: ChatContext,
        tools: list[FunctionTool] | None = None,
        conversational_graph: Any | None = None,
        **kwargs: Any
    ) -> AsyncIterator[LLMResponse]:
        """
        Stream chat completions. Routes between the existing HTTP path and the
        WSS Responses path based on the ``streaming`` flag.
        """
        self._cancelled = False
        if self.streaming:
            async for response in self._chat_websocket(
                messages, tools=tools, conversational_graph=conversational_graph, **kwargs
            ):
                yield response
        else:
            async for response in self._chat_http(
                messages, tools=tools, conversational_graph=conversational_graph, **kwargs
            ):
                yield response

    async def _chat_http(
        self,
        messages: ChatContext,
        tools: list[FunctionTool] | None = None,
        conversational_graph: Any | None = None,
        **kwargs: Any,
    ) -> AsyncIterator[LLMResponse]:
        """
        Implement chat functionality using OpenAI's chat completion API.

        Args:
            messages: ChatContext containing conversation history.
            tools: Optional list of function tools available to the model.
            **kwargs: Additional arguments forwarded to the OpenAI API.

        Yields:
            LLMResponse objects containing the model's responses.
        """
        is_reasoning = self._is_reasoning_model()

        openai_messages = messages.to_openai_messages(
            reasoning_model=is_reasoning
        )

        completion_params: dict = {
            "model": self.model,
            "messages": openai_messages,
            "stream": True,
            "stream_options": {"include_usage": True},
        }

        if is_reasoning:
            if self.max_completion_tokens is not None:
                completion_params["max_completion_tokens"] = self.max_completion_tokens
            if self.reasoning_effort is not None:
                completion_params["reasoning_effort"] = self.reasoning_effort
            if self.verbosity is not None:
                completion_params["text"] = {"format": {"type": "text"}, "verbosity": self.verbosity}
        else:
            completion_params["temperature"] = self.temperature
            if self.max_completion_tokens is not None:
                completion_params["max_completion_tokens"] = self.max_completion_tokens

            if self.top_p is not None:
                completion_params["top_p"] = self.top_p
            if self.frequency_penalty is not None:
                completion_params["frequency_penalty"] = self.frequency_penalty
            if self.presence_penalty is not None:
                completion_params["presence_penalty"] = self.presence_penalty

        if self.seed is not None:
            completion_params["seed"] = self.seed

        if conversational_graph:
            completion_params["response_format"] = {
                "type": "json_schema",
                "json_schema": {
                    "name": "conversational_graph_response",
                    "strict": True,
                    "schema": conversational_graph._get_graph_schema()
                }
            }

        # Modern tools API (replaces deprecated functions/function_call)
        if tools:
            formatted_tools = []
            for tool in tools:
                if not is_function_tool(tool):
                    continue
                try:
                    tool_schema = build_openai_schema(tool)
                    formatted_tools.append({"type": "function", "function": tool_schema})
                except Exception as e:
                    self.emit("error", f"Failed to format tool {tool}: {e}")
                    continue

            if formatted_tools:
                completion_params["tools"] = formatted_tools
                # tool_choice: "auto"|"required"|"none" or {"type":"function","function":{"name":"..."}}
                if isinstance(self.tool_choice, dict):
                    completion_params["tool_choice"] = self.tool_choice
                else:
                    completion_params["tool_choice"] = self.tool_choice
                if self.parallel_tool_calls is not None:
                    completion_params["parallel_tool_calls"] = self.parallel_tool_calls

        # Pass-through overrides from caller
        completion_params.update(kwargs)

        # Passthrough extra headers / query / body
        create_kwargs: dict = {}
        if self.extra_headers:
            create_kwargs["extra_headers"] = self.extra_headers
        if self.extra_query:
            create_kwargs["extra_query"] = self.extra_query
        if self.extra_body:
            create_kwargs["extra_body"] = self.extra_body

        response_stream = None
        try:
            response_stream = await self._client.chat.completions.create(
                **completion_params, **create_kwargs
            )
            current_content = ""
            # Accumulate streamed tool call fragments keyed by delta index
            pending_tool_calls: dict[int, dict] = {}
            streaming_state = {
                "in_response": False,
                "response_start_index": -1,
                "yielded_content_length": 0
            }

            usage_metadata: dict = {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0,
                "prompt_cached_tokens": 0,
                "reasoning_tokens": 0,
                "request_id": None,
                "model": self.model,
            }

            async for chunk in response_stream:
                if self._cancelled:
                    break

                if hasattr(chunk, 'usage') and chunk.usage is not None:
                    usage_metadata["prompt_tokens"] = chunk.usage.prompt_tokens or 0
                    usage_metadata["completion_tokens"] = chunk.usage.completion_tokens or 0
                    usage_metadata["total_tokens"] = chunk.usage.total_tokens or 0
                    usage_metadata["request_id"] = getattr(chunk, "id", None)
                    usage_metadata["model"] = getattr(chunk, "model", self.model)

                    if hasattr(chunk.usage, 'prompt_tokens_details') and chunk.usage.prompt_tokens_details:
                        usage_metadata["prompt_cached_tokens"] = getattr(
                            chunk.usage.prompt_tokens_details, 'cached_tokens', 0
                        ) or 0
                    if hasattr(chunk.usage, 'completion_tokens_details') and chunk.usage.completion_tokens_details:
                        usage_metadata["reasoning_tokens"] = getattr(
                            chunk.usage.completion_tokens_details, 'reasoning_tokens', 0
                        ) or 0

                    yield LLMResponse(content="", role=ChatRole.ASSISTANT, metadata={"usage": usage_metadata})

                if not chunk.choices:
                    continue

                delta = chunk.choices[0].delta
                finish_reason = chunk.choices[0].finish_reason

                # Accumulate tool call fragments per index
                if delta.tool_calls:
                    for tc in delta.tool_calls:
                        idx = tc.index
                        if idx not in pending_tool_calls:
                            pending_tool_calls[idx] = {
                                "id": tc.id or "",
                                "name": (tc.function.name or "") if tc.function else "",
                                "arguments": (tc.function.arguments or "") if tc.function else "",
                            }
                        else:
                            if tc.function:
                                if tc.function.name:
                                    pending_tool_calls[idx]["name"] += tc.function.name
                                if tc.function.arguments:
                                    pending_tool_calls[idx]["arguments"] += tc.function.arguments

                # Emit all accumulated tool calls once the model signals it is done
                if finish_reason == "tool_calls" and pending_tool_calls:
                    for tc_data in sorted(pending_tool_calls.values(), key=lambda x: x["id"]):
                        try:
                            args = json.loads(tc_data["arguments"])
                        except json.JSONDecodeError:
                            self.emit("error", f"Failed to parse tool call arguments: {tc_data['arguments']}")
                            args = {}
                        yield LLMResponse(
                            content="",
                            role=ChatRole.ASSISTANT,
                            metadata={
                                "function_call": {"name": tc_data["name"], "arguments": args, "id": tc_data["id"]},
                                "usage": usage_metadata,
                            }
                        )
                    pending_tool_calls = {}

                elif delta.content is not None:
                    current_content += delta.content
                    if conversational_graph:
                        for content_chunk in conversational_graph.stream_conversational_graph_response(
                            current_content, streaming_state
                        ):
                            yield LLMResponse(
                                content=content_chunk,
                                role=ChatRole.ASSISTANT,
                                metadata={"usage": usage_metadata},
                            )
                    else:
                        yield LLMResponse(
                            content=delta.content,
                            role=ChatRole.ASSISTANT,
                            metadata={"usage": usage_metadata},
                        )

            # Flush any tool calls not yet emitted (stream ended without explicit finish_reason)
            if pending_tool_calls and not self._cancelled:
                for tc_data in sorted(pending_tool_calls.values(), key=lambda x: x["id"]):
                    try:
                        args = json.loads(tc_data["arguments"])
                    except json.JSONDecodeError:
                        self.emit("error", f"Failed to parse tool call arguments: {tc_data['arguments']}")
                        args = {}
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata={
                            "function_call": {"name": tc_data["name"], "arguments": args, "id": tc_data["id"]},
                            "usage": usage_metadata,
                        }
                    )

            if current_content and not self._cancelled and conversational_graph:
                try:
                    parsed_json = json.loads(current_content.strip())
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata={"usage": usage_metadata, "graph_response": parsed_json}
                    )
                except json.JSONDecodeError:
                    yield LLMResponse(
                        content=current_content,
                        role=ChatRole.ASSISTANT,
                        metadata={"usage": usage_metadata}
                    )

        except Exception as e:
            if not self._cancelled:
                self.emit("error", e)
            raise
        finally:
            if response_stream is not None:
                try:
                    await response_stream.close()
                except Exception:
                    pass

    # ------------------------------------------------------------------
    # WSS Responses API path
    # ------------------------------------------------------------------

    async def _ensure_ws(self) -> aiohttp.ClientWebSocketResponse:
        """Ensure a live WebSocket connection to the Responses endpoint.

        The slow ``ws_connect`` (TLS + WS handshake) runs OUTSIDE ``_ws_lock``
        so a concurrent ``_close_ws()`` (driven by ``cancel_current_generation``)
        can acquire the lock and proceed even while a handshake is in flight.
        Without this, a barge-in landing during the first chat call's handshake
        would block the cancel for the full ws_connect duration.
        """
        ws = self._ws
        if ws is not None and not ws.closed:
            return ws

        async with self._ws_lock:
            if self._ws is not None and not self._ws.closed:
                return self._ws
            if self._ws_session is None or self._ws_session.closed:
                self._ws_session = aiohttp.ClientSession()
            connecting_session = self._ws_session

        headers = {"Authorization": f"Bearer {self.api_key}"}
        if self.extra_headers:
            headers.update(self.extra_headers)

        new_ws = await connecting_session.ws_connect(
            self._wss_url,
            headers=headers,
            autoping=True,
            heartbeat=30,
            autoclose=False,
            timeout=30,
        )

        async with self._ws_lock:
            if self._ws is not None and not self._ws.closed:
                try:
                    await new_ws.close()
                except Exception:
                    pass
                return self._ws
            self._ws = new_ws
            # Fresh connection — chain state is invalid.
            self._previous_response_id = None
            self._last_seen_items_count = 0
            return self._ws

    async def _close_ws(self) -> None:
        async with self._ws_lock:
            if self._ws is not None:
                try:
                    await self._ws.close()
                except Exception:
                    pass
                self._ws = None
            self._previous_response_id = None
            self._last_seen_items_count = 0

    def _build_responses_payload(
        self,
        *,
        input_items: list[dict],
        previous_response_id: str | None,
        tools: list[FunctionTool] | None,
        conversational_graph: Any | None,
        extra: dict,
    ) -> dict:
        """Build a ``response.create`` event payload for the Responses API."""
        is_reasoning = self._is_reasoning_model()
        payload: dict = {
            "type": "response.create",
            "model": self.model,
            "store": self.store,
            "input": input_items,
        }
        if previous_response_id:
            payload["previous_response_id"] = previous_response_id

        if is_reasoning:
            if self.max_completion_tokens is not None:
                payload["max_output_tokens"] = self.max_completion_tokens
            if self.reasoning_effort is not None and self.reasoning_effort != "none":
                payload["reasoning"] = {"effort": self.reasoning_effort}
            if self.verbosity is not None:
                payload.setdefault("text", {})["verbosity"] = self.verbosity
        else:
            payload["temperature"] = self.temperature
            if self.max_completion_tokens is not None:
                payload["max_output_tokens"] = self.max_completion_tokens
            if self.top_p is not None:
                payload["top_p"] = self.top_p

        if self.seed is not None:
            payload["seed"] = self.seed

        if conversational_graph:
            text_cfg = payload.setdefault("text", {})
            text_cfg["format"] = {
                "type": "json_schema",
                "name": "conversational_graph_response",
                "strict": True,
                "schema": conversational_graph._get_graph_schema(),
            }

        if tools:
            formatted_tools: list[dict] = []
            for tool in tools:
                if not is_function_tool(tool):
                    continue
                try:
                    schema = build_openai_schema(tool)
                    fn_tool = {
                        "type": "function",
                        "name": schema["name"],
                        "description": schema.get("description", ""),
                        "parameters": schema.get("parameters", {"type": "object", "properties": {}}),
                    }
                    if schema.get("strict") is not None:
                        fn_tool["strict"] = schema["strict"]
                    formatted_tools.append(fn_tool)
                except Exception as e:
                    self.emit("error", f"Failed to format tool {tool}: {e}")
                    continue
            if formatted_tools:
                payload["tools"] = formatted_tools
                payload["tool_choice"] = self.tool_choice
                if self.parallel_tool_calls is not None:
                    payload["parallel_tool_calls"] = self.parallel_tool_calls

        if self.extra_body:
            payload.update(self.extra_body)
        if extra:
            payload.update(extra)
        return payload

    def _slice_incremental_items(self, all_items: list) -> list:
        """Filter items added since the last turn down to those the server
        does not already know about (i.e., not part of the previous response)."""
        new_items = all_items[self._last_seen_items_count:]
        return [
            item for item in new_items
            if not (
                isinstance(item, FunctionCall)
                or (
                    isinstance(item, ChatMessage)
                    and item.role == ChatRole.ASSISTANT
                )
            )
        ]

    async def _chat_websocket(
        self,
        messages: ChatContext,
        tools: list[FunctionTool] | None = None,
        conversational_graph: Any | None = None,
        **kwargs: Any,
    ) -> AsyncIterator[LLMResponse]:
        """Stream chat responses over the WSS Responses API."""

        all_items = list(messages.items)

        # Decide initial payload: incremental (chained) or full.
        can_chain = (
            self._previous_response_id is not None
            and self._ws is not None
            and not self._ws.closed
            and self._last_seen_items_count > 0
            and len(all_items) >= self._last_seen_items_count
        )

        if can_chain:
            send_objs = self._slice_incremental_items(all_items)
            input_items = _chat_items_to_responses_input(send_objs)
            previous_response_id = self._previous_response_id
        else:
            input_items = _chat_items_to_responses_input(all_items)
            previous_response_id = None

        payload = self._build_responses_payload(
            input_items=input_items,
            previous_response_id=previous_response_id,
            tools=tools,
            conversational_graph=conversational_graph,
            extra=kwargs,
        )

        fallback_done = False

        while True:
            if self._cancelled:
                return

            try:
                ws = await self._ensure_ws()
                logger.info(
                    "[openai-wss] sending request | chained=%s items=%d",
                    bool(payload.get("previous_response_id")),
                    len(input_items),
                )
                await ws.send_str(json.dumps(payload))
            except Exception as e:
                if fallback_done:
                    if not self._cancelled:
                        self.emit("error", e)
                    raise
                fallback_done = True
                await self._close_ws()
                payload["input"] = _chat_items_to_responses_input(all_items)
                payload.pop("previous_response_id", None)
                continue

            usage_metadata: dict = {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0,
                "prompt_cached_tokens": 0,
                "reasoning_tokens": 0,
                "request_id": None,
                "model": self.model,
            }
            current_content = ""
            streaming_state = {
                "in_response": False,
                "response_start_index": -1,
                "yielded_content_length": 0,
            }
            # item_id -> {"call_id", "name", "arguments"}
            function_calls: dict[str, dict] = {}
            response_id_this_turn: Optional[str] = None
            need_retry = False
            completed = False

            try:
                async for ws_msg in ws:
                    if self._cancelled:
                        break

                    if ws_msg.type == aiohttp.WSMsgType.TEXT:
                        try:
                            event = json.loads(ws_msg.data)
                        except json.JSONDecodeError:
                            logger.debug("Skipping non-JSON WSS frame: %r", ws_msg.data)
                            continue

                        etype = event.get("type")

                        if self._drain_until_response_created and etype != "error":
                            if etype == "response.created":
                                self._drain_until_response_created = False
                            else:
                                logger.debug("[openai-wss] draining stale event: %s", etype)
                                continue

                        if etype == "error":
                            err = event.get("error") or {}
                            code = err.get("code")
                            recoverable = code in (
                                "previous_response_not_found",
                                "websocket_connection_limit_reached",
                            )
                            if recoverable and not fallback_done:
                                fallback_done = True
                                need_retry = True
                                if code == "websocket_connection_limit_reached":
                                    await self._close_ws()
                                else:
                                    self._previous_response_id = None
                                    self._last_seen_items_count = 0
                                    self._drain_until_response_created = True
                                payload["input"] = _chat_items_to_responses_input(all_items)
                                payload.pop("previous_response_id", None)
                                break
                            if recoverable and fallback_done:
                                logger.warning(
                                    "[openai-wss] recoverable error after retry (%s) — closing WS for clean slate",
                                    code,
                                )
                                await self._close_ws()
                                self._drain_until_response_created = False
                                return
                            raise RuntimeError(
                                f"OpenAI WSS error: {err.get('message') or event}"
                            )

                        elif etype == "response.created":
                            response_id_this_turn = (event.get("response") or {}).get("id")
                            logger.info("[openai-wss] response.created received id=%s", response_id_this_turn)

                        elif etype == "response.output_text.delta":
                            delta = event.get("delta", "")
                            if not delta:
                                continue
                            current_content += delta
                            if conversational_graph:
                                for content_chunk in conversational_graph.stream_conversational_graph_response(
                                    current_content, streaming_state
                                ):
                                    yield LLMResponse(
                                        content=content_chunk,
                                        role=ChatRole.ASSISTANT,
                                        metadata={"usage": usage_metadata},
                                    )
                            else:
                                yield LLMResponse(
                                    content=delta,
                                    role=ChatRole.ASSISTANT,
                                    metadata={"usage": usage_metadata},
                                )

                        elif etype == "response.output_item.added":
                            item = event.get("item") or {}
                            if item.get("type") == "function_call":
                                iid = item.get("id", "")
                                function_calls[iid] = {
                                    "call_id": item.get("call_id", "") or "",
                                    "name": item.get("name", "") or "",
                                    "arguments": item.get("arguments", "") or "",
                                }

                        elif etype == "response.function_call_arguments.delta":
                            iid = event.get("item_id")
                            delta = event.get("delta", "")
                            if iid and iid in function_calls and delta:
                                function_calls[iid]["arguments"] += delta

                        elif etype == "response.output_item.done":
                            item = event.get("item") or {}
                            if item.get("type") == "function_call":
                                iid = item.get("id", "")
                                existing = function_calls.get(iid, {})
                                fc_entry = {
                                    "call_id": item.get("call_id") or existing.get("call_id", ""),
                                    "name": item.get("name") or existing.get("name", ""),
                                    "arguments": item.get("arguments") or existing.get("arguments", ""),
                                    "dispatched": existing.get("dispatched", False),
                                }
                                function_calls[iid] = fc_entry
                                if not fc_entry["dispatched"]:
                                    args_str = fc_entry.get("arguments") or ""
                                    try:
                                        args = json.loads(args_str) if args_str else {}
                                    except json.JSONDecodeError:
                                        self.emit(
                                            "error",
                                            f"Failed to parse tool call arguments: {args_str}",
                                        )
                                        args = {}
                                    fc_entry["dispatched"] = True
                                    yield LLMResponse(
                                        content="",
                                        role=ChatRole.ASSISTANT,
                                        metadata={
                                            "function_call": {
                                                "name": fc_entry.get("name", ""),
                                                "arguments": args,
                                                "id": fc_entry.get("call_id", ""),
                                            },
                                            "usage": usage_metadata,
                                        },
                                    )

                        elif etype == "response.completed":
                            completed = True
                            resp = event.get("response") or {}
                            response_id_this_turn = resp.get("id") or response_id_this_turn
                            usage = resp.get("usage") or {}
                            usage_metadata["prompt_tokens"] = (
                                usage.get("input_tokens") or usage.get("prompt_tokens") or 0
                            )
                            usage_metadata["completion_tokens"] = (
                                usage.get("output_tokens") or usage.get("completion_tokens") or 0
                            )
                            usage_metadata["total_tokens"] = usage.get("total_tokens") or 0
                            usage_metadata["request_id"] = response_id_this_turn
                            usage_metadata["model"] = resp.get("model") or self.model
                            in_details = usage.get("input_tokens_details") or {}
                            usage_metadata["prompt_cached_tokens"] = (
                                in_details.get("cached_tokens") or 0
                            )
                            out_details = usage.get("output_tokens_details") or {}
                            usage_metadata["reasoning_tokens"] = (
                                out_details.get("reasoning_tokens") or 0
                            )

                            yield LLMResponse(
                                content="",
                                role=ChatRole.ASSISTANT,
                                metadata={"usage": usage_metadata},
                            )

                            for fc in function_calls.values():
                                if fc.get("dispatched"):
                                    continue
                                args_str = fc.get("arguments") or ""
                                try:
                                    args = json.loads(args_str) if args_str else {}
                                except json.JSONDecodeError:
                                    self.emit(
                                        "error",
                                        f"Failed to parse tool call arguments: {args_str}",
                                    )
                                    args = {}
                                fc["dispatched"] = True
                                yield LLMResponse(
                                    content="",
                                    role=ChatRole.ASSISTANT,
                                    metadata={
                                        "function_call": {
                                            "name": fc.get("name", ""),
                                            "arguments": args,
                                            "id": fc.get("call_id", ""),
                                        },
                                        "usage": usage_metadata,
                                    },
                                )

                            if current_content and conversational_graph:
                                try:
                                    parsed_json = json.loads(current_content.strip())
                                    yield LLMResponse(
                                        content="",
                                        role=ChatRole.ASSISTANT,
                                        metadata={
                                            "usage": usage_metadata,
                                            "graph_response": parsed_json,
                                        },
                                    )
                                except json.JSONDecodeError:
                                    yield LLMResponse(
                                        content=current_content,
                                        role=ChatRole.ASSISTANT,
                                        metadata={"usage": usage_metadata},
                                    )

                            # Update chain tracking only on success.
                            if response_id_this_turn:
                                self._previous_response_id = response_id_this_turn
                                self._last_seen_items_count = len(all_items)
                            break

                    elif ws_msg.type in (
                        aiohttp.WSMsgType.CLOSED,
                        aiohttp.WSMsgType.CLOSE,
                        aiohttp.WSMsgType.CLOSING,
                        aiohttp.WSMsgType.ERROR,
                    ):
                        await self._close_ws()
                        if not completed and not fallback_done:
                            fallback_done = True
                            need_retry = True
                            payload["input"] = _chat_items_to_responses_input(all_items)
                            payload.pop("previous_response_id", None)
                        break
            except Exception as e:
                if not self._cancelled:
                    self.emit("error", e)
                raise

            if need_retry and not self._cancelled:
                continue
            return

    async def cancel_current_generation(self) -> None:
        """Cancel the in-flight response.

        Sends a ``response.cancel`` event over the existing WebSocket instead
        of closing it, so the next ``chat()`` call reuses the same connection
        and avoids paying another TLS+WS handshake.

        Falls back to closing the WS if the cancel event can't be delivered
        (e.g. the connection is already broken).
        """
        self._cancelled = True
        self._previous_response_id = None
        self._last_seen_items_count = 0

        self._drain_until_response_created = True

        ws = self._ws
        if ws is None or ws.closed:
            return
        try:
            await ws.send_str(json.dumps({"type": "response.cancel"}))
        except Exception:
            await self._close_ws()

    async def aclose(self) -> None:
        """Cleanup resources. Closes the underlying HTTP client (if owned) and any
        WSS connection / session that was opened for streaming mode."""
        await self.cancel_current_generation()
        if self._prewarm_task is not None and not self._prewarm_task.done():
            self._prewarm_task.cancel()
            try:
                await self._prewarm_task
            except (asyncio.CancelledError, Exception):
                pass
        await self._close_ws()
        if self._ws_session is not None and not self._ws_session.closed:
            try:
                await self._ws_session.close()
            except Exception:
                pass
            self._ws_session = None
        if self._owns_client and self._client:
            await self._client.close()
        await super().aclose()

Base class for LLM implementations.

Initialize the OpenAI LLM plugin.

Args

api_key
OpenAI API key. Falls back to OPENAI_API_KEY env var.
model
Chat model name. Defaults to "gpt-4o-mini".
base_url
Override the default OpenAI API base URL.
temperature
Sampling temperature. Defaults to 0.7.
tool_choice
Controls which (if any) tool is called. Defaults to "auto".
max_completion_tokens
Maximum tokens in the completion.
top_p
Nucleus sampling probability mass.
frequency_penalty
Penalise repeated tokens by frequency.
presence_penalty
Penalise tokens that have already appeared.
seed
Seed for deterministic sampling.
organization
OpenAI organisation ID.
project
OpenAI project ID.
parallel_tool_calls
Allow the model to call multiple tools in one turn.
timeout
Custom httpx.Timeout for the underlying HTTP client.
extra_headers
Additional HTTP headers forwarded to every API call.
extra_query
Additional query-string parameters forwarded to every API call.
extra_body
Additional JSON body fields forwarded to every API call.
client
Optional pre-built openai.AsyncOpenAI instance to use instead of creating a new one. Useful for sharing a client across instances or for testing. When provided, api_key, base_url, organization, project, timeout, and max_retries are ignored.
max_retries
Number of automatic retries on transient errors. Defaults to 0.
reasoning_effort
Controls reasoning depth for reasoning models. Supported values: "none", "low", "medium", "high". Defaults to None (uses the model's default). Only applied for reasoning / GPT-5 models.
verbosity
Controls output verbosity for reasoning / GPT-5 models. Supported values: "low", "medium", "high". Defaults to None.
streaming
When True, use OpenAI's WebSocket Responses API (wss://api.openai.com/v1/responses) instead of the standard HTTP chat completions endpoint. The connection is reused across turns and continues with previous_response_id for lower per-turn latency. Defaults to False (HTTP mode).
store
Only used when streaming=True. Controls whether responses are persisted server-side. Defaults to False (ZDR-friendly). With store=False and an unrecoverable cache miss the connection resends the full context.
wss_url
Override the WebSocket Responses URL. Defaults to OpenAI's public endpoint.

Ancestors

  • videosdk.agents.llm.llm.LLM
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Static methods

def azure(*,
model: str = 'gpt-4o-mini',
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
parallel_tool_calls: bool | None = None,
timeout: httpx.Timeout | None = None,
extra_headers: dict | None = None,
extra_query: dict | None = None,
extra_body: dict | None = None,
client: openai.AsyncAzureOpenAI | None = None,
max_retries: int = 0,
reasoning_effort: "Literal['none', 'low', 'medium', 'high'] | None" = 'none',
verbosity: "Literal['low', 'medium', 'high'] | None" = 'low') ‑> OpenAILLM
Expand source code
@staticmethod
def azure(
    *,
    model: str = "gpt-4o-mini",
    azure_endpoint: str | None = None,
    azure_deployment: str | None = None,
    api_version: str | None = None,
    api_key: str | None = None,
    azure_ad_token: str | None = None,
    organization: str | None = None,
    project: str | None = None,
    base_url: str | None = None,
    temperature: float = 0.7,
    tool_choice: ToolChoice = "auto",
    max_completion_tokens: int | None = None,
    top_p: float | None = None,
    frequency_penalty: float | None = None,
    presence_penalty: float | None = None,
    seed: int | None = None,
    parallel_tool_calls: bool | None = None,
    timeout: httpx.Timeout | None = None,
    extra_headers: dict | None = None,
    extra_query: dict | None = None,
    extra_body: dict | None = None,
    client: openai.AsyncAzureOpenAI | None = None,
    max_retries: int = 0,
    reasoning_effort: Literal["none", "low", "medium", "high"] | None = "none",
    verbosity: Literal["low", "medium", "high"] | None = "low",
) -> "OpenAILLM":
    """
    Create a new instance of Azure OpenAI LLM.

    Automatically infers the following from environment variables when not provided:
    - ``api_key`` from ``AZURE_OPENAI_API_KEY``
    - ``organization`` from ``OPENAI_ORG_ID``
    - ``project`` from ``OPENAI_PROJECT_ID``
    - ``azure_ad_token`` from ``AZURE_OPENAI_AD_TOKEN``
    - ``api_version`` from ``OPENAI_API_VERSION``
    - ``azure_endpoint`` from ``AZURE_OPENAI_ENDPOINT``
    - ``azure_deployment`` from ``AZURE_OPENAI_DEPLOYMENT`` (falls back to ``model``)

    Pass ``client`` to supply a pre-built ``openai.AsyncAzureOpenAI`` instance.
    When ``client`` is provided, connection/credential params are ignored.
    """
    if client is not None:
        instance = OpenAILLM(
            model=model,
            temperature=temperature,
            tool_choice=tool_choice,
            max_completion_tokens=max_completion_tokens,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            seed=seed,
            parallel_tool_calls=parallel_tool_calls,
            extra_headers=extra_headers,
            extra_query=extra_query,
            extra_body=extra_body,
            client=client,
            reasoning_effort=reasoning_effort,
            verbosity=verbosity,
        )
        return instance

    azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
    azure_deployment = azure_deployment or os.getenv("AZURE_OPENAI_DEPLOYMENT")
    api_version = api_version or os.getenv("OPENAI_API_VERSION")
    api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
    azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN")
    organization = organization or os.getenv("OPENAI_ORG_ID")
    project = project or os.getenv("OPENAI_PROJECT_ID")

    if not azure_deployment:
        azure_deployment = model

    if not azure_endpoint:
        raise ValueError(
            "Azure endpoint must be provided either through azure_endpoint parameter "
            "or AZURE_OPENAI_ENDPOINT environment variable"
        )

    if not api_key and not azure_ad_token:
        raise ValueError("Either API key or Azure AD token must be provided")

    _timeout = timeout or httpx.Timeout(connect=15.0, read=10.0, write=5.0, pool=5.0)
    azure_client = openai.AsyncAzureOpenAI(
        max_retries=max_retries,
        azure_endpoint=azure_endpoint,
        azure_deployment=azure_deployment,
        api_version=api_version,
        api_key=api_key,
        azure_ad_token=azure_ad_token,
        organization=organization,
        project=project,
        base_url=base_url,
        timeout=_timeout,
    )

    instance = OpenAILLM(
        model=model,
        temperature=temperature,
        tool_choice=tool_choice,
        max_completion_tokens=max_completion_tokens,
        top_p=top_p,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty,
        seed=seed,
        parallel_tool_calls=parallel_tool_calls,
        extra_headers=extra_headers,
        extra_query=extra_query,
        extra_body=extra_body,
        client=azure_client,
        reasoning_effort=reasoning_effort,
        verbosity=verbosity,
    )
    return instance

Create a new instance of Azure OpenAI LLM.

Automatically infers the following from environment variables when not provided: - api_key from AZURE_OPENAI_API_KEY - organization from OPENAI_ORG_ID - project from OPENAI_PROJECT_ID - azure_ad_token from AZURE_OPENAI_AD_TOKEN - api_version from OPENAI_API_VERSION - azure_endpoint from AZURE_OPENAI_ENDPOINT - azure_deployment from AZURE_OPENAI_DEPLOYMENT (falls back to model)

Pass client to supply a pre-built openai.AsyncAzureOpenAI instance. When client is provided, connection/credential params are ignored.

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources. Closes the underlying HTTP client (if owned) and any
    WSS connection / session that was opened for streaming mode."""
    await self.cancel_current_generation()
    if self._prewarm_task is not None and not self._prewarm_task.done():
        self._prewarm_task.cancel()
        try:
            await self._prewarm_task
        except (asyncio.CancelledError, Exception):
            pass
    await self._close_ws()
    if self._ws_session is not None and not self._ws_session.closed:
        try:
            await self._ws_session.close()
        except Exception:
            pass
        self._ws_session = None
    if self._owns_client and self._client:
        await self._client.close()
    await super().aclose()

Cleanup resources. Closes the underlying HTTP client (if owned) and any WSS connection / session that was opened for streaming mode.

async def cancel_current_generation(self) ‑> None
Expand source code
async def cancel_current_generation(self) -> None:
    """Cancel the in-flight response.

    Sends a ``response.cancel`` event over the existing WebSocket instead
    of closing it, so the next ``chat()`` call reuses the same connection
    and avoids paying another TLS+WS handshake.

    Falls back to closing the WS if the cancel event can't be delivered
    (e.g. the connection is already broken).
    """
    self._cancelled = True
    self._previous_response_id = None
    self._last_seen_items_count = 0

    self._drain_until_response_created = True

    ws = self._ws
    if ws is None or ws.closed:
        return
    try:
        await ws.send_str(json.dumps({"type": "response.cancel"}))
    except Exception:
        await self._close_ws()

Cancel the in-flight response.

Sends a response.cancel event over the existing WebSocket instead of closing it, so the next chat() call reuses the same connection and avoids paying another TLS+WS handshake.

Falls back to closing the WS if the cancel event can't be delivered (e.g. the connection is already broken).

async def chat(self,
messages: ChatContext,
tools: list[FunctionTool] | None = None,
conversational_graph: Any | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]
Expand source code
async def chat(
    self,
    messages: ChatContext,
    tools: list[FunctionTool] | None = None,
    conversational_graph: Any | None = None,
    **kwargs: Any
) -> AsyncIterator[LLMResponse]:
    """
    Stream chat completions. Routes between the existing HTTP path and the
    WSS Responses path based on the ``streaming`` flag.
    """
    self._cancelled = False
    if self.streaming:
        async for response in self._chat_websocket(
            messages, tools=tools, conversational_graph=conversational_graph, **kwargs
        ):
            yield response
    else:
        async for response in self._chat_http(
            messages, tools=tools, conversational_graph=conversational_graph, **kwargs
        ):
            yield response

Stream chat completions. Routes between the existing HTTP path and the WSS Responses path based on the streaming flag.

async def prewarm(self, *, instructions: str | None = None, tools: list[FunctionTool] | None = None) ‑> None
Expand source code
async def prewarm(
    self,
    *,
    instructions: str | None = None,
    tools: list[FunctionTool] | None = None,
) -> None:
    """Eagerly establish the WSS connection (and optionally prime request state).

    Call this once after constructing ``OpenAILLM(streaming=True)`` to avoid
    paying the TLS + WebSocket handshake on the first ``chat()`` call. When
    ``instructions`` and/or ``tools`` are provided, also sends a warmup
    ``response.create`` with ``generate: false`` so the server pre-builds
    request state for the first real turn (per OpenAI's WSS docs). The
    returned response id is used as ``previous_response_id`` on the first
    real turn for further latency reduction.

    No-op when ``streaming=False``.
    """
    if not self.streaming:
        return

    if self._prewarm_task is not None and not self._prewarm_task.done():
        try:
            await self._prewarm_task
        except Exception:
            pass

    ws = await self._ensure_ws()

    if instructions is None and not tools:
        return

    warmup_input: list[dict] = []
    if instructions:
        warmup_input.append(
            {
                "type": "message",
                "role": "system",
                "content": [{"type": "input_text", "text": instructions}],
            }
        )

    payload = self._build_responses_payload(
        input_items=warmup_input,
        previous_response_id=None,
        tools=tools,
        conversational_graph=None,
        extra={"generate": False},
    )

    try:
        await ws.send_str(json.dumps(payload))
    except Exception as e:
        logger.warning("OpenAI WSS warmup send failed: %s", e)
        return

    try:
        async for msg in ws:
            if msg.type != aiohttp.WSMsgType.TEXT:
                if msg.type in (
                    aiohttp.WSMsgType.CLOSED,
                    aiohttp.WSMsgType.CLOSE,
                    aiohttp.WSMsgType.CLOSING,
                    aiohttp.WSMsgType.ERROR,
                ):
                    break
                continue
            try:
                event = json.loads(msg.data)
            except json.JSONDecodeError:
                continue
            etype = event.get("type")
            if etype == "response.created":
                self._previous_response_id = (
                    (event.get("response") or {}).get("id")
                )
            elif etype == "response.completed":
                resp = event.get("response") or {}
                self._previous_response_id = (
                    resp.get("id") or self._previous_response_id
                )
                break
            elif etype == "error":
                err = event.get("error") or {}
                logger.warning("OpenAI WSS warmup error: %s", err.get("message") or err)
                self._previous_response_id = None
                break
    except Exception as e:
        logger.warning("OpenAI WSS warmup read failed: %s", e)
        self._previous_response_id = None

Eagerly establish the WSS connection (and optionally prime request state).

Call this once after constructing OpenAILLM(streaming=True) to avoid paying the TLS + WebSocket handshake on the first chat() call. When instructions and/or tools are provided, also sends a warmup response.create with generate: false so the server pre-builds request state for the first real turn (per OpenAI's WSS docs). The returned response id is used as previous_response_id on the first real turn for further latency reduction.

No-op when streaming=False.

class OpenAIRealtime (*,
api_key: str | None = None,
model: str,
config: OpenAIRealtimeConfig | None = None,
base_url: str | None = None)
Expand source code
class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
    """OpenAI's realtime model implementation."""

    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str,
        config: OpenAIRealtimeConfig | None = None,
        base_url: str | None = None,
    ) -> None:
        """
        Initialize OpenAI realtime model.

        Args:
            api_key: OpenAI API key. If not provided, will attempt to read from OPENAI_API_KEY env var
            model: The OpenAI model identifier to use (e.g. 'gpt-4', 'gpt-3.5-turbo')
            config: Optional configuration object for customizing model behavior. Contains settings for:
                   - voice: Voice ID to use for audio output
                   - temperature: Sampling temperature for responses
                   - turn_detection: Settings for detecting user speech turns
                   - input_audio_transcription: Settings for audio transcription
                   - tool_choice: How tools should be selected ('auto' or 'none')
                   - modalities: List of enabled modalities ('text', 'audio')
            base_url: Base URL for OpenAI API. Defaults to 'https://api.openai.com/v1'

        Raises:
            ValueError: If no API key is provided and none found in environment variables
        """
        super().__init__()
        self.model = model
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")
        self.base_url = base_url or OPENAI_BASE_URL
        if not self.api_key:
            self.emit(
                "error",
                "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable",
            )
            raise ValueError(
                "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable"
            )
        self._http_session: Optional[aiohttp.ClientSession] = None
        self._session: Optional[OpenAISession] = None
        self._closing = False
        self._instructions: Optional[str] = None
        self._tools: Optional[List[FunctionTool]] = []
        self._formatted_tools: Optional[List[Dict[str, Any]]] = None
        self.config: OpenAIRealtimeConfig = config or OpenAIRealtimeConfig()
        self.input_sample_rate = 48000
        # GA Realtime API: audio/pcm is fixed at 24kHz. This drives both the
        # resample target for outgoing user audio and the rate declared in
        # session.audio.input.format — they must stay equal.
        self.target_sample_rate = 24000
        self._agent_speaking = False
        self._active_response_id: Optional[str] = None

    def set_agent(self, agent: Agent) -> None:
        self._agent = agent
        self._instructions = agent.instructions
        self._tools = agent.tools
        self.tools_formatted = self._format_tools_for_session(self._tools)
        self._formatted_tools = self.tools_formatted

    async def connect(self) -> None:
        headers = {"Agent": "VideoSDK Agents"}
        headers["Authorization"] = f"Bearer {self.api_key}"
        # GA Realtime API — do NOT send "OpenAI-Beta: realtime=v1". That header
        # opts into the retired Beta API, which the server now rejects with
        # "The Realtime Beta API is no longer supported."

        url = self.process_base_url(self.base_url, self.model)

        if "audio" in self.config.modalities:
            self.reframe_audio_track(self.target_sample_rate)

        try:
            self._session = await self._create_session(url, headers)
            await self._handle_websocket(self._session)
            await self.send_first_session_update()
        except aiohttp.WSServerHandshakeError as e:
            # Bad/expired API key, wrong URL, or rejected model fail here —
            # before the WebSocket opens, so the receive loop (and
            # _handle_error) never run. Surface it on the error channel.
            message = (
                f"OpenAI Realtime connection rejected (HTTP {e.status}): {e.message}"
            )
            if e.status in (401, 403):
                message += " — verify OPENAI_API_KEY is set and valid."
            logger.error(message)
            self.emit("error", message)
            raise
        except Exception as e:
            message = f"OpenAI Realtime connection failed: {e}"
            logger.error(message)
            self.emit("error", message)
            raise

    async def handle_audio_input(self, audio_data: bytes) -> None:
        """Handle incoming audio data from the user"""
        if self._session and not self._closing and "audio" in self.config.modalities:
            if self.current_utterance and not self.current_utterance.is_interruptible:
                logger.info("Interruption is disabled for the current utterance. Not processing audio input.")
                return
            # WebRTC source (aiortc) delivers 48 kHz s16 stereo-interleaved
            # frames flattened to bytes — _input_stream's frame.to_ndarray()[0]
            # is one row of L,R,L,R samples, NOT mono. Mix channels to mono
            # BEFORE resampling: without this, the buffer is twice the true
            # mono length, and once we declare GA's required rate=24000 the
            # server reads it at half real-time speed → both the transcription
            # model and the realtime LLM hear slowed-down speech and hallucinate
            # random-language tokens. (Gemini Live papers over this by declaring
            # rate=48000; GA OpenAI cannot — audio/pcm is fixed at 24 kHz.)
            raw = np.frombuffer(audio_data, dtype=np.int16)
            if raw.size >= 2 and raw.size % 2 == 0:
                mono = raw.reshape(-1, 2).astype(np.float32).mean(axis=1)
            else:
                mono = raw.astype(np.float32)
            resampled = signal.resample(
                mono,
                int(len(mono) * self.target_sample_rate / self.input_sample_rate),
            )
            audio_data = np.clip(resampled, -32767, 32767).astype(np.int16).tobytes()
            base64_audio_data = base64.b64encode(audio_data).decode("utf-8")
            audio_event = {
                "type": "input_audio_buffer.append",
                "audio": base64_audio_data,
            }
            await self.send_event(audio_event)

    async def _ensure_http_session(self) -> aiohttp.ClientSession:
        """Ensure we have an HTTP session"""
        if not self._http_session:
            self._http_session = aiohttp.ClientSession()
        return self._http_session

    async def _create_session(self, url: str, headers: dict) -> OpenAISession:
        """Create a new WebSocket session"""

        http_session = await self._ensure_http_session()
        ws = await http_session.ws_connect(
            url,
            headers=headers,
            autoping=True,
            heartbeat=10,
            autoclose=False,
            timeout=30,
        )
        msg_queue: asyncio.Queue = asyncio.Queue()
        tasks: list[asyncio.Task] = []

        self._closing = False

        return OpenAISession(ws=ws, msg_queue=msg_queue, tasks=tasks)


    async def send_message(self, message: str) -> None:
        """Send a message to the OpenAI realtime API"""
        await self.send_event(
            {
                "type": "conversation.item.create",
                "item": {
                    "type": "message",
                    "role": "assistant",
                    "content": [
                        {
                            # GA: assistant/output message content is "output_text"
                            # (the Beta API's "text" is rejected).
                            "type": "output_text",
                            "text": "Repeat the user's exact message back to them:"
                            + message
                            + "DO NOT ADD ANYTHING ELSE",
                        }
                    ],
                },
            }
        )
        await self.create_response()

    async def handle_video_input(self, video_data: av.VideoFrame) -> None:
        if not self._session or self._closing:
            return

        try:
            if not video_data or not video_data.planes:
                return

            processed_jpeg = encode_image(video_data, DEFAULT_IMAGE_ENCODE_OPTIONS)

            if not processed_jpeg or len(processed_jpeg) < 100:
                logger.warning("Invalid JPEG data generated")
                return

            base64_url = self.bytes_to_base64_url(processed_jpeg)

            content = [{"type": "input_image", "image_url": base64_url}]

            conversation_event = {
                "type": "conversation.item.create",
                "item": {
                    "type": "message",
                    "role": "user",
                    "content": content,
                },
            }
            await self.send_event(conversation_event)

        except Exception as e:
            self.emit("error", f"Video processing error: {str(e)}")

    async def send_message_with_frames(
        self, message: Optional[str], frames: list[av.VideoFrame]
    ) -> None:
        content = []
        if message:
            content.append({"type": "input_text", "text": message})

        for frame in frames:
            try:
                processed_jpeg = encode_image(frame, DEFAULT_IMAGE_ENCODE_OPTIONS)

                if not processed_jpeg or len(processed_jpeg) < 100:
                    logger.warning("Invalid JPEG data generated")
                    continue

                base64_url = self.bytes_to_base64_url(processed_jpeg)
                content.append({"type": "input_image", "image_url": base64_url})
            except Exception as e:
                logger.error(f"Error processing frame: {e}")

        if not any(
            item.get("type") == "input_image" or item.get("type") == "input_text"
            for item in content
        ):
            logger.warning("No content to send.")
            return

        conversation_event = {
            "type": "conversation.item.create",
            "item": {
                "type": "message",
                "role": "user",
                "content": content,
            },
        }
        await self.send_event(conversation_event)

        await self.create_response()
    
    async def create_response(self) -> None:
        """Create a response to the OpenAI realtime API"""
        if not self._session:
            self.emit("error", "No active WebSocket session")
            raise RuntimeError("No active WebSocket session")

        response_event = {
            "type": "response.create",
            "event_id": str(uuid.uuid4()),
            "response": {
                "instructions": self._instructions,
                "metadata": {"client_event_id": str(uuid.uuid4())},
            },
        }

        await self.send_event(response_event)

    async def _handle_websocket(self, session: OpenAISession) -> None:
        """Start WebSocket send/receive tasks"""
        session.tasks.extend(
            [
                asyncio.create_task(self._send_loop(session), name="send_loop"),
                asyncio.create_task(self._receive_loop(session), name="receive_loop"),
            ]
        )

    async def _send_loop(self, session: OpenAISession) -> None:
        """Send messages from queue to WebSocket"""
        try:
            while not self._closing:
                msg = await session.msg_queue.get()
                if isinstance(msg, dict):
                    await session.ws.send_json(msg)
                else:
                    await session.ws.send_str(str(msg))
        except asyncio.CancelledError:
            pass
        except ConnectionError as e:
            # The WebSocket was closed underneath us. Don't leak an unretrieved
            # task exception, but do log it — a mid-session close here is a
            # symptom, not just teardown noise.
            logger.warning("OpenAI Realtime send loop stopped — connection closed: %s", e)
        finally:
            await self._cleanup_session(session)

    async def _receive_loop(self, session: OpenAISession) -> None:
        """Receive and process WebSocket messages"""
        try:
            while not self._closing:
                msg = await session.ws.receive()

                if msg.type in (
                    aiohttp.WSMsgType.CLOSED,
                    aiohttp.WSMsgType.CLOSE,
                    aiohttp.WSMsgType.CLOSING,
                ):
                    # Server (or aiohttp's heartbeat) closed the socket. Log the
                    # close code so the cause is visible — 1000 clean, 1006
                    # abnormal/heartbeat, 1011 server error, 4xxx app-specific.
                    logger.error(
                        "OpenAI Realtime WebSocket closed by server "
                        "(msg_type=%s close_code=%s reason=%s)",
                        msg.type.name, session.ws.close_code, msg.extra,
                    )
                    self.emit(
                        "error",
                        f"OpenAI Realtime WebSocket closed: "
                        f"code={session.ws.close_code} reason={msg.extra}",
                    )
                    break
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    logger.error("OpenAI Realtime WebSocket error: %s", msg.data)
                    self.emit("error", f"WebSocket error: {msg.data}")
                    break
                elif msg.type == aiohttp.WSMsgType.TEXT:
                    await self._handle_message(json.loads(msg.data))
        except Exception as e:
            logger.error("OpenAI Realtime receive loop crashed: %s", e, exc_info=True)
            self.emit("error", f"WebSocket receive error: {str(e)}")
        finally:
            await self._cleanup_session(session)

    async def _handle_message(self, data: dict) -> None:
        """Handle incoming WebSocket messages"""
        try:
            event_type = data.get("type")

            if event_type == "input_audio_buffer.speech_started":
                await self._handle_speech_started(data)

            elif event_type == "input_audio_buffer.speech_stopped":
                await self._handle_speech_stopped(data)

            elif event_type == "response.created":
                await self._handle_response_created(data)

            elif event_type == "response.output_item.added":
                await self._handle_output_item_added(data)

            elif event_type == "response.content_part.added":
                await self._handle_content_part_added(data)

            elif event_type in ("response.text.delta", "response.output_text.delta"):
                await self._handle_text_delta(data)

            elif event_type in ("response.audio.delta", "response.output_audio.delta"):
                await self._handle_audio_delta(data)

            elif event_type in (
                "response.audio_transcript.delta",
                "response.output_audio_transcript.delta",
            ):
                await self._handle_audio_transcript_delta(data)

            elif event_type == "response.done":
                await self._handle_response_done(data)

            elif event_type == "error":
                await self._handle_error(data)

            elif event_type == "response.function_call_arguments.delta":
                await self._handle_function_call_arguments_delta(data)

            elif event_type == "response.function_call_arguments.done":
                await self._handle_function_call_arguments_done(data)

            elif event_type == "response.output_item.done":
                await self._handle_output_item_done(data)

            elif event_type == "conversation.item.input_audio_transcription.completed":
                await self._handle_input_audio_transcription_completed(data)

            elif event_type == "response.text.done":
                await self._handle_text_done(data)

        except Exception as e:
            self.emit("error", f"Error handling event {event_type}: {str(e)}")

    async def _handle_speech_started(self, data: dict) -> None:
        """Handle speech detection start"""
        if "audio" in self.config.modalities:
            self.emit("user_speech_started", {"type": "done"})
            logger.info("Interrupting on speech start.>>")
            if self.current_utterance and not self.current_utterance.is_interruptible:
                logger.info("Interruption is disabled for the current utterance. Not interrupting on speech start.")
                return
            await self.interrupt()
            if self.audio_track:
                self.audio_track.interrupt()
        metrics_collector.on_user_speech_start()
        metrics_collector.start_turn()

    async def _handle_speech_stopped(self, data: dict) -> None:
        """Handle speech detection end"""
        metrics_collector.on_user_speech_end()
        self.emit("user_speech_ended", {})

    async def _handle_response_created(self, data: dict) -> None:
        """Handle initial response creation"""
        self._active_response_id = data.get("response", {}).get("id")

    async def _handle_output_item_added(self, data: dict) -> None:
        """Handle new output item addition"""

    async def _handle_output_item_done(self, data: dict) -> None:
        """Handle output item done"""
        try:
            item = data.get("item", {})
            if (
                item.get("type") == "function_call"
                and item.get("status") == "completed"
            ):
                name = item.get("name")
                arguments = json.loads(item.get("arguments", "{}"))

                if name and self._tools:
                    for tool in self._tools:
                        tool_info = get_tool_info(tool)
                        if tool_info.name == name:
                            try:
                                metrics_collector.add_function_tool_call(tool_name=name)
                                result = await tool(**arguments)
                                self.emit(
                                    "realtime_model_function_executed",
                                    {
                                        "name": name,
                                        "arguments": item.get("arguments", "{}"),
                                        "call_id": item.get("call_id"),
                                        "output": result if isinstance(result, str) else json.dumps(result),
                                        "is_error": False,
                                    },
                                )
                                await self.send_event(
                                    {
                                        "type": "conversation.item.create",
                                        "item": {
                                            "type": "function_call_output",
                                            "call_id": item.get("call_id"),
                                            "output": json.dumps(result),
                                        },
                                    }
                                )

                                await self.send_event(
                                    {
                                        "type": "response.create",
                                        "event_id": str(uuid.uuid4()),
                                        "response": {
                                            "instructions": self._instructions,
                                            "metadata": {
                                                "client_event_id": str(uuid.uuid4())
                                            },
                                        },
                                    }
                                )

                            except Exception as e:
                                self.emit(
                                    "realtime_model_function_executed",
                                    {
                                        "name": name,
                                        "arguments": item.get("arguments", "{}"),
                                        "call_id": item.get("call_id"),
                                        "output": str(e),
                                        "is_error": True,
                                    },
                                )
                                self.emit(
                                    "error", f"Error executing function {name}: {e}"
                                )
                            break
        except Exception as e:
            self.emit("error", f"Error handling output item done: {e}")

    async def _handle_content_part_added(self, data: dict) -> None:
        """Handle new content part"""

    async def _handle_text_delta(self, data: dict) -> None:
        """Handle text delta chunk (for text-only mode)"""
        delta_content = data.get("delta", "")
        
        if not hasattr(self, "_current_text_response"):
            self._current_text_response = ""
        
        if not self._agent_speaking and delta_content:
            metrics_collector.on_agent_speech_start()
            self._agent_speaking = True
            self.emit("agent_speech_started", {})

        self._current_text_response += delta_content
        
        self.emit("realtime_model_text_delta", {
            "role": "assistant",
            "delta": delta_content,
            "text": self._current_text_response,
        })

    async def _handle_audio_delta(self, data: dict) -> None:
        """Handle audio chunk"""
        if "audio" not in self.config.modalities:
            return

        try:
            if not self._agent_speaking:
                metrics_collector.on_agent_speech_start()
                self._agent_speaking = True
                self.emit("agent_speech_started", {})
            base64_audio_data = base64.b64decode(data.get("delta"))
            if base64_audio_data:
                if self.audio_track and self.loop:
                    asyncio.create_task(
                        self.audio_track.add_new_bytes(base64_audio_data)
                    )
        except Exception as e:
            self.emit("error", f"Error handling audio delta: {e}")
            traceback.print_exc()

    async def interrupt(self) -> None:
        """Interrupt the current response and flush audio"""
        if self._session and not self._closing:
            if self.current_utterance and not self.current_utterance.is_interruptible:
                logger.info("Interruption is disabled for the current utterance. Not interrupting OpenAI realtime session.")
                return
            # Only cancel when a response is actually in flight — GA rejects a
            # stray response.cancel with "no active response found".
            if self._active_response_id:
                cancel_event = {"type": "response.cancel", "event_id": str(uuid.uuid4())}
                await self.send_event(cancel_event)
                self._active_response_id = None
                metrics_collector.on_interrupted()
        if self.audio_track:
            self.audio_track.interrupt()
        if self._agent_speaking:
            if self.audio_track:
                self.audio_track.mark_synthesis_complete()
            self._agent_speaking = False

    async def _handle_audio_transcript_delta(self, data: dict) -> None:
        """Handle transcript chunk"""
        delta_content = data.get("delta", "")
        if not hasattr(self, "_current_audio_transcript"):
            self._current_audio_transcript = ""
        self._current_audio_transcript += delta_content

    async def _handle_input_audio_transcription_completed(self, data: dict) -> None:
        """Handle input audio transcription completion for user transcript"""
        transcript = data.get("transcript", "")
        if transcript:
            metrics_collector.set_user_transcript(transcript)
            try:
                self.emit(
                    "realtime_model_transcription",
                    {"role": "user", "text": transcript, "is_final": True},
                )
            except Exception:
                pass

    async def _handle_response_done(self, data: dict) -> None:
        """Handle response completion for agent transcript"""
        usage_metadata = self.get_realtime_tokens(data)
        metrics_collector.set_realtime_usage(usage_metadata)
        if (
            hasattr(self, "_current_audio_transcript")
            and self._current_audio_transcript
        ):
            metrics_collector.set_agent_response(
                self._current_audio_transcript
            )
            
            self.emit("llm_text_output", {"text": self._current_audio_transcript})
            
            global_event_emitter.emit(
                "text_response",
                {"text": self._current_audio_transcript, "type": "done"},
            )
            try:
                self.emit(
                    "realtime_model_transcription",
                    {
                        "role": "agent",
                        "text": self._current_audio_transcript,
                        "is_final": True,
                    },
                )
            except Exception:
                pass
            self._current_audio_transcript = ""
        self._active_response_id = None
        self.audio_track.mark_synthesis_complete()
        # self.emit("agent_speech_ended", {})
        # metrics_collector.on_agent_speech_end()
        # metrics_collector.schedule_turn_complete(timeout=1.0)
        self._agent_speaking = False
        pass

    async def _handle_function_call_arguments_delta(self, data: dict) -> None:
        """Handle function call arguments delta"""

    async def _handle_function_call_arguments_done(self, data: dict) -> None:
        """Handle function call arguments done"""

    async def _handle_error(self, data: dict) -> None:
        """Handle error events from the OpenAI Realtime API.

        Previously a silent no-op, which made every API-side failure
        (invalid model, bad session config, rejected items) invisible.
        """
        error = data.get("error", data)
        if isinstance(error, dict):
            message = error.get("message") or error.get("code") or str(error)
        else:
            message = str(error)
        logger.error(f"OpenAI Realtime API error: {message}")
        self.emit("error", f"OpenAI Realtime API error: {message}")

    async def _cleanup_session(self, session: OpenAISession) -> None:
        """Clean up session resources"""
        if self._closing:
            return

        logger.info(
            "OpenAI Realtime session teardown — closing send/receive loops "
            "(ws_closed=%s)", session.ws.closed,
        )
        self._closing = True

        for task in session.tasks:
            if not task.done():
                task.cancel()
                try:
                    await asyncio.wait_for(task, timeout=1.0)  # Add timeout
                except (asyncio.CancelledError, asyncio.TimeoutError):
                    pass

        if not session.ws.closed:
            try:
                await session.ws.close()
            except Exception:
                pass

    async def send_event(self, event: Dict[str, Any]) -> None:
        """Send an event to the WebSocket"""
        if self._session and not self._closing:
            await self._session.msg_queue.put(event)

    async def aclose(self) -> None:
        """Cleanup all resources"""
        if self._closing:
            return

        self._closing = True

        if self._session:
            await self._cleanup_session(self._session)

        if self._http_session and not self._http_session.closed:
            await self._http_session.close()

        await super().aclose()

    async def send_first_session_update(self) -> None:
        """Send the initial session.update using the GA Realtime API schema.

        The GA ``session`` object differs from the retired Beta schema:
        - ``modalities`` → ``output_modalities``
        - flat ``voice`` / ``input_audio_format`` / ``output_audio_format`` /
          ``turn_detection`` / ``input_audio_transcription`` are nested under
          ``audio.input`` / ``audio.output``
        - audio formats are objects (``{"type": "audio/pcm", "rate": N}``)
        - ``session.type`` is ``"realtime"``; the model is set via the
          connect URL, not the session object.
        """
        if not self._session:
            return

        audio_mode = "audio" in self.config.modalities

        session: Dict[str, Any] = {
            "type": "realtime",
            "instructions": self.instructions_with_context(self._instructions),
            "output_modalities": ["audio"] if audio_mode else ["text"],
            "tools": self._formatted_tools or [],
            "tool_choice": self.config.tool_choice,
        }

        if audio_mode:
            audio: Dict[str, Any] = {
                "input": {
                    "format": {"type": "audio/pcm", "rate": self.target_sample_rate},
                },
                "output": {
                    "format": {"type": "audio/pcm", "rate": 24000},
                    "voice": self.config.voice,
                },
            }
            if self.config.turn_detection:
                audio["input"]["turn_detection"] = self.config.turn_detection.model_dump(
                    by_alias=True, exclude_unset=True, exclude_defaults=True,
                )
            if self.config.input_audio_transcription:
                audio["input"]["transcription"] = (
                    self.config.input_audio_transcription.model_dump(
                        by_alias=True, exclude_unset=True, exclude_defaults=True,
                    )
                )
            session["audio"] = audio

        await self.send_event({"type": "session.update", "session": session})


    def process_base_url(self, url: str, model: str) -> str:
        if url.startswith("http"):
            url = url.replace("http", "ws", 1)

        parsed_url = urlparse(url)
        query_params = parse_qs(parsed_url.query)

        if not parsed_url.path or parsed_url.path.rstrip("/") in ["", "/v1", "/openai"]:
            path = parsed_url.path.rstrip("/") + "/realtime"
        else:
            path = parsed_url.path

        if "model" not in query_params:
            query_params["model"] = [model]

        new_query = urlencode(query_params, doseq=True)
        new_url = urlunparse(
            (parsed_url.scheme, parsed_url.netloc, path, "", new_query, "")
        )

        return new_url

    def _format_tools_for_session(
        self, tools: List[FunctionTool]
    ) -> List[Dict[str, Any]]:
        """Format tools for OpenAI session update"""
        oai_tools = []
        for tool in tools:
            if not is_function_tool(tool):
                continue

            try:
                tool_schema = build_openai_schema(tool)
                oai_tools.append(tool_schema)
            except Exception as e:
                self.emit("error", f"Failed to format tool {tool}: {e}")
                continue

        return oai_tools

    async def send_text_message(self, message: str) -> None:
        """Send a text message to the OpenAI realtime API"""
        if not self._session:
            self.emit("error", "No active WebSocket session")
            raise RuntimeError("No active WebSocket session")

        await self.send_event(
            {
                "type": "conversation.item.create",
                "item": {
                    "type": "message",
                    "role": "user",
                    "content": [{"type": "input_text", "text": message}],
                },
            }
        )
        await self.create_response()


    def bytes_to_base64_url(self, image_bytes: bytes, fmt: str = "jpeg") -> str:
            mime = f"image/{fmt.lower()}"
            encoded = base64.b64encode(image_bytes).decode("utf-8")
            return f"data:{mime};base64,{encoded}"

    def get_realtime_tokens(self, event: dict) -> dict:
        """
        Extract and flatten all token details needed for pricing from a
        OpenAI Realtime response.done event into a single-level dictionary.

        Parameters:
            event (dict): Full Realtime event payload

        Returns:
            dict: Single-level dictionary with token counts
        """
        usage = event.get("response", {}).get("usage", {})
        input_details = usage.get("input_token_details", {})
        cached_details = input_details.get("cached_tokens_details", {})
        output_details = usage.get("output_token_details", {})

        token_dict = {
            "total_tokens": usage.get("total_tokens", 0),
            "input_tokens": usage.get("input_tokens", 0),
            "output_tokens": usage.get("output_tokens", 0),

            "input_text_tokens": input_details.get("text_tokens", 0),
            "input_audio_tokens": input_details.get("audio_tokens", 0),
            "input_image_tokens": input_details.get("image_tokens", 0),
            "input_cached_tokens": input_details.get("cached_tokens", 0),

            "cached_text_tokens": cached_details.get("text_tokens", 0),
            "cached_audio_tokens": cached_details.get("audio_tokens", 0),
            "cached_image_tokens": cached_details.get("image_tokens", 0),

            "output_text_tokens": output_details.get("text_tokens", 0),
            "output_audio_tokens": output_details.get("audio_tokens", 0),
            "output_image_tokens": output_details.get("image_tokens", 0)
        }

        return token_dict            

OpenAI's realtime model implementation.

Initialize OpenAI realtime model.

Args

api_key
OpenAI API key. If not provided, will attempt to read from OPENAI_API_KEY env var
model
The OpenAI model identifier to use (e.g. 'gpt-4', 'gpt-3.5-turbo')
config
Optional configuration object for customizing model behavior. Contains settings for: - voice: Voice ID to use for audio output - temperature: Sampling temperature for responses - turn_detection: Settings for detecting user speech turns - input_audio_transcription: Settings for audio transcription - tool_choice: How tools should be selected ('auto' or 'none') - modalities: List of enabled modalities ('text', 'audio')
base_url
Base URL for OpenAI API. Defaults to 'https://api.openai.com/v1'

Raises

ValueError
If no API key is provided and none found in environment variables

Ancestors

  • videosdk.agents.realtime_base_model.RealtimeBaseModel
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic
  • abc.ABC

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup all resources"""
    if self._closing:
        return

    self._closing = True

    if self._session:
        await self._cleanup_session(self._session)

    if self._http_session and not self._http_session.closed:
        await self._http_session.close()

    await super().aclose()

Cleanup all resources

def bytes_to_base64_url(self, image_bytes: bytes, fmt: str = 'jpeg') ‑> str
Expand source code
def bytes_to_base64_url(self, image_bytes: bytes, fmt: str = "jpeg") -> str:
        mime = f"image/{fmt.lower()}"
        encoded = base64.b64encode(image_bytes).decode("utf-8")
        return f"data:{mime};base64,{encoded}"
async def connect(self) ‑> None
Expand source code
async def connect(self) -> None:
    headers = {"Agent": "VideoSDK Agents"}
    headers["Authorization"] = f"Bearer {self.api_key}"
    # GA Realtime API — do NOT send "OpenAI-Beta: realtime=v1". That header
    # opts into the retired Beta API, which the server now rejects with
    # "The Realtime Beta API is no longer supported."

    url = self.process_base_url(self.base_url, self.model)

    if "audio" in self.config.modalities:
        self.reframe_audio_track(self.target_sample_rate)

    try:
        self._session = await self._create_session(url, headers)
        await self._handle_websocket(self._session)
        await self.send_first_session_update()
    except aiohttp.WSServerHandshakeError as e:
        # Bad/expired API key, wrong URL, or rejected model fail here —
        # before the WebSocket opens, so the receive loop (and
        # _handle_error) never run. Surface it on the error channel.
        message = (
            f"OpenAI Realtime connection rejected (HTTP {e.status}): {e.message}"
        )
        if e.status in (401, 403):
            message += " — verify OPENAI_API_KEY is set and valid."
        logger.error(message)
        self.emit("error", message)
        raise
    except Exception as e:
        message = f"OpenAI Realtime connection failed: {e}"
        logger.error(message)
        self.emit("error", message)
        raise
async def create_response(self) ‑> None
Expand source code
async def create_response(self) -> None:
    """Create a response to the OpenAI realtime API"""
    if not self._session:
        self.emit("error", "No active WebSocket session")
        raise RuntimeError("No active WebSocket session")

    response_event = {
        "type": "response.create",
        "event_id": str(uuid.uuid4()),
        "response": {
            "instructions": self._instructions,
            "metadata": {"client_event_id": str(uuid.uuid4())},
        },
    }

    await self.send_event(response_event)

Create a response to the OpenAI realtime API

def get_realtime_tokens(self, event: dict) ‑> dict
Expand source code
def get_realtime_tokens(self, event: dict) -> dict:
    """
    Extract and flatten all token details needed for pricing from a
    OpenAI Realtime response.done event into a single-level dictionary.

    Parameters:
        event (dict): Full Realtime event payload

    Returns:
        dict: Single-level dictionary with token counts
    """
    usage = event.get("response", {}).get("usage", {})
    input_details = usage.get("input_token_details", {})
    cached_details = input_details.get("cached_tokens_details", {})
    output_details = usage.get("output_token_details", {})

    token_dict = {
        "total_tokens": usage.get("total_tokens", 0),
        "input_tokens": usage.get("input_tokens", 0),
        "output_tokens": usage.get("output_tokens", 0),

        "input_text_tokens": input_details.get("text_tokens", 0),
        "input_audio_tokens": input_details.get("audio_tokens", 0),
        "input_image_tokens": input_details.get("image_tokens", 0),
        "input_cached_tokens": input_details.get("cached_tokens", 0),

        "cached_text_tokens": cached_details.get("text_tokens", 0),
        "cached_audio_tokens": cached_details.get("audio_tokens", 0),
        "cached_image_tokens": cached_details.get("image_tokens", 0),

        "output_text_tokens": output_details.get("text_tokens", 0),
        "output_audio_tokens": output_details.get("audio_tokens", 0),
        "output_image_tokens": output_details.get("image_tokens", 0)
    }

    return token_dict            

Extract and flatten all token details needed for pricing from a OpenAI Realtime response.done event into a single-level dictionary.

Parameters

event (dict): Full Realtime event payload

Returns

dict
Single-level dictionary with token counts
async def handle_audio_input(self, audio_data: bytes) ‑> None
Expand source code
async def handle_audio_input(self, audio_data: bytes) -> None:
    """Handle incoming audio data from the user"""
    if self._session and not self._closing and "audio" in self.config.modalities:
        if self.current_utterance and not self.current_utterance.is_interruptible:
            logger.info("Interruption is disabled for the current utterance. Not processing audio input.")
            return
        # WebRTC source (aiortc) delivers 48 kHz s16 stereo-interleaved
        # frames flattened to bytes — _input_stream's frame.to_ndarray()[0]
        # is one row of L,R,L,R samples, NOT mono. Mix channels to mono
        # BEFORE resampling: without this, the buffer is twice the true
        # mono length, and once we declare GA's required rate=24000 the
        # server reads it at half real-time speed → both the transcription
        # model and the realtime LLM hear slowed-down speech and hallucinate
        # random-language tokens. (Gemini Live papers over this by declaring
        # rate=48000; GA OpenAI cannot — audio/pcm is fixed at 24 kHz.)
        raw = np.frombuffer(audio_data, dtype=np.int16)
        if raw.size >= 2 and raw.size % 2 == 0:
            mono = raw.reshape(-1, 2).astype(np.float32).mean(axis=1)
        else:
            mono = raw.astype(np.float32)
        resampled = signal.resample(
            mono,
            int(len(mono) * self.target_sample_rate / self.input_sample_rate),
        )
        audio_data = np.clip(resampled, -32767, 32767).astype(np.int16).tobytes()
        base64_audio_data = base64.b64encode(audio_data).decode("utf-8")
        audio_event = {
            "type": "input_audio_buffer.append",
            "audio": base64_audio_data,
        }
        await self.send_event(audio_event)

Handle incoming audio data from the user

async def handle_video_input(self, video_data: av.VideoFrame) ‑> None
Expand source code
async def handle_video_input(self, video_data: av.VideoFrame) -> None:
    if not self._session or self._closing:
        return

    try:
        if not video_data or not video_data.planes:
            return

        processed_jpeg = encode_image(video_data, DEFAULT_IMAGE_ENCODE_OPTIONS)

        if not processed_jpeg or len(processed_jpeg) < 100:
            logger.warning("Invalid JPEG data generated")
            return

        base64_url = self.bytes_to_base64_url(processed_jpeg)

        content = [{"type": "input_image", "image_url": base64_url}]

        conversation_event = {
            "type": "conversation.item.create",
            "item": {
                "type": "message",
                "role": "user",
                "content": content,
            },
        }
        await self.send_event(conversation_event)

    except Exception as e:
        self.emit("error", f"Video processing error: {str(e)}")
async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    """Interrupt the current response and flush audio"""
    if self._session and not self._closing:
        if self.current_utterance and not self.current_utterance.is_interruptible:
            logger.info("Interruption is disabled for the current utterance. Not interrupting OpenAI realtime session.")
            return
        # Only cancel when a response is actually in flight — GA rejects a
        # stray response.cancel with "no active response found".
        if self._active_response_id:
            cancel_event = {"type": "response.cancel", "event_id": str(uuid.uuid4())}
            await self.send_event(cancel_event)
            self._active_response_id = None
            metrics_collector.on_interrupted()
    if self.audio_track:
        self.audio_track.interrupt()
    if self._agent_speaking:
        if self.audio_track:
            self.audio_track.mark_synthesis_complete()
        self._agent_speaking = False

Interrupt the current response and flush audio

def process_base_url(self, url: str, model: str) ‑> str
Expand source code
def process_base_url(self, url: str, model: str) -> str:
    if url.startswith("http"):
        url = url.replace("http", "ws", 1)

    parsed_url = urlparse(url)
    query_params = parse_qs(parsed_url.query)

    if not parsed_url.path or parsed_url.path.rstrip("/") in ["", "/v1", "/openai"]:
        path = parsed_url.path.rstrip("/") + "/realtime"
    else:
        path = parsed_url.path

    if "model" not in query_params:
        query_params["model"] = [model]

    new_query = urlencode(query_params, doseq=True)
    new_url = urlunparse(
        (parsed_url.scheme, parsed_url.netloc, path, "", new_query, "")
    )

    return new_url
async def send_event(self, event: Dict[str, Any]) ‑> None
Expand source code
async def send_event(self, event: Dict[str, Any]) -> None:
    """Send an event to the WebSocket"""
    if self._session and not self._closing:
        await self._session.msg_queue.put(event)

Send an event to the WebSocket

async def send_first_session_update(self) ‑> None
Expand source code
async def send_first_session_update(self) -> None:
    """Send the initial session.update using the GA Realtime API schema.

    The GA ``session`` object differs from the retired Beta schema:
    - ``modalities`` → ``output_modalities``
    - flat ``voice`` / ``input_audio_format`` / ``output_audio_format`` /
      ``turn_detection`` / ``input_audio_transcription`` are nested under
      ``audio.input`` / ``audio.output``
    - audio formats are objects (``{"type": "audio/pcm", "rate": N}``)
    - ``session.type`` is ``"realtime"``; the model is set via the
      connect URL, not the session object.
    """
    if not self._session:
        return

    audio_mode = "audio" in self.config.modalities

    session: Dict[str, Any] = {
        "type": "realtime",
        "instructions": self.instructions_with_context(self._instructions),
        "output_modalities": ["audio"] if audio_mode else ["text"],
        "tools": self._formatted_tools or [],
        "tool_choice": self.config.tool_choice,
    }

    if audio_mode:
        audio: Dict[str, Any] = {
            "input": {
                "format": {"type": "audio/pcm", "rate": self.target_sample_rate},
            },
            "output": {
                "format": {"type": "audio/pcm", "rate": 24000},
                "voice": self.config.voice,
            },
        }
        if self.config.turn_detection:
            audio["input"]["turn_detection"] = self.config.turn_detection.model_dump(
                by_alias=True, exclude_unset=True, exclude_defaults=True,
            )
        if self.config.input_audio_transcription:
            audio["input"]["transcription"] = (
                self.config.input_audio_transcription.model_dump(
                    by_alias=True, exclude_unset=True, exclude_defaults=True,
                )
            )
        session["audio"] = audio

    await self.send_event({"type": "session.update", "session": session})

Send the initial session.update using the GA Realtime API schema.

The GA session object differs from the retired Beta schema: - modalitiesoutput_modalities - flat voice / input_audio_format / output_audio_format / turn_detection / input_audio_transcription are nested under audio.input / audio.output - audio formats are objects ({"type": "audio/pcm", "rate": N}) - session.type is "realtime"; the model is set via the connect URL, not the session object.

async def send_message(self, message: str) ‑> None
Expand source code
async def send_message(self, message: str) -> None:
    """Send a message to the OpenAI realtime API"""
    await self.send_event(
        {
            "type": "conversation.item.create",
            "item": {
                "type": "message",
                "role": "assistant",
                "content": [
                    {
                        # GA: assistant/output message content is "output_text"
                        # (the Beta API's "text" is rejected).
                        "type": "output_text",
                        "text": "Repeat the user's exact message back to them:"
                        + message
                        + "DO NOT ADD ANYTHING ELSE",
                    }
                ],
            },
        }
    )
    await self.create_response()

Send a message to the OpenAI realtime API

async def send_message_with_frames(self, message: Optional[str], frames: list[av.VideoFrame]) ‑> None
Expand source code
async def send_message_with_frames(
    self, message: Optional[str], frames: list[av.VideoFrame]
) -> None:
    content = []
    if message:
        content.append({"type": "input_text", "text": message})

    for frame in frames:
        try:
            processed_jpeg = encode_image(frame, DEFAULT_IMAGE_ENCODE_OPTIONS)

            if not processed_jpeg or len(processed_jpeg) < 100:
                logger.warning("Invalid JPEG data generated")
                continue

            base64_url = self.bytes_to_base64_url(processed_jpeg)
            content.append({"type": "input_image", "image_url": base64_url})
        except Exception as e:
            logger.error(f"Error processing frame: {e}")

    if not any(
        item.get("type") == "input_image" or item.get("type") == "input_text"
        for item in content
    ):
        logger.warning("No content to send.")
        return

    conversation_event = {
        "type": "conversation.item.create",
        "item": {
            "type": "message",
            "role": "user",
            "content": content,
        },
    }
    await self.send_event(conversation_event)

    await self.create_response()
async def send_text_message(self, message: str) ‑> None
Expand source code
async def send_text_message(self, message: str) -> None:
    """Send a text message to the OpenAI realtime API"""
    if not self._session:
        self.emit("error", "No active WebSocket session")
        raise RuntimeError("No active WebSocket session")

    await self.send_event(
        {
            "type": "conversation.item.create",
            "item": {
                "type": "message",
                "role": "user",
                "content": [{"type": "input_text", "text": message}],
            },
        }
    )
    await self.create_response()

Send a text message to the OpenAI realtime API

def set_agent(self, agent: Agent) ‑> None
Expand source code
def set_agent(self, agent: Agent) -> None:
    self._agent = agent
    self._instructions = agent.instructions
    self._tools = agent.tools
    self.tools_formatted = self._format_tools_for_session(self._tools)
    self._formatted_tools = self.tools_formatted
class OpenAIRealtimeConfig (voice: str = 'alloy',
temperature: float = 0.8,
turn_detection: TurnDetection | None = <factory>,
input_audio_transcription: InputAudioTranscription | None = <factory>,
tool_choice: ToolChoice | None = 'auto',
modalities: list[str] = <factory>)
Expand source code
@dataclass
class OpenAIRealtimeConfig:
    """Configuration for the OpenAI realtime API

    Args:
        voice: Voice ID for audio output. Default is 'alloy'
        temperature: Controls randomness in response generation. Higher values (e.g. 0.8) make output more random,
                    lower values make it more deterministic. Default is 0.8
        turn_detection: Configuration for detecting user speech turns. Contains settings for:
                       - type: Detection type ('server_vad')
                       - threshold: Voice activity detection threshold (0.0-1.0)
                       - prefix_padding_ms: Padding before speech start (ms)
                       - silence_duration_ms: Silence duration to mark end (ms)
                       - create_response: Whether to generate response on turn
                       - interrupt_response: Whether to allow interruption
        input_audio_transcription: Configuration for audio transcription. Contains:
                                 - model: Model to use for transcription
        tool_choice: How tools should be selected ('auto' or 'none'). Default is 'auto'
        modalities: List of enabled response types ["text", "audio"]. Default includes both
    """

    voice: str = DEFAULT_VOICE
    temperature: float = DEFAULT_TEMPERATURE
    turn_detection: TurnDetection | None = field(
        default_factory=lambda: DEFAULT_TURN_DETECTION
    )
    input_audio_transcription: InputAudioTranscription | None = field(
        default_factory=lambda: DEFAULT_INPUT_AUDIO_TRANSCRIPTION
    )
    tool_choice: ToolChoice | None = DEFAULT_TOOL_CHOICE
    modalities: list[str] = field(default_factory=lambda: ["text", "audio"])
    
    @property
    def is_text_only_mode(self) -> bool:
        """Check if configured for text-only responses (no audio)"""
        return "audio" not in self.modalities

Configuration for the OpenAI realtime API

Args

voice
Voice ID for audio output. Default is 'alloy'
temperature
Controls randomness in response generation. Higher values (e.g. 0.8) make output more random, lower values make it more deterministic. Default is 0.8
turn_detection
Configuration for detecting user speech turns. Contains settings for: - type: Detection type ('server_vad') - threshold: Voice activity detection threshold (0.0-1.0) - prefix_padding_ms: Padding before speech start (ms) - silence_duration_ms: Silence duration to mark end (ms) - create_response: Whether to generate response on turn - interrupt_response: Whether to allow interruption
input_audio_transcription
Configuration for audio transcription. Contains: - model: Model to use for transcription
tool_choice
How tools should be selected ('auto' or 'none'). Default is 'auto'
modalities
List of enabled response types ["text", "audio"]. Default includes both

Instance variables

var input_audio_transcription : openai.types.beta.realtime.session.InputAudioTranscription | None
prop is_text_only_mode : bool
Expand source code
@property
def is_text_only_mode(self) -> bool:
    """Check if configured for text-only responses (no audio)"""
    return "audio" not in self.modalities

Check if configured for text-only responses (no audio)

var modalities : list[str]
var temperature : float
var tool_choice : Literal['auto', 'required', 'none'] | None
var turn_detection : openai.types.beta.realtime.session.TurnDetection | None
var voice : str
class OpenAISTT (*,
api_key: str | None = None,
model: str = 'gpt-4o-mini-transcribe',
base_url: str | None = None,
prompt: str | None = None,
language: str = 'en',
turn_detection: dict | None = None,
enable_streaming: bool = True,
silence_threshold: float = 0.01,
silence_duration: float = 0.8)
Expand source code
class OpenAISTT(BaseSTT):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = "gpt-4o-mini-transcribe",
        base_url: str | None = None,
        prompt: str | None = None,
        language: str = "en",
        turn_detection: dict | None = None,
        enable_streaming: bool = True,
        silence_threshold: float = 0.01,
        silence_duration: float = 0.8,
    ) -> None:
        """Initialize the OpenAI STT plugin.

        Args:
            api_key (Optional[str], optional): OpenAI API key. Defaults to None.
            model (str): The model to use for the STT plugin. Defaults to "whisper-1".
            base_url (Optional[str], optional): The base URL for the OpenAI API. Defaults to None.
            prompt (Optional[str], optional): The prompt for the STT plugin. Defaults to None.
            language (str): The language to use for the STT plugin. Defaults to "en".
            turn_detection (dict | None): The turn detection for the STT plugin. Defaults to None.
        """
        super().__init__()
        
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")
        if not self.api_key:
            raise ValueError("OpenAI API key must be provided either through api_key parameter or OPENAI_API_KEY environment variable")
        
        self.model = model
        self.language = language
        self.prompt = prompt
        self.turn_detection = turn_detection or {
            "type": "server_vad",
            "threshold": 0.5,
            "prefix_padding_ms": 300,
            "silence_duration_ms": 500,
        }
        self.enable_streaming = enable_streaming
        
        # Custom VAD parameters for non-streaming mode
        self.silence_threshold_bytes = int(silence_threshold * 32767)
        self.silence_duration_frames = int(silence_duration * 48000)  # input_sample_rate
        
        self.client = openai.AsyncClient(
            max_retries=0,
            api_key=self.api_key,
            base_url=base_url or None,
            http_client=httpx.AsyncClient(
                timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
                follow_redirects=True,
                limits=httpx.Limits(
                    max_connections=50,
                    max_keepalive_connections=50,
                    keepalive_expiry=120,
                ),
            ),
        )
        
        self._session: Optional[aiohttp.ClientSession] = None
        self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
        self._ws_task: Optional[asyncio.Task] = None
        self._current_text = ""
        self._last_interim_at = 0
        self.input_sample_rate = 48000
        self.target_sample_rate = 16000
        self._audio_buffer = bytearray()
        
        # Custom VAD state for non-streaming mode
        self._is_speaking = False
        self._silence_frames = 0
        
    @staticmethod
    def azure(
        *,
        model: str = "gpt-4o-mini-transcribe",
        language: str = "en",
        prompt: str | None = None,
        turn_detection: dict | None = None,
        azure_endpoint: str | None = None,
        azure_deployment: str | None = None,
        api_version: str | None = None,
        api_key: str | None = None,
        azure_ad_token: str | None = None,
        organization: str | None = None,
        project: str | None = None,
        base_url: str | None = None,
        enable_streaming: bool = False,
        timeout: httpx.Timeout | None = None,
    ) -> "OpenAISTT":
        """
        Create a new instance of Azure OpenAI STT.

        This automatically infers the following arguments from their corresponding environment variables if they are not provided:
        - `api_key` from `AZURE_OPENAI_API_KEY`
        - `organization` from `OPENAI_ORG_ID`
        - `project` from `OPENAI_PROJECT_ID`
        - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
        - `api_version` from `OPENAI_API_VERSION`
        - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
        - `azure_deployment` from `AZURE_OPENAI_DEPLOYMENT` (if not provided, uses `model` as deployment name)
        """
        
        # Get values from environment variables if not provided
        azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
        azure_deployment = azure_deployment or os.getenv("AZURE_OPENAI_DEPLOYMENT")
        api_version = api_version or os.getenv("OPENAI_API_VERSION")
        api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
        azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN")
        organization = organization or os.getenv("OPENAI_ORG_ID")
        project = project or os.getenv("OPENAI_PROJECT_ID")
        
        # If azure_deployment is not provided, use model as the deployment name
        if not azure_deployment:
            azure_deployment = model
        
        if not azure_endpoint:
            raise ValueError("Azure endpoint must be provided either through azure_endpoint parameter or AZURE_OPENAI_ENDPOINT environment variable")
        
        if not api_key and not azure_ad_token:
            raise ValueError("Either API key or Azure AD token must be provided")
        
        azure_client = openai.AsyncAzureOpenAI(
            max_retries=0,
            azure_endpoint=azure_endpoint,
            azure_deployment=azure_deployment,
            api_version=api_version,
            api_key=api_key,
            azure_ad_token=azure_ad_token,
            organization=organization,
            project=project,
            base_url=base_url,
            timeout=timeout
            if timeout
            else httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
        )
        
        instance = OpenAISTT(
            model=model,
            language=language,
            prompt=prompt,
            turn_detection=turn_detection,
            enable_streaming=enable_streaming,
        )
        instance.client = azure_client
        return instance
        
    async def process_audio(
        self,
        audio_frames: bytes,
        language: Optional[str] = None,
        **kwargs: Any
    ) -> None:
        """Process audio frames and send to OpenAI based on enabled mode"""
        
        if not self.enable_streaming:
            await self._transcribe_non_streaming(audio_frames)
            return
        
        if not self._ws:
            await self._connect_ws()
            self._ws_task = asyncio.create_task(self._listen_for_responses())
            
        try:
            audio_data = np.frombuffer(audio_frames, dtype=np.int16)
            audio_data = signal.resample(audio_data, int(len(audio_data) * self.target_sample_rate / self.input_sample_rate))
            audio_data = audio_data.astype(np.int16).tobytes()
            audio_data = base64.b64encode(audio_data).decode("utf-8")
            message = {
                "type": "input_audio_buffer.append",
                "audio": audio_data,
            }
            await self._ws.send_json(message)
        except Exception as e:
            print(f"Error in process_audio: {str(e)}")
            self.emit("error", str(e))
            if self._ws:
                await self._ws.close()
                self._ws = None
                if self._ws_task:
                    self._ws_task.cancel()
                    self._ws_task = None

    async def flush(self) -> None:
        """Force OpenAI to finalize the current audio buffer immediately.
        """
        if not self.enable_streaming:
            if self._audio_buffer:
                await self._process_audio_buffer()
            return

        if not self._ws or self._ws.closed:
            return
        try:
            await self._ws.send_json({"type": "input_audio_buffer.commit"})
        except Exception as e:
            print(f"Error flushing OpenAI STT: {str(e)}")

    async def _transcribe_non_streaming(self, audio_frames: bytes) -> None:
        """HTTP-based transcription using OpenAI audio/transcriptions API with custom VAD"""
        if not audio_frames:
            return
            
        self._audio_buffer.extend(audio_frames)
        
        # Custom VAD logic similar to other STT implementations
        is_silent_chunk = self._is_silent(audio_frames)
        
        if not is_silent_chunk:
            if not self._is_speaking:
                self._is_speaking = True
                global_event_emitter.emit("speech_started")
            self._silence_frames = 0
        else:
            if self._is_speaking:
                self._silence_frames += len(audio_frames) // 4  # Approximate frame count
                if self._silence_frames > self.silence_duration_frames:
                    global_event_emitter.emit("speech_stopped")
                    await self._process_audio_buffer()
                    self._is_speaking = False
                    self._silence_frames = 0

    def _is_silent(self, audio_chunk: bytes) -> bool:
        """Simple VAD: check if the max amplitude is below a threshold."""
        audio_data = np.frombuffer(audio_chunk, dtype=np.int16)
        return np.max(np.abs(audio_data)) < self.silence_threshold_bytes



    async def _process_audio_buffer(self) -> None:
        """Process the accumulated audio buffer with OpenAI transcription"""
        if not self._audio_buffer:
            return
            
        audio_data = bytes(self._audio_buffer)
        self._audio_buffer.clear()
        
        wav_bytes = self._audio_frames_to_wav_bytes(audio_data)
        
        try:
            resp = await self.client.audio.transcriptions.create(
                file=("audio.wav", wav_bytes, "audio/wav"),
                model=self.model,
                language=self.language,
                prompt=self.prompt or openai.NOT_GIVEN,
            )
            text = getattr(resp, "text", "")
            if text and self._transcript_callback:
                await self._transcript_callback(STTResponse(
                    event_type=SpeechEventType.FINAL,
                    data=SpeechData(text=text, language=self.language),
                    metadata={"model": self.model}
                ))
        except Exception as e:
            print(f"OpenAI transcription error: {str(e)}")
            self.emit("error", str(e))

    def _audio_frames_to_wav_bytes(self, audio_frames: bytes) -> bytes:
        """Convert audio frames to WAV bytes"""
        pcm = np.frombuffer(audio_frames, dtype=np.int16)
        resampled = signal.resample(pcm, int(len(pcm) * self.target_sample_rate / self.input_sample_rate))
        resampled = resampled.astype(np.int16)
        
        buf = io.BytesIO()
        with wave.open(buf, "wb") as wf:
            wf.setnchannels(1)  # Mono
            wf.setsampwidth(2)  # 16-bit PCM
            wf.setframerate(self.target_sample_rate)
            wf.writeframes(resampled.tobytes())
        
        return buf.getvalue()

    async def _listen_for_responses(self) -> None:
        """Background task to listen for WebSocket responses"""
        if not self._ws:
            return
            
        try:
            async for msg in self._ws:
                if msg.type == aiohttp.WSMsgType.TEXT:
                    data = msg.json()
                    responses = self._handle_ws_message(data)
                    for response in responses:
                        if self._transcript_callback:
                            await self._transcript_callback(response)
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    error = f"WebSocket error: {self._ws.exception()}"
                    print(error)
                    self.emit("error", error)
                    break
                elif msg.type == aiohttp.WSMsgType.CLOSED:
                    print("WebSocket connection closed")
                    break
        except Exception as e:
            error = f"Error in WebSocket listener: {str(e)}"
            print(error)
            self.emit("error", error)
        finally:
            if self._ws:
                await self._ws.close()
                self._ws = None
                
    async def _connect_ws(self) -> None:
        """Establish WebSocket connection with OpenAI's Realtime API"""
        
        if not self._session:
            self._session = aiohttp.ClientSession()
            
        config = {
            "type": "transcription_session.update",
            "session": {
                "input_audio_format": "pcm16",
                "input_audio_transcription": {
                    "model": self.model,
                    "prompt": self.prompt or "",
                    "language": self.language if self.language else None,
                },
                "turn_detection": self.turn_detection,
                "input_audio_noise_reduction": {
                    "type": "near_field"
                },
                "include": ["item.input_audio_transcription.logprobs"]
            }
        }
        
        query_params = {
            "intent": "transcription",
        }
        headers = {
            "User-Agent": "VideoSDK",
            "Authorization": f"Bearer {self.api_key}",
            "OpenAI-Beta": "realtime=v1",
        }
        
        base_url = str(self.client.base_url).rstrip('/')
        ws_url = f"{base_url}/realtime?{urlencode(query_params)}"
        if ws_url.startswith("http"):
            ws_url = ws_url.replace("http", "ws", 1)

        try:
            self._ws = await self._session.ws_connect(ws_url, headers=headers)
            
            initial_response = await self._ws.receive_json()
            
            if initial_response.get("type") != "transcription_session.created":
                raise Exception(f"Expected session creation, got: {initial_response}")
            
            await self._ws.send_json(config)
            
            update_response = await self._ws.receive_json()
            
            if update_response.get("type") != "transcription_session.updated":
                raise Exception(f"Configuration update failed: {update_response}")
            
        except Exception as e:
            print(f"Error connecting to WebSocket: {str(e)}")
            if self._ws:
                await self._ws.close()
                self._ws = None
            raise
        
    def _handle_ws_message(self, msg: dict) -> list[STTResponse]:
        """Handle incoming WebSocket messages and generate STT responses"""
        responses = []
        
        try:
            msg_type = msg.get("type")
            if msg_type == "conversation.item.input_audio_transcription.delta":
                delta = msg.get("delta", "")
                if delta:
                    self._current_text += delta
                    current_time = asyncio.get_event_loop().time()
                    
                    if current_time - self._last_interim_at > 0.5:
                        responses.append(STTResponse(
                            event_type=SpeechEventType.INTERIM,
                            data=SpeechData(
                                text=self._current_text,
                                language=self.language,
                            ),
                            metadata={"model": self.model}
                        ))
                        self._last_interim_at = current_time
                        
            elif msg_type == "conversation.item.input_audio_transcription.completed":
                transcript = msg.get("transcript", "")
                metrics_data = self.extract_tokens_and_avg_probability(msg)
                if transcript:
                    responses.append(STTResponse(
                        event_type=SpeechEventType.FINAL,
                        data=SpeechData(
                            text=transcript,
                            language=self.language,
                            confidence=metrics_data.get("confidence", 1.0)
                        ),
                        metadata={"model": self.model, "metrics": metrics_data}
                    ))
                    self._current_text = ""
            
            elif msg_type == "input_audio_buffer.speech_started":
                global_event_emitter.emit("speech_started")
            
            elif msg_type == "input_audio_buffer.speech_stopped":
                global_event_emitter.emit("speech_stopped")
                
        except Exception as e:
            print(f"Error handling WebSocket message: {str(e)}")
        
        return responses

    def extract_tokens_and_avg_probability(self, event: Dict) -> Tuple[int, float]:
        """
        Extracts:
        - Input tokens
        - Output tokens
        - Total tokens
        - Average token probability from logprobs

        Args:
            event (dict): OpenAI transcription completed payload

        Returns:
            dict with token usage + avg probability
        """
        usage = event.get("usage", {})

        input_tokens = usage.get("input_tokens", 0)
        output_tokens = usage.get("output_tokens", 0)
        total_tokens = usage.get("total_tokens", 0)

        logprobs = event.get("logprobs", [])

        total_prob = 0.0
        count = 0

        for token_info in logprobs:
            lp = token_info.get("logprob")
            if lp is None:
                continue

            total_prob += math.exp(lp)
            count += 1

        avg_probability = total_prob / count if count > 0 else 0.0

        return {
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "total_tokens": total_tokens,
            "confidence": avg_probability,
        }


    async def aclose(self) -> None:
        """Cleanup resources"""
        self._audio_buffer.clear()
        
        if self._ws_task:
            self._ws_task.cancel()
            try:
                await self._ws_task
            except asyncio.CancelledError:
                pass
            self._ws_task = None
            
        if self._ws:
            await self._ws.close()
            self._ws = None
            
        if self._session:
            await self._session.close()
            self._session = None
            
        await self.client.close()
        await super().aclose()

    async def _ensure_ws_connection(self):
        """Ensure WebSocket is connected, reconnect if necessary"""
        if not self._ws or self._ws.closed:
            await self._connect_ws()

Base class for Speech-to-Text implementations

Initialize the OpenAI STT plugin.

Args

api_key : Optional[str], optional
OpenAI API key. Defaults to None.
model : str
The model to use for the STT plugin. Defaults to "whisper-1".
base_url : Optional[str], optional
The base URL for the OpenAI API. Defaults to None.
prompt : Optional[str], optional
The prompt for the STT plugin. Defaults to None.
language : str
The language to use for the STT plugin. Defaults to "en".
turn_detection : dict | None
The turn detection for the STT plugin. Defaults to None.

Ancestors

  • videosdk.agents.stt.stt.STT
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Static methods

def azure(*,
model: str = 'gpt-4o-mini-transcribe',
language: str = 'en',
prompt: str | None = None,
turn_detection: dict | None = None,
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
enable_streaming: bool = False,
timeout: httpx.Timeout | None = None) ‑> OpenAISTT
Expand source code
@staticmethod
def azure(
    *,
    model: str = "gpt-4o-mini-transcribe",
    language: str = "en",
    prompt: str | None = None,
    turn_detection: dict | None = None,
    azure_endpoint: str | None = None,
    azure_deployment: str | None = None,
    api_version: str | None = None,
    api_key: str | None = None,
    azure_ad_token: str | None = None,
    organization: str | None = None,
    project: str | None = None,
    base_url: str | None = None,
    enable_streaming: bool = False,
    timeout: httpx.Timeout | None = None,
) -> "OpenAISTT":
    """
    Create a new instance of Azure OpenAI STT.

    This automatically infers the following arguments from their corresponding environment variables if they are not provided:
    - `api_key` from `AZURE_OPENAI_API_KEY`
    - `organization` from `OPENAI_ORG_ID`
    - `project` from `OPENAI_PROJECT_ID`
    - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
    - `api_version` from `OPENAI_API_VERSION`
    - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
    - `azure_deployment` from `AZURE_OPENAI_DEPLOYMENT` (if not provided, uses `model` as deployment name)
    """
    
    # Get values from environment variables if not provided
    azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
    azure_deployment = azure_deployment or os.getenv("AZURE_OPENAI_DEPLOYMENT")
    api_version = api_version or os.getenv("OPENAI_API_VERSION")
    api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
    azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN")
    organization = organization or os.getenv("OPENAI_ORG_ID")
    project = project or os.getenv("OPENAI_PROJECT_ID")
    
    # If azure_deployment is not provided, use model as the deployment name
    if not azure_deployment:
        azure_deployment = model
    
    if not azure_endpoint:
        raise ValueError("Azure endpoint must be provided either through azure_endpoint parameter or AZURE_OPENAI_ENDPOINT environment variable")
    
    if not api_key and not azure_ad_token:
        raise ValueError("Either API key or Azure AD token must be provided")
    
    azure_client = openai.AsyncAzureOpenAI(
        max_retries=0,
        azure_endpoint=azure_endpoint,
        azure_deployment=azure_deployment,
        api_version=api_version,
        api_key=api_key,
        azure_ad_token=azure_ad_token,
        organization=organization,
        project=project,
        base_url=base_url,
        timeout=timeout
        if timeout
        else httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
    )
    
    instance = OpenAISTT(
        model=model,
        language=language,
        prompt=prompt,
        turn_detection=turn_detection,
        enable_streaming=enable_streaming,
    )
    instance.client = azure_client
    return instance

Create a new instance of Azure OpenAI STT.

This automatically infers the following arguments from their corresponding environment variables if they are not provided: - api_key from AZURE_OPENAI_API_KEY - organization from OPENAI_ORG_ID - project from OPENAI_PROJECT_ID - azure_ad_token from AZURE_OPENAI_AD_TOKEN - api_version from OPENAI_API_VERSION - azure_endpoint from AZURE_OPENAI_ENDPOINT - azure_deployment from AZURE_OPENAI_DEPLOYMENT (if not provided, uses model as deployment name)

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources"""
    self._audio_buffer.clear()
    
    if self._ws_task:
        self._ws_task.cancel()
        try:
            await self._ws_task
        except asyncio.CancelledError:
            pass
        self._ws_task = None
        
    if self._ws:
        await self._ws.close()
        self._ws = None
        
    if self._session:
        await self._session.close()
        self._session = None
        
    await self.client.close()
    await super().aclose()

Cleanup resources

def extract_tokens_and_avg_probability(self, event: Dict) ‑> Tuple[int, float]
Expand source code
def extract_tokens_and_avg_probability(self, event: Dict) -> Tuple[int, float]:
    """
    Extracts:
    - Input tokens
    - Output tokens
    - Total tokens
    - Average token probability from logprobs

    Args:
        event (dict): OpenAI transcription completed payload

    Returns:
        dict with token usage + avg probability
    """
    usage = event.get("usage", {})

    input_tokens = usage.get("input_tokens", 0)
    output_tokens = usage.get("output_tokens", 0)
    total_tokens = usage.get("total_tokens", 0)

    logprobs = event.get("logprobs", [])

    total_prob = 0.0
    count = 0

    for token_info in logprobs:
        lp = token_info.get("logprob")
        if lp is None:
            continue

        total_prob += math.exp(lp)
        count += 1

    avg_probability = total_prob / count if count > 0 else 0.0

    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "total_tokens": total_tokens,
        "confidence": avg_probability,
    }

Extracts: - Input tokens - Output tokens - Total tokens - Average token probability from logprobs

Args

event : dict
OpenAI transcription completed payload

Returns

dict with token usage + avg probability

async def flush(self) ‑> None
Expand source code
async def flush(self) -> None:
    """Force OpenAI to finalize the current audio buffer immediately.
    """
    if not self.enable_streaming:
        if self._audio_buffer:
            await self._process_audio_buffer()
        return

    if not self._ws or self._ws.closed:
        return
    try:
        await self._ws.send_json({"type": "input_audio_buffer.commit"})
    except Exception as e:
        print(f"Error flushing OpenAI STT: {str(e)}")

Force OpenAI to finalize the current audio buffer immediately.

async def process_audio(self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any) ‑> None
Expand source code
async def process_audio(
    self,
    audio_frames: bytes,
    language: Optional[str] = None,
    **kwargs: Any
) -> None:
    """Process audio frames and send to OpenAI based on enabled mode"""
    
    if not self.enable_streaming:
        await self._transcribe_non_streaming(audio_frames)
        return
    
    if not self._ws:
        await self._connect_ws()
        self._ws_task = asyncio.create_task(self._listen_for_responses())
        
    try:
        audio_data = np.frombuffer(audio_frames, dtype=np.int16)
        audio_data = signal.resample(audio_data, int(len(audio_data) * self.target_sample_rate / self.input_sample_rate))
        audio_data = audio_data.astype(np.int16).tobytes()
        audio_data = base64.b64encode(audio_data).decode("utf-8")
        message = {
            "type": "input_audio_buffer.append",
            "audio": audio_data,
        }
        await self._ws.send_json(message)
    except Exception as e:
        print(f"Error in process_audio: {str(e)}")
        self.emit("error", str(e))
        if self._ws:
            await self._ws.close()
            self._ws = None
            if self._ws_task:
                self._ws_task.cancel()
                self._ws_task = None

Process audio frames and send to OpenAI based on enabled mode

class OpenAITTS (*,
api_key: str | None = None,
model: str = 'gpt-4o-mini-tts',
voice: str | dict[str, str] = 'ash',
speed: float = 1.0,
instructions: str | None = None,
language: str | None = None,
base_url: str | None = None,
response_format: str = 'pcm',
chunked_synthesis: bool = False)
Expand source code
class OpenAITTS(TTS):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = DEFAULT_MODEL,
        voice: str | dict[str, str] = DEFAULT_VOICE,
        speed: float = 1.0,
        instructions: str | None = None,
        language: str | None = None,
        base_url: str | None = None,
        response_format: str = "pcm",
        chunked_synthesis: bool = False,
    ) -> None:
        """Initialize the OpenAI TTS plugin.

        Args:
            api_key (Optional[str], optional): OpenAI API key. Defaults to None.
            model (str): The model to use for the TTS plugin. Defaults to "gpt-4o-mini-tts".
                Built-in options: "gpt-4o-mini-tts" (recommended, supports instructions),
                "tts-1" (low latency), "tts-1-hd" (higher quality).
            voice (str | dict): Built-in voice name (e.g. "marin", "cedar", "ash", "coral")
                or a custom voice reference dict {"id": "voice_xxx"}. Defaults to "ash".
                For best quality with gpt-4o-mini-tts, use "marin" or "cedar".
            speed (float): The speed to use for the TTS plugin. Defaults to 1.0.
            instructions (Optional[str], optional): Natural-language style control
                ("Speak in a cheerful tone", accent hints, etc.). Only honored by
                gpt-4o-mini-tts; ignored by tts-1 / tts-1-hd. Defaults to None.
            language (Optional[str], optional): ISO language hint (e.g. "hi", "mr", "fr").
                Useful for non-English input or with custom voices. Defaults to None.
            base_url (Optional[str], optional): Custom base URL for the OpenAI API. Defaults to None.
            response_format (str): The response format to use for the TTS plugin. Defaults to "pcm".
            chunked_synthesis (bool): When ``True``, dispatch one POST per ``FlushMarker``
                boundary received from the upstream pipeline. When ``False`` (default),
                the entire LLM stream is accumulated into a single POST — better for
                prosody continuity and request economics. Set ``True`` only for very
                long utterances (>30s) where sub-sentence TTFB matters more than
                cross-sentence prosody. Defaults to False.
        """
        super().__init__(sample_rate=OPENAI_TTS_SAMPLE_RATE, num_channels=OPENAI_TTS_CHANNELS)

        self.model = model
        self.voice = voice
        self.speed = speed
        self.instructions = instructions
        self.language = language
        self.audio_track = None
        self.loop = None
        self.response_format = response_format
        self.chunked_synthesis = chunked_synthesis
        self._first_chunk_sent = False
        self._current_synthesis_task: asyncio.Task | None = None
        self._interrupted = False

        self.api_key = api_key or os.getenv("OPENAI_API_KEY")
        if not self.api_key:
            raise ValueError(
                "OpenAI API key must be provided either through api_key parameter or OPENAI_API_KEY environment variable")

        self._client = openai.AsyncClient(
            max_retries=0,
            api_key=self.api_key,
            base_url=base_url or None,
            http_client=httpx.AsyncClient(
                timeout=httpx.Timeout(
                    connect=15.0, read=5.0, write=5.0, pool=5.0),
                follow_redirects=True,
                limits=httpx.Limits(
                    max_connections=50,
                    max_keepalive_connections=50,
                    keepalive_expiry=120,
                ),
            ),
        )

    @staticmethod
    def azure(
        *,
        model: str = DEFAULT_MODEL,
        voice: str | dict[str, str] = DEFAULT_VOICE,
        speed: float = 1.0,
        instructions: str | None = None,
        language: str | None = None,
        azure_endpoint: str | None = None,
        azure_deployment: str | None = None,
        api_version: str | None = None,
        api_key: str | None = None,
        azure_ad_token: str | None = None,
        organization: str | None = None,
        project: str | None = None,
        base_url: str | None = None,
        response_format: str = "pcm",
        chunked_synthesis: bool = False,
        timeout: httpx.Timeout | None = None,
    ) -> "OpenAITTS":
        """
        Create a new instance of Azure OpenAI TTS.

        This automatically infers the following arguments from their corresponding environment variables if they are not provided:
        - `api_key` from `AZURE_OPENAI_API_KEY`
        - `organization` from `OPENAI_ORG_ID`
        - `project` from `OPENAI_PROJECT_ID`
        - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
        - `api_version` from `OPENAI_API_VERSION`
        - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
        - `azure_deployment` from `AZURE_OPENAI_DEPLOYMENT` (if not provided, uses `model` as deployment name)
        """
        
        azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
        azure_deployment = azure_deployment or os.getenv("AZURE_OPENAI_DEPLOYMENT")
        api_version = api_version or os.getenv("OPENAI_API_VERSION")
        api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
        azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN")
        organization = organization or os.getenv("OPENAI_ORG_ID")
        project = project or os.getenv("OPENAI_PROJECT_ID")
        
        if not azure_deployment:
            azure_deployment = model
        
        if not azure_endpoint:
            raise ValueError("Azure endpoint must be provided either through azure_endpoint parameter or AZURE_OPENAI_ENDPOINT environment variable")
        
        if not api_key and not azure_ad_token:
            raise ValueError("Either API key or Azure AD token must be provided")
        
        azure_client = openai.AsyncAzureOpenAI(
            max_retries=0,
            azure_endpoint=azure_endpoint,
            azure_deployment=azure_deployment,
            api_version=api_version,
            api_key=api_key,
            azure_ad_token=azure_ad_token,
            organization=organization,
            project=project,
            base_url=base_url,
            timeout=timeout
            if timeout
            else httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
        )
        
        instance = OpenAITTS(
            model=model,
            voice=voice,
            speed=speed,
            instructions=instructions,
            language=language,
            response_format=response_format,
            chunked_synthesis=chunked_synthesis,
        )
        instance._client = azure_client
        return instance

    def reset_first_audio_tracking(self) -> None:
        """Reset the first audio tracking state for next TTS task"""
        self._first_chunk_sent = False

    async def synthesize(
        self,
        text: AsyncIterator[Union[str, FlushMarker]] | str,
        voice_id: Optional[str | dict[str, str]] = None,
        **kwargs: Any,
    ) -> None:
        """
        Convert text to speech using OpenAI's TTS API and stream to audio track.

        Args:
            text: Text to convert to speech, or async iterator yielding ``str``
                chunks and ``FlushMarker`` segment boundaries.
            voice_id: Optional voice override
            **kwargs: Additional provider-specific arguments
        """
        try:
            if not self.audio_track or not self.loop:
                self.emit("error", "Audio track or event loop not set")
                raise RuntimeError("Audio track or event loop not set")

            self._interrupted = False

            if isinstance(text, str):
                if not self._interrupted:
                    await self._synthesize_segment(text, voice_id, **kwargs)
                return

            if self.chunked_synthesis:
                buf: list[str] = []
                async for chunk in text:
                    if self._interrupted:
                        break
                    if isinstance(chunk, FlushMarker):
                        if buf:
                            combined = "".join(buf)
                            buf = []
                            if combined.strip():
                                await self._synthesize_segment(combined, voice_id, **kwargs)
                        continue
                    if chunk and chunk.strip():
                        buf.append(chunk)
                if buf and not self._interrupted:
                    tail = "".join(buf)
                    if tail.strip():
                        await self._synthesize_segment(tail, voice_id, **kwargs)
                return

            parts: list[str] = []
            async for chunk in text:
                if self._interrupted:
                    break
                if isinstance(chunk, FlushMarker):
                    continue
                if chunk and chunk.strip():
                    parts.append(chunk)
            if parts and not self._interrupted:
                combined_text = "".join(parts)
                if combined_text.strip():
                    await self._synthesize_segment(combined_text, voice_id, **kwargs)

        except Exception as e:
            self.emit("error", f"TTS synthesis failed: {str(e)}")
            raise

    async def _synthesize_segment(
        self,
        text: str,
        voice_id: Optional[str | dict[str, str]] = None,
        **kwargs: Any,
    ) -> None:
        """Synthesize a single text segment.

        Streams audio frames to the audio track as they arrive from OpenAI's
        chunked HTTP response. Maintains a leftover buffer between iterations
        so partial bytes don't get silence-padded mid-stream — padding only
        applies to the final frame at end-of-response.
        """
        if not text.strip() or self._interrupted:
            return

        # 20ms frame @ 24kHz, 16-bit, mono = 960 bytes
        frame_size = int(
            OPENAI_TTS_SAMPLE_RATE * OPENAI_TTS_CHANNELS * 2 * 20 / 1000
        )
        leftover = bytearray()

        try:
            async with self._client.audio.speech.with_streaming_response.create(
                model=self.model,
                voice=voice_id or self.voice,
                input=text,
                speed=self.speed,
                response_format=self.response_format,
                **({"instructions": self.instructions} if self.instructions else {}),
                **({"extra_body": {"language": self.language}} if self.language else {}),
            ) as response:
                async for chunk in response.iter_bytes():
                    if self._interrupted:
                        break
                    if not chunk:
                        continue
                    leftover.extend(chunk)

                    # Emit complete 20ms frames as soon as they're available.
                    while len(leftover) >= frame_size and not self._interrupted:
                        frame = bytes(leftover[:frame_size])
                        del leftover[:frame_size]

                        if not self._first_chunk_sent and self._first_audio_callback:
                            self._first_chunk_sent = True
                            await self._first_audio_callback()

                        asyncio.create_task(self.audio_track.add_new_bytes(frame))
                        await asyncio.sleep(0.001)

            # End of stream: zero-pad the final partial frame and emit.
            if leftover and not self._interrupted:
                frame = bytes(leftover) + b"\x00" * (frame_size - len(leftover))
                if not self._first_chunk_sent and self._first_audio_callback:
                    self._first_chunk_sent = True
                    await self._first_audio_callback()
                asyncio.create_task(self.audio_track.add_new_bytes(frame))

        except Exception as e:
            if not self._interrupted:
                self.emit("error", f"Segment synthesis failed: {str(e)}")
                raise

    async def _stream_audio_chunks(self, audio_bytes: bytes) -> None:
        """Stream audio data in chunks for smooth playback"""
        chunk_size = int(OPENAI_TTS_SAMPLE_RATE *
                         OPENAI_TTS_CHANNELS * 2 * 20 / 1000)

        for i in range(0, len(audio_bytes), chunk_size):
            chunk = audio_bytes[i:i + chunk_size]

            if len(chunk) < chunk_size and len(chunk) > 0:
                padding_needed = chunk_size - len(chunk)
                chunk += b'\x00' * padding_needed

            if len(chunk) == chunk_size:
                if not self._first_chunk_sent and self._first_audio_callback:
                    self._first_chunk_sent = True
                    await self._first_audio_callback()

                asyncio.create_task(self.audio_track.add_new_bytes(chunk))
                await asyncio.sleep(0.001)

    async def aclose(self) -> None:
        """Cleanup resources"""
        await self._client.close()
        await super().aclose()

    async def interrupt(self) -> None:
        """Interrupt TTS synthesis"""
        self._interrupted = True
        if self._current_synthesis_task:
            self._current_synthesis_task.cancel()
        if self.audio_track:
            self.audio_track.interrupt()

Base class for Text-to-Speech implementations

Initialize the OpenAI TTS plugin.

Args

api_key : Optional[str], optional
OpenAI API key. Defaults to None.
model : str
The model to use for the TTS plugin. Defaults to "gpt-4o-mini-tts". Built-in options: "gpt-4o-mini-tts" (recommended, supports instructions), "tts-1" (low latency), "tts-1-hd" (higher quality).
voice : str | dict
Built-in voice name (e.g. "marin", "cedar", "ash", "coral") or a custom voice reference dict {"id": "voice_xxx"}. Defaults to "ash". For best quality with gpt-4o-mini-tts, use "marin" or "cedar".
speed : float
The speed to use for the TTS plugin. Defaults to 1.0.
instructions : Optional[str], optional
Natural-language style control ("Speak in a cheerful tone", accent hints, etc.). Only honored by gpt-4o-mini-tts; ignored by tts-1 / tts-1-hd. Defaults to None.
language : Optional[str], optional
ISO language hint (e.g. "hi", "mr", "fr"). Useful for non-English input or with custom voices. Defaults to None.
base_url : Optional[str], optional
Custom base URL for the OpenAI API. Defaults to None.
response_format : str
The response format to use for the TTS plugin. Defaults to "pcm".
chunked_synthesis : bool
When True, dispatch one POST per FlushMarker boundary received from the upstream pipeline. When False (default), the entire LLM stream is accumulated into a single POST — better for prosody continuity and request economics. Set True only for very long utterances (>30s) where sub-sentence TTFB matters more than cross-sentence prosody. Defaults to False.

Ancestors

  • videosdk.agents.tts.tts.TTS
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Static methods

def azure(*,
model: str = 'gpt-4o-mini-tts',
voice: str | dict[str, str] = 'ash',
speed: float = 1.0,
instructions: str | None = None,
language: str | None = None,
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
response_format: str = 'pcm',
chunked_synthesis: bool = False,
timeout: httpx.Timeout | None = None) ‑> OpenAITTS
Expand source code
@staticmethod
def azure(
    *,
    model: str = DEFAULT_MODEL,
    voice: str | dict[str, str] = DEFAULT_VOICE,
    speed: float = 1.0,
    instructions: str | None = None,
    language: str | None = None,
    azure_endpoint: str | None = None,
    azure_deployment: str | None = None,
    api_version: str | None = None,
    api_key: str | None = None,
    azure_ad_token: str | None = None,
    organization: str | None = None,
    project: str | None = None,
    base_url: str | None = None,
    response_format: str = "pcm",
    chunked_synthesis: bool = False,
    timeout: httpx.Timeout | None = None,
) -> "OpenAITTS":
    """
    Create a new instance of Azure OpenAI TTS.

    This automatically infers the following arguments from their corresponding environment variables if they are not provided:
    - `api_key` from `AZURE_OPENAI_API_KEY`
    - `organization` from `OPENAI_ORG_ID`
    - `project` from `OPENAI_PROJECT_ID`
    - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
    - `api_version` from `OPENAI_API_VERSION`
    - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
    - `azure_deployment` from `AZURE_OPENAI_DEPLOYMENT` (if not provided, uses `model` as deployment name)
    """
    
    azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
    azure_deployment = azure_deployment or os.getenv("AZURE_OPENAI_DEPLOYMENT")
    api_version = api_version or os.getenv("OPENAI_API_VERSION")
    api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
    azure_ad_token = azure_ad_token or os.getenv("AZURE_OPENAI_AD_TOKEN")
    organization = organization or os.getenv("OPENAI_ORG_ID")
    project = project or os.getenv("OPENAI_PROJECT_ID")
    
    if not azure_deployment:
        azure_deployment = model
    
    if not azure_endpoint:
        raise ValueError("Azure endpoint must be provided either through azure_endpoint parameter or AZURE_OPENAI_ENDPOINT environment variable")
    
    if not api_key and not azure_ad_token:
        raise ValueError("Either API key or Azure AD token must be provided")
    
    azure_client = openai.AsyncAzureOpenAI(
        max_retries=0,
        azure_endpoint=azure_endpoint,
        azure_deployment=azure_deployment,
        api_version=api_version,
        api_key=api_key,
        azure_ad_token=azure_ad_token,
        organization=organization,
        project=project,
        base_url=base_url,
        timeout=timeout
        if timeout
        else httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
    )
    
    instance = OpenAITTS(
        model=model,
        voice=voice,
        speed=speed,
        instructions=instructions,
        language=language,
        response_format=response_format,
        chunked_synthesis=chunked_synthesis,
    )
    instance._client = azure_client
    return instance

Create a new instance of Azure OpenAI TTS.

This automatically infers the following arguments from their corresponding environment variables if they are not provided: - api_key from AZURE_OPENAI_API_KEY - organization from OPENAI_ORG_ID - project from OPENAI_PROJECT_ID - azure_ad_token from AZURE_OPENAI_AD_TOKEN - api_version from OPENAI_API_VERSION - azure_endpoint from AZURE_OPENAI_ENDPOINT - azure_deployment from AZURE_OPENAI_DEPLOYMENT (if not provided, uses model as deployment name)

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources"""
    await self._client.close()
    await super().aclose()

Cleanup resources

async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    """Interrupt TTS synthesis"""
    self._interrupted = True
    if self._current_synthesis_task:
        self._current_synthesis_task.cancel()
    if self.audio_track:
        self.audio_track.interrupt()

Interrupt TTS synthesis

def reset_first_audio_tracking(self) ‑> None
Expand source code
def reset_first_audio_tracking(self) -> None:
    """Reset the first audio tracking state for next TTS task"""
    self._first_chunk_sent = False

Reset the first audio tracking state for next TTS task

async def synthesize(self,
text: AsyncIterator[Union[str, FlushMarker]] | str,
voice_id: Optional[str | dict[str, str]] = None,
**kwargs: Any) ‑> None
Expand source code
async def synthesize(
    self,
    text: AsyncIterator[Union[str, FlushMarker]] | str,
    voice_id: Optional[str | dict[str, str]] = None,
    **kwargs: Any,
) -> None:
    """
    Convert text to speech using OpenAI's TTS API and stream to audio track.

    Args:
        text: Text to convert to speech, or async iterator yielding ``str``
            chunks and ``FlushMarker`` segment boundaries.
        voice_id: Optional voice override
        **kwargs: Additional provider-specific arguments
    """
    try:
        if not self.audio_track or not self.loop:
            self.emit("error", "Audio track or event loop not set")
            raise RuntimeError("Audio track or event loop not set")

        self._interrupted = False

        if isinstance(text, str):
            if not self._interrupted:
                await self._synthesize_segment(text, voice_id, **kwargs)
            return

        if self.chunked_synthesis:
            buf: list[str] = []
            async for chunk in text:
                if self._interrupted:
                    break
                if isinstance(chunk, FlushMarker):
                    if buf:
                        combined = "".join(buf)
                        buf = []
                        if combined.strip():
                            await self._synthesize_segment(combined, voice_id, **kwargs)
                    continue
                if chunk and chunk.strip():
                    buf.append(chunk)
            if buf and not self._interrupted:
                tail = "".join(buf)
                if tail.strip():
                    await self._synthesize_segment(tail, voice_id, **kwargs)
            return

        parts: list[str] = []
        async for chunk in text:
            if self._interrupted:
                break
            if isinstance(chunk, FlushMarker):
                continue
            if chunk and chunk.strip():
                parts.append(chunk)
        if parts and not self._interrupted:
            combined_text = "".join(parts)
            if combined_text.strip():
                await self._synthesize_segment(combined_text, voice_id, **kwargs)

    except Exception as e:
        self.emit("error", f"TTS synthesis failed: {str(e)}")
        raise

Convert text to speech using OpenAI's TTS API and stream to audio track.

Args

text
Text to convert to speech, or async iterator yielding str chunks and FlushMarker segment boundaries.
voice_id
Optional voice override
**kwargs
Additional provider-specific arguments