Package videosdk.plugins.xai
Sub-modules
videosdk.plugins.xai.llmvideosdk.plugins.xai.sttvideosdk.plugins.xai.ttsvideosdk.plugins.xai.xai_realtime
Classes
class XAILLM (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
base_url: str = 'https://api.x.ai/v1',
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None,
tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None)-
Expand source code
class XAILLM(LLM): """ LLM Plugin for xAI (Grok) API. Supports Grok-4, and reasoning models with standard client-side function calling. """ def __init__( self, *, api_key: str | None = None, model: str = "grok-4-1-fast-non-reasoning", base_url: str = "https://api.x.ai/v1", temperature: float = 0.7, tool_choice: ToolChoice = "auto", max_completion_tokens: int | None = None, tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None, ) -> None: """Initialize the xAI LLM plugin. Args: api_key (Optional[str], optional): xAI API key. Defaults to XAI_API_KEY env var. model (str): The model to use (e.g., "grok-4", "grok-4-1-fast"). base_url (str): The base URL for the xAI API. Defaults to "https://api.x.ai/v1". temperature (float): The temperature to use. Defaults to 0.7. tool_choice (ToolChoice): The tool choice to use. Defaults to "auto". max_completion_tokens (Optional[int], optional): Max tokens. tools (Optional[List], optional): List of FunctionTools to be available to the LLM. """ super().__init__() self.api_key = api_key or os.getenv("XAI_API_KEY") if not self.api_key: raise ValueError("xAI API key must be provided either through api_key parameter or XAI_API_KEY environment variable") self.model = model self.temperature = temperature self.tool_choice = tool_choice self.max_completion_tokens = max_completion_tokens self.tools = tools or [] self._cancelled = False self._client = openai.AsyncOpenAI( api_key=self.api_key, base_url=base_url, max_retries=0, http_client=httpx.AsyncClient( timeout=httpx.Timeout(connect=15.0, read=60.0, write=5.0, pool=5.0), follow_redirects=True, limits=httpx.Limits( max_connections=50, max_keepalive_connections=50, keepalive_expiry=120, ), ), ) async def chat( self, messages: ChatContext, tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None, conversational_graph: Any | None = None, **kwargs: Any ) -> AsyncIterator[LLMResponse]: """ Implement chat functionality using xAI's API via OpenAI SDK compatibility. """ self._cancelled = False openai_messages = messages.to_openai_messages() completion_params = { "model": self.model, "messages": openai_messages, "temperature": self.temperature, "stream": True, "max_tokens": self.max_completion_tokens, } if conversational_graph: completion_params["response_format"] = { "type": "json_schema", "json_schema": { "name": "conversational_graph_response", "strict": True, "schema": conversational_graph._get_graph_schema() } } combined_tools = (self.tools or []) + (tools or []) if combined_tools: formatted_tools = [] for tool in combined_tools: if is_function_tool(tool): try: tool_schema = build_openai_schema(tool) if "function" not in tool_schema: inner_tool = {k: v for k, v in tool_schema.items() if k != "type"} formatted_tools.append({ "type": "function", "function": inner_tool }) else: formatted_tools.append(tool_schema) except Exception as e: self.emit("error", f"Failed to format tool {tool}: {e}") continue elif isinstance(tool, dict): formatted_tools.append(tool) if formatted_tools: completion_params["tools"] = formatted_tools completion_params["tool_choice"] = self.tool_choice completion_params.update(kwargs) try: response_stream = await self._client.chat.completions.create(**completion_params) current_content = "" current_tool_calls = {} streaming_state = { "in_response": False, "response_start_index": -1, "yielded_content_length": 0 } async for chunk in response_stream: if self._cancelled: break if not chunk.choices: continue delta = chunk.choices[0].delta if delta.tool_calls: for tool_call in delta.tool_calls: idx = tool_call.index if idx not in current_tool_calls: current_tool_calls[idx] = { "id": tool_call.id or "", "name": tool_call.function.name or "", "arguments": tool_call.function.arguments or "" } else: if tool_call.function.name: current_tool_calls[idx]["name"] += tool_call.function.name if tool_call.function.arguments: current_tool_calls[idx]["arguments"] += tool_call.function.arguments if delta.content is not None: current_content += delta.content if conversational_graph: for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state): yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT) else: yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT) if current_tool_calls and not self._cancelled: for idx in sorted(current_tool_calls.keys()): tool_data = current_tool_calls[idx] try: args_str = tool_data["arguments"] parsed_args = json.loads(args_str) yield LLMResponse( content="", role=ChatRole.ASSISTANT, metadata={ "function_call": { "name": tool_data["name"], "arguments": parsed_args }, "tool_call_id": tool_data["id"] } ) except json.JSONDecodeError: self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}") if current_content and not self._cancelled and conversational_graph: try: parsed_json = json.loads(current_content.strip()) yield LLMResponse( content="", role=ChatRole.ASSISTANT, metadata=parsed_json ) except json.JSONDecodeError: pass except Exception as e: if not self._cancelled: self.emit("error", e) raise async def cancel_current_generation(self) -> None: self._cancelled = True async def aclose(self) -> None: """Cleanup resources""" await self.cancel_current_generation() if self._client: await self._client.close() await super().aclose()LLM Plugin for xAI (Grok) API. Supports Grok-4, and reasoning models with standard client-side function calling.
Initialize the xAI LLM plugin.
Args
api_key:Optional[str], optional- xAI API key. Defaults to XAI_API_KEY env var.
model:str- The model to use (e.g., "grok-4", "grok-4-1-fast").
base_url:str- The base URL for the xAI API. Defaults to "https://api.x.ai/v1".
temperature:float- The temperature to use. Defaults to 0.7.
tool_choice:ToolChoice- The tool choice to use. Defaults to "auto".
max_completion_tokens:Optional[int], optional- Max tokens.
tools:Optional[List], optional- List of FunctionTools to be available to the LLM.
Ancestors
- videosdk.agents.llm.llm.LLM
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Cleanup resources""" await self.cancel_current_generation() if self._client: await self._client.close() await super().aclose()Cleanup resources
async def cancel_current_generation(self) ‑> None-
Expand source code
async def cancel_current_generation(self) -> None: self._cancelled = TrueCancel the current LLM generation if active.
Raises
NotImplementedError- This method must be implemented by subclasses.
async def chat(self,
messages: ChatContext,
tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
conversational_graph: Any | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]-
Expand source code
async def chat( self, messages: ChatContext, tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None, conversational_graph: Any | None = None, **kwargs: Any ) -> AsyncIterator[LLMResponse]: """ Implement chat functionality using xAI's API via OpenAI SDK compatibility. """ self._cancelled = False openai_messages = messages.to_openai_messages() completion_params = { "model": self.model, "messages": openai_messages, "temperature": self.temperature, "stream": True, "max_tokens": self.max_completion_tokens, } if conversational_graph: completion_params["response_format"] = { "type": "json_schema", "json_schema": { "name": "conversational_graph_response", "strict": True, "schema": conversational_graph._get_graph_schema() } } combined_tools = (self.tools or []) + (tools or []) if combined_tools: formatted_tools = [] for tool in combined_tools: if is_function_tool(tool): try: tool_schema = build_openai_schema(tool) if "function" not in tool_schema: inner_tool = {k: v for k, v in tool_schema.items() if k != "type"} formatted_tools.append({ "type": "function", "function": inner_tool }) else: formatted_tools.append(tool_schema) except Exception as e: self.emit("error", f"Failed to format tool {tool}: {e}") continue elif isinstance(tool, dict): formatted_tools.append(tool) if formatted_tools: completion_params["tools"] = formatted_tools completion_params["tool_choice"] = self.tool_choice completion_params.update(kwargs) try: response_stream = await self._client.chat.completions.create(**completion_params) current_content = "" current_tool_calls = {} streaming_state = { "in_response": False, "response_start_index": -1, "yielded_content_length": 0 } async for chunk in response_stream: if self._cancelled: break if not chunk.choices: continue delta = chunk.choices[0].delta if delta.tool_calls: for tool_call in delta.tool_calls: idx = tool_call.index if idx not in current_tool_calls: current_tool_calls[idx] = { "id": tool_call.id or "", "name": tool_call.function.name or "", "arguments": tool_call.function.arguments or "" } else: if tool_call.function.name: current_tool_calls[idx]["name"] += tool_call.function.name if tool_call.function.arguments: current_tool_calls[idx]["arguments"] += tool_call.function.arguments if delta.content is not None: current_content += delta.content if conversational_graph: for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state): yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT) else: yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT) if current_tool_calls and not self._cancelled: for idx in sorted(current_tool_calls.keys()): tool_data = current_tool_calls[idx] try: args_str = tool_data["arguments"] parsed_args = json.loads(args_str) yield LLMResponse( content="", role=ChatRole.ASSISTANT, metadata={ "function_call": { "name": tool_data["name"], "arguments": parsed_args }, "tool_call_id": tool_data["id"] } ) except json.JSONDecodeError: self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}") if current_content and not self._cancelled and conversational_graph: try: parsed_json = json.loads(current_content.strip()) yield LLMResponse( content="", role=ChatRole.ASSISTANT, metadata=parsed_json ) except json.JSONDecodeError: pass except Exception as e: if not self._cancelled: self.emit("error", e) raiseImplement chat functionality using xAI's API via OpenAI SDK compatibility.
class XAIRealtime (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
config: XAIRealtimeConfig | None = None,
base_url: str | None = None)-
Expand source code
class XAIRealtime(RealtimeBaseModel[XAIEventTypes]): """xAI's Grok realtime model implementation""" def __init__( self, *, api_key: str | None = None, model: str = "grok-4-1-fast-non-reasoning", config: XAIRealtimeConfig | None = None, base_url: str | None = None, ) -> None: super().__init__() self.api_key = api_key or os.getenv("XAI_API_KEY") self.base_url = base_url or XAI_BASE_URL if not self.api_key: self.emit("error", "XAI_API_KEY is required") raise ValueError("XAI_API_KEY is required") self.config: XAIRealtimeConfig = config or XAIRealtimeConfig() self._http_session: Optional[aiohttp.ClientSession] = None self._session: Optional[XAISession] = None self._closing = False self._instructions: str = "You are a helpful assistant." self._tools: List[FunctionTool] = [] self._formatted_tools: List[Dict[str, Any]] = [] self.input_sample_rate = INPUT_SAMPLE_RATE self.target_sample_rate = DEFAULT_SAMPLE_RATE self._agent_speaking = False self._current_response_id: str | None = None self._is_configured = False self._session_ready = asyncio.Event() self._has_unprocessed_tool_outputs = False self._generated_text_in_current_response = False def set_agent(self, agent: Agent) -> None: self._agent = agent if agent.instructions: self._instructions = agent.instructions self._tools = agent.tools self._formatted_tools = self._format_tools_for_session(self._tools) async def connect(self) -> None: """Establish WebSocket connection to xAI""" headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } self._session = await self._create_session(self.base_url, headers) await self._handle_websocket(self._session) self.reframe_audio_track(self.target_sample_rate) try: await asyncio.wait_for(self._session_ready.wait(), timeout=10.0) logger.info("xAI session configuration complete") except asyncio.TimeoutError: logger.warning("Timeout waiting for xAI session configuration") async def _create_session(self, url: str, headers: dict) -> XAISession: if not self._http_session: self._http_session = aiohttp.ClientSession() try: ws = await self._http_session.ws_connect( url, headers=headers, autoping=True, heartbeat=10, timeout=30, ) except Exception as e: self.emit("error", f"Connection failed: {e}") raise msg_queue: asyncio.Queue = asyncio.Queue() tasks: list[asyncio.Task] = [] self._closing = False return XAISession(ws=ws, msg_queue=msg_queue, tasks=tasks) async def _send_initial_config(self) -> None: """Send session.update to configure voice and audio""" if not self._session: return tools_config = [] if self._formatted_tools: tools_config.extend(self._formatted_tools) if self.config.enable_web_search: tools_config.append({"type": "web_search"}) if self.config.enable_x_search or self.config.allowed_x_handles: x_search_config = {"type": "x_search"} if self.config.allowed_x_handles: logger.info(f"Allowed xAI handles: {self.config.allowed_x_handles}") x_search_config["allowed_x_handles"] = self.config.allowed_x_handles tools_config.append(x_search_config) if self.config.collection_id: tools_config.append({ "type": "file_search", "vector_store_ids": [self.config.collection_id], "max_num_results": self.config.max_num_results, }) session_update = { "type": "session.update", "session": { "instructions": self.instructions_with_context(self._instructions), "voice": self.config.voice, "audio": { "input": { "format": { "type": "audio/pcm", "rate": self.target_sample_rate } }, "output": { "format": { "type": "audio/pcm", "rate": self.target_sample_rate } } }, "turn_detection": { "type": "server_vad" }, "tools": tools_config if tools_config else None } } await self.send_event(session_update) async def handle_audio_input(self, audio_data: bytes) -> None: """Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI""" if not self._session or self._closing: return if "audio" not in self.config.modalities: return if self.current_utterance and not self.current_utterance.is_interruptible: return try: raw_audio = np.frombuffer(audio_data, dtype=np.int16) if len(raw_audio) >= 1920 and len(raw_audio) % 2 == 0: raw_audio = (raw_audio.reshape(-1, 2).astype(np.int32).mean(axis=1)).astype(np.int16) if self.input_sample_rate != self.target_sample_rate: num_samples = int(len(raw_audio) * self.target_sample_rate / self.input_sample_rate) float_audio = raw_audio.astype(np.float32) resampled_audio = signal.resample(float_audio, num_samples).astype(np.int16) else: resampled_audio = raw_audio base64_audio = base64.b64encode(resampled_audio.tobytes()).decode("utf-8") if not hasattr(self, "_audio_log_counter"): self._audio_log_counter = 0 self._audio_log_counter += 1 if self._audio_log_counter % 100 == 0: rms = np.sqrt(np.mean(resampled_audio.astype(np.float32)**2)) logger.info(f"xAI Audio: Sent chunk {self._audio_log_counter}, samples={len(resampled_audio)}, RMS={rms:.2f}") await self.send_event({ "type": "input_audio_buffer.append", "audio": base64_audio }) except Exception as e: logger.error(f"Error processing audio input: {e}") async def handle_video_input(self, video_data: Any) -> None: """xAI Voice API currently does not document direct video stream support in this endpoint.""" pass async def send_message(self, message: str) -> None: """Send text message to trigger audio response""" await self.send_event({"type": "input_audio_buffer.commit"}) await self.send_event({ "type": "conversation.item.create", "item": { "type": "message", "role": "user", "content": [{ "type": "input_text", "text": message }], } }) await self.create_response() async def create_response(self) -> None: """Trigger a response from the model""" await self.send_event({ "type": "response.create" }) async def send_text_message(self, message: str) -> None: """Send text message (same as send_message for xAI flow)""" await self.send_message(message) async def interrupt(self) -> None: """Interrupt current generation""" if self._session and not self._closing: if self.current_utterance and not self.current_utterance.is_interruptible: return if self.audio_track: self.audio_track.interrupt() if self._agent_speaking: if self.audio_track: self.audio_track.mark_synthesis_complete() self._agent_speaking = False metrics_collector.on_interrupted() async def _handle_websocket(self, session: XAISession) -> None: session.tasks.extend([ asyncio.create_task(self._send_loop(session), name="xai_send"), asyncio.create_task(self._receive_loop(session), name="xai_recv"), ]) async def _send_loop(self, session: XAISession) -> None: try: while not self._closing: msg = await session.msg_queue.get() if isinstance(msg, dict): logger.debug(f"Sending xAI event: {msg.get('type')}") await session.ws.send_json(msg) else: await session.ws.send_str(str(msg)) except asyncio.CancelledError: pass except Exception as e: logger.error(f"xAI Send loop error: {e}") self.emit("error", f"Send loop error: {e}") async def _receive_loop(self, session: XAISession) -> None: try: while not self._closing: msg = await session.ws.receive() if msg.type == aiohttp.WSMsgType.TEXT: data = json.loads(msg.data) await self._handle_event(data) elif msg.type == aiohttp.WSMsgType.CLOSED: logger.info("xAI WebSocket closed") break elif msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"xAI WebSocket error: {session.ws.exception()}") break except Exception as e: logger.error(f"xAI Receive loop error: {e}") self.emit("error", f"Receive loop error: {e}") finally: logger.info("xAI Receive loop finished, closing session") await self.aclose() async def _handle_event(self, data: dict) -> None: event_type = data.get("type") try: if event_type == "conversation.created": if not self._is_configured: await self._send_initial_config() self._is_configured = True elif event_type == "input_audio_buffer.speech_started": await self._handle_speech_started() elif event_type == "input_audio_buffer.speech_stopped": await self._handle_speech_stopped() elif event_type == "session.updated": logger.info("xAI Session updated successfully") self._session_ready.set() elif event_type == "response.created": logger.info(f"Response created: {data.get('response', {}).get('id')}") self._generated_text_in_current_response = False elif event_type == "response.output_item.added": logger.info(f"Output item added: {data.get('item', {}).get('id')}") elif event_type == "response.output_audio.delta": await self._handle_audio_delta(data) elif event_type == "response.output_audio_transcript.delta": await self._handle_transcript_delta(data) elif event_type == "response.output_audio_transcript.done": await self._handle_transcript_done(data) elif event_type == "conversation.item.input_audio_transcription.completed": await self._handle_input_audio_transcription_completed(data) elif event_type == "response.function_call_arguments.done": await self._handle_function_call(data) elif event_type == "response.done": await self._handle_response_done() elif event_type == "error": logger.error(f"xAI Error: {data}") except Exception as e: logger.error(f"Error handling event {event_type}: {e}") traceback.print_exc() async def _handle_speech_started(self) -> None: logger.info("xAI User speech started") self.emit("user_speech_started", {"type": "done"}) metrics_collector.on_user_speech_start() metrics_collector.start_turn() if self.current_utterance and not self.current_utterance.is_interruptible: return await self.interrupt() async def _handle_speech_stopped(self) -> None: logger.info("xAI User speech stopped") metrics_collector.on_user_speech_end() self.emit("user_speech_ended", {}) async def _handle_audio_delta(self, data: dict) -> None: delta = data.get("delta") if not delta: return if not self._agent_speaking: metrics_collector.on_agent_speech_start() self._agent_speaking = True self.emit("agent_speech_started", {}) if self.audio_track and self.loop: audio_bytes = base64.b64decode(delta) asyncio.create_task(self.audio_track.add_new_bytes(audio_bytes)) async def _handle_transcript_delta(self, data: dict) -> None: delta = data.get("delta", "") if delta: self._generated_text_in_current_response = True if not hasattr(self, "_current_transcript"): self._current_transcript = "" self._current_transcript += delta async def _handle_transcript_done(self, data: dict) -> None: pass async def _handle_input_audio_transcription_completed(self, data: dict) -> None: """Handle input audio transcription completion for user transcript""" transcript = data.get("transcript", "") if transcript: logger.info(f"xAI User transcript: {transcript}") metrics_collector.set_user_transcript(transcript) try: self.emit( "realtime_model_transcription", {"role": "user", "text": transcript, "is_final": True}, ) except Exception: pass async def _handle_response_done(self) -> None: if hasattr(self, "_current_transcript") and self._current_transcript: logger.info(f"xAI Agent response: {self._current_transcript}") metrics_collector.set_agent_response(self._current_transcript) try: self.emit( "realtime_model_transcription", {"role": "assistant", "text": self._current_transcript, "is_final": True}, ) except Exception: pass self.emit("llm_text_output", {"text": self._current_transcript}) global_event_emitter.emit( "text_response", {"text": self._current_transcript, "type": "done"}, ) self._current_transcript = "" logger.info("xAI Agent speech ended") if self.audio_track: self.audio_track.mark_synthesis_complete() self._agent_speaking = False if self._has_unprocessed_tool_outputs and not self._generated_text_in_current_response: logger.info("xAI: Triggering follow-up response for tool outputs") self._has_unprocessed_tool_outputs = False await self.create_response() else: self._has_unprocessed_tool_outputs = False async def _handle_function_call(self, data: dict) -> None: """Handle tool execution flow for xAI""" name = data.get("name") call_id = data.get("call_id") args_str = data.get("arguments") if not name or not args_str: return try: arguments = json.loads(args_str) logger.info(f"Executing tool: {name} with args: {arguments}") metrics_collector.add_function_tool_call(tool_name=name) result = None found = False for tool in self._tools: info = get_tool_info(tool) if info.name == name: result = await tool(**arguments) found = True break if not found: logger.warning(f"Tool {name} not found") result = {"error": "Tool not found"} self.emit( "realtime_model_function_executed", { "name": name, "arguments": args_str, "call_id": call_id, "output": result if isinstance(result, str) else json.dumps(result), "is_error": not found, }, ) await self.send_event({ "type": "conversation.item.create", "item": { "type": "function_call_output", "call_id": call_id, "output": json.dumps(result) } }) if found: self._has_unprocessed_tool_outputs = True except Exception as e: self.emit( "realtime_model_function_executed", { "name": name, "arguments": args_str, "call_id": call_id, "output": str(e), "is_error": True, }, ) logger.error(f"Error executing function {name}: {e}") async def send_event(self, event: Dict[str, Any]) -> None: if self._session and not self._closing: await self._session.msg_queue.put(event) def _format_tools_for_session(self, tools: List[FunctionTool]) -> List[Dict[str, Any]]: """Format tools using OpenAI schema builder (xAI is compatible)""" formatted = [] for tool in tools: if is_function_tool(tool): try: schema = build_openai_schema(tool) formatted.append(schema) except Exception as e: logger.error(f"Failed to format tool {tool}: {e}") return formatted async def aclose(self) -> None: """Cleanup resources""" if self._closing: return self._closing = True if self._session: for task in self._session.tasks: if not task.done(): task.cancel() if not self._session.ws.closed: await self._session.ws.close() if self._http_session and not self._http_session.closed: await self._http_session.close() await super().aclose()xAI's Grok realtime model implementation
Initialize the realtime model
Ancestors
- videosdk.agents.realtime_base_model.RealtimeBaseModel
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
- abc.ABC
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Cleanup resources""" if self._closing: return self._closing = True if self._session: for task in self._session.tasks: if not task.done(): task.cancel() if not self._session.ws.closed: await self._session.ws.close() if self._http_session and not self._http_session.closed: await self._http_session.close() await super().aclose()Cleanup resources
async def connect(self) ‑> None-
Expand source code
async def connect(self) -> None: """Establish WebSocket connection to xAI""" headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } self._session = await self._create_session(self.base_url, headers) await self._handle_websocket(self._session) self.reframe_audio_track(self.target_sample_rate) try: await asyncio.wait_for(self._session_ready.wait(), timeout=10.0) logger.info("xAI session configuration complete") except asyncio.TimeoutError: logger.warning("Timeout waiting for xAI session configuration")Establish WebSocket connection to xAI
async def create_response(self) ‑> None-
Expand source code
async def create_response(self) -> None: """Trigger a response from the model""" await self.send_event({ "type": "response.create" })Trigger a response from the model
async def handle_audio_input(self, audio_data: bytes) ‑> None-
Expand source code
async def handle_audio_input(self, audio_data: bytes) -> None: """Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI""" if not self._session or self._closing: return if "audio" not in self.config.modalities: return if self.current_utterance and not self.current_utterance.is_interruptible: return try: raw_audio = np.frombuffer(audio_data, dtype=np.int16) if len(raw_audio) >= 1920 and len(raw_audio) % 2 == 0: raw_audio = (raw_audio.reshape(-1, 2).astype(np.int32).mean(axis=1)).astype(np.int16) if self.input_sample_rate != self.target_sample_rate: num_samples = int(len(raw_audio) * self.target_sample_rate / self.input_sample_rate) float_audio = raw_audio.astype(np.float32) resampled_audio = signal.resample(float_audio, num_samples).astype(np.int16) else: resampled_audio = raw_audio base64_audio = base64.b64encode(resampled_audio.tobytes()).decode("utf-8") if not hasattr(self, "_audio_log_counter"): self._audio_log_counter = 0 self._audio_log_counter += 1 if self._audio_log_counter % 100 == 0: rms = np.sqrt(np.mean(resampled_audio.astype(np.float32)**2)) logger.info(f"xAI Audio: Sent chunk {self._audio_log_counter}, samples={len(resampled_audio)}, RMS={rms:.2f}") await self.send_event({ "type": "input_audio_buffer.append", "audio": base64_audio }) except Exception as e: logger.error(f"Error processing audio input: {e}")Process incoming audio: Resample 48k -> target (usually 24k) and send to xAI
async def handle_video_input(self, video_data: Any) ‑> None-
Expand source code
async def handle_video_input(self, video_data: Any) -> None: """xAI Voice API currently does not document direct video stream support in this endpoint.""" passxAI Voice API currently does not document direct video stream support in this endpoint.
async def interrupt(self) ‑> None-
Expand source code
async def interrupt(self) -> None: """Interrupt current generation""" if self._session and not self._closing: if self.current_utterance and not self.current_utterance.is_interruptible: return if self.audio_track: self.audio_track.interrupt() if self._agent_speaking: if self.audio_track: self.audio_track.mark_synthesis_complete() self._agent_speaking = False metrics_collector.on_interrupted()Interrupt current generation
async def send_event(self, event: Dict[str, Any]) ‑> None-
Expand source code
async def send_event(self, event: Dict[str, Any]) -> None: if self._session and not self._closing: await self._session.msg_queue.put(event) async def send_message(self, message: str) ‑> None-
Expand source code
async def send_message(self, message: str) -> None: """Send text message to trigger audio response""" await self.send_event({"type": "input_audio_buffer.commit"}) await self.send_event({ "type": "conversation.item.create", "item": { "type": "message", "role": "user", "content": [{ "type": "input_text", "text": message }], } }) await self.create_response()Send text message to trigger audio response
async def send_text_message(self, message: str) ‑> None-
Expand source code
async def send_text_message(self, message: str) -> None: """Send text message (same as send_message for xAI flow)""" await self.send_message(message)Send text message (same as send_message for xAI flow)
def set_agent(self, agent: Agent) ‑> None-
Expand source code
def set_agent(self, agent: Agent) -> None: self._agent = agent if agent.instructions: self._instructions = agent.instructions self._tools = agent.tools self._formatted_tools = self._format_tools_for_session(self._tools)
class XAIRealtimeConfig (voice: XAIVoice = 'Ara',
instructions: str | None = None,
turn_detection: XAITurnDetection | None = <factory>,
modalities: List[str] = <factory>,
enable_web_search: bool = False,
enable_x_search: bool = False,
allowed_x_handles: List[str] | None = None,
collection_id: str | None = None,
max_num_results: int = 10)-
Expand source code
@dataclass class XAIRealtimeConfig: """Configuration for the xAI (Grok) Realtime API Args: voice: The voice identifier. Options: 'Ara', 'Rex', 'Sal', 'Eve', 'Leo'. Default: 'Ara' instructions: System instructions for the agent. turn_detection: Configuration for server-side VAD. tools: List of specific xAI tools (e.g., web_search, x_search). Standard function tools are handled via the Agent class. """ voice: XAIVoice = DEFAULT_XAI_VOICE instructions: str | None = None turn_detection: XAITurnDetection | None = field(default_factory=XAITurnDetection) modalities: List[str] = field(default_factory=lambda: ["text", "audio"]) enable_web_search: bool = False enable_x_search: bool = False allowed_x_handles: List[str] | None = None collection_id: str | None = None max_num_results: int = 10Configuration for the xAI (Grok) Realtime API
Args
voice- The voice identifier. Options: 'Ara', 'Rex', 'Sal', 'Eve', 'Leo'. Default: 'Ara'
instructions- System instructions for the agent.
turn_detection- Configuration for server-side VAD.
tools- List of specific xAI tools (e.g., web_search, x_search). Standard function tools are handled via the Agent class.
Instance variables
var allowed_x_handles : List[str] | Nonevar collection_id : str | Nonevar enable_web_search : boolvar enable_x_search : boolvar instructions : str | Nonevar max_num_results : intvar modalities : List[str]var turn_detection : XAITurnDetection | Nonevar voice : Literal['Ara', 'Rex', 'Sal', 'Eve', 'Leo']
class XAISTT (*,
api_key: str | None = None,
sample_rate: int = 48000,
encoding: "Literal['pcm', 'mulaw', 'alaw']" = 'pcm',
interim_results: bool = True,
endpointing: int = 50,
language: str | None = 'en',
diarize: bool = False,
multichannel: bool = False,
channels: int = 1,
base_url: str = 'wss://api.x.ai/v1/stt')-
Expand source code
class XAISTT(BaseSTT): def __init__( self, *, api_key: str | None = None, sample_rate: int = 48000, encoding: Literal["pcm", "mulaw", "alaw"] = "pcm", interim_results: bool = True, endpointing: int = 50, language: str | None = "en", diarize: bool = False, multichannel: bool = False, channels: int = 1, base_url: str = XAI_STT_BASE_URL, ) -> None: """Initialize the xAI STT plugin. Args: api_key: xAI API key. Falls back to XAI_API_KEY env var. sample_rate: Audio sample rate in Hz. xAI accepts 8000/16000/22050/24000/44100/48000. Defaults to 48000 to match the framework's native input rate. encoding: Raw audio encoding. One of "pcm" (signed 16-bit LE), "mulaw", "alaw". interim_results: Emit partial transcripts (is_final=false) as they arrive. endpointing: Silence duration (ms) before xAI fires utterance-final. Range 0–5000. Kept low (50ms default) because the framework's VAD is the primary turn detector; flush() injects synthetic silence to force utterance-final, and a low threshold means flush latency is short. language: BCP-47 language code (e.g. "en", "fr"). Pass None to skip the param. xAI transcribes any supported language regardless of this — the value only enables Inverse Text Normalization (numbers, currencies in written form). diarize: When true, each word in the response includes a `speaker` field. multichannel: When true, transcribes each input channel independently. Requires interleaved multi-channel audio. When false, input is downmixed to mono. channels: Number of input channels (only relevant with multichannel=True). base_url: WebSocket endpoint URL. """ super().__init__() self.api_key = api_key or os.getenv("XAI_API_KEY") if not self.api_key: raise ValueError( "xAI API key must be provided either through the api_key parameter " "or the XAI_API_KEY environment variable" ) if sample_rate not in SUPPORTED_SAMPLE_RATES: raise ValueError( f"sample_rate must be one of {sorted(SUPPORTED_SAMPLE_RATES)}, got {sample_rate}" ) if encoding not in SUPPORTED_ENCODINGS: raise ValueError( f"encoding must be one of {sorted(SUPPORTED_ENCODINGS)}, got {encoding}" ) if not 0 <= endpointing <= 5000: raise ValueError(f"endpointing must be in [0, 5000], got {endpointing}") self.sample_rate = sample_rate self.encoding = encoding self.interim_results = interim_results self.endpointing = endpointing self.language = language self.diarize = diarize self.multichannel = multichannel self.channels = channels self.base_url = base_url self._session: Optional[aiohttp.ClientSession] = None self._ws: Optional[aiohttp.ClientWebSocketResponse] = None self._ws_task: Optional[asyncio.Task] = None self._server_ready: asyncio.Event = asyncio.Event() self._closed = False async def process_audio( self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any, ) -> None: """Process audio frames and stream them to xAI's WebSocket STT API.""" if self._closed: return if not self._ws: await self._connect_ws() self._ws_task = asyncio.create_task(self._listen_for_responses()) try: await asyncio.wait_for(self._server_ready.wait(), timeout=5.0) except asyncio.TimeoutError: self.emit("error", "Timed out waiting for xAI transcript.created") return try: audio_bytes = audio_frames if ( not self.multichannel and self.encoding == "pcm" and len(audio_frames) % 4 == 0 ): audio_data = np.frombuffer(audio_frames, dtype=np.int16) if audio_data.size > 0 and audio_data.size % 2 == 0: audio_data = ( audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16) ) audio_bytes = audio_data.tobytes() await self._ws.send_bytes(audio_bytes) except Exception as e: logger.error(f"Error sending audio to xAI STT: {e}") self.emit("error", str(e)) await self._reset_connection() async def _connect_ws(self) -> None: """Open the WebSocket connection to xAI's STT endpoint.""" if not self._session: self._session = aiohttp.ClientSession() params: list[tuple[str, str]] = [ ("sample_rate", str(self.sample_rate)), ("encoding", self.encoding), ("interim_results", str(self.interim_results).lower()), ("endpointing", str(self.endpointing)), ("diarize", str(self.diarize).lower()), ("multichannel", str(self.multichannel).lower()), ("channels", str(1 if not self.multichannel else self.channels)), ] if self.language: params.append(("language", self.language)) ws_url = f"{self.base_url}?{urlencode(params)}" headers = {"Authorization": f"Bearer {self.api_key}"} self._server_ready = asyncio.Event() try: self._ws = await self._session.ws_connect( ws_url, headers=headers, heartbeat=30.0 ) except Exception as e: logger.error(f"Error connecting to xAI STT WebSocket: {e}") raise async def _listen_for_responses(self) -> None: """Background task that reads transcript events from the WebSocket.""" if not self._ws: return try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: try: data = msg.json() except Exception as e: logger.error(f"Failed to parse xAI STT message: {e}") continue event_type = data.get("type") if event_type == "transcript.created": self._server_ready.set() elif event_type == "transcript.partial": response = self._handle_partial(data) if response and self._transcript_callback: await self._transcript_callback(response) elif event_type == "transcript.done": response = self._handle_partial(data) if response and self._transcript_callback: await self._transcript_callback(response) elif event_type == "error": message = data.get("message", "unknown error") logger.error(f"xAI STT error event: {message}") self.emit("error", message) elif msg.type == aiohttp.WSMsgType.ERROR: err = self._ws.exception() logger.error(f"xAI STT WebSocket error: {err}") self.emit("error", f"WebSocket error: {err}") break elif msg.type == aiohttp.WSMsgType.CLOSED: break except asyncio.CancelledError: raise except Exception as e: logger.error(f"Error in xAI STT listener: {e}") self.emit("error", f"Error in WebSocket listener: {e}") finally: if self._ws and not self._ws.closed: await self._ws.close() self._ws = None self._server_ready.clear() def _handle_partial(self, event: dict) -> Optional[STTResponse]: """Map an xAI transcript event to an STTResponse.""" text = event.get("text", "") if not text: return None is_final = bool(event.get("is_final", False)) speech_final = bool(event.get("speech_final", False)) event_is_done = event.get("type") == "transcript.done" event_type = ( SpeechEventType.FINAL if (is_final and speech_final) or event_is_done else SpeechEventType.INTERIM ) words = event.get("words") or [] start = event.get("start", 0.0) or 0.0 duration = event.get("duration", 0.0) or 0.0 if words: start_time = float(words[0].get("start", start)) end_time = float(words[-1].get("end", start + duration)) else: start_time = float(start) end_time = float(start) + float(duration) return STTResponse( event_type=event_type, data=SpeechData( text=text, language=self.language, confidence=0.0, start_time=start_time, end_time=end_time, duration=float(duration), ), metadata={ "is_final": is_final, "speech_final": speech_final, "channel_index": event.get("channel_index"), }, ) async def flush(self) -> None: """Force xAI to emit the current utterance-final transcript. We inject a short chunk of silence bytes — once the configured endpointing window elapses, xAI fires `transcript.partial` with speech_final=true, which we map to SpeechEventType.FINAL. """ if self._closed or not self._ws or self._ws.closed: return if not self._server_ready.is_set(): return try: silence_ms = max(self.endpointing + 50, 100) bytes_per_sample = 2 if self.encoding == "pcm" else 1 channel_count = self.channels if self.multichannel else 1 n_bytes = int( self.sample_rate * bytes_per_sample * channel_count * silence_ms / 1000 ) silence = _SILENCE_BYTE[self.encoding] * n_bytes await self._ws.send_bytes(silence) except Exception as e: logger.warning(f"xAI STT flush failed: {e}") async def _reset_connection(self) -> None: if self._ws: try: await self._ws.close() except Exception: pass self._ws = None if self._ws_task: self._ws_task.cancel() try: await self._ws_task except (asyncio.CancelledError, Exception): pass self._ws_task = None self._server_ready.clear() async def aclose(self) -> None: """Cleanup resources.""" self._closed = True if self._ws and not self._ws.closed: try: await self._ws.send_str(json.dumps({"type": "audio.done"})) except Exception: pass if self._ws_task: self._ws_task.cancel() try: await self._ws_task except (asyncio.CancelledError, Exception): pass self._ws_task = None if self._ws and not self._ws.closed: await self._ws.close() self._ws = None if self._session: await self._session.close() self._session = None await super().aclose()Base class for Speech-to-Text implementations
Initialize the xAI STT plugin.
Args
api_key- xAI API key. Falls back to XAI_API_KEY env var.
sample_rate- Audio sample rate in Hz. xAI accepts 8000/16000/22050/24000/44100/48000. Defaults to 48000 to match the framework's native input rate.
encoding- Raw audio encoding. One of "pcm" (signed 16-bit LE), "mulaw", "alaw".
interim_results- Emit partial transcripts (is_final=false) as they arrive.
endpointing- Silence duration (ms) before xAI fires utterance-final. Range 0–5000. Kept low (50ms default) because the framework's VAD is the primary turn detector; flush() injects synthetic silence to force utterance-final, and a low threshold means flush latency is short.
language- BCP-47 language code (e.g. "en", "fr"). Pass None to skip the param. xAI transcribes any supported language regardless of this — the value only enables Inverse Text Normalization (numbers, currencies in written form).
diarize- When true, each word in the response includes a
speakerfield. multichannel- When true, transcribes each input channel independently. Requires interleaved multi-channel audio. When false, input is downmixed to mono.
channels- Number of input channels (only relevant with multichannel=True).
base_url- WebSocket endpoint URL.
Ancestors
- videosdk.agents.stt.stt.STT
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Cleanup resources.""" self._closed = True if self._ws and not self._ws.closed: try: await self._ws.send_str(json.dumps({"type": "audio.done"})) except Exception: pass if self._ws_task: self._ws_task.cancel() try: await self._ws_task except (asyncio.CancelledError, Exception): pass self._ws_task = None if self._ws and not self._ws.closed: await self._ws.close() self._ws = None if self._session: await self._session.close() self._session = None await super().aclose()Cleanup resources.
async def flush(self) ‑> None-
Expand source code
async def flush(self) -> None: """Force xAI to emit the current utterance-final transcript. We inject a short chunk of silence bytes — once the configured endpointing window elapses, xAI fires `transcript.partial` with speech_final=true, which we map to SpeechEventType.FINAL. """ if self._closed or not self._ws or self._ws.closed: return if not self._server_ready.is_set(): return try: silence_ms = max(self.endpointing + 50, 100) bytes_per_sample = 2 if self.encoding == "pcm" else 1 channel_count = self.channels if self.multichannel else 1 n_bytes = int( self.sample_rate * bytes_per_sample * channel_count * silence_ms / 1000 ) silence = _SILENCE_BYTE[self.encoding] * n_bytes await self._ws.send_bytes(silence) except Exception as e: logger.warning(f"xAI STT flush failed: {e}")Force xAI to emit the current utterance-final transcript.
We inject a short chunk of silence bytes — once the configured endpointing window elapses, xAI fires
transcript.partialwith speech_final=true, which we map to SpeechEventType.FINAL. async def process_audio(self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any) ‑> None-
Expand source code
async def process_audio( self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any, ) -> None: """Process audio frames and stream them to xAI's WebSocket STT API.""" if self._closed: return if not self._ws: await self._connect_ws() self._ws_task = asyncio.create_task(self._listen_for_responses()) try: await asyncio.wait_for(self._server_ready.wait(), timeout=5.0) except asyncio.TimeoutError: self.emit("error", "Timed out waiting for xAI transcript.created") return try: audio_bytes = audio_frames if ( not self.multichannel and self.encoding == "pcm" and len(audio_frames) % 4 == 0 ): audio_data = np.frombuffer(audio_frames, dtype=np.int16) if audio_data.size > 0 and audio_data.size % 2 == 0: audio_data = ( audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16) ) audio_bytes = audio_data.tobytes() await self._ws.send_bytes(audio_bytes) except Exception as e: logger.error(f"Error sending audio to xAI STT: {e}") self.emit("error", str(e)) await self._reset_connection()Process audio frames and stream them to xAI's WebSocket STT API.
class XAITTS (*,
api_key: str | None = None,
voice: str = 'eve',
language: str = 'en',
codec: "Literal['pcm', 'mulaw']" = 'pcm',
sample_rate: int = 24000,
optimize_streaming_latency: int = 0,
text_normalization: bool = False,
base_url: str = 'wss://api.x.ai/v1/tts',
max_connection_age_sec: float = 300.0)-
Expand source code
class XAITTS(TTS): def __init__( self, *, api_key: str | None = None, voice: str = "eve", language: str = "en", codec: Literal["pcm", "mulaw"] = "pcm", sample_rate: int = 24000, optimize_streaming_latency: int = 0, text_normalization: bool = False, base_url: str = XAI_TTS_BASE_URL, max_connection_age_sec: float = DEFAULT_CONNECTION_MAX_AGE_SEC, ) -> None: """Initialize the xAI TTS plugin. Args: api_key: xAI API key. Falls back to XAI_API_KEY env var. voice: Voice ID — one of "eve", "ara", "rex", "sal", "leo". Case-insensitive. language: BCP-47 language code (e.g. "en", "fr", "pt-BR") or "auto" for automatic language detection. Required by xAI. codec: Output codec. Restricted to "pcm" (signed 16-bit LE, default) or "mulaw" — both are raw, byte-streamable formats compatible with the framework's audio_track. mp3/wav/alaw are not exposed because they require a decoder before bytes can be played. sample_rate: Output sample rate in Hz. One of 8000/16000/22050/24000/44100/48000. Defaults to 24000 (xAI's recommended rate). optimize_streaming_latency: 0 (default, best quality) or 1 (lower time-to-first-audio with minor quality tradeoff). text_normalization: When true, xAI normalizes written-form text (numbers, abbreviations, symbols) into spoken-form before synthesis. base_url: WebSocket endpoint URL. Speech tags: xAI supports inline expression tags ([pause], [long-pause], [laugh], [sigh], [breath], etc.) and wrapping style tags (<whisper>...</whisper>, <soft>, <loud>, <slow>, <fast>, <higher-pitch>, <lower-pitch>, <emphasis>, <singing>, <sing-song>, <laugh-speak>, <build-intensity>, <decrease-intensity>) directly inside the `text` you pass to `synthesize()`. No separate parameter is needed — the tags are sent verbatim as part of each text.delta message and parsed server-side. Example:: await tts.synthesize( "So I walked in and [pause] there it was. [laugh] Incredible!" ) await tts.synthesize( "I need to tell you something. " "<whisper>It is a secret.</whisper> Pretty cool, right?" ) Caveat for streaming input: when synthesize() receives an AsyncIterator[str] (e.g. LLM tokens), a single tag can be split across two chunks ("[pa", "use]") which xAI will not recognize. Tags only work reliably when an entire tag arrives within one text chunk. """ if sample_rate not in SUPPORTED_SAMPLE_RATES: raise ValueError( f"sample_rate must be one of {sorted(SUPPORTED_SAMPLE_RATES)}, got {sample_rate}" ) if codec not in SUPPORTED_CODECS: raise ValueError( f"codec must be one of {sorted(SUPPORTED_CODECS)} for raw PCM-compatible " f"output (got {codec}). mp3/wav/alaw are not supported because they " f"produce framed audio that the audio_track cannot consume directly." ) if optimize_streaming_latency not in (0, 1): raise ValueError("optimize_streaming_latency must be 0 or 1") super().__init__( sample_rate=sample_rate, num_channels=XAI_TTS_NUM_CHANNELS, word_timestamps=False, ) self._api_key = api_key or os.getenv("XAI_API_KEY") if not self._api_key: raise ValueError( "xAI API key must be provided either through the api_key parameter " "or the XAI_API_KEY environment variable" ) voice_lower = voice.lower() if voice_lower not in SUPPORTED_VOICES: raise ValueError( f"voice must be one of {sorted(SUPPORTED_VOICES)}, got {voice}" ) self._voice = voice_lower self.language = language self.codec = codec self.optimize_streaming_latency = optimize_streaming_latency self.text_normalization = text_normalization self.base_url = base_url self._max_connection_age_sec = max_connection_age_sec self._ws_session: Optional[aiohttp.ClientSession] = None self._ws_connection: Optional[aiohttp.ClientWebSocketResponse] = None self._ws_connect_time: float = 0.0 self._connection_lock = asyncio.Lock() self._synthesis_lock = asyncio.Lock() self._receive_task: Optional[asyncio.Task] = None self._current_done_future: Optional[asyncio.Future[None]] = None self._first_chunk_sent = False self._interrupted = False self._closed = False def reset_first_audio_tracking(self) -> None: """Reset the first-audio-byte tracking state for the next synthesis turn.""" self._first_chunk_sent = False async def prewarm(self) -> None: """Pre-establish the xAI WebSocket so the first ``synthesize()`` call does not pay the TLS + auth + upgrade cost. Safe to call repeatedly.""" try: await self._ensure_ws_connection() except Exception as e: logger.warning(f"xAI TTS prewarm failed (non-fatal): {e}") async def synthesize( self, text: AsyncIterator[Union[str, FlushMarker]] | str, voice_id: Optional[str] = None, **kwargs: Any, ) -> None: """Synthesize text to speech via xAI's bidirectional WebSocket API. ``FlushMarker`` segment markers in the input stream are silently dropped — xAI has no per-sentence flush primitive, and the server segments naturally on ``text.done``.""" try: if not self.audio_track or not self.loop: self.emit("error", "Audio track or event loop not set") return if voice_id: voice_lower = voice_id.lower() if voice_lower not in SUPPORTED_VOICES: self.emit( "error", f"voice_id must be one of {sorted(SUPPORTED_VOICES)}, got {voice_id}", ) return self._voice = voice_lower async with self._synthesis_lock: self._interrupted = False self._first_chunk_sent = False await self._ensure_ws_connection() if not self._ws_connection: raise RuntimeError("WebSocket connection is not available.") done_future: asyncio.Future[None] = ( asyncio.get_event_loop().create_future() ) self._current_done_future = done_future async def _string_iterator(s: str) -> AsyncIterator[str]: yield s text_iterator = ( _string_iterator(text) if isinstance(text, str) else text ) send_task = asyncio.create_task( self._send_task(text_iterator, done_future) ) try: await done_future finally: if not send_task.done(): try: await send_task except Exception: pass except Exception as e: self.emit("error", f"TTS synthesis failed: {e}") raise finally: self._current_done_future = None async def _send_task( self, text_iterator: AsyncIterator[Union[str, FlushMarker]], done_future: asyncio.Future[None], ) -> None: """Send text.delta messages, then text.done at end of utterance.""" has_sent = False try: async for chunk in text_iterator: if self._interrupted: break if isinstance(chunk, FlushMarker): # xAI has no per-sentence flush primitive — the server # segments naturally on ``text.done`` at end-of-utterance. continue if not chunk or not chunk.strip(): continue if not self._ws_connection or self._ws_connection.closed: break payload = {"type": "text.delta", "delta": chunk} await self._ws_connection.send_str(json.dumps(payload)) has_sent = True except Exception as e: if not done_future.done(): done_future.set_exception(e) return finally: if ( has_sent and not self._interrupted and self._ws_connection and not self._ws_connection.closed ): try: await self._ws_connection.send_str( json.dumps({"type": "text.done"}) ) except Exception as e: if not done_future.done(): done_future.set_exception(e) if not has_sent and not done_future.done(): done_future.set_result(None) async def _receive_loop(self) -> None: """Long-running task: read audio.delta / audio.done / error frames.""" try: while self._ws_connection and not self._ws_connection.closed: msg = await self._ws_connection.receive() if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): break if msg.type == aiohttp.WSMsgType.ERROR: err = self._ws_connection.exception() self._fail_pending(RuntimeError(f"xAI TTS WebSocket error: {err}")) break if msg.type != aiohttp.WSMsgType.TEXT: continue try: data = json.loads(msg.data) except Exception as e: logger.error(f"Failed to parse xAI TTS message: {e}") continue event_type = data.get("type") if event_type == "audio.delta": delta = data.get("delta") if delta: try: await self._stream_audio(base64.b64decode(delta)) except Exception as e: logger.error(f"Failed to decode/stream audio: {e}") elif event_type == "audio.done": future = self._current_done_future if future and not future.done(): future.set_result(None) elif event_type == "error": message = data.get("message", "unknown error") self._fail_pending(RuntimeError(f"xAI TTS error: {message}")) except asyncio.CancelledError: raise except Exception as e: self._fail_pending(e) def _fail_pending(self, exc: BaseException) -> None: future = self._current_done_future if future and not future.done(): future.set_exception(exc) async def _stream_audio(self, audio_chunk: bytes) -> None: """Push a chunk of raw audio bytes into the framework's audio_track.""" if self._interrupted or not audio_chunk: return # Drop late audio that belongs to a cancelled synthesis: if the active # done_future has already resolved (cancelled or completed), the frame # is for a stale context and would bleed into the next turn. future = self._current_done_future if future is not None and future.done(): return if not self._first_chunk_sent: self._first_chunk_sent = True if self._first_audio_callback: await self._first_audio_callback() if self.audio_track: await self.audio_track.add_new_bytes(audio_chunk) async def _ensure_ws_connection(self) -> None: """Open or re-open the WebSocket connection if needed.""" async with self._connection_lock: now = asyncio.get_event_loop().time() if self._ws_connection and not self._ws_connection.closed: age = now - self._ws_connect_time if age < self._max_connection_age_sec: return logger.info(f"Refreshing xAI WebSocket (age={age:.1f}s)") if self._receive_task and not self._receive_task.done(): self._receive_task.cancel() try: await self._receive_task except (asyncio.CancelledError, Exception): pass self._receive_task = None if self._ws_connection: try: await self._ws_connection.close() except Exception: pass self._ws_connection = None if self._ws_session: try: await self._ws_session.close() except Exception: pass self._ws_session = None try: self._ws_session = aiohttp.ClientSession() params = [ ("voice", self._voice), ("language", self.language), ("codec", self.codec), ("sample_rate", str(self.sample_rate)), ("optimize_streaming_latency", str(self.optimize_streaming_latency)), ("text_normalization", str(self.text_normalization).lower()), ] ws_url = f"{self.base_url}?{urlencode(params)}" headers = {"Authorization": f"Bearer {self._api_key}"} self._ws_connection = await asyncio.wait_for( self._ws_session.ws_connect( ws_url, headers=headers, heartbeat=30.0 ), timeout=5.0, ) self._ws_connect_time = now self._receive_task = asyncio.create_task(self._receive_loop()) except aiohttp.WSServerHandshakeError as e: self.emit( "error", f"xAI TTS WebSocket handshake failed (status {e.status}): {e.message}", ) raise except Exception as e: self.emit("error", f"Failed to establish xAI TTS WebSocket: {e}") raise async def interrupt(self) -> None: """Stop emitting audio for the current synthesis. Keeps the WebSocket open so the next turn does not pay reconnect cost; in-flight audio frames received after this point are dropped via the done-future filter in :meth:`_stream_audio`.""" self._interrupted = True if self.audio_track: self.audio_track.interrupt() future = self._current_done_future if future and not future.done(): future.cancel() async def aclose(self) -> None: """Gracefully clean up all resources.""" await super().aclose() self._interrupted = True self._closed = True if self._receive_task and not self._receive_task.done(): self._receive_task.cancel() try: await self._receive_task except (asyncio.CancelledError, Exception): pass self._receive_task = None if self._ws_connection and not self._ws_connection.closed: await self._ws_connection.close() self._ws_connection = None if self._ws_session and not self._ws_session.closed: await self._ws_session.close() self._ws_session = NoneBase class for Text-to-Speech implementations
Initialize the xAI TTS plugin.
Args
api_key- xAI API key. Falls back to XAI_API_KEY env var.
voice- Voice ID — one of "eve", "ara", "rex", "sal", "leo". Case-insensitive.
language- BCP-47 language code (e.g. "en", "fr", "pt-BR") or "auto" for automatic language detection. Required by xAI.
codec- Output codec. Restricted to "pcm" (signed 16-bit LE, default) or "mulaw" — both are raw, byte-streamable formats compatible with the framework's audio_track. mp3/wav/alaw are not exposed because they require a decoder before bytes can be played.
sample_rate- Output sample rate in Hz. One of 8000/16000/22050/24000/44100/48000. Defaults to 24000 (xAI's recommended rate).
optimize_streaming_latency- 0 (default, best quality) or 1 (lower time-to-first-audio with minor quality tradeoff).
text_normalization- When true, xAI normalizes written-form text (numbers, abbreviations, symbols) into spoken-form before synthesis.
base_url- WebSocket endpoint URL.
Speech tags: xAI supports inline expression tags ([pause], [long-pause], [laugh], [sigh], [breath], etc.) and wrapping style tags (
… ,, , , , , , , , , , , ) directly inside the textyou pass tosynthesize(). No separate parameter is needed — the tags are sent verbatim as part of each text.delta message and parsed server-side.Example:: await tts.synthesize( "So I walked in and [pause] there it was. [laugh] Incredible!" ) await tts.synthesize( "I need to tell you something. " "<whisper>It is a secret.</whisper> Pretty cool, right?" ) Caveat for streaming input: when synthesize() receives an AsyncIterator[str] (e.g. LLM tokens), a single tag can be split across two chunks ("[pa", "use]") which xAI will not recognize. Tags only work reliably when an entire tag arrives within one text chunk.Ancestors
- videosdk.agents.tts.tts.TTS
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Gracefully clean up all resources.""" await super().aclose() self._interrupted = True self._closed = True if self._receive_task and not self._receive_task.done(): self._receive_task.cancel() try: await self._receive_task except (asyncio.CancelledError, Exception): pass self._receive_task = None if self._ws_connection and not self._ws_connection.closed: await self._ws_connection.close() self._ws_connection = None if self._ws_session and not self._ws_session.closed: await self._ws_session.close() self._ws_session = NoneGracefully clean up all resources.
async def interrupt(self) ‑> None-
Expand source code
async def interrupt(self) -> None: """Stop emitting audio for the current synthesis. Keeps the WebSocket open so the next turn does not pay reconnect cost; in-flight audio frames received after this point are dropped via the done-future filter in :meth:`_stream_audio`.""" self._interrupted = True if self.audio_track: self.audio_track.interrupt() future = self._current_done_future if future and not future.done(): future.cancel()Stop emitting audio for the current synthesis. Keeps the WebSocket open so the next turn does not pay reconnect cost; in-flight audio frames received after this point are dropped via the done-future filter in :meth:
_stream_audio. async def prewarm(self) ‑> None-
Expand source code
async def prewarm(self) -> None: """Pre-establish the xAI WebSocket so the first ``synthesize()`` call does not pay the TLS + auth + upgrade cost. Safe to call repeatedly.""" try: await self._ensure_ws_connection() except Exception as e: logger.warning(f"xAI TTS prewarm failed (non-fatal): {e}")Pre-establish the xAI WebSocket so the first
synthesize()call does not pay the TLS + auth + upgrade cost. Safe to call repeatedly. def reset_first_audio_tracking(self) ‑> None-
Expand source code
def reset_first_audio_tracking(self) -> None: """Reset the first-audio-byte tracking state for the next synthesis turn.""" self._first_chunk_sent = FalseReset the first-audio-byte tracking state for the next synthesis turn.
async def synthesize(self,
text: AsyncIterator[Union[str, FlushMarker]] | str,
voice_id: Optional[str] = None,
**kwargs: Any) ‑> None-
Expand source code
async def synthesize( self, text: AsyncIterator[Union[str, FlushMarker]] | str, voice_id: Optional[str] = None, **kwargs: Any, ) -> None: """Synthesize text to speech via xAI's bidirectional WebSocket API. ``FlushMarker`` segment markers in the input stream are silently dropped — xAI has no per-sentence flush primitive, and the server segments naturally on ``text.done``.""" try: if not self.audio_track or not self.loop: self.emit("error", "Audio track or event loop not set") return if voice_id: voice_lower = voice_id.lower() if voice_lower not in SUPPORTED_VOICES: self.emit( "error", f"voice_id must be one of {sorted(SUPPORTED_VOICES)}, got {voice_id}", ) return self._voice = voice_lower async with self._synthesis_lock: self._interrupted = False self._first_chunk_sent = False await self._ensure_ws_connection() if not self._ws_connection: raise RuntimeError("WebSocket connection is not available.") done_future: asyncio.Future[None] = ( asyncio.get_event_loop().create_future() ) self._current_done_future = done_future async def _string_iterator(s: str) -> AsyncIterator[str]: yield s text_iterator = ( _string_iterator(text) if isinstance(text, str) else text ) send_task = asyncio.create_task( self._send_task(text_iterator, done_future) ) try: await done_future finally: if not send_task.done(): try: await send_task except Exception: pass except Exception as e: self.emit("error", f"TTS synthesis failed: {e}") raise finally: self._current_done_future = NoneSynthesize text to speech via xAI's bidirectional WebSocket API.
FlushMarkersegment markers in the input stream are silently dropped — xAI has no per-sentence flush primitive, and the server segments naturally ontext.done.
class XAITurnDetection (type: "Literal['server_vad'] | None" = 'server_vad',
threshold: float = 0.5,
prefix_padding_ms: int = 300,
silence_duration_ms: int = 200)-
Expand source code
@dataclass class XAITurnDetection: type: Literal["server_vad"] | None = "server_vad" threshold: float = 0.5 prefix_padding_ms: int = 300 silence_duration_ms: int = 200XAITurnDetection(type: "Literal['server_vad'] | None" = 'server_vad', threshold: 'float' = 0.5, prefix_padding_ms: 'int' = 300, silence_duration_ms: 'int' = 200)
Instance variables
var prefix_padding_ms : intvar silence_duration_ms : intvar threshold : floatvar type : Literal['server_vad'] | None