Package videosdk.plugins.xai
Sub-modules
videosdk.plugins.xai.llmvideosdk.plugins.xai.xai_realtime
Classes
class XAILLM (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
base_url: str = 'https://api.x.ai/v1',
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None,
tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None)-
Expand source code
class XAILLM(LLM): """ LLM Plugin for xAI (Grok) API. Supports Grok-4, and reasoning models with standard client-side function calling. """ def __init__( self, *, api_key: str | None = None, model: str = "grok-4-1-fast-non-reasoning", base_url: str = "https://api.x.ai/v1", temperature: float = 0.7, tool_choice: ToolChoice = "auto", max_completion_tokens: int | None = None, tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None, ) -> None: """Initialize the xAI LLM plugin. Args: api_key (Optional[str], optional): xAI API key. Defaults to XAI_API_KEY env var. model (str): The model to use (e.g., "grok-4", "grok-4-1-fast"). base_url (str): The base URL for the xAI API. Defaults to "https://api.x.ai/v1". temperature (float): The temperature to use. Defaults to 0.7. tool_choice (ToolChoice): The tool choice to use. Defaults to "auto". max_completion_tokens (Optional[int], optional): Max tokens. tools (Optional[List], optional): List of FunctionTools to be available to the LLM. """ super().__init__() self.api_key = api_key or os.getenv("XAI_API_KEY") if not self.api_key: raise ValueError("xAI API key must be provided either through api_key parameter or XAI_API_KEY environment variable") self.model = model self.temperature = temperature self.tool_choice = tool_choice self.max_completion_tokens = max_completion_tokens self.tools = tools or [] self._cancelled = False self._client = openai.AsyncOpenAI( api_key=self.api_key, base_url=base_url, max_retries=0, http_client=httpx.AsyncClient( timeout=httpx.Timeout(connect=15.0, read=60.0, write=5.0, pool=5.0), follow_redirects=True, limits=httpx.Limits( max_connections=50, max_keepalive_connections=50, keepalive_expiry=120, ), ), ) async def chat( self, messages: ChatContext, tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None, conversational_graph: Any | None = None, **kwargs: Any ) -> AsyncIterator[LLMResponse]: """ Implement chat functionality using xAI's API via OpenAI SDK compatibility. """ self._cancelled = False def _format_content(content: Union[str, List[ChatContent]]): if isinstance(content, str): return content formatted_parts = [] for part in content: if isinstance(part, str): formatted_parts.append({"type": "text", "text": part}) elif isinstance(part, ImageContent): image_url_data = {"url": part.to_data_url()} if part.inference_detail != "auto": image_url_data["detail"] = part.inference_detail formatted_parts.append( { "type": "image_url", "image_url": image_url_data, } ) return formatted_parts openai_messages = [] for msg in messages.items: if msg is None: continue if isinstance(msg, ChatMessage): openai_messages.append({ "role": msg.role.value, "content": _format_content(msg.content), **({"name": msg.name} if hasattr(msg, "name") else {}), }) elif isinstance(msg, FunctionCall): openai_messages.append({ "role": "assistant", "content": None, "tool_calls": [{ "id": getattr(msg, "call_id", getattr(msg, "id", "call_unknown")), "type": "function", "function": { "name": msg.name, "arguments": msg.arguments } }] }) elif isinstance(msg, FunctionCallOutput): openai_messages.append({ "role": "tool", "tool_call_id": getattr(msg, "call_id", getattr(msg, "id", "call_unknown")), "content": msg.output, }) completion_params = { "model": self.model, "messages": openai_messages, "temperature": self.temperature, "stream": True, "max_tokens": self.max_completion_tokens, } if conversational_graph: completion_params["response_format"] = { "type": "json_schema", "json_schema": { "name": "conversational_graph_response", "strict": True, "schema": conversational_graph_schema } } combined_tools = (self.tools or []) + (tools or []) if combined_tools: formatted_tools = [] for tool in combined_tools: if is_function_tool(tool): try: tool_schema = build_openai_schema(tool) if "function" not in tool_schema: inner_tool = {k: v for k, v in tool_schema.items() if k != "type"} formatted_tools.append({ "type": "function", "function": inner_tool }) else: formatted_tools.append(tool_schema) except Exception as e: self.emit("error", f"Failed to format tool {tool}: {e}") continue elif isinstance(tool, dict): formatted_tools.append(tool) if formatted_tools: completion_params["tools"] = formatted_tools completion_params["tool_choice"] = self.tool_choice completion_params.update(kwargs) try: response_stream = await self._client.chat.completions.create(**completion_params) current_content = "" current_tool_calls = {} streaming_state = { "in_response": False, "response_start_index": -1, "yielded_content_length": 0 } async for chunk in response_stream: if self._cancelled: break if not chunk.choices: continue delta = chunk.choices[0].delta if delta.tool_calls: for tool_call in delta.tool_calls: idx = tool_call.index if idx not in current_tool_calls: current_tool_calls[idx] = { "id": tool_call.id or "", "name": tool_call.function.name or "", "arguments": tool_call.function.arguments or "" } else: if tool_call.function.name: current_tool_calls[idx]["name"] += tool_call.function.name if tool_call.function.arguments: current_tool_calls[idx]["arguments"] += tool_call.function.arguments if delta.content is not None: current_content += delta.content if conversational_graph: for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state): yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT) else: yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT) if current_tool_calls and not self._cancelled: for idx in sorted(current_tool_calls.keys()): tool_data = current_tool_calls[idx] try: args_str = tool_data["arguments"] parsed_args = json.loads(args_str) yield LLMResponse( content="", role=ChatRole.ASSISTANT, metadata={ "function_call": { "name": tool_data["name"], "arguments": parsed_args }, "tool_call_id": tool_data["id"] } ) except json.JSONDecodeError: self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}") if current_content and not self._cancelled and conversational_graph: try: parsed_json = json.loads(current_content.strip()) yield LLMResponse( content="", role=ChatRole.ASSISTANT, metadata=parsed_json ) except json.JSONDecodeError: pass except Exception as e: if not self._cancelled: self.emit("error", e) raise async def cancel_current_generation(self) -> None: self._cancelled = True async def aclose(self) -> None: """Cleanup resources""" await self.cancel_current_generation() if self._client: await self._client.close() await super().aclose()LLM Plugin for xAI (Grok) API. Supports Grok-4, and reasoning models with standard client-side function calling.
Initialize the xAI LLM plugin.
Args
api_key:Optional[str], optional- xAI API key. Defaults to XAI_API_KEY env var.
model:str- The model to use (e.g., "grok-4", "grok-4-1-fast").
base_url:str- The base URL for the xAI API. Defaults to "https://api.x.ai/v1".
temperature:float- The temperature to use. Defaults to 0.7.
tool_choice:ToolChoice- The tool choice to use. Defaults to "auto".
max_completion_tokens:Optional[int], optional- Max tokens.
tools:Optional[List], optional- List of FunctionTools to be available to the LLM.
Ancestors
- videosdk.agents.llm.llm.LLM
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Cleanup resources""" await self.cancel_current_generation() if self._client: await self._client.close() await super().aclose()Cleanup resources
async def cancel_current_generation(self) ‑> None-
Expand source code
async def cancel_current_generation(self) -> None: self._cancelled = TrueCancel the current LLM generation if active.
Raises
NotImplementedError- This method must be implemented by subclasses.
async def chat(self,
messages: ChatContext,
tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
conversational_graph: Any | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]-
Expand source code
async def chat( self, messages: ChatContext, tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None, conversational_graph: Any | None = None, **kwargs: Any ) -> AsyncIterator[LLMResponse]: """ Implement chat functionality using xAI's API via OpenAI SDK compatibility. """ self._cancelled = False def _format_content(content: Union[str, List[ChatContent]]): if isinstance(content, str): return content formatted_parts = [] for part in content: if isinstance(part, str): formatted_parts.append({"type": "text", "text": part}) elif isinstance(part, ImageContent): image_url_data = {"url": part.to_data_url()} if part.inference_detail != "auto": image_url_data["detail"] = part.inference_detail formatted_parts.append( { "type": "image_url", "image_url": image_url_data, } ) return formatted_parts openai_messages = [] for msg in messages.items: if msg is None: continue if isinstance(msg, ChatMessage): openai_messages.append({ "role": msg.role.value, "content": _format_content(msg.content), **({"name": msg.name} if hasattr(msg, "name") else {}), }) elif isinstance(msg, FunctionCall): openai_messages.append({ "role": "assistant", "content": None, "tool_calls": [{ "id": getattr(msg, "call_id", getattr(msg, "id", "call_unknown")), "type": "function", "function": { "name": msg.name, "arguments": msg.arguments } }] }) elif isinstance(msg, FunctionCallOutput): openai_messages.append({ "role": "tool", "tool_call_id": getattr(msg, "call_id", getattr(msg, "id", "call_unknown")), "content": msg.output, }) completion_params = { "model": self.model, "messages": openai_messages, "temperature": self.temperature, "stream": True, "max_tokens": self.max_completion_tokens, } if conversational_graph: completion_params["response_format"] = { "type": "json_schema", "json_schema": { "name": "conversational_graph_response", "strict": True, "schema": conversational_graph_schema } } combined_tools = (self.tools or []) + (tools or []) if combined_tools: formatted_tools = [] for tool in combined_tools: if is_function_tool(tool): try: tool_schema = build_openai_schema(tool) if "function" not in tool_schema: inner_tool = {k: v for k, v in tool_schema.items() if k != "type"} formatted_tools.append({ "type": "function", "function": inner_tool }) else: formatted_tools.append(tool_schema) except Exception as e: self.emit("error", f"Failed to format tool {tool}: {e}") continue elif isinstance(tool, dict): formatted_tools.append(tool) if formatted_tools: completion_params["tools"] = formatted_tools completion_params["tool_choice"] = self.tool_choice completion_params.update(kwargs) try: response_stream = await self._client.chat.completions.create(**completion_params) current_content = "" current_tool_calls = {} streaming_state = { "in_response": False, "response_start_index": -1, "yielded_content_length": 0 } async for chunk in response_stream: if self._cancelled: break if not chunk.choices: continue delta = chunk.choices[0].delta if delta.tool_calls: for tool_call in delta.tool_calls: idx = tool_call.index if idx not in current_tool_calls: current_tool_calls[idx] = { "id": tool_call.id or "", "name": tool_call.function.name or "", "arguments": tool_call.function.arguments or "" } else: if tool_call.function.name: current_tool_calls[idx]["name"] += tool_call.function.name if tool_call.function.arguments: current_tool_calls[idx]["arguments"] += tool_call.function.arguments if delta.content is not None: current_content += delta.content if conversational_graph: for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state): yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT) else: yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT) if current_tool_calls and not self._cancelled: for idx in sorted(current_tool_calls.keys()): tool_data = current_tool_calls[idx] try: args_str = tool_data["arguments"] parsed_args = json.loads(args_str) yield LLMResponse( content="", role=ChatRole.ASSISTANT, metadata={ "function_call": { "name": tool_data["name"], "arguments": parsed_args }, "tool_call_id": tool_data["id"] } ) except json.JSONDecodeError: self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}") if current_content and not self._cancelled and conversational_graph: try: parsed_json = json.loads(current_content.strip()) yield LLMResponse( content="", role=ChatRole.ASSISTANT, metadata=parsed_json ) except json.JSONDecodeError: pass except Exception as e: if not self._cancelled: self.emit("error", e) raiseImplement chat functionality using xAI's API via OpenAI SDK compatibility.
class XAIRealtime (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
config: XAIRealtimeConfig | None = None,
base_url: str | None = None)-
Expand source code
class XAIRealtime(RealtimeBaseModel[XAIEventTypes]): """xAI's Grok realtime model implementation""" def __init__( self, *, api_key: str | None = None, model: str = "grok-4-1-fast-non-reasoning", config: XAIRealtimeConfig | None = None, base_url: str | None = None, ) -> None: super().__init__() self.api_key = api_key or os.getenv("XAI_API_KEY") self.base_url = base_url or XAI_BASE_URL if not self.api_key: self.emit("error", "XAI_API_KEY is required") raise ValueError("XAI_API_KEY is required") self.config: XAIRealtimeConfig = config or XAIRealtimeConfig() self._http_session: Optional[aiohttp.ClientSession] = None self._session: Optional[XAISession] = None self._closing = False self._instructions: str = "You are a helpful assistant." self._tools: List[FunctionTool] = [] self._formatted_tools: List[Dict[str, Any]] = [] self.loop = None self.audio_track: Optional[CustomAudioStreamTrack] = None self.input_sample_rate = INPUT_SAMPLE_RATE self.target_sample_rate = DEFAULT_SAMPLE_RATE self._agent_speaking = False self._current_response_id: str | None = None self._is_configured = False self._session_ready = asyncio.Event() self._has_unprocessed_tool_outputs = False self._generated_text_in_current_response = False def set_agent(self, agent: Agent) -> None: if agent.instructions: self._instructions = agent.instructions self._tools = agent.tools self._formatted_tools = self._format_tools_for_session(self._tools) async def connect(self) -> None: """Establish WebSocket connection to xAI""" headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } self._session = await self._create_session(self.base_url, headers) await self._handle_websocket(self._session) if self.audio_track: from fractions import Fraction self.audio_track.sample_rate = self.target_sample_rate self.audio_track.time_base_fraction = Fraction(1, self.target_sample_rate) self.audio_track.samples = int(0.02 * self.target_sample_rate) self.audio_track.chunk_size = int(self.audio_track.samples * getattr(self.audio_track, "channels", 1) * getattr(self.audio_track, "sample_width", 2)) try: await asyncio.wait_for(self._session_ready.wait(), timeout=10.0) logger.info("xAI session configuration complete") except asyncio.TimeoutError: logger.warning("Timeout waiting for xAI session configuration") async def _create_session(self, url: str, headers: dict) -> XAISession: if not self._http_session: self._http_session = aiohttp.ClientSession() try: ws = await self._http_session.ws_connect( url, headers=headers, autoping=True, heartbeat=10, timeout=30, ) except Exception as e: self.emit("error", f"Connection failed: {e}") raise msg_queue: asyncio.Queue = asyncio.Queue() tasks: list[asyncio.Task] = [] self._closing = False return XAISession(ws=ws, msg_queue=msg_queue, tasks=tasks) async def _send_initial_config(self) -> None: """Send session.update to configure voice and audio""" if not self._session: return tools_config = [] if self._formatted_tools: tools_config.extend(self._formatted_tools) if self.config.enable_web_search: tools_config.append({"type": "web_search"}) if self.config.enable_x_search or self.config.allowed_x_handles: x_search_config = {"type": "x_search"} if self.config.allowed_x_handles: logger.info(f"Allowed xAI handles: {self.config.allowed_x_handles}") x_search_config["allowed_x_handles"] = self.config.allowed_x_handles tools_config.append(x_search_config) if self.config.collection_id: tools_config.append({ "type": "file_search", "vector_store_ids": [self.config.collection_id], "max_num_results": self.config.max_num_results, }) session_update = { "type": "session.update", "session": { "instructions": self._instructions, "voice": self.config.voice, "audio": { "input": { "format": { "type": "audio/pcm", "rate": self.target_sample_rate } }, "output": { "format": { "type": "audio/pcm", "rate": self.target_sample_rate } } }, "turn_detection": { "type": "server_vad" }, "tools": tools_config if tools_config else None } } await self.send_event(session_update) async def handle_audio_input(self, audio_data: bytes) -> None: """Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI""" if not self._session or self._closing: return if "audio" not in self.config.modalities: return if self.current_utterance and not self.current_utterance.is_interruptible: return try: raw_audio = np.frombuffer(audio_data, dtype=np.int16) if len(raw_audio) >= 1920 and len(raw_audio) % 2 == 0: raw_audio = (raw_audio.reshape(-1, 2).astype(np.int32).mean(axis=1)).astype(np.int16) if self.input_sample_rate != self.target_sample_rate: num_samples = int(len(raw_audio) * self.target_sample_rate / self.input_sample_rate) float_audio = raw_audio.astype(np.float32) resampled_audio = signal.resample(float_audio, num_samples).astype(np.int16) else: resampled_audio = raw_audio base64_audio = base64.b64encode(resampled_audio.tobytes()).decode("utf-8") if not hasattr(self, "_audio_log_counter"): self._audio_log_counter = 0 self._audio_log_counter += 1 if self._audio_log_counter % 100 == 0: rms = np.sqrt(np.mean(resampled_audio.astype(np.float32)**2)) logger.info(f"xAI Audio: Sent chunk {self._audio_log_counter}, samples={len(resampled_audio)}, RMS={rms:.2f}") await self.send_event({ "type": "input_audio_buffer.append", "audio": base64_audio }) except Exception as e: logger.error(f"Error processing audio input: {e}") async def handle_video_input(self, video_data: Any) -> None: """xAI Voice API currently does not document direct video stream support in this endpoint.""" pass async def send_message(self, message: str) -> None: """Send text message to trigger audio response""" await self.send_event({"type": "input_audio_buffer.commit"}) await self.send_event({ "type": "conversation.item.create", "item": { "type": "message", "role": "user", "content": [{ "type": "input_text", "text": message }], } }) await self.create_response() async def create_response(self) -> None: """Trigger a response from the model""" await self.send_event({ "type": "response.create" }) async def send_text_message(self, message: str) -> None: """Send text message (same as send_message for xAI flow)""" await self.send_message(message) async def interrupt(self) -> None: """Interrupt current generation""" if self._session and not self._closing: if self.current_utterance and not self.current_utterance.is_interruptible: return await realtime_metrics_collector.set_interrupted() if self.audio_track: self.audio_track.interrupt() if self._agent_speaking: self.emit("agent_speech_ended", {}) self._agent_speaking = False async def _handle_websocket(self, session: XAISession) -> None: session.tasks.extend([ asyncio.create_task(self._send_loop(session), name="xai_send"), asyncio.create_task(self._receive_loop(session), name="xai_recv"), ]) async def _send_loop(self, session: XAISession) -> None: try: while not self._closing: msg = await session.msg_queue.get() if isinstance(msg, dict): logger.debug(f"Sending xAI event: {msg.get('type')}") await session.ws.send_json(msg) else: await session.ws.send_str(str(msg)) except asyncio.CancelledError: pass except Exception as e: logger.error(f"xAI Send loop error: {e}") self.emit("error", f"Send loop error: {e}") async def _receive_loop(self, session: XAISession) -> None: try: while not self._closing: msg = await session.ws.receive() if msg.type == aiohttp.WSMsgType.TEXT: data = json.loads(msg.data) await self._handle_event(data) elif msg.type == aiohttp.WSMsgType.CLOSED: logger.info("xAI WebSocket closed") break elif msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"xAI WebSocket error: {session.ws.exception()}") break except Exception as e: logger.error(f"xAI Receive loop error: {e}") self.emit("error", f"Receive loop error: {e}") finally: logger.info("xAI Receive loop finished, closing session") await self.aclose() async def _handle_event(self, data: dict) -> None: event_type = data.get("type") try: if event_type == "conversation.created": if not self._is_configured: await self._send_initial_config() self._is_configured = True elif event_type == "input_audio_buffer.speech_started": await self._handle_speech_started() elif event_type == "input_audio_buffer.speech_stopped": await self._handle_speech_stopped() elif event_type == "session.updated": logger.info("xAI Session updated successfully") self._session_ready.set() elif event_type == "response.created": logger.info(f"Response created: {data.get('response', {}).get('id')}") self._generated_text_in_current_response = False elif event_type == "response.output_item.added": logger.info(f"Output item added: {data.get('item', {}).get('id')}") elif event_type == "response.output_audio.delta": await self._handle_audio_delta(data) elif event_type == "response.output_audio_transcript.delta": await self._handle_transcript_delta(data) elif event_type == "response.output_audio_transcript.done": await self._handle_transcript_done(data) elif event_type == "conversation.item.input_audio_transcription.completed": await self._handle_input_audio_transcription_completed(data) elif event_type == "response.function_call_arguments.done": await self._handle_function_call(data) elif event_type == "response.done": await self._handle_response_done() elif event_type == "error": logger.error(f"xAI Error: {data}") except Exception as e: logger.error(f"Error handling event {event_type}: {e}") traceback.print_exc() async def _handle_speech_started(self) -> None: logger.info("xAI User speech started") self.emit("user_speech_started", {"type": "done"}) await realtime_metrics_collector.set_user_speech_start() if self.current_utterance and not self.current_utterance.is_interruptible: return await self.interrupt() async def _handle_speech_stopped(self) -> None: logger.info("xAI User speech stopped") await realtime_metrics_collector.set_user_speech_end() self.emit("user_speech_ended", {}) async def _handle_audio_delta(self, data: dict) -> None: delta = data.get("delta") if not delta: return if not self._agent_speaking: await realtime_metrics_collector.set_agent_speech_start() self._agent_speaking = True self.emit("agent_speech_started", {}) if self.audio_track and self.loop: audio_bytes = base64.b64decode(delta) asyncio.create_task(self.audio_track.add_new_bytes(audio_bytes)) async def _handle_transcript_delta(self, data: dict) -> None: delta = data.get("delta", "") if delta: self._generated_text_in_current_response = True if not hasattr(self, "_current_transcript"): self._current_transcript = "" self._current_transcript += delta async def _handle_transcript_done(self, data: dict) -> None: pass async def _handle_input_audio_transcription_completed(self, data: dict) -> None: """Handle input audio transcription completion for user transcript""" transcript = data.get("transcript", "") if transcript: logger.info(f"xAI User transcript: {transcript}") await realtime_metrics_collector.set_user_transcript(transcript) try: self.emit( "realtime_model_transcription", {"role": "user", "text": transcript, "is_final": True}, ) except Exception: pass async def _handle_response_done(self) -> None: if hasattr(self, "_current_transcript") and self._current_transcript: logger.info(f"xAI Agent response: {self._current_transcript}") await realtime_metrics_collector.set_agent_response(self._current_transcript) global_event_emitter.emit( "text_response", {"text": self._current_transcript, "type": "done"}, ) self._current_transcript = "" logger.info("xAI Agent speech ended") self.emit("agent_speech_ended", {}) await realtime_metrics_collector.set_agent_speech_end(timeout=1.0) self._agent_speaking = False if self._has_unprocessed_tool_outputs and not self._generated_text_in_current_response: logger.info("xAI: Triggering follow-up response for tool outputs") self._has_unprocessed_tool_outputs = False await self.create_response() else: self._has_unprocessed_tool_outputs = False async def _handle_function_call(self, data: dict) -> None: """Handle tool execution flow for xAI""" name = data.get("name") call_id = data.get("call_id") args_str = data.get("arguments") if not name or not args_str: return try: arguments = json.loads(args_str) logger.info(f"Executing tool: {name} with args: {arguments}") await realtime_metrics_collector.add_tool_call(name) result = None found = False for tool in self._tools: info = get_tool_info(tool) if info.name == name: result = await tool(**arguments) found = True break if not found: logger.warning(f"Tool {name} not found") result = {"error": "Tool not found"} await self.send_event({ "type": "conversation.item.create", "item": { "type": "function_call_output", "call_id": call_id, "output": json.dumps(result) } }) if found: self._has_unprocessed_tool_outputs = True except Exception as e: logger.error(f"Error executing function {name}: {e}") async def send_event(self, event: Dict[str, Any]) -> None: if self._session and not self._closing: await self._session.msg_queue.put(event) def _format_tools_for_session(self, tools: List[FunctionTool]) -> List[Dict[str, Any]]: """Format tools using OpenAI schema builder (xAI is compatible)""" formatted = [] for tool in tools: if is_function_tool(tool): try: schema = build_openai_schema(tool) formatted.append(schema) except Exception as e: logger.error(f"Failed to format tool {tool}: {e}") return formatted async def aclose(self) -> None: """Cleanup resources""" if self._closing: return self._closing = True if self._session: for task in self._session.tasks: if not task.done(): task.cancel() if not self._session.ws.closed: await self._session.ws.close() if self._http_session and not self._http_session.closed: await self._http_session.close() if hasattr(self.audio_track, "cleanup") and self.audio_track: await self.audio_track.cleanup()xAI's Grok realtime model implementation
Initialize the realtime model
Ancestors
- videosdk.agents.realtime_base_model.RealtimeBaseModel
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
- abc.ABC
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Cleanup resources""" if self._closing: return self._closing = True if self._session: for task in self._session.tasks: if not task.done(): task.cancel() if not self._session.ws.closed: await self._session.ws.close() if self._http_session and not self._http_session.closed: await self._http_session.close() if hasattr(self.audio_track, "cleanup") and self.audio_track: await self.audio_track.cleanup()Cleanup resources
async def connect(self) ‑> None-
Expand source code
async def connect(self) -> None: """Establish WebSocket connection to xAI""" headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } self._session = await self._create_session(self.base_url, headers) await self._handle_websocket(self._session) if self.audio_track: from fractions import Fraction self.audio_track.sample_rate = self.target_sample_rate self.audio_track.time_base_fraction = Fraction(1, self.target_sample_rate) self.audio_track.samples = int(0.02 * self.target_sample_rate) self.audio_track.chunk_size = int(self.audio_track.samples * getattr(self.audio_track, "channels", 1) * getattr(self.audio_track, "sample_width", 2)) try: await asyncio.wait_for(self._session_ready.wait(), timeout=10.0) logger.info("xAI session configuration complete") except asyncio.TimeoutError: logger.warning("Timeout waiting for xAI session configuration")Establish WebSocket connection to xAI
async def create_response(self) ‑> None-
Expand source code
async def create_response(self) -> None: """Trigger a response from the model""" await self.send_event({ "type": "response.create" })Trigger a response from the model
async def handle_audio_input(self, audio_data: bytes) ‑> None-
Expand source code
async def handle_audio_input(self, audio_data: bytes) -> None: """Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI""" if not self._session or self._closing: return if "audio" not in self.config.modalities: return if self.current_utterance and not self.current_utterance.is_interruptible: return try: raw_audio = np.frombuffer(audio_data, dtype=np.int16) if len(raw_audio) >= 1920 and len(raw_audio) % 2 == 0: raw_audio = (raw_audio.reshape(-1, 2).astype(np.int32).mean(axis=1)).astype(np.int16) if self.input_sample_rate != self.target_sample_rate: num_samples = int(len(raw_audio) * self.target_sample_rate / self.input_sample_rate) float_audio = raw_audio.astype(np.float32) resampled_audio = signal.resample(float_audio, num_samples).astype(np.int16) else: resampled_audio = raw_audio base64_audio = base64.b64encode(resampled_audio.tobytes()).decode("utf-8") if not hasattr(self, "_audio_log_counter"): self._audio_log_counter = 0 self._audio_log_counter += 1 if self._audio_log_counter % 100 == 0: rms = np.sqrt(np.mean(resampled_audio.astype(np.float32)**2)) logger.info(f"xAI Audio: Sent chunk {self._audio_log_counter}, samples={len(resampled_audio)}, RMS={rms:.2f}") await self.send_event({ "type": "input_audio_buffer.append", "audio": base64_audio }) except Exception as e: logger.error(f"Error processing audio input: {e}")Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI
async def handle_video_input(self, video_data: Any) ‑> None-
Expand source code
async def handle_video_input(self, video_data: Any) -> None: """xAI Voice API currently does not document direct video stream support in this endpoint.""" passxAI Voice API currently does not document direct video stream support in this endpoint.
async def interrupt(self) ‑> None-
Expand source code
async def interrupt(self) -> None: """Interrupt current generation""" if self._session and not self._closing: if self.current_utterance and not self.current_utterance.is_interruptible: return await realtime_metrics_collector.set_interrupted() if self.audio_track: self.audio_track.interrupt() if self._agent_speaking: self.emit("agent_speech_ended", {}) self._agent_speaking = FalseInterrupt current generation
async def send_event(self, event: Dict[str, Any]) ‑> None-
Expand source code
async def send_event(self, event: Dict[str, Any]) -> None: if self._session and not self._closing: await self._session.msg_queue.put(event) async def send_message(self, message: str) ‑> None-
Expand source code
async def send_message(self, message: str) -> None: """Send text message to trigger audio response""" await self.send_event({"type": "input_audio_buffer.commit"}) await self.send_event({ "type": "conversation.item.create", "item": { "type": "message", "role": "user", "content": [{ "type": "input_text", "text": message }], } }) await self.create_response()Send text message to trigger audio response
async def send_text_message(self, message: str) ‑> None-
Expand source code
async def send_text_message(self, message: str) -> None: """Send text message (same as send_message for xAI flow)""" await self.send_message(message)Send text message (same as send_message for xAI flow)
def set_agent(self, agent: Agent) ‑> None-
Expand source code
def set_agent(self, agent: Agent) -> None: if agent.instructions: self._instructions = agent.instructions self._tools = agent.tools self._formatted_tools = self._format_tools_for_session(self._tools)
class XAIRealtimeConfig (voice: XAIVoice = 'Ara',
instructions: str | None = None,
turn_detection: XAITurnDetection | None = <factory>,
modalities: List[str] = <factory>,
enable_web_search: bool = False,
enable_x_search: bool = False,
allowed_x_handles: List[str] | None = None,
collection_id: str | None = None,
max_num_results: int = 10)-
Expand source code
@dataclass class XAIRealtimeConfig: """Configuration for the xAI (Grok) Realtime API Args: voice: The voice identifier. Options: 'Ara', 'Rex', 'Sal', 'Eve', 'Leo'. Default: 'Ara' instructions: System instructions for the agent. turn_detection: Configuration for server-side VAD. tools: List of specific xAI tools (e.g., web_search, x_search). Standard function tools are handled via the Agent class. """ voice: XAIVoice = DEFAULT_XAI_VOICE instructions: str | None = None turn_detection: XAITurnDetection | None = field(default_factory=XAITurnDetection) modalities: List[str] = field(default_factory=lambda: ["text", "audio"]) enable_web_search: bool = False enable_x_search: bool = False allowed_x_handles: List[str] | None = None collection_id: str | None = None max_num_results: int = 10Configuration for the xAI (Grok) Realtime API
Args
voice- The voice identifier. Options: 'Ara', 'Rex', 'Sal', 'Eve', 'Leo'. Default: 'Ara'
instructions- System instructions for the agent.
turn_detection- Configuration for server-side VAD.
tools- List of specific xAI tools (e.g., web_search, x_search). Standard function tools are handled via the Agent class.
Instance variables
var allowed_x_handles : List[str] | Nonevar collection_id : str | Nonevar enable_web_search : boolvar enable_x_search : boolvar instructions : str | Nonevar max_num_results : intvar modalities : List[str]var turn_detection : XAITurnDetection | Nonevar voice : Literal['Ara', 'Rex', 'Sal', 'Eve', 'Leo']
class XAITurnDetection (type: "Literal['server_vad'] | None" = 'server_vad',
threshold: float = 0.5,
prefix_padding_ms: int = 300,
silence_duration_ms: int = 200)-
Expand source code
@dataclass class XAITurnDetection: type: Literal["server_vad"] | None = "server_vad" threshold: float = 0.5 prefix_padding_ms: int = 300 silence_duration_ms: int = 200XAITurnDetection(type: "Literal['server_vad'] | None" = 'server_vad', threshold: 'float' = 0.5, prefix_padding_ms: 'int' = 300, silence_duration_ms: 'int' = 200)
Instance variables
var prefix_padding_ms : intvar silence_duration_ms : intvar threshold : floatvar type : Literal['server_vad'] | None