Package videosdk.plugins.xai

Sub-modules

videosdk.plugins.xai.llm
videosdk.plugins.xai.stt
videosdk.plugins.xai.tts
videosdk.plugins.xai.xai_realtime

Classes

class XAILLM (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
base_url: str = 'https://api.x.ai/v1',
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None,
tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None)
Expand source code
class XAILLM(LLM):
    """
    LLM Plugin for xAI (Grok) API.
    Supports Grok-4, and reasoning models with standard client-side function calling.
    """
    
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = "grok-4-1-fast-non-reasoning", 
        base_url: str = "https://api.x.ai/v1",
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_completion_tokens: int | None = None,
        tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None,
    ) -> None:
        """Initialize the xAI LLM plugin.

        Args:
            api_key (Optional[str], optional): xAI API key. Defaults to XAI_API_KEY env var.
            model (str): The model to use (e.g., "grok-4", "grok-4-1-fast").
            base_url (str): The base URL for the xAI API. Defaults to "https://api.x.ai/v1".
            temperature (float): The temperature to use. Defaults to 0.7.
            tool_choice (ToolChoice): The tool choice to use. Defaults to "auto".
            max_completion_tokens (Optional[int], optional): Max tokens.
            tools (Optional[List], optional): List of FunctionTools to be available to the LLM.
        """
        super().__init__()
        self.api_key = api_key or os.getenv("XAI_API_KEY")
        if not self.api_key:
            raise ValueError("xAI API key must be provided either through api_key parameter or XAI_API_KEY environment variable")
        
        self.model = model
        self.temperature = temperature
        self.tool_choice = tool_choice
        self.max_completion_tokens = max_completion_tokens
        self.tools = tools or []
        self._cancelled = False
        
        self._client = openai.AsyncOpenAI(
            api_key=self.api_key,
            base_url=base_url,
            max_retries=0,
            http_client=httpx.AsyncClient(
                timeout=httpx.Timeout(connect=15.0, read=60.0, write=5.0, pool=5.0),
                follow_redirects=True,
                limits=httpx.Limits(
                    max_connections=50,
                    max_keepalive_connections=50,
                    keepalive_expiry=120,
                ),
            ),
        )

    async def chat(
        self,
        messages: ChatContext,
        tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
        conversational_graph: Any | None = None,
        **kwargs: Any
    ) -> AsyncIterator[LLMResponse]:
        """
        Implement chat functionality using xAI's API via OpenAI SDK compatibility.
        """
        self._cancelled = False

        openai_messages = messages.to_openai_messages()

        completion_params = {
            "model": self.model,
            "messages": openai_messages,
            "temperature": self.temperature,
            "stream": True,
            "max_tokens": self.max_completion_tokens,
        }
        
        if conversational_graph:
            completion_params["response_format"] = {
                "type": "json_schema",
                "json_schema": {
                    "name": "conversational_graph_response",
                    "strict": True,
                    "schema": conversational_graph._get_graph_schema()
                }
            }

        combined_tools = (self.tools or []) + (tools or [])
        
        if combined_tools:
            formatted_tools = []
            for tool in combined_tools:
                if is_function_tool(tool):
                    try:
                        tool_schema = build_openai_schema(tool)
                        if "function" not in tool_schema:
                            inner_tool = {k: v for k, v in tool_schema.items() if k != "type"}
                            formatted_tools.append({
                                "type": "function",
                                "function": inner_tool
                            })
                        else:
                            formatted_tools.append(tool_schema)
                    except Exception as e:
                        self.emit("error", f"Failed to format tool {tool}: {e}")
                        continue
                elif isinstance(tool, dict):
                    formatted_tools.append(tool)
            
            if formatted_tools:
                completion_params["tools"] = formatted_tools
                completion_params["tool_choice"] = self.tool_choice

        completion_params.update(kwargs)

        try:
            response_stream = await self._client.chat.completions.create(**completion_params)
            
            current_content = ""
            current_tool_calls = {} 
            streaming_state = {
                "in_response": False,
                "response_start_index": -1,
                "yielded_content_length": 0
            }

            async for chunk in response_stream:
                if self._cancelled:
                    break
                
                if not chunk.choices:
                    continue
                    
                delta = chunk.choices[0].delta
                
                if delta.tool_calls:
                    for tool_call in delta.tool_calls:
                        idx = tool_call.index
                        if idx not in current_tool_calls:
                            current_tool_calls[idx] = {
                                "id": tool_call.id or "",
                                "name": tool_call.function.name or "",
                                "arguments": tool_call.function.arguments or ""
                            }
                        else:
                            if tool_call.function.name:
                                current_tool_calls[idx]["name"] += tool_call.function.name
                            if tool_call.function.arguments:
                                current_tool_calls[idx]["arguments"] += tool_call.function.arguments

                if delta.content is not None:
                    current_content += delta.content   
                    if conversational_graph:                     
                        for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state):
                            yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT)
                    else:
                        yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT)

            if current_tool_calls and not self._cancelled:
                for idx in sorted(current_tool_calls.keys()):
                    tool_data = current_tool_calls[idx]
                    try:
                        args_str = tool_data["arguments"]
                        parsed_args = json.loads(args_str) 
                        
                        yield LLMResponse(
                            content="",
                            role=ChatRole.ASSISTANT,
                            metadata={
                                "function_call": {
                                    "name": tool_data["name"],
                                    "arguments": parsed_args
                                },
                                "tool_call_id": tool_data["id"]
                            }
                        )
                    except json.JSONDecodeError:
                        self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}")

            if current_content and not self._cancelled and conversational_graph:
                try:
                    parsed_json = json.loads(current_content.strip())
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata=parsed_json
                    )
                except json.JSONDecodeError:
                     pass

        except Exception as e:
            if not self._cancelled:
                self.emit("error", e)
            raise

    async def cancel_current_generation(self) -> None:
        self._cancelled = True

    async def aclose(self) -> None:
        """Cleanup resources"""
        await self.cancel_current_generation()
        if self._client:
            await self._client.close()
        await super().aclose()

LLM Plugin for xAI (Grok) API. Supports Grok-4, and reasoning models with standard client-side function calling.

Initialize the xAI LLM plugin.

Args

api_key : Optional[str], optional
xAI API key. Defaults to XAI_API_KEY env var.
model : str
The model to use (e.g., "grok-4", "grok-4-1-fast").
base_url : str
The base URL for the xAI API. Defaults to "https://api.x.ai/v1".
temperature : float
The temperature to use. Defaults to 0.7.
tool_choice : ToolChoice
The tool choice to use. Defaults to "auto".
max_completion_tokens : Optional[int], optional
Max tokens.
tools : Optional[List], optional
List of FunctionTools to be available to the LLM.

