Package videosdk.plugins.xai

Sub-modules

videosdk.plugins.xai.llm
videosdk.plugins.xai.xai_realtime

Classes

class XAILLM (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
base_url: str = 'https://api.x.ai/v1',
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None,
tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None)
Expand source code
class XAILLM(LLM):
    """
    LLM Plugin for xAI (Grok) API.
    Supports Grok-4, and reasoning models with standard client-side function calling.
    """
    
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = "grok-4-1-fast-non-reasoning", 
        base_url: str = "https://api.x.ai/v1",
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_completion_tokens: int | None = None,
        tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None,
    ) -> None:
        """Initialize the xAI LLM plugin.

        Args:
            api_key (Optional[str], optional): xAI API key. Defaults to XAI_API_KEY env var.
            model (str): The model to use (e.g., "grok-4", "grok-4-1-fast").
            base_url (str): The base URL for the xAI API. Defaults to "https://api.x.ai/v1".
            temperature (float): The temperature to use. Defaults to 0.7.
            tool_choice (ToolChoice): The tool choice to use. Defaults to "auto".
            max_completion_tokens (Optional[int], optional): Max tokens.
            tools (Optional[List], optional): List of FunctionTools to be available to the LLM.
        """
        super().__init__()
        self.api_key = api_key or os.getenv("XAI_API_KEY")
        if not self.api_key:
            raise ValueError("xAI API key must be provided either through api_key parameter or XAI_API_KEY environment variable")
        
        self.model = model
        self.temperature = temperature
        self.tool_choice = tool_choice
        self.max_completion_tokens = max_completion_tokens
        self.tools = tools or []
        self._cancelled = False
        
        self._client = openai.AsyncOpenAI(
            api_key=self.api_key,
            base_url=base_url,
            max_retries=0,
            http_client=httpx.AsyncClient(
                timeout=httpx.Timeout(connect=15.0, read=60.0, write=5.0, pool=5.0),
                follow_redirects=True,
                limits=httpx.Limits(
                    max_connections=50,
                    max_keepalive_connections=50,
                    keepalive_expiry=120,
                ),
            ),
        )

    async def chat(
        self,
        messages: ChatContext,
        tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
        conversational_graph: Any | None = None,
        **kwargs: Any
    ) -> AsyncIterator[LLMResponse]:
        """
        Implement chat functionality using xAI's API via OpenAI SDK compatibility.
        """
        self._cancelled = False
        
        def _format_content(content: Union[str, List[ChatContent]]):
            if isinstance(content, str):
                return content

            formatted_parts = []
            for part in content:
                if isinstance(part, str):
                    formatted_parts.append({"type": "text", "text": part})
                elif isinstance(part, ImageContent):
                    image_url_data = {"url": part.to_data_url()}
                    if part.inference_detail != "auto":
                        image_url_data["detail"] = part.inference_detail
                    formatted_parts.append(
                        {
                            "type": "image_url",
                            "image_url": image_url_data,
                        }
                    )
            return formatted_parts

            
        openai_messages = []
        for msg in messages.items:
            if msg is None:
                continue

            if isinstance(msg, ChatMessage):
                openai_messages.append({
                    "role": msg.role.value,
                    "content": _format_content(msg.content),
                    **({"name": msg.name} if hasattr(msg, "name") else {}),
                })
            elif isinstance(msg, FunctionCall):
                openai_messages.append({
                    "role": "assistant",
                    "content": None,
                    "tool_calls": [{
                        "id": getattr(msg, "call_id", getattr(msg, "id", "call_unknown")),
                        "type": "function",
                        "function": {
                            "name": msg.name,
                            "arguments": msg.arguments
                        }
                    }]
                })
            elif isinstance(msg, FunctionCallOutput):
                openai_messages.append({
                    "role": "tool",
                    "tool_call_id": getattr(msg, "call_id", getattr(msg, "id", "call_unknown")),
                    "content": msg.output,
                })

        completion_params = {
            "model": self.model,
            "messages": openai_messages,
            "temperature": self.temperature,
            "stream": True,
            "max_tokens": self.max_completion_tokens,
        }
        
        if conversational_graph:
            completion_params["response_format"] = {
                "type": "json_schema",
                "json_schema": {
                    "name": "conversational_graph_response",
                    "strict": True,
                    "schema": conversational_graph_schema
                }
            }

        combined_tools = (self.tools or []) + (tools or [])
        
        if combined_tools:
            formatted_tools = []
            for tool in combined_tools:
                if is_function_tool(tool):
                    try:
                        tool_schema = build_openai_schema(tool)
                        if "function" not in tool_schema:
                            inner_tool = {k: v for k, v in tool_schema.items() if k != "type"}
                            formatted_tools.append({
                                "type": "function",
                                "function": inner_tool
                            })
                        else:
                            formatted_tools.append(tool_schema)
                    except Exception as e:
                        self.emit("error", f"Failed to format tool {tool}: {e}")
                        continue
                elif isinstance(tool, dict):
                    formatted_tools.append(tool)
            
            if formatted_tools:
                completion_params["tools"] = formatted_tools
                completion_params["tool_choice"] = self.tool_choice

        completion_params.update(kwargs)

        try:
            response_stream = await self._client.chat.completions.create(**completion_params)
            
            current_content = ""
            current_tool_calls = {} 
            streaming_state = {
                "in_response": False,
                "response_start_index": -1,
                "yielded_content_length": 0
            }

            async for chunk in response_stream:
                if self._cancelled:
                    break
                
                if not chunk.choices:
                    continue
                    
                delta = chunk.choices[0].delta
                
                if delta.tool_calls:
                    for tool_call in delta.tool_calls:
                        idx = tool_call.index
                        if idx not in current_tool_calls:
                            current_tool_calls[idx] = {
                                "id": tool_call.id or "",
                                "name": tool_call.function.name or "",
                                "arguments": tool_call.function.arguments or ""
                            }
                        else:
                            if tool_call.function.name:
                                current_tool_calls[idx]["name"] += tool_call.function.name
                            if tool_call.function.arguments:
                                current_tool_calls[idx]["arguments"] += tool_call.function.arguments

                if delta.content is not None:
                    current_content += delta.content   
                    if conversational_graph:                     
                        for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state):
                            yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT)
                    else:
                        yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT)

            if current_tool_calls and not self._cancelled:
                for idx in sorted(current_tool_calls.keys()):
                    tool_data = current_tool_calls[idx]
                    try:
                        args_str = tool_data["arguments"]
                        parsed_args = json.loads(args_str) 
                        
                        yield LLMResponse(
                            content="",
                            role=ChatRole.ASSISTANT,
                            metadata={
                                "function_call": {
                                    "name": tool_data["name"],
                                    "arguments": parsed_args
                                },
                                "tool_call_id": tool_data["id"]
                            }
                        )
                    except json.JSONDecodeError:
                        self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}")

            if current_content and not self._cancelled and conversational_graph:
                try:
                    parsed_json = json.loads(current_content.strip())
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata=parsed_json
                    )
                except json.JSONDecodeError:
                     pass

        except Exception as e:
            if not self._cancelled:
                self.emit("error", e)
            raise

    async def cancel_current_generation(self) -> None:
        self._cancelled = True

    async def aclose(self) -> None:
        """Cleanup resources"""
        await self.cancel_current_generation()
        if self._client:
            await self._client.close()
        await super().aclose()