Ancestors

  • videosdk.agents.llm.llm.LLM
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources"""
    await self.cancel_current_generation()
    if self._client:
        await self._client.close()
    await super().aclose()

Cleanup resources

async def cancel_current_generation(self) ‑> None
Expand source code
async def cancel_current_generation(self) -> None:
    self._cancelled = True

Cancel the current LLM generation if active.

Raises

NotImplementedError
This method must be implemented by subclasses.
async def chat(self,
messages: ChatContext,
tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
conversational_graph: Any | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]
Expand source code
async def chat(
    self,
    messages: ChatContext,
    tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
    conversational_graph: Any | None = None,
    **kwargs: Any
) -> AsyncIterator[LLMResponse]:
    """
    Implement chat functionality using xAI's API via OpenAI SDK compatibility.
    """
    self._cancelled = False

    openai_messages = messages.to_openai_messages()

    completion_params = {
        "model": self.model,
        "messages": openai_messages,
        "temperature": self.temperature,
        "stream": True,
        "max_tokens": self.max_completion_tokens,
    }
    
    if conversational_graph:
        completion_params["response_format"] = {
            "type": "json_schema",
            "json_schema": {
                "name": "conversational_graph_response",
                "strict": True,
                "schema": conversational_graph._get_graph_schema()
            }
        }

    combined_tools = (self.tools or []) + (tools or [])
    
    if combined_tools:
        formatted_tools = []
        for tool in combined_tools:
            if is_function_tool(tool):
                try:
                    tool_schema = build_openai_schema(tool)
                    if "function" not in tool_schema:
                        inner_tool = {k: v for k, v in tool_schema.items() if k != "type"}
                        formatted_tools.append({
                            "type": "function",
                            "function": inner_tool
                        })
                    else:
                        formatted_tools.append(tool_schema)
                except Exception as e:
                    self.emit("error", f"Failed to format tool {tool}: {e}")
                    continue
            elif isinstance(tool, dict):
                formatted_tools.append(tool)
        
        if formatted_tools:
            completion_params["tools"] = formatted_tools
            completion_params["tool_choice"] = self.tool_choice

    completion_params.update(kwargs)

    try:
        response_stream = await self._client.chat.completions.create(**completion_params)
        
        current_content = ""
        current_tool_calls = {} 
        streaming_state = {
            "in_response": False,
            "response_start_index": -1,
            "yielded_content_length": 0
        }

        async for chunk in response_stream:
            if self._cancelled:
                break
            
            if not chunk.choices:
                continue
                
            delta = chunk.choices[0].delta
            
            if delta.tool_calls:
                for tool_call in delta.tool_calls:
                    idx = tool_call.index
                    if idx not in current_tool_calls:
                        current_tool_calls[idx] = {
                            "id": tool_call.id or "",
                            "name": tool_call.function.name or "",
                            "arguments": tool_call.function.arguments or ""
                        }
                    else:
                        if tool_call.function.name:
                            current_tool_calls[idx]["name"] += tool_call.function.name
                        if tool_call.function.arguments:
                            current_tool_calls[idx]["arguments"] += tool_call.function.arguments

            if delta.content is not None:
                current_content += delta.content   
                if conversational_graph:                     
                    for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state):
                        yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT)
                else:
                    yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT)

        if current_tool_calls and not self._cancelled:
            for idx in sorted(current_tool_calls.keys()):
                tool_data = current_tool_calls[idx]
                try:
                    args_str = tool_data["arguments"]
                    parsed_args = json.loads(args_str) 
                    
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata={
                            "function_call": {
                                "name": tool_data["name"],
                                "arguments": parsed_args
                            },
                            "tool_call_id": tool_data["id"]
                        }
                    )
                except json.JSONDecodeError:
                    self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}")

        if current_content and not self._cancelled and conversational_graph:
            try:
                parsed_json = json.loads(current_content.strip())
                yield LLMResponse(
                    content="",
                    role=ChatRole.ASSISTANT,
                    metadata=parsed_json
                )
            except json.JSONDecodeError:
                 pass

    except Exception as e:
        if not self._cancelled:
            self.emit("error", e)
        raise

Implement chat functionality using xAI's API via OpenAI SDK compatibility.

class XAIRealtime (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
config: XAIRealtimeConfig | None = None,
base_url: str | None = None)
Expand source code
class XAIRealtime(RealtimeBaseModel[XAIEventTypes]):
    """xAI's Grok realtime model implementation"""

    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = "grok-4-1-fast-non-reasoning",
        config: XAIRealtimeConfig | None = None,
        base_url: str | None = None,
    ) -> None:
        super().__init__()
        self.api_key = api_key or os.getenv("XAI_API_KEY")
        self.base_url = base_url or XAI_BASE_URL
        
        if not self.api_key:
            self.emit("error", "XAI_API_KEY is required")
            raise ValueError("XAI_API_KEY is required")

        self.config: XAIRealtimeConfig = config or XAIRealtimeConfig()
        
        self._http_session: Optional[aiohttp.ClientSession] = None
        self._session: Optional[XAISession] = None
        self._closing = False
        self._instructions: str = "You are a helpful assistant."
        self._tools: List[FunctionTool] = []
        self._formatted_tools: List[Dict[str, Any]] = []
        
        self.input_sample_rate = INPUT_SAMPLE_RATE
        self.target_sample_rate = DEFAULT_SAMPLE_RATE
        self._agent_speaking = False
        self._current_response_id: str | None = None
        self._is_configured = False
        self._session_ready = asyncio.Event()
        self._has_unprocessed_tool_outputs = False
        self._generated_text_in_current_response = False

    def set_agent(self, agent: Agent) -> None:
        self._agent = agent
        if agent.instructions:
            self._instructions = agent.instructions
        self._tools = agent.tools
        self._formatted_tools = self._format_tools_for_session(self._tools)

    async def connect(self) -> None:
        """Establish WebSocket connection to xAI"""
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

        self._session = await self._create_session(self.base_url, headers)
        await self._handle_websocket(self._session)
        
        self.reframe_audio_track(self.target_sample_rate)
        
        try:
            await asyncio.wait_for(self._session_ready.wait(), timeout=10.0)
            logger.info("xAI session configuration complete")
        except asyncio.TimeoutError:
            logger.warning("Timeout waiting for xAI session configuration")

    async def _create_session(self, url: str, headers: dict) -> XAISession:
        if not self._http_session:
            self._http_session = aiohttp.ClientSession()
            
        try:
            ws = await self._http_session.ws_connect(
                url,
                headers=headers,
                autoping=True,
                heartbeat=10,
                timeout=30,
            )
        except Exception as e:
            self.emit("error", f"Connection failed: {e}")
            raise

        msg_queue: asyncio.Queue = asyncio.Queue()
        tasks: list[asyncio.Task] = []
        self._closing = False

        return XAISession(ws=ws, msg_queue=msg_queue, tasks=tasks)

    async def _send_initial_config(self) -> None:
        """Send session.update to configure voice and audio"""
        if not self._session:
            return

        tools_config = []
        
        if self._formatted_tools:
            tools_config.extend(self._formatted_tools)
            
        if self.config.enable_web_search:
            tools_config.append({"type": "web_search"})
        
        if self.config.enable_x_search or self.config.allowed_x_handles:
            x_search_config = {"type": "x_search"}
            if self.config.allowed_x_handles:
                logger.info(f"Allowed xAI handles: {self.config.allowed_x_handles}")
                x_search_config["allowed_x_handles"] = self.config.allowed_x_handles
            tools_config.append(x_search_config)

        if self.config.collection_id:
            tools_config.append({
                "type": "file_search",
                "vector_store_ids": [self.config.collection_id],
                "max_num_results": self.config.max_num_results,
            })

        session_update = {
            "type": "session.update",
            "session": {
                "instructions": self.instructions_with_context(self._instructions),
                "voice": self.config.voice,
                "audio": {
                    "input": {
                        "format": {
                            "type": "audio/pcm",
                            "rate": self.target_sample_rate
                        }
                    },
                    "output": {
                        "format": {
                            "type": "audio/pcm",
                            "rate": self.target_sample_rate
                        }
                    }
                },
                "turn_detection": {
                    "type": "server_vad"
                },
                "tools": tools_config if tools_config else None
            }
        }

        await self.send_event(session_update)

    async def handle_audio_input(self, audio_data: bytes) -> None:
        """Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI"""
        if not self._session or self._closing:
            return

        if "audio" not in self.config.modalities:
            return

        if self.current_utterance and not self.current_utterance.is_interruptible:
            return

        try:
            raw_audio = np.frombuffer(audio_data, dtype=np.int16)
            
            if len(raw_audio) >= 1920 and len(raw_audio) % 2 == 0:
                raw_audio = (raw_audio.reshape(-1, 2).astype(np.int32).mean(axis=1)).astype(np.int16)
            
            if self.input_sample_rate != self.target_sample_rate:
                num_samples = int(len(raw_audio) * self.target_sample_rate / self.input_sample_rate)
                float_audio = raw_audio.astype(np.float32)
                resampled_audio = signal.resample(float_audio, num_samples).astype(np.int16)
            else:
                resampled_audio = raw_audio

            base64_audio = base64.b64encode(resampled_audio.tobytes()).decode("utf-8")
            
            if not hasattr(self, "_audio_log_counter"):
                self._audio_log_counter = 0
            self._audio_log_counter += 1
            if self._audio_log_counter % 100 == 0:
                rms = np.sqrt(np.mean(resampled_audio.astype(np.float32)**2))
                logger.info(f"xAI Audio: Sent chunk {self._audio_log_counter}, samples={len(resampled_audio)}, RMS={rms:.2f}")

            await self.send_event({
                "type": "input_audio_buffer.append",
                "audio": base64_audio
            })
        except Exception as e:
            logger.error(f"Error processing audio input: {e}")

    async def handle_video_input(self, video_data: Any) -> None:
        """xAI Voice API currently does not document direct video stream support in this endpoint."""
        pass

    async def send_message(self, message: str) -> None:
        """Send text message to trigger audio response"""
        await self.send_event({"type": "input_audio_buffer.commit"})
        await self.send_event({
            "type": "conversation.item.create",
            "item": {
                "type": "message",
                "role": "user",
                "content": [{
                    "type": "input_text", 
                    "text": message
                }],
            }
        })
        await self.create_response()

    async def create_response(self) -> None:
        """Trigger a response from the model"""
        await self.send_event({
            "type": "response.create"
        })

    async def send_text_message(self, message: str) -> None:
        """Send text message (same as send_message for xAI flow)"""
        await self.send_message(message)

    async def interrupt(self) -> None:
        """Interrupt current generation"""
        if self._session and not self._closing:
            if self.current_utterance and not self.current_utterance.is_interruptible:
                return

            if self.audio_track:
                self.audio_track.interrupt()

            if self._agent_speaking:
                if self.audio_track:
                    self.audio_track.mark_synthesis_complete()
                self._agent_speaking = False

            metrics_collector.on_interrupted()

    async def _handle_websocket(self, session: XAISession) -> None:
        session.tasks.extend([
            asyncio.create_task(self._send_loop(session), name="xai_send"),
            asyncio.create_task(self._receive_loop(session), name="xai_recv"),
        ])

    async def _send_loop(self, session: XAISession) -> None:
        try:
            while not self._closing:
                msg = await session.msg_queue.get()
                if isinstance(msg, dict):
                    logger.debug(f"Sending xAI event: {msg.get('type')}")
                    await session.ws.send_json(msg)
                else:
                    await session.ws.send_str(str(msg))
        except asyncio.CancelledError:
            pass
        except Exception as e:
            logger.error(f"xAI Send loop error: {e}")
            self.emit("error", f"Send loop error: {e}")

    async def _receive_loop(self, session: XAISession) -> None:
        try:
            while not self._closing:
                msg = await session.ws.receive()
                if msg.type == aiohttp.WSMsgType.TEXT:
                    data = json.loads(msg.data)
                    await self._handle_event(data)
                elif msg.type == aiohttp.WSMsgType.CLOSED:
                    logger.info("xAI WebSocket closed")
                    break
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    logger.error(f"xAI WebSocket error: {session.ws.exception()}")
                    break
        except Exception as e:
            logger.error(f"xAI Receive loop error: {e}")
            self.emit("error", f"Receive loop error: {e}")
        finally:
            logger.info("xAI Receive loop finished, closing session")
            await self.aclose()

    async def _handle_event(self, data: dict) -> None:
        event_type = data.get("type")
        try:
            if event_type == "conversation.created":
                if not self._is_configured:
                    await self._send_initial_config()
                    self._is_configured = True
            elif event_type == "input_audio_buffer.speech_started":
                await self._handle_speech_started()
            elif event_type == "input_audio_buffer.speech_stopped":
                await self._handle_speech_stopped()
            elif event_type == "session.updated":
                logger.info("xAI Session updated successfully")
                self._session_ready.set()
            elif event_type == "response.created":
                logger.info(f"Response created: {data.get('response', {}).get('id')}")
                self._generated_text_in_current_response = False
            elif event_type == "response.output_item.added":
                logger.info(f"Output item added: {data.get('item', {}).get('id')}")
            elif event_type == "response.output_audio.delta":
                await self._handle_audio_delta(data)
            elif event_type == "response.output_audio_transcript.delta":
                await self._handle_transcript_delta(data)
            elif event_type == "response.output_audio_transcript.done":
                 await self._handle_transcript_done(data)
            elif event_type == "conversation.item.input_audio_transcription.completed":
                await self._handle_input_audio_transcription_completed(data)
            elif event_type == "response.function_call_arguments.done":
                await self._handle_function_call(data)
            elif event_type == "response.done":
                await self._handle_response_done()
            elif event_type == "error":
                logger.error(f"xAI Error: {data}")
                
        except Exception as e:
            logger.error(f"Error handling event {event_type}: {e}")
            traceback.print_exc()

    async def _handle_speech_started(self) -> None:
        logger.info("xAI User speech started")
        self.emit("user_speech_started", {"type": "done"})
        metrics_collector.on_user_speech_start()
        metrics_collector.start_turn()

        if self.current_utterance and not self.current_utterance.is_interruptible:
            return
            
        await self.interrupt()

    async def _handle_speech_stopped(self) -> None:
        logger.info("xAI User speech stopped")
        metrics_collector.on_user_speech_end()
        self.emit("user_speech_ended", {})

    async def _handle_audio_delta(self, data: dict) -> None:
        delta = data.get("delta")
        if not delta:
            return

        if not self._agent_speaking:
            metrics_collector.on_agent_speech_start()
            self._agent_speaking = True
            self.emit("agent_speech_started", {})

        if self.audio_track and self.loop:
            audio_bytes = base64.b64decode(delta)
            asyncio.create_task(self.audio_track.add_new_bytes(audio_bytes))

    async def _handle_transcript_delta(self, data: dict) -> None:
        delta = data.get("delta", "")
        if delta:
            self._generated_text_in_current_response = True
            if not hasattr(self, "_current_transcript"):
                self._current_transcript = ""
            self._current_transcript += delta

    async def _handle_transcript_done(self, data: dict) -> None:
        pass

    async def _handle_input_audio_transcription_completed(self, data: dict) -> None:
        """Handle input audio transcription completion for user transcript"""
        transcript = data.get("transcript", "")
        if transcript:
            logger.info(f"xAI User transcript: {transcript}")
            metrics_collector.set_user_transcript(transcript)
            try:
                self.emit(
                    "realtime_model_transcription",
                    {"role": "user", "text": transcript, "is_final": True},
                )
            except Exception:
                pass

    async def _handle_response_done(self) -> None:
        if hasattr(self, "_current_transcript") and self._current_transcript:
             logger.info(f"xAI Agent response: {self._current_transcript}")
             metrics_collector.set_agent_response(self._current_transcript)
             
             try:
                 self.emit(
                     "realtime_model_transcription",
                     {"role": "assistant", "text": self._current_transcript, "is_final": True},
                 )
             except Exception:
                 pass
             
             self.emit("llm_text_output", {"text": self._current_transcript})
             
             global_event_emitter.emit(
                "text_response",
                {"text": self._current_transcript, "type": "done"},
            )
             self._current_transcript = ""

        logger.info("xAI Agent speech ended")
        if self.audio_track:
            self.audio_track.mark_synthesis_complete()
        self._agent_speaking = False

        if self._has_unprocessed_tool_outputs and not self._generated_text_in_current_response:
            logger.info("xAI: Triggering follow-up response for tool outputs")
            self._has_unprocessed_tool_outputs = False
            await self.create_response()
        else:
            self._has_unprocessed_tool_outputs = False

    async def _handle_function_call(self, data: dict) -> None:
        """Handle tool execution flow for xAI"""
        name = data.get("name")
        call_id = data.get("call_id")
        args_str = data.get("arguments")
        
        if not name or not args_str:
            return

        try:
            arguments = json.loads(args_str)
            logger.info(f"Executing tool: {name} with args: {arguments}")
            metrics_collector.add_function_tool_call(tool_name=name)
            result = None
            found = False
            for tool in self._tools:
                info = get_tool_info(tool)
                if info.name == name:
                    result = await tool(**arguments)
                    found = True
                    break
            
            if not found:
                logger.warning(f"Tool {name} not found")
                result = {"error": "Tool not found"}

            self.emit(
                "realtime_model_function_executed",
                {
                    "name": name,
                    "arguments": args_str,
                    "call_id": call_id,
                    "output": result if isinstance(result, str) else json.dumps(result),
                    "is_error": not found,
                },
            )

            await self.send_event({
                "type": "conversation.item.create",
                "item": {
                    "type": "function_call_output",
                    "call_id": call_id,
                    "output": json.dumps(result)
                }
            })

            if found:
                self._has_unprocessed_tool_outputs = True

        except Exception as e:
            self.emit(
                "realtime_model_function_executed",
                {
                    "name": name,
                    "arguments": args_str,
                    "call_id": call_id,
                    "output": str(e),
                    "is_error": True,
                },
            )
            logger.error(f"Error executing function {name}: {e}")

    async def send_event(self, event: Dict[str, Any]) -> None:
        if self._session and not self._closing:
            await self._session.msg_queue.put(event)

    def _format_tools_for_session(self, tools: List[FunctionTool]) -> List[Dict[str, Any]]:
        """Format tools using OpenAI schema builder (xAI is compatible)"""
        formatted = []
        for tool in tools:
            if is_function_tool(tool):
                try:
                    schema = build_openai_schema(tool)
                    formatted.append(schema)
                except Exception as e:
                    logger.error(f"Failed to format tool {tool}: {e}")
        return formatted

    async def aclose(self) -> None:
        """Cleanup resources"""
        if self._closing:
            return
        
        self._closing = True
        
        if self._session:
            for task in self._session.tasks:
                if not task.done():
                    task.cancel()
            
            if not self._session.ws.closed:
                await self._session.ws.close()
                
        if self._http_session and not self._http_session.closed:
            await self._http_session.close()

        await super().aclose()

xAI's Grok realtime model implementation

Initialize the realtime model

Ancestors

  • videosdk.agents.realtime_base_model.RealtimeBaseModel
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic
  • abc.ABC

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources"""
    if self._closing:
        return
    
    self._closing = True
    
    if self._session:
        for task in self._session.tasks:
            if not task.done():
                task.cancel()
        
        if not self._session.ws.closed:
            await self._session.ws.close()
            
    if self._http_session and not self._http_session.closed:
        await self._http_session.close()

    await super().aclose()

Cleanup resources

async def connect(self) ‑> None
Expand source code
async def connect(self) -> None:
    """Establish WebSocket connection to xAI"""
    headers = {
        "Authorization": f"Bearer {self.api_key}",
        "Content-Type": "application/json",
    }

    self._session = await self._create_session(self.base_url, headers)
    await self._handle_websocket(self._session)
    
    self.reframe_audio_track(self.target_sample_rate)
    
    try:
        await asyncio.wait_for(self._session_ready.wait(), timeout=10.0)
        logger.info("xAI session configuration complete")
    except asyncio.TimeoutError:
        logger.warning("Timeout waiting for xAI session configuration")

Establish WebSocket connection to xAI

async def create_response(self) ‑> None
Expand source code
async def create_response(self) -> None:
    """Trigger a response from the model"""
    await self.send_event({
        "type": "response.create"
    })

Trigger a response from the model

async def handle_audio_input(self, audio_data: bytes) ‑> None
Expand source code
async def handle_audio_input(self, audio_data: bytes) -> None:
    """Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI"""
    if not self._session or self._closing:
        return

    if "audio" not in self.config.modalities:
        return

    if self.current_utterance and not self.current_utterance.is_interruptible:
        return

    try:
        raw_audio = np.frombuffer(audio_data, dtype=np.int16)
        
        if len(raw_audio) >= 1920 and len(raw_audio) % 2 == 0:
            raw_audio = (raw_audio.reshape(-1, 2).astype(np.int32).mean(axis=1)).astype(np.int16)
        
        if self.input_sample_rate != self.target_sample_rate:
            num_samples = int(len(raw_audio) * self.target_sample_rate / self.input_sample_rate)
            float_audio = raw_audio.astype(np.float32)
            resampled_audio = signal.resample(float_audio, num_samples).astype(np.int16)
        else:
            resampled_audio = raw_audio

        base64_audio = base64.b64encode(resampled_audio.tobytes()).decode("utf-8")
        
        if not hasattr(self, "_audio_log_counter"):
            self._audio_log_counter = 0
        self._audio_log_counter += 1
        if self._audio_log_counter % 100 == 0:
            rms = np.sqrt(np.mean(resampled_audio.astype(np.float32)**2))
            logger.info(f"xAI Audio: Sent chunk {self._audio_log_counter}, samples={len(resampled_audio)}, RMS={rms:.2f}")

        await self.send_event({
            "type": "input_audio_buffer.append",
            "audio": base64_audio
        })
    except Exception as e:
        logger.error(f"Error processing audio input: {e}")

Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI

async def handle_video_input(self, video_data: Any) ‑> None
Expand source code
async def handle_video_input(self, video_data: Any) -> None:
    """xAI Voice API currently does not document direct video stream support in this endpoint."""
    pass

xAI Voice API currently does not document direct video stream support in this endpoint.

async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    """Interrupt current generation"""
    if self._session and not self._closing:
        if self.current_utterance and not self.current_utterance.is_interruptible:
            return

        if self.audio_track:
            self.audio_track.interrupt()

        if self._agent_speaking:
            if self.audio_track:
                self.audio_track.mark_synthesis_complete()
            self._agent_speaking = False

        metrics_collector.on_interrupted()

Interrupt current generation

async def send_event(self, event: Dict[str, Any]) ‑> None
Expand source code
async def send_event(self, event: Dict[str, Any]) -> None:
    if self._session and not self._closing:
        await self._session.msg_queue.put(event)
async def send_message(self, message: str) ‑> None
Expand source code
async def send_message(self, message: str) -> None:
    """Send text message to trigger audio response"""
    await self.send_event({"type": "input_audio_buffer.commit"})
    await self.send_event({
        "type": "conversation.item.create",
        "item": {
            "type": "message",
            "role": "user",
            "content": [{
                "type": "input_text", 
                "text": message
            }],
        }
    })
    await self.create_response()

Send text message to trigger audio response

async def send_text_message(self, message: str) ‑> None
Expand source code
async def send_text_message(self, message: str) -> None:
    """Send text message (same as send_message for xAI flow)"""
    await self.send_message(message)

Send text message (same as send_message for xAI flow)

def set_agent(self, agent: Agent) ‑> None
Expand source code
def set_agent(self, agent: Agent) -> None:
    self._agent = agent
    if agent.instructions:
        self._instructions = agent.instructions
    self._tools = agent.tools
    self._formatted_tools = self._format_tools_for_session(self._tools)
class XAIRealtimeConfig (voice: XAIVoice = 'Ara',
instructions: str | None = None,
turn_detection: XAITurnDetection | None = <factory>,
modalities: List[str] = <factory>,
enable_web_search: bool = False,
enable_x_search: bool = False,
allowed_x_handles: List[str] | None = None,
collection_id: str | None = None,
max_num_results: int = 10)
Expand source code
@dataclass
class XAIRealtimeConfig:
    """Configuration for the xAI (Grok) Realtime API
    
    Args:
        voice: The voice identifier. Options: 'Ara', 'Rex', 'Sal', 'Eve', 'Leo'. Default: 'Ara'
        instructions: System instructions for the agent.
        turn_detection: Configuration for server-side VAD.
        tools: List of specific xAI tools (e.g., web_search, x_search). 
               Standard function tools are handled via the Agent class.
    """
    voice: XAIVoice = DEFAULT_XAI_VOICE
    instructions: str | None = None
    turn_detection: XAITurnDetection | None = field(default_factory=XAITurnDetection)
    modalities: List[str] = field(default_factory=lambda: ["text", "audio"])
    enable_web_search: bool = False
    enable_x_search: bool = False
    allowed_x_handles: List[str] | None = None
    collection_id: str | None = None
    max_num_results: int = 10

Configuration for the xAI (Grok) Realtime API

Args

voice
The voice identifier. Options: 'Ara', 'Rex', 'Sal', 'Eve', 'Leo'. Default: 'Ara'
instructions
System instructions for the agent.
turn_detection
Configuration for server-side VAD.
tools
List of specific xAI tools (e.g., web_search, x_search). Standard function tools are handled via the Agent class.

Instance variables

var allowed_x_handles : List[str] | None
var collection_id : str | None
var instructions : str | None
var max_num_results : int
var modalities : List[str]
var turn_detectionXAITurnDetection | None
var voice : Literal['Ara', 'Rex', 'Sal', 'Eve', 'Leo']
class XAISTT (*,
api_key: str | None = None,
sample_rate: int = 48000,
encoding: "Literal['pcm', 'mulaw', 'alaw']" = 'pcm',
interim_results: bool = True,
endpointing: int = 50,
language: str | None = 'en',
diarize: bool = False,
multichannel: bool = False,
channels: int = 1,
base_url: str = 'wss://api.x.ai/v1/stt')
Expand source code
class XAISTT(BaseSTT):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        sample_rate: int = 48000,
        encoding: Literal["pcm", "mulaw", "alaw"] = "pcm",
        interim_results: bool = True,
        endpointing: int = 50,
        language: str | None = "en",
        diarize: bool = False,
        multichannel: bool = False,
        channels: int = 1,
        base_url: str = XAI_STT_BASE_URL,
    ) -> None:
        """Initialize the xAI STT plugin.

        Args:
            api_key: xAI API key. Falls back to XAI_API_KEY env var.
            sample_rate: Audio sample rate in Hz. xAI accepts 8000/16000/22050/24000/44100/48000.
                Defaults to 48000 to match the framework's native input rate.
            encoding: Raw audio encoding. One of "pcm" (signed 16-bit LE), "mulaw", "alaw".
            interim_results: Emit partial transcripts (is_final=false) as they arrive.
            endpointing: Silence duration (ms) before xAI fires utterance-final. Range 0–5000.
                Kept low (50ms default) because the framework's VAD is the primary turn
                detector; flush() injects synthetic silence to force utterance-final, and
                a low threshold means flush latency is short.
            language: BCP-47 language code (e.g. "en", "fr"). Pass None to skip the param.
                xAI transcribes any supported language regardless of this — the value only
                enables Inverse Text Normalization (numbers, currencies in written form).
            diarize: When true, each word in the response includes a `speaker` field.
            multichannel: When true, transcribes each input channel independently. Requires
                interleaved multi-channel audio. When false, input is downmixed to mono.
            channels: Number of input channels (only relevant with multichannel=True).
            base_url: WebSocket endpoint URL.
        """
        super().__init__()

        self.api_key = api_key or os.getenv("XAI_API_KEY")
        if not self.api_key:
            raise ValueError(
                "xAI API key must be provided either through the api_key parameter "
                "or the XAI_API_KEY environment variable"
            )

        if sample_rate not in SUPPORTED_SAMPLE_RATES:
            raise ValueError(
                f"sample_rate must be one of {sorted(SUPPORTED_SAMPLE_RATES)}, got {sample_rate}"
            )
        if encoding not in SUPPORTED_ENCODINGS:
            raise ValueError(
                f"encoding must be one of {sorted(SUPPORTED_ENCODINGS)}, got {encoding}"
            )
        if not 0 <= endpointing <= 5000:
            raise ValueError(f"endpointing must be in [0, 5000], got {endpointing}")

        self.sample_rate = sample_rate
        self.encoding = encoding
        self.interim_results = interim_results
        self.endpointing = endpointing
        self.language = language
        self.diarize = diarize
        self.multichannel = multichannel
        self.channels = channels
        self.base_url = base_url

        self._session: Optional[aiohttp.ClientSession] = None
        self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
        self._ws_task: Optional[asyncio.Task] = None
        self._server_ready: asyncio.Event = asyncio.Event()
        self._closed = False

    async def process_audio(
        self,
        audio_frames: bytes,
        language: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """Process audio frames and stream them to xAI's WebSocket STT API."""
        if self._closed:
            return

        if not self._ws:
            await self._connect_ws()
            self._ws_task = asyncio.create_task(self._listen_for_responses())
            try:
                await asyncio.wait_for(self._server_ready.wait(), timeout=5.0)
            except asyncio.TimeoutError:
                self.emit("error", "Timed out waiting for xAI transcript.created")
                return

        try:
            audio_bytes = audio_frames
            if (
                not self.multichannel
                and self.encoding == "pcm"
                and len(audio_frames) % 4 == 0
            ):
                audio_data = np.frombuffer(audio_frames, dtype=np.int16)
                if audio_data.size > 0 and audio_data.size % 2 == 0:
                    audio_data = (
                        audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
                    )
                    audio_bytes = audio_data.tobytes()

            await self._ws.send_bytes(audio_bytes)
        except Exception as e:
            logger.error(f"Error sending audio to xAI STT: {e}")
            self.emit("error", str(e))
            await self._reset_connection()

    async def _connect_ws(self) -> None:
        """Open the WebSocket connection to xAI's STT endpoint."""
        if not self._session:
            self._session = aiohttp.ClientSession()

        params: list[tuple[str, str]] = [
            ("sample_rate", str(self.sample_rate)),
            ("encoding", self.encoding),
            ("interim_results", str(self.interim_results).lower()),
            ("endpointing", str(self.endpointing)),
            ("diarize", str(self.diarize).lower()),
            ("multichannel", str(self.multichannel).lower()),
            ("channels", str(1 if not self.multichannel else self.channels)),
        ]
        if self.language:
            params.append(("language", self.language))

        ws_url = f"{self.base_url}?{urlencode(params)}"
        headers = {"Authorization": f"Bearer {self.api_key}"}

        self._server_ready = asyncio.Event()

        try:
            self._ws = await self._session.ws_connect(
                ws_url, headers=headers, heartbeat=30.0
            )
        except Exception as e:
            logger.error(f"Error connecting to xAI STT WebSocket: {e}")
            raise

    async def _listen_for_responses(self) -> None:
        """Background task that reads transcript events from the WebSocket."""
        if not self._ws:
            return

        try:
            async for msg in self._ws:
                if msg.type == aiohttp.WSMsgType.TEXT:
                    try:
                        data = msg.json()
                    except Exception as e:
                        logger.error(f"Failed to parse xAI STT message: {e}")
                        continue

                    event_type = data.get("type")
                    if event_type == "transcript.created":
                        self._server_ready.set()
                    elif event_type == "transcript.partial":
                        response = self._handle_partial(data)
                        if response and self._transcript_callback:
                            await self._transcript_callback(response)
                    elif event_type == "transcript.done":
                        response = self._handle_partial(data)
                        if response and self._transcript_callback:
                            await self._transcript_callback(response)
                    elif event_type == "error":
                        message = data.get("message", "unknown error")
                        logger.error(f"xAI STT error event: {message}")
                        self.emit("error", message)
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    err = self._ws.exception()
                    logger.error(f"xAI STT WebSocket error: {err}")
                    self.emit("error", f"WebSocket error: {err}")
                    break
                elif msg.type == aiohttp.WSMsgType.CLOSED:
                    break
        except asyncio.CancelledError:
            raise
        except Exception as e:
            logger.error(f"Error in xAI STT listener: {e}")
            self.emit("error", f"Error in WebSocket listener: {e}")
        finally:
            if self._ws and not self._ws.closed:
                await self._ws.close()
            self._ws = None
            self._server_ready.clear()

    def _handle_partial(self, event: dict) -> Optional[STTResponse]:
        """Map an xAI transcript event to an STTResponse."""
        text = event.get("text", "")
        if not text:
            return None

        is_final = bool(event.get("is_final", False))
        speech_final = bool(event.get("speech_final", False))
        event_is_done = event.get("type") == "transcript.done"
        event_type = (
            SpeechEventType.FINAL
            if (is_final and speech_final) or event_is_done
            else SpeechEventType.INTERIM
        )

        words = event.get("words") or []
        start = event.get("start", 0.0) or 0.0
        duration = event.get("duration", 0.0) or 0.0
        if words:
            start_time = float(words[0].get("start", start))
            end_time = float(words[-1].get("end", start + duration))
        else:
            start_time = float(start)
            end_time = float(start) + float(duration)

        return STTResponse(
            event_type=event_type,
            data=SpeechData(
                text=text,
                language=self.language,
                confidence=0.0,
                start_time=start_time,
                end_time=end_time,
                duration=float(duration),
            ),
            metadata={
                "is_final": is_final,
                "speech_final": speech_final,
                "channel_index": event.get("channel_index"),
            },
        )

    async def flush(self) -> None:
        """Force xAI to emit the current utterance-final transcript.

        We inject a short chunk of silence bytes — once the configured
        endpointing window elapses, xAI fires `transcript.partial` with
        speech_final=true, which we map to SpeechEventType.FINAL.
        """
        if self._closed or not self._ws or self._ws.closed:
            return
        if not self._server_ready.is_set():
            return

        try:
            silence_ms = max(self.endpointing + 50, 100)
            bytes_per_sample = 2 if self.encoding == "pcm" else 1
            channel_count = self.channels if self.multichannel else 1
            n_bytes = int(
                self.sample_rate * bytes_per_sample * channel_count * silence_ms / 1000
            )
            silence = _SILENCE_BYTE[self.encoding] * n_bytes
            await self._ws.send_bytes(silence)
        except Exception as e:
            logger.warning(f"xAI STT flush failed: {e}")

    async def _reset_connection(self) -> None:
        if self._ws:
            try:
                await self._ws.close()
            except Exception:
                pass
            self._ws = None
        if self._ws_task:
            self._ws_task.cancel()
            try:
                await self._ws_task
            except (asyncio.CancelledError, Exception):
                pass
            self._ws_task = None
        self._server_ready.clear()

    async def aclose(self) -> None:
        """Cleanup resources."""
        self._closed = True

        if self._ws and not self._ws.closed:
            try:
                await self._ws.send_str(json.dumps({"type": "audio.done"}))
            except Exception:
                pass

        if self._ws_task:
            self._ws_task.cancel()
            try:
                await self._ws_task
            except (asyncio.CancelledError, Exception):
                pass
            self._ws_task = None

        if self._ws and not self._ws.closed:
            await self._ws.close()
        self._ws = None

        if self._session:
            await self._session.close()
            self._session = None

        await super().aclose()