LLM Plugin for xAI (Grok) API. Supports Grok-4, and reasoning models with standard client-side function calling.

Initialize the xAI LLM plugin.

Args

api_key : Optional[str], optional
xAI API key. Defaults to XAI_API_KEY env var.
model : str
The model to use (e.g., "grok-4", "grok-4-1-fast").
base_url : str
The base URL for the xAI API. Defaults to "https://api.x.ai/v1".
temperature : float
The temperature to use. Defaults to 0.7.
tool_choice : ToolChoice
The tool choice to use. Defaults to "auto".
max_completion_tokens : Optional[int], optional
Max tokens.
tools : Optional[List], optional
List of FunctionTools to be available to the LLM.

Ancestors

  • videosdk.agents.llm.llm.LLM
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources"""
    await self.cancel_current_generation()
    if self._client:
        await self._client.close()
    await super().aclose()

Cleanup resources

async def cancel_current_generation(self) ‑> None
Expand source code
async def cancel_current_generation(self) -> None:
    self._cancelled = True

Cancel the current LLM generation if active.

Raises

NotImplementedError
This method must be implemented by subclasses.
async def chat(self,
messages: ChatContext,
tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
conversational_graph: Any | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]
Expand source code
async def chat(
    self,
    messages: ChatContext,
    tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
    conversational_graph: Any | None = None,
    **kwargs: Any
) -> AsyncIterator[LLMResponse]:
    """
    Implement chat functionality using xAI's API via OpenAI SDK compatibility.
    """
    self._cancelled = False
    
    def _format_content(content: Union[str, List[ChatContent]]):
        if isinstance(content, str):
            return content

        formatted_parts = []
        for part in content:
            if isinstance(part, str):
                formatted_parts.append({"type": "text", "text": part})
            elif isinstance(part, ImageContent):
                image_url_data = {"url": part.to_data_url()}
                if part.inference_detail != "auto":
                    image_url_data["detail"] = part.inference_detail
                formatted_parts.append(
                    {
                        "type": "image_url",
                        "image_url": image_url_data,
                    }
                )
        return formatted_parts

        
    openai_messages = []
    for msg in messages.items:
        if msg is None:
            continue

        if isinstance(msg, ChatMessage):
            openai_messages.append({
                "role": msg.role.value,
                "content": _format_content(msg.content),
                **({"name": msg.name} if hasattr(msg, "name") else {}),
            })
        elif isinstance(msg, FunctionCall):
            openai_messages.append({
                "role": "assistant",
                "content": None,
                "tool_calls": [{
                    "id": getattr(msg, "call_id", getattr(msg, "id", "call_unknown")),
                    "type": "function",
                    "function": {
                        "name": msg.name,
                        "arguments": msg.arguments
                    }
                }]
            })
        elif isinstance(msg, FunctionCallOutput):
            openai_messages.append({
                "role": "tool",
                "tool_call_id": getattr(msg, "call_id", getattr(msg, "id", "call_unknown")),
                "content": msg.output,
            })

    completion_params = {
        "model": self.model,
        "messages": openai_messages,
        "temperature": self.temperature,
        "stream": True,
        "max_tokens": self.max_completion_tokens,
    }
    
    if conversational_graph:
        completion_params["response_format"] = {
            "type": "json_schema",
            "json_schema": {
                "name": "conversational_graph_response",
                "strict": True,
                "schema": conversational_graph_schema
            }
        }

    combined_tools = (self.tools or []) + (tools or [])
    
    if combined_tools:
        formatted_tools = []
        for tool in combined_tools:
            if is_function_tool(tool):
                try:
                    tool_schema = build_openai_schema(tool)
                    if "function" not in tool_schema:
                        inner_tool = {k: v for k, v in tool_schema.items() if k != "type"}
                        formatted_tools.append({
                            "type": "function",
                            "function": inner_tool
                        })
                    else:
                        formatted_tools.append(tool_schema)
                except Exception as e:
                    self.emit("error", f"Failed to format tool {tool}: {e}")
                    continue
            elif isinstance(tool, dict):
                formatted_tools.append(tool)
        
        if formatted_tools:
            completion_params["tools"] = formatted_tools
            completion_params["tool_choice"] = self.tool_choice

    completion_params.update(kwargs)

    try:
        response_stream = await self._client.chat.completions.create(**completion_params)
        
        current_content = ""
        current_tool_calls = {} 
        streaming_state = {
            "in_response": False,
            "response_start_index": -1,
            "yielded_content_length": 0
        }

        async for chunk in response_stream:
            if self._cancelled:
                break
            
            if not chunk.choices:
                continue
                
            delta = chunk.choices[0].delta
            
            if delta.tool_calls:
                for tool_call in delta.tool_calls:
                    idx = tool_call.index
                    if idx not in current_tool_calls:
                        current_tool_calls[idx] = {
                            "id": tool_call.id or "",
                            "name": tool_call.function.name or "",
                            "arguments": tool_call.function.arguments or ""
                        }
                    else:
                        if tool_call.function.name:
                            current_tool_calls[idx]["name"] += tool_call.function.name
                        if tool_call.function.arguments:
                            current_tool_calls[idx]["arguments"] += tool_call.function.arguments

            if delta.content is not None:
                current_content += delta.content   
                if conversational_graph:                     
                    for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state):
                        yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT)
                else:
                    yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT)

        if current_tool_calls and not self._cancelled:
            for idx in sorted(current_tool_calls.keys()):
                tool_data = current_tool_calls[idx]
                try:
                    args_str = tool_data["arguments"]
                    parsed_args = json.loads(args_str) 
                    
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata={
                            "function_call": {
                                "name": tool_data["name"],
                                "arguments": parsed_args
                            },
                            "tool_call_id": tool_data["id"]
                        }
                    )
                except json.JSONDecodeError:
                    self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}")

        if current_content and not self._cancelled and conversational_graph:
            try:
                parsed_json = json.loads(current_content.strip())
                yield LLMResponse(
                    content="",
                    role=ChatRole.ASSISTANT,
                    metadata=parsed_json
                )
            except json.JSONDecodeError:
                 pass

    except Exception as e:
        if not self._cancelled:
            self.emit("error", e)
        raise

Implement chat functionality using xAI's API via OpenAI SDK compatibility.

class XAIRealtime (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
config: XAIRealtimeConfig | None = None,
base_url: str | None = None)
Expand source code
class XAIRealtime(RealtimeBaseModel[XAIEventTypes]):
    """xAI's Grok realtime model implementation"""

    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = "grok-4-1-fast-non-reasoning",
        config: XAIRealtimeConfig | None = None,
        base_url: str | None = None,
    ) -> None:
        super().__init__()
        self.api_key = api_key or os.getenv("XAI_API_KEY")
        self.base_url = base_url or XAI_BASE_URL
        
        if not self.api_key:
            self.emit("error", "XAI_API_KEY is required")
            raise ValueError("XAI_API_KEY is required")

        self.config: XAIRealtimeConfig = config or XAIRealtimeConfig()
        
        self._http_session: Optional[aiohttp.ClientSession] = None
        self._session: Optional[XAISession] = None
        self._closing = False
        self._instructions: str = "You are a helpful assistant."
        self._tools: List[FunctionTool] = []
        self._formatted_tools: List[Dict[str, Any]] = []
        
        self.loop = None
        self.audio_track: Optional[CustomAudioStreamTrack] = None
        self.input_sample_rate = INPUT_SAMPLE_RATE
        self.target_sample_rate = DEFAULT_SAMPLE_RATE
        self._agent_speaking = False
        self._current_response_id: str | None = None
        self._is_configured = False
        self._session_ready = asyncio.Event()
        self._has_unprocessed_tool_outputs = False
        self._generated_text_in_current_response = False

    def set_agent(self, agent: Agent) -> None:
        if agent.instructions:
            self._instructions = agent.instructions
        self._tools = agent.tools
        self._formatted_tools = self._format_tools_for_session(self._tools)

    async def connect(self) -> None:
        """Establish WebSocket connection to xAI"""
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

        self._session = await self._create_session(self.base_url, headers)
        await self._handle_websocket(self._session)
        
        if self.audio_track:
            from fractions import Fraction
            self.audio_track.sample_rate = self.target_sample_rate
            self.audio_track.time_base_fraction = Fraction(1, self.target_sample_rate)
            self.audio_track.samples = int(0.02 * self.target_sample_rate)
            self.audio_track.chunk_size = int(self.audio_track.samples * getattr(self.audio_track, "channels", 1) * getattr(self.audio_track, "sample_width", 2))
        
        try:
            await asyncio.wait_for(self._session_ready.wait(), timeout=10.0)
            logger.info("xAI session configuration complete")
        except asyncio.TimeoutError:
            logger.warning("Timeout waiting for xAI session configuration")

    async def _create_session(self, url: str, headers: dict) -> XAISession:
        if not self._http_session:
            self._http_session = aiohttp.ClientSession()
            
        try:
            ws = await self._http_session.ws_connect(
                url,
                headers=headers,
                autoping=True,
                heartbeat=10,
                timeout=30,
            )
        except Exception as e:
            self.emit("error", f"Connection failed: {e}")
            raise

        msg_queue: asyncio.Queue = asyncio.Queue()
        tasks: list[asyncio.Task] = []
        self._closing = False

        return XAISession(ws=ws, msg_queue=msg_queue, tasks=tasks)

    async def _send_initial_config(self) -> None:
        """Send session.update to configure voice and audio"""
        if not self._session:
            return

        tools_config = []
        
        if self._formatted_tools:
            tools_config.extend(self._formatted_tools)
            
        if self.config.enable_web_search:
            tools_config.append({"type": "web_search"})
        
        if self.config.enable_x_search or self.config.allowed_x_handles:
            x_search_config = {"type": "x_search"}
            if self.config.allowed_x_handles:
                logger.info(f"Allowed xAI handles: {self.config.allowed_x_handles}")
                x_search_config["allowed_x_handles"] = self.config.allowed_x_handles
            tools_config.append(x_search_config)

        if self.config.collection_id:
            tools_config.append({
                "type": "file_search",
                "vector_store_ids": [self.config.collection_id],
                "max_num_results": self.config.max_num_results,
            })

        session_update = {
            "type": "session.update",
            "session": {
                "instructions": self._instructions,
                "voice": self.config.voice,
                "audio": {
                    "input": {
                        "format": {
                            "type": "audio/pcm",
                            "rate": self.target_sample_rate
                        }
                    },
                    "output": {
                        "format": {
                            "type": "audio/pcm",
                            "rate": self.target_sample_rate
                        }
                    }
                },
                "turn_detection": {
                    "type": "server_vad"
                },
                "tools": tools_config if tools_config else None
            }
        }

        await self.send_event(session_update)

    async def handle_audio_input(self, audio_data: bytes) -> None:
        """Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI"""
        if not self._session or self._closing:
            return

        if "audio" not in self.config.modalities:
            return

        if self.current_utterance and not self.current_utterance.is_interruptible:
            return

        try:
            raw_audio = np.frombuffer(audio_data, dtype=np.int16)
            
            if len(raw_audio) >= 1920 and len(raw_audio) % 2 == 0:
                raw_audio = (raw_audio.reshape(-1, 2).astype(np.int32).mean(axis=1)).astype(np.int16)
            
            if self.input_sample_rate != self.target_sample_rate:
                num_samples = int(len(raw_audio) * self.target_sample_rate / self.input_sample_rate)
                float_audio = raw_audio.astype(np.float32)
                resampled_audio = signal.resample(float_audio, num_samples).astype(np.int16)
            else:
                resampled_audio = raw_audio

            base64_audio = base64.b64encode(resampled_audio.tobytes()).decode("utf-8")
            
            if not hasattr(self, "_audio_log_counter"):
                self._audio_log_counter = 0
            self._audio_log_counter += 1
            if self._audio_log_counter % 100 == 0:
                rms = np.sqrt(np.mean(resampled_audio.astype(np.float32)**2))
                logger.info(f"xAI Audio: Sent chunk {self._audio_log_counter}, samples={len(resampled_audio)}, RMS={rms:.2f}")

            await self.send_event({
                "type": "input_audio_buffer.append",
                "audio": base64_audio
            })
        except Exception as e:
            logger.error(f"Error processing audio input: {e}")

    async def handle_video_input(self, video_data: Any) -> None:
        """xAI Voice API currently does not document direct video stream support in this endpoint."""
        pass

    async def send_message(self, message: str) -> None:
        """Send text message to trigger audio response"""
        await self.send_event({"type": "input_audio_buffer.commit"})
        await self.send_event({
            "type": "conversation.item.create",
            "item": {
                "type": "message",
                "role": "user",
                "content": [{
                    "type": "input_text", 
                    "text": message
                }],
            }
        })
        await self.create_response()

    async def create_response(self) -> None:
        """Trigger a response from the model"""
        await self.send_event({
            "type": "response.create"
        })

    async def send_text_message(self, message: str) -> None:
        """Send text message (same as send_message for xAI flow)"""
        await self.send_message(message)

    async def interrupt(self) -> None:
        """Interrupt current generation"""
        if self._session and not self._closing:
            if self.current_utterance and not self.current_utterance.is_interruptible:
                return
            await realtime_metrics_collector.set_interrupted()
            
        if self.audio_track:
            self.audio_track.interrupt()
        
        if self._agent_speaking:
            self.emit("agent_speech_ended", {})
            self._agent_speaking = False

    async def _handle_websocket(self, session: XAISession) -> None:
        session.tasks.extend([
            asyncio.create_task(self._send_loop(session), name="xai_send"),
            asyncio.create_task(self._receive_loop(session), name="xai_recv"),
        ])

    async def _send_loop(self, session: XAISession) -> None:
        try:
            while not self._closing:
                msg = await session.msg_queue.get()
                if isinstance(msg, dict):
                    logger.debug(f"Sending xAI event: {msg.get('type')}")
                    await session.ws.send_json(msg)
                else:
                    await session.ws.send_str(str(msg))
        except asyncio.CancelledError:
            pass
        except Exception as e:
            logger.error(f"xAI Send loop error: {e}")
            self.emit("error", f"Send loop error: {e}")

    async def _receive_loop(self, session: XAISession) -> None:
        try:
            while not self._closing:
                msg = await session.ws.receive()
                if msg.type == aiohttp.WSMsgType.TEXT:
                    data = json.loads(msg.data)
                    await self._handle_event(data)
                elif msg.type == aiohttp.WSMsgType.CLOSED:
                    logger.info("xAI WebSocket closed")
                    break
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    logger.error(f"xAI WebSocket error: {session.ws.exception()}")
                    break
        except Exception as e:
            logger.error(f"xAI Receive loop error: {e}")
            self.emit("error", f"Receive loop error: {e}")
        finally:
            logger.info("xAI Receive loop finished, closing session")
            await self.aclose()

    async def _handle_event(self, data: dict) -> None:
        event_type = data.get("type")
        try:
            if event_type == "conversation.created":
                if not self._is_configured:
                    await self._send_initial_config()
                    self._is_configured = True
            elif event_type == "input_audio_buffer.speech_started":
                await self._handle_speech_started()
            elif event_type == "input_audio_buffer.speech_stopped":
                await self._handle_speech_stopped()
            elif event_type == "session.updated":
                logger.info("xAI Session updated successfully")
                self._session_ready.set()
            elif event_type == "response.created":
                logger.info(f"Response created: {data.get('response', {}).get('id')}")
                self._generated_text_in_current_response = False
            elif event_type == "response.output_item.added":
                logger.info(f"Output item added: {data.get('item', {}).get('id')}")
            elif event_type == "response.output_audio.delta":
                await self._handle_audio_delta(data)
            elif event_type == "response.output_audio_transcript.delta":
                await self._handle_transcript_delta(data)
            elif event_type == "response.output_audio_transcript.done":
                 await self._handle_transcript_done(data)
            elif event_type == "conversation.item.input_audio_transcription.completed":
                await self._handle_input_audio_transcription_completed(data)
            elif event_type == "response.function_call_arguments.done":
                await self._handle_function_call(data)
            elif event_type == "response.done":
                await self._handle_response_done()
            elif event_type == "error":
                logger.error(f"xAI Error: {data}")
                
        except Exception as e:
            logger.error(f"Error handling event {event_type}: {e}")
            traceback.print_exc()

    async def _handle_speech_started(self) -> None:
        logger.info("xAI User speech started")
        self.emit("user_speech_started", {"type": "done"})
        await realtime_metrics_collector.set_user_speech_start()
        
        if self.current_utterance and not self.current_utterance.is_interruptible:
            return
            
        await self.interrupt()

    async def _handle_speech_stopped(self) -> None:
        logger.info("xAI User speech stopped")
        await realtime_metrics_collector.set_user_speech_end()
        self.emit("user_speech_ended", {})

    async def _handle_audio_delta(self, data: dict) -> None:
        delta = data.get("delta")
        if not delta:
            return

        if not self._agent_speaking:
            await realtime_metrics_collector.set_agent_speech_start()
            self._agent_speaking = True
            self.emit("agent_speech_started", {})

        if self.audio_track and self.loop:
            audio_bytes = base64.b64decode(delta)
            asyncio.create_task(self.audio_track.add_new_bytes(audio_bytes))

    async def _handle_transcript_delta(self, data: dict) -> None:
        delta = data.get("delta", "")
        if delta:
            self._generated_text_in_current_response = True
            if not hasattr(self, "_current_transcript"):
                self._current_transcript = ""
            self._current_transcript += delta

    async def _handle_transcript_done(self, data: dict) -> None:
        pass

    async def _handle_input_audio_transcription_completed(self, data: dict) -> None:
        """Handle input audio transcription completion for user transcript"""
        transcript = data.get("transcript", "")
        if transcript:
            logger.info(f"xAI User transcript: {transcript}")
            await realtime_metrics_collector.set_user_transcript(transcript)
            try:
                self.emit(
                    "realtime_model_transcription",
                    {"role": "user", "text": transcript, "is_final": True},
                )
            except Exception:
                pass

    async def _handle_response_done(self) -> None:
        if hasattr(self, "_current_transcript") and self._current_transcript:
             logger.info(f"xAI Agent response: {self._current_transcript}")
             await realtime_metrics_collector.set_agent_response(self._current_transcript)
             global_event_emitter.emit(
                "text_response",
                {"text": self._current_transcript, "type": "done"},
            )
             self._current_transcript = ""

        logger.info("xAI Agent speech ended")
        self.emit("agent_speech_ended", {})
        await realtime_metrics_collector.set_agent_speech_end(timeout=1.0)
        self._agent_speaking = False

        if self._has_unprocessed_tool_outputs and not self._generated_text_in_current_response:
            logger.info("xAI: Triggering follow-up response for tool outputs")
            self._has_unprocessed_tool_outputs = False
            await self.create_response()
        else:
            self._has_unprocessed_tool_outputs = False

    async def _handle_function_call(self, data: dict) -> None:
        """Handle tool execution flow for xAI"""
        name = data.get("name")
        call_id = data.get("call_id")
        args_str = data.get("arguments")
        
        if not name or not args_str:
            return

        try:
            arguments = json.loads(args_str)
            logger.info(f"Executing tool: {name} with args: {arguments}")
            await realtime_metrics_collector.add_tool_call(name)
            result = None
            found = False
            for tool in self._tools:
                info = get_tool_info(tool)
                if info.name == name:
                    result = await tool(**arguments)
                    found = True
                    break
            
            if not found:
                logger.warning(f"Tool {name} not found")
                result = {"error": "Tool not found"}

            await self.send_event({
                "type": "conversation.item.create",
                "item": {
                    "type": "function_call_output",
                    "call_id": call_id,
                    "output": json.dumps(result)
                }
            })

            if found:
                self._has_unprocessed_tool_outputs = True

        except Exception as e:
            logger.error(f"Error executing function {name}: {e}")

    async def send_event(self, event: Dict[str, Any]) -> None:
        if self._session and not self._closing:
            await self._session.msg_queue.put(event)

    def _format_tools_for_session(self, tools: List[FunctionTool]) -> List[Dict[str, Any]]:
        """Format tools using OpenAI schema builder (xAI is compatible)"""
        formatted = []
        for tool in tools:
            if is_function_tool(tool):
                try:
                    schema = build_openai_schema(tool)
                    formatted.append(schema)
                except Exception as e:
                    logger.error(f"Failed to format tool {tool}: {e}")
        return formatted

    async def aclose(self) -> None:
        """Cleanup resources"""
        if self._closing:
            return
        
        self._closing = True
        
        if self._session:
            for task in self._session.tasks:
                if not task.done():
                    task.cancel()
            
            if not self._session.ws.closed:
                await self._session.ws.close()
                
        if self._http_session and not self._http_session.closed:
            await self._http_session.close()

        if hasattr(self.audio_track, "cleanup") and self.audio_track:
            await self.audio_track.cleanup()

xAI's Grok realtime model implementation

Initialize the realtime model

Ancestors

  • videosdk.agents.realtime_base_model.RealtimeBaseModel
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic
  • abc.ABC

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources"""
    if self._closing:
        return
    
    self._closing = True
    
    if self._session:
        for task in self._session.tasks:
            if not task.done():
                task.cancel()
        
        if not self._session.ws.closed:
            await self._session.ws.close()
            
    if self._http_session and not self._http_session.closed:
        await self._http_session.close()

    if hasattr(self.audio_track, "cleanup") and self.audio_track:
        await self.audio_track.cleanup()

Cleanup resources

async def connect(self) ‑> None
Expand source code
async def connect(self) -> None:
    """Establish WebSocket connection to xAI"""
    headers = {
        "Authorization": f"Bearer {self.api_key}",
        "Content-Type": "application/json",
    }

    self._session = await self._create_session(self.base_url, headers)
    await self._handle_websocket(self._session)
    
    if self.audio_track:
        from fractions import Fraction
        self.audio_track.sample_rate = self.target_sample_rate
        self.audio_track.time_base_fraction = Fraction(1, self.target_sample_rate)
        self.audio_track.samples = int(0.02 * self.target_sample_rate)
        self.audio_track.chunk_size = int(self.audio_track.samples * getattr(self.audio_track, "channels", 1) * getattr(self.audio_track, "sample_width", 2))
    
    try:
        await asyncio.wait_for(self._session_ready.wait(), timeout=10.0)
        logger.info("xAI session configuration complete")
    except asyncio.TimeoutError:
        logger.warning("Timeout waiting for xAI session configuration")

Establish WebSocket connection to xAI

async def create_response(self) ‑> None
Expand source code
async def create_response(self) -> None:
    """Trigger a response from the model"""
    await self.send_event({
        "type": "response.create"
    })

Trigger a response from the model

async def handle_audio_input(self, audio_data: bytes) ‑> None
Expand source code
async def handle_audio_input(self, audio_data: bytes) -> None:
    """Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI"""
    if not self._session or self._closing:
        return

    if "audio" not in self.config.modalities:
        return

    if self.current_utterance and not self.current_utterance.is_interruptible:
        return

    try:
        raw_audio = np.frombuffer(audio_data, dtype=np.int16)
        
        if len(raw_audio) >= 1920 and len(raw_audio) % 2 == 0:
            raw_audio = (raw_audio.reshape(-1, 2).astype(np.int32).mean(axis=1)).astype(np.int16)
        
        if self.input_sample_rate != self.target_sample_rate:
            num_samples = int(len(raw_audio) * self.target_sample_rate / self.input_sample_rate)
            float_audio = raw_audio.astype(np.float32)
            resampled_audio = signal.resample(float_audio, num_samples).astype(np.int16)
        else:
            resampled_audio = raw_audio

        base64_audio = base64.b64encode(resampled_audio.tobytes()).decode("utf-8")
        
        if not hasattr(self, "_audio_log_counter"):
            self._audio_log_counter = 0
        self._audio_log_counter += 1
        if self._audio_log_counter % 100 == 0:
            rms = np.sqrt(np.mean(resampled_audio.astype(np.float32)**2))
            logger.info(f"xAI Audio: Sent chunk {self._audio_log_counter}, samples={len(resampled_audio)}, RMS={rms:.2f}")

        await self.send_event({
            "type": "input_audio_buffer.append",
            "audio": base64_audio
        })
    except Exception as e:
        logger.error(f"Error processing audio input: {e}")

Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI

async def handle_video_input(self, video_data: Any) ‑> None
Expand source code
async def handle_video_input(self, video_data: Any) -> None:
    """xAI Voice API currently does not document direct video stream support in this endpoint."""
    pass

xAI Voice API currently does not document direct video stream support in this endpoint.

async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    """Interrupt current generation"""
    if self._session and not self._closing:
        if self.current_utterance and not self.current_utterance.is_interruptible:
            return
        await realtime_metrics_collector.set_interrupted()
        
    if self.audio_track:
        self.audio_track.interrupt()
    
    if self._agent_speaking:
        self.emit("agent_speech_ended", {})
        self._agent_speaking = False

Interrupt current generation

async def send_event(self, event: Dict[str, Any]) ‑> None
Expand source code
async def send_event(self, event: Dict[str, Any]) -> None:
    if self._session and not self._closing:
        await self._session.msg_queue.put(event)
async def send_message(self, message: str) ‑> None
Expand source code
async def send_message(self, message: str) -> None:
    """Send text message to trigger audio response"""
    await self.send_event({"type": "input_audio_buffer.commit"})
    await self.send_event({
        "type": "conversation.item.create",
        "item": {
            "type": "message",
            "role": "user",
            "content": [{
                "type": "input_text", 
                "text": message
            }],
        }
    })
    await self.create_response()

Send text message to trigger audio response

async def send_text_message(self, message: str) ‑> None
Expand source code
async def send_text_message(self, message: str) -> None:
    """Send text message (same as send_message for xAI flow)"""
    await self.send_message(message)

Send text message (same as send_message for xAI flow)

def set_agent(self, agent: Agent) ‑> None
Expand source code
def set_agent(self, agent: Agent) -> None:
    if agent.instructions:
        self._instructions = agent.instructions
    self._tools = agent.tools
    self._formatted_tools = self._format_tools_for_session(self._tools)
class XAIRealtimeConfig (voice: XAIVoice = 'Ara',
instructions: str | None = None,
turn_detection: XAITurnDetection | None = <factory>,
modalities: List[str] = <factory>,
enable_web_search: bool = False,
enable_x_search: bool = False,
allowed_x_handles: List[str] | None = None,
collection_id: str | None = None,
max_num_results: int = 10)
Expand source code
@dataclass
class XAIRealtimeConfig:
    """Configuration for the xAI (Grok) Realtime API
    
    Args:
        voice: The voice identifier. Options: 'Ara', 'Rex', 'Sal', 'Eve', 'Leo'. Default: 'Ara'
        instructions: System instructions for the agent.
        turn_detection: Configuration for server-side VAD.
        tools: List of specific xAI tools (e.g., web_search, x_search). 
               Standard function tools are handled via the Agent class.
    """
    voice: XAIVoice = DEFAULT_XAI_VOICE
    instructions: str | None = None
    turn_detection: XAITurnDetection | None = field(default_factory=XAITurnDetection)
    modalities: List[str] = field(default_factory=lambda: ["text", "audio"])
    enable_web_search: bool = False
    enable_x_search: bool = False
    allowed_x_handles: List[str] | None = None
    collection_id: str | None = None
    max_num_results: int = 10

Configuration for the xAI (Grok) Realtime API

Args

voice
The voice identifier. Options: 'Ara', 'Rex', 'Sal', 'Eve', 'Leo'. Default: 'Ara'
instructions
System instructions for the agent.
turn_detection
Configuration for server-side VAD.
tools
List of specific xAI tools (e.g., web_search, x_search). Standard function tools are handled via the Agent class.

Instance variables

var allowed_x_handles : List[str] | None
var collection_id : str | None
var instructions : str | None
var max_num_results : int
var modalities : List[str]
var turn_detectionXAITurnDetection | None
var voice : Literal['Ara', 'Rex', 'Sal', 'Eve', 'Leo']
class XAITurnDetection (type: "Literal['server_vad'] | None" = 'server_vad',
threshold: float = 0.5,
prefix_padding_ms: int = 300,
silence_duration_ms: int = 200)
Expand source code
@dataclass
class XAITurnDetection:
    type: Literal["server_vad"] | None = "server_vad"
    threshold: float = 0.5
    prefix_padding_ms: int = 300
    silence_duration_ms: int = 200

XAITurnDetection(type: "Literal['server_vad'] | None" = 'server_vad', threshold: 'float' = 0.5, prefix_padding_ms: 'int' = 300, silence_duration_ms: 'int' = 200)

Instance variables

var prefix_padding_ms : int
var silence_duration_ms : int
var threshold : float
var type : Literal['server_vad'] | None