Base class for Speech-to-Text implementations

Initialize the xAI STT plugin.

Args

api_key
xAI API key. Falls back to XAI_API_KEY env var.
sample_rate
Audio sample rate in Hz. xAI accepts 8000/16000/22050/24000/44100/48000. Defaults to 48000 to match the framework's native input rate.
encoding
Raw audio encoding. One of "pcm" (signed 16-bit LE), "mulaw", "alaw".
interim_results
Emit partial transcripts (is_final=false) as they arrive.
endpointing
Silence duration (ms) before xAI fires utterance-final. Range 0–5000. Kept low (50ms default) because the framework's VAD is the primary turn detector; flush() injects synthetic silence to force utterance-final, and a low threshold means flush latency is short.
language
BCP-47 language code (e.g. "en", "fr"). Pass None to skip the param. xAI transcribes any supported language regardless of this — the value only enables Inverse Text Normalization (numbers, currencies in written form).
diarize
When true, each word in the response includes a speaker field.
multichannel
When true, transcribes each input channel independently. Requires interleaved multi-channel audio. When false, input is downmixed to mono.
channels
Number of input channels (only relevant with multichannel=True).
base_url
WebSocket endpoint URL.

Ancestors

  • videosdk.agents.stt.stt.STT
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources."""
    self._closed = True

    if self._ws and not self._ws.closed:
        try:
            await self._ws.send_str(json.dumps({"type": "audio.done"}))
        except Exception:
            pass

    if self._ws_task:
        self._ws_task.cancel()
        try:
            await self._ws_task
        except (asyncio.CancelledError, Exception):
            pass
        self._ws_task = None

    if self._ws and not self._ws.closed:
        await self._ws.close()
    self._ws = None

    if self._session:
        await self._session.close()
        self._session = None

    await super().aclose()

Cleanup resources.

async def flush(self) ‑> None
Expand source code
async def flush(self) -> None:
    """Force xAI to emit the current utterance-final transcript.

    We inject a short chunk of silence bytes — once the configured
    endpointing window elapses, xAI fires `transcript.partial` with
    speech_final=true, which we map to SpeechEventType.FINAL.
    """
    if self._closed or not self._ws or self._ws.closed:
        return
    if not self._server_ready.is_set():
        return

    try:
        silence_ms = max(self.endpointing + 50, 100)
        bytes_per_sample = 2 if self.encoding == "pcm" else 1
        channel_count = self.channels if self.multichannel else 1
        n_bytes = int(
            self.sample_rate * bytes_per_sample * channel_count * silence_ms / 1000
        )
        silence = _SILENCE_BYTE[self.encoding] * n_bytes
        await self._ws.send_bytes(silence)
    except Exception as e:
        logger.warning(f"xAI STT flush failed: {e}")

Force xAI to emit the current utterance-final transcript.

We inject a short chunk of silence bytes — once the configured endpointing window elapses, xAI fires transcript.partial with speech_final=true, which we map to SpeechEventType.FINAL.

async def process_audio(self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any) ‑> None
Expand source code
async def process_audio(
    self,
    audio_frames: bytes,
    language: Optional[str] = None,
    **kwargs: Any,
) -> None:
    """Process audio frames and stream them to xAI's WebSocket STT API."""
    if self._closed:
        return

    if not self._ws:
        await self._connect_ws()
        self._ws_task = asyncio.create_task(self._listen_for_responses())
        try:
            await asyncio.wait_for(self._server_ready.wait(), timeout=5.0)
        except asyncio.TimeoutError:
            self.emit("error", "Timed out waiting for xAI transcript.created")
            return

    try:
        audio_bytes = audio_frames
        if (
            not self.multichannel
            and self.encoding == "pcm"
            and len(audio_frames) % 4 == 0
        ):
            audio_data = np.frombuffer(audio_frames, dtype=np.int16)
            if audio_data.size > 0 and audio_data.size % 2 == 0:
                audio_data = (
                    audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
                )
                audio_bytes = audio_data.tobytes()

        await self._ws.send_bytes(audio_bytes)
    except Exception as e:
        logger.error(f"Error sending audio to xAI STT: {e}")
        self.emit("error", str(e))
        await self._reset_connection()

Process audio frames and stream them to xAI's WebSocket STT API.

class XAITTS (*,
api_key: str | None = None,
voice: str = 'eve',
language: str = 'en',
codec: "Literal['pcm', 'mulaw']" = 'pcm',
sample_rate: int = 24000,
optimize_streaming_latency: int = 0,
text_normalization: bool = False,
base_url: str = 'wss://api.x.ai/v1/tts',
max_connection_age_sec: float = 300.0)
Expand source code
class XAITTS(TTS):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        voice: str = "eve",
        language: str = "en",
        codec: Literal["pcm", "mulaw"] = "pcm",
        sample_rate: int = 24000,
        optimize_streaming_latency: int = 0,
        text_normalization: bool = False,
        base_url: str = XAI_TTS_BASE_URL,
        max_connection_age_sec: float = DEFAULT_CONNECTION_MAX_AGE_SEC,
    ) -> None:
        """Initialize the xAI TTS plugin.

        Args:
            api_key: xAI API key. Falls back to XAI_API_KEY env var.
            voice: Voice ID — one of "eve", "ara", "rex", "sal", "leo". Case-insensitive.
            language: BCP-47 language code (e.g. "en", "fr", "pt-BR") or "auto" for
                automatic language detection. Required by xAI.
            codec: Output codec. Restricted to "pcm" (signed 16-bit LE, default) or
                "mulaw" — both are raw, byte-streamable formats compatible with the
                framework's audio_track. mp3/wav/alaw are not exposed because they
                require a decoder before bytes can be played.
            sample_rate: Output sample rate in Hz. One of 8000/16000/22050/24000/44100/48000.
                Defaults to 24000 (xAI's recommended rate).
            optimize_streaming_latency: 0 (default, best quality) or 1 (lower
                time-to-first-audio with minor quality tradeoff).
            text_normalization: When true, xAI normalizes written-form text
                (numbers, abbreviations, symbols) into spoken-form before synthesis.
            base_url: WebSocket endpoint URL.

        Speech tags:
            xAI supports inline expression tags ([pause], [long-pause], [laugh],
            [sigh], [breath], etc.) and wrapping style tags (<whisper>...</whisper>,
            <soft>, <loud>, <slow>, <fast>, <higher-pitch>, <lower-pitch>,
            <emphasis>, <singing>, <sing-song>, <laugh-speak>, <build-intensity>,
            <decrease-intensity>) directly inside the `text` you pass to
            `synthesize()`. No separate parameter is needed — the tags are sent
            verbatim as part of each text.delta message and parsed server-side.

            Example::

                await tts.synthesize(
                    "So I walked in and [pause] there it was. [laugh] Incredible!"
                )
                await tts.synthesize(
                    "I need to tell you something. "
                    "<whisper>It is a secret.</whisper> Pretty cool, right?"
                )

            Caveat for streaming input: when synthesize() receives an
            AsyncIterator[str] (e.g. LLM tokens), a single tag can be split across
            two chunks ("[pa", "use]") which xAI will not recognize. Tags only work
            reliably when an entire tag arrives within one text chunk.
        """
        if sample_rate not in SUPPORTED_SAMPLE_RATES:
            raise ValueError(
                f"sample_rate must be one of {sorted(SUPPORTED_SAMPLE_RATES)}, got {sample_rate}"
            )
        if codec not in SUPPORTED_CODECS:
            raise ValueError(
                f"codec must be one of {sorted(SUPPORTED_CODECS)} for raw PCM-compatible "
                f"output (got {codec}). mp3/wav/alaw are not supported because they "
                f"produce framed audio that the audio_track cannot consume directly."
            )
        if optimize_streaming_latency not in (0, 1):
            raise ValueError("optimize_streaming_latency must be 0 or 1")

        super().__init__(
            sample_rate=sample_rate,
            num_channels=XAI_TTS_NUM_CHANNELS,
            word_timestamps=False,
        )

        self._api_key = api_key or os.getenv("XAI_API_KEY")
        if not self._api_key:
            raise ValueError(
                "xAI API key must be provided either through the api_key parameter "
                "or the XAI_API_KEY environment variable"
            )

        voice_lower = voice.lower()
        if voice_lower not in SUPPORTED_VOICES:
            raise ValueError(
                f"voice must be one of {sorted(SUPPORTED_VOICES)}, got {voice}"
            )

        self._voice = voice_lower
        self.language = language
        self.codec = codec
        self.optimize_streaming_latency = optimize_streaming_latency
        self.text_normalization = text_normalization
        self.base_url = base_url
        self._max_connection_age_sec = max_connection_age_sec

        self._ws_session: Optional[aiohttp.ClientSession] = None
        self._ws_connection: Optional[aiohttp.ClientWebSocketResponse] = None
        self._ws_connect_time: float = 0.0
        self._connection_lock = asyncio.Lock()
        self._synthesis_lock = asyncio.Lock()
        self._receive_task: Optional[asyncio.Task] = None
        self._current_done_future: Optional[asyncio.Future[None]] = None
        self._first_chunk_sent = False
        self._interrupted = False
        self._closed = False

    def reset_first_audio_tracking(self) -> None:
        """Reset the first-audio-byte tracking state for the next synthesis turn."""
        self._first_chunk_sent = False

    async def prewarm(self) -> None:
        """Pre-establish the xAI WebSocket so the first ``synthesize()`` call
        does not pay the TLS + auth + upgrade cost. Safe to call repeatedly."""
        try:
            await self._ensure_ws_connection()
        except Exception as e:
            logger.warning(f"xAI TTS prewarm failed (non-fatal): {e}")

    async def synthesize(
        self,
        text: AsyncIterator[Union[str, FlushMarker]] | str,
        voice_id: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """Synthesize text to speech via xAI's bidirectional WebSocket API.

        ``FlushMarker`` segment markers in the input stream are silently dropped
        — xAI has no per-sentence flush primitive, and the server segments
        naturally on ``text.done``."""
        try:
            if not self.audio_track or not self.loop:
                self.emit("error", "Audio track or event loop not set")
                return

            if voice_id:
                voice_lower = voice_id.lower()
                if voice_lower not in SUPPORTED_VOICES:
                    self.emit(
                        "error",
                        f"voice_id must be one of {sorted(SUPPORTED_VOICES)}, got {voice_id}",
                    )
                    return
                self._voice = voice_lower

            async with self._synthesis_lock:
                self._interrupted = False
                self._first_chunk_sent = False

                await self._ensure_ws_connection()
                if not self._ws_connection:
                    raise RuntimeError("WebSocket connection is not available.")

                done_future: asyncio.Future[None] = (
                    asyncio.get_event_loop().create_future()
                )
                self._current_done_future = done_future

                async def _string_iterator(s: str) -> AsyncIterator[str]:
                    yield s

                text_iterator = (
                    _string_iterator(text) if isinstance(text, str) else text
                )

                send_task = asyncio.create_task(
                    self._send_task(text_iterator, done_future)
                )

                try:
                    await done_future
                finally:
                    if not send_task.done():
                        try:
                            await send_task
                        except Exception:
                            pass

        except Exception as e:
            self.emit("error", f"TTS synthesis failed: {e}")
            raise
        finally:
            self._current_done_future = None

    async def _send_task(
        self,
        text_iterator: AsyncIterator[Union[str, FlushMarker]],
        done_future: asyncio.Future[None],
    ) -> None:
        """Send text.delta messages, then text.done at end of utterance."""
        has_sent = False
        try:
            async for chunk in text_iterator:
                if self._interrupted:
                    break
                if isinstance(chunk, FlushMarker):
                    # xAI has no per-sentence flush primitive — the server
                    # segments naturally on ``text.done`` at end-of-utterance.
                    continue
                if not chunk or not chunk.strip():
                    continue
                if not self._ws_connection or self._ws_connection.closed:
                    break
                payload = {"type": "text.delta", "delta": chunk}
                await self._ws_connection.send_str(json.dumps(payload))
                has_sent = True
        except Exception as e:
            if not done_future.done():
                done_future.set_exception(e)
            return
        finally:
            if (
                has_sent
                and not self._interrupted
                and self._ws_connection
                and not self._ws_connection.closed
            ):
                try:
                    await self._ws_connection.send_str(
                        json.dumps({"type": "text.done"})
                    )
                except Exception as e:
                    if not done_future.done():
                        done_future.set_exception(e)

        if not has_sent and not done_future.done():
            done_future.set_result(None)

    async def _receive_loop(self) -> None:
        """Long-running task: read audio.delta / audio.done / error frames."""
        try:
            while self._ws_connection and not self._ws_connection.closed:
                msg = await self._ws_connection.receive()
                if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
                    break
                if msg.type == aiohttp.WSMsgType.ERROR:
                    err = self._ws_connection.exception()
                    self._fail_pending(RuntimeError(f"xAI TTS WebSocket error: {err}"))
                    break
                if msg.type != aiohttp.WSMsgType.TEXT:
                    continue

                try:
                    data = json.loads(msg.data)
                except Exception as e:
                    logger.error(f"Failed to parse xAI TTS message: {e}")
                    continue

                event_type = data.get("type")
                if event_type == "audio.delta":
                    delta = data.get("delta")
                    if delta:
                        try:
                            await self._stream_audio(base64.b64decode(delta))
                        except Exception as e:
                            logger.error(f"Failed to decode/stream audio: {e}")
                elif event_type == "audio.done":
                    future = self._current_done_future
                    if future and not future.done():
                        future.set_result(None)
                elif event_type == "error":
                    message = data.get("message", "unknown error")
                    self._fail_pending(RuntimeError(f"xAI TTS error: {message}"))
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self._fail_pending(e)

    def _fail_pending(self, exc: BaseException) -> None:
        future = self._current_done_future
        if future and not future.done():
            future.set_exception(exc)

    async def _stream_audio(self, audio_chunk: bytes) -> None:
        """Push a chunk of raw audio bytes into the framework's audio_track."""
        if self._interrupted or not audio_chunk:
            return

        # Drop late audio that belongs to a cancelled synthesis: if the active
        # done_future has already resolved (cancelled or completed), the frame
        # is for a stale context and would bleed into the next turn.
        future = self._current_done_future
        if future is not None and future.done():
            return

        if not self._first_chunk_sent:
            self._first_chunk_sent = True
            if self._first_audio_callback:
                await self._first_audio_callback()

        if self.audio_track:
            await self.audio_track.add_new_bytes(audio_chunk)

    async def _ensure_ws_connection(self) -> None:
        """Open or re-open the WebSocket connection if needed."""
        async with self._connection_lock:
            now = asyncio.get_event_loop().time()

            if self._ws_connection and not self._ws_connection.closed:
                age = now - self._ws_connect_time
                if age < self._max_connection_age_sec:
                    return
                logger.info(f"Refreshing xAI WebSocket (age={age:.1f}s)")

            if self._receive_task and not self._receive_task.done():
                self._receive_task.cancel()
                try:
                    await self._receive_task
                except (asyncio.CancelledError, Exception):
                    pass
            self._receive_task = None

            if self._ws_connection:
                try:
                    await self._ws_connection.close()
                except Exception:
                    pass
                self._ws_connection = None

            if self._ws_session:
                try:
                    await self._ws_session.close()
                except Exception:
                    pass
                self._ws_session = None

            try:
                self._ws_session = aiohttp.ClientSession()

                params = [
                    ("voice", self._voice),
                    ("language", self.language),
                    ("codec", self.codec),
                    ("sample_rate", str(self.sample_rate)),
                    ("optimize_streaming_latency", str(self.optimize_streaming_latency)),
                    ("text_normalization", str(self.text_normalization).lower()),
                ]
                ws_url = f"{self.base_url}?{urlencode(params)}"
                headers = {"Authorization": f"Bearer {self._api_key}"}

                self._ws_connection = await asyncio.wait_for(
                    self._ws_session.ws_connect(
                        ws_url, headers=headers, heartbeat=30.0
                    ),
                    timeout=5.0,
                )
                self._ws_connect_time = now
                self._receive_task = asyncio.create_task(self._receive_loop())
            except aiohttp.WSServerHandshakeError as e:
                self.emit(
                    "error",
                    f"xAI TTS WebSocket handshake failed (status {e.status}): {e.message}",
                )
                raise
            except Exception as e:
                self.emit("error", f"Failed to establish xAI TTS WebSocket: {e}")
                raise

    async def interrupt(self) -> None:
        """Stop emitting audio for the current synthesis. Keeps the WebSocket
        open so the next turn does not pay reconnect cost; in-flight audio
        frames received after this point are dropped via the done-future
        filter in :meth:`_stream_audio`."""
        self._interrupted = True

        if self.audio_track:
            self.audio_track.interrupt()

        future = self._current_done_future
        if future and not future.done():
            future.cancel()

    async def aclose(self) -> None:
        """Gracefully clean up all resources."""
        await super().aclose()
        self._interrupted = True
        self._closed = True

        if self._receive_task and not self._receive_task.done():
            self._receive_task.cancel()
            try:
                await self._receive_task
            except (asyncio.CancelledError, Exception):
                pass
            self._receive_task = None

        if self._ws_connection and not self._ws_connection.closed:
            await self._ws_connection.close()
        self._ws_connection = None

        if self._ws_session and not self._ws_session.closed:
            await self._ws_session.close()
        self._ws_session = None

Base class for Text-to-Speech implementations

Initialize the xAI TTS plugin.

Args

api_key
xAI API key. Falls back to XAI_API_KEY env var.
voice
Voice ID — one of "eve", "ara", "rex", "sal", "leo". Case-insensitive.
language
BCP-47 language code (e.g. "en", "fr", "pt-BR") or "auto" for automatic language detection. Required by xAI.
codec
Output codec. Restricted to "pcm" (signed 16-bit LE, default) or "mulaw" — both are raw, byte-streamable formats compatible with the framework's audio_track. mp3/wav/alaw are not exposed because they require a decoder before bytes can be played.
sample_rate
Output sample rate in Hz. One of 8000/16000/22050/24000/44100/48000. Defaults to 24000 (xAI's recommended rate).
optimize_streaming_latency
0 (default, best quality) or 1 (lower time-to-first-audio with minor quality tradeoff).
text_normalization
When true, xAI normalizes written-form text (numbers, abbreviations, symbols) into spoken-form before synthesis.
base_url
WebSocket endpoint URL.

Speech tags: xAI supports inline expression tags ([pause], [long-pause], [laugh], [sigh], [breath], etc.) and wrapping style tags (, , , , , , , , , , , , ) directly inside the text you pass to synthesize(). No separate parameter is needed — the tags are sent verbatim as part of each text.delta message and parsed server-side.

Example::

    await tts.synthesize(
        "So I walked in and [pause] there it was. [laugh] Incredible!"
    )
    await tts.synthesize(
        "I need to tell you something. "
        "<whisper>It is a secret.</whisper> Pretty cool, right?"
    )

Caveat for streaming input: when synthesize() receives an
AsyncIterator[str] (e.g. LLM tokens), a single tag can be split across
two chunks ("[pa", "use]") which xAI will not recognize. Tags only work
reliably when an entire tag arrives within one text chunk.

Ancestors

  • videosdk.agents.tts.tts.TTS
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Gracefully clean up all resources."""
    await super().aclose()
    self._interrupted = True
    self._closed = True

    if self._receive_task and not self._receive_task.done():
        self._receive_task.cancel()
        try:
            await self._receive_task
        except (asyncio.CancelledError, Exception):
            pass
        self._receive_task = None

    if self._ws_connection and not self._ws_connection.closed:
        await self._ws_connection.close()
    self._ws_connection = None

    if self._ws_session and not self._ws_session.closed:
        await self._ws_session.close()
    self._ws_session = None

Gracefully clean up all resources.

async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    """Stop emitting audio for the current synthesis. Keeps the WebSocket
    open so the next turn does not pay reconnect cost; in-flight audio
    frames received after this point are dropped via the done-future
    filter in :meth:`_stream_audio`."""
    self._interrupted = True

    if self.audio_track:
        self.audio_track.interrupt()

    future = self._current_done_future
    if future and not future.done():
        future.cancel()

Stop emitting audio for the current synthesis. Keeps the WebSocket open so the next turn does not pay reconnect cost; in-flight audio frames received after this point are dropped via the done-future filter in :meth:_stream_audio.

async def prewarm(self) ‑> None
Expand source code
async def prewarm(self) -> None:
    """Pre-establish the xAI WebSocket so the first ``synthesize()`` call
    does not pay the TLS + auth + upgrade cost. Safe to call repeatedly."""
    try:
        await self._ensure_ws_connection()
    except Exception as e:
        logger.warning(f"xAI TTS prewarm failed (non-fatal): {e}")

Pre-establish the xAI WebSocket so the first synthesize() call does not pay the TLS + auth + upgrade cost. Safe to call repeatedly.

def reset_first_audio_tracking(self) ‑> None
Expand source code
def reset_first_audio_tracking(self) -> None:
    """Reset the first-audio-byte tracking state for the next synthesis turn."""
    self._first_chunk_sent = False

Reset the first-audio-byte tracking state for the next synthesis turn.

async def synthesize(self,
text: AsyncIterator[Union[str, FlushMarker]] | str,
voice_id: Optional[str] = None,
**kwargs: Any) ‑> None
Expand source code
async def synthesize(
    self,
    text: AsyncIterator[Union[str, FlushMarker]] | str,
    voice_id: Optional[str] = None,
    **kwargs: Any,
) -> None:
    """Synthesize text to speech via xAI's bidirectional WebSocket API.

    ``FlushMarker`` segment markers in the input stream are silently dropped
    — xAI has no per-sentence flush primitive, and the server segments
    naturally on ``text.done``."""
    try:
        if not self.audio_track or not self.loop:
            self.emit("error", "Audio track or event loop not set")
            return

        if voice_id:
            voice_lower = voice_id.lower()
            if voice_lower not in SUPPORTED_VOICES:
                self.emit(
                    "error",
                    f"voice_id must be one of {sorted(SUPPORTED_VOICES)}, got {voice_id}",
                )
                return
            self._voice = voice_lower

        async with self._synthesis_lock:
            self._interrupted = False
            self._first_chunk_sent = False

            await self._ensure_ws_connection()
            if not self._ws_connection:
                raise RuntimeError("WebSocket connection is not available.")

            done_future: asyncio.Future[None] = (
                asyncio.get_event_loop().create_future()
            )
            self._current_done_future = done_future

            async def _string_iterator(s: str) -> AsyncIterator[str]:
                yield s

            text_iterator = (
                _string_iterator(text) if isinstance(text, str) else text
            )

            send_task = asyncio.create_task(
                self._send_task(text_iterator, done_future)
            )

            try:
                await done_future
            finally:
                if not send_task.done():
                    try:
                        await send_task
                    except Exception:
                        pass

    except Exception as e:
        self.emit("error", f"TTS synthesis failed: {e}")
        raise
    finally:
        self._current_done_future = None

Synthesize text to speech via xAI's bidirectional WebSocket API.

FlushMarker segment markers in the input stream are silently dropped — xAI has no per-sentence flush primitive, and the server segments naturally on text.done.

class XAITurnDetection (type: "Literal['server_vad'] | None" = 'server_vad',
threshold: float = 0.5,
prefix_padding_ms: int = 300,
silence_duration_ms: int = 200)
Expand source code
@dataclass
class XAITurnDetection:
    type: Literal["server_vad"] | None = "server_vad"
    threshold: float = 0.5
    prefix_padding_ms: int = 300
    silence_duration_ms: int = 200

XAITurnDetection(type: "Literal['server_vad'] | None" = 'server_vad', threshold: 'float' = 0.5, prefix_padding_ms: 'int' = 300, silence_duration_ms: 'int' = 200)

Instance variables

var prefix_padding_ms : int
var silence_duration_ms : int
var threshold : float
var type : Literal['server_vad'] | None