Module videosdk.plugins.xai.stt
Classes
class XAISTT (*,
api_key: str | None = None,
sample_rate: int = 48000,
encoding: "Literal['pcm', 'mulaw', 'alaw']" = 'pcm',
interim_results: bool = True,
endpointing: int = 50,
language: str | None = 'en',
diarize: bool = False,
multichannel: bool = False,
channels: int = 1,
base_url: str = 'wss://api.x.ai/v1/stt')-
Expand source code
class XAISTT(BaseSTT): def __init__( self, *, api_key: str | None = None, sample_rate: int = 48000, encoding: Literal["pcm", "mulaw", "alaw"] = "pcm", interim_results: bool = True, endpointing: int = 50, language: str | None = "en", diarize: bool = False, multichannel: bool = False, channels: int = 1, base_url: str = XAI_STT_BASE_URL, ) -> None: """Initialize the xAI STT plugin. Args: api_key: xAI API key. Falls back to XAI_API_KEY env var. sample_rate: Audio sample rate in Hz. xAI accepts 8000/16000/22050/24000/44100/48000. Defaults to 48000 to match the framework's native input rate. encoding: Raw audio encoding. One of "pcm" (signed 16-bit LE), "mulaw", "alaw". interim_results: Emit partial transcripts (is_final=false) as they arrive. endpointing: Silence duration (ms) before xAI fires utterance-final. Range 0–5000. Kept low (50ms default) because the framework's VAD is the primary turn detector; flush() injects synthetic silence to force utterance-final, and a low threshold means flush latency is short. language: BCP-47 language code (e.g. "en", "fr"). Pass None to skip the param. xAI transcribes any supported language regardless of this — the value only enables Inverse Text Normalization (numbers, currencies in written form). diarize: When true, each word in the response includes a `speaker` field. multichannel: When true, transcribes each input channel independently. Requires interleaved multi-channel audio. When false, input is downmixed to mono. channels: Number of input channels (only relevant with multichannel=True). base_url: WebSocket endpoint URL. """ super().__init__() self.api_key = api_key or os.getenv("XAI_API_KEY") if not self.api_key: raise ValueError( "xAI API key must be provided either through the api_key parameter " "or the XAI_API_KEY environment variable" ) if sample_rate not in SUPPORTED_SAMPLE_RATES: raise ValueError( f"sample_rate must be one of {sorted(SUPPORTED_SAMPLE_RATES)}, got {sample_rate}" ) if encoding not in SUPPORTED_ENCODINGS: raise ValueError( f"encoding must be one of {sorted(SUPPORTED_ENCODINGS)}, got {encoding}" ) if not 0 <= endpointing <= 5000: raise ValueError(f"endpointing must be in [0, 5000], got {endpointing}") self.sample_rate = sample_rate self.encoding = encoding self.interim_results = interim_results self.endpointing = endpointing self.language = language self.diarize = diarize self.multichannel = multichannel self.channels = channels self.base_url = base_url self._session: Optional[aiohttp.ClientSession] = None self._ws: Optional[aiohttp.ClientWebSocketResponse] = None self._ws_task: Optional[asyncio.Task] = None self._server_ready: asyncio.Event = asyncio.Event() self._closed = False async def process_audio( self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any, ) -> None: """Process audio frames and stream them to xAI's WebSocket STT API.""" if self._closed: return if not self._ws: await self._connect_ws() self._ws_task = asyncio.create_task(self._listen_for_responses()) try: await asyncio.wait_for(self._server_ready.wait(), timeout=5.0) except asyncio.TimeoutError: self.emit("error", "Timed out waiting for xAI transcript.created") return try: audio_bytes = audio_frames if ( not self.multichannel and self.encoding == "pcm" and len(audio_frames) % 4 == 0 ): audio_data = np.frombuffer(audio_frames, dtype=np.int16) if audio_data.size > 0 and audio_data.size % 2 == 0: audio_data = ( audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16) ) audio_bytes = audio_data.tobytes() await self._ws.send_bytes(audio_bytes) except Exception as e: logger.error(f"Error sending audio to xAI STT: {e}") self.emit("error", str(e)) await self._reset_connection() async def _connect_ws(self) -> None: """Open the WebSocket connection to xAI's STT endpoint.""" if not self._session: self._session = aiohttp.ClientSession() params: list[tuple[str, str]] = [ ("sample_rate", str(self.sample_rate)), ("encoding", self.encoding), ("interim_results", str(self.interim_results).lower()), ("endpointing", str(self.endpointing)), ("diarize", str(self.diarize).lower()), ("multichannel", str(self.multichannel).lower()), ("channels", str(1 if not self.multichannel else self.channels)), ] if self.language: params.append(("language", self.language)) ws_url = f"{self.base_url}?{urlencode(params)}" headers = {"Authorization": f"Bearer {self.api_key}"} self._server_ready = asyncio.Event() try: self._ws = await self._session.ws_connect( ws_url, headers=headers, heartbeat=30.0 ) except Exception as e: logger.error(f"Error connecting to xAI STT WebSocket: {e}") raise async def _listen_for_responses(self) -> None: """Background task that reads transcript events from the WebSocket.""" if not self._ws: return try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: try: data = msg.json() except Exception as e: logger.error(f"Failed to parse xAI STT message: {e}") continue event_type = data.get("type") if event_type == "transcript.created": self._server_ready.set() elif event_type == "transcript.partial": response = self._handle_partial(data) if response and self._transcript_callback: await self._transcript_callback(response) elif event_type == "transcript.done": response = self._handle_partial(data) if response and self._transcript_callback: await self._transcript_callback(response) elif event_type == "error": message = data.get("message", "unknown error") logger.error(f"xAI STT error event: {message}") self.emit("error", message) elif msg.type == aiohttp.WSMsgType.ERROR: err = self._ws.exception() logger.error(f"xAI STT WebSocket error: {err}") self.emit("error", f"WebSocket error: {err}") break elif msg.type == aiohttp.WSMsgType.CLOSED: break except asyncio.CancelledError: raise except Exception as e: logger.error(f"Error in xAI STT listener: {e}") self.emit("error", f"Error in WebSocket listener: {e}") finally: if self._ws and not self._ws.closed: await self._ws.close() self._ws = None self._server_ready.clear() def _handle_partial(self, event: dict) -> Optional[STTResponse]: """Map an xAI transcript event to an STTResponse.""" text = event.get("text", "") if not text: return None is_final = bool(event.get("is_final", False)) speech_final = bool(event.get("speech_final", False)) event_is_done = event.get("type") == "transcript.done" event_type = ( SpeechEventType.FINAL if (is_final and speech_final) or event_is_done else SpeechEventType.INTERIM ) words = event.get("words") or [] start = event.get("start", 0.0) or 0.0 duration = event.get("duration", 0.0) or 0.0 if words: start_time = float(words[0].get("start", start)) end_time = float(words[-1].get("end", start + duration)) else: start_time = float(start) end_time = float(start) + float(duration) return STTResponse( event_type=event_type, data=SpeechData( text=text, language=self.language, confidence=0.0, start_time=start_time, end_time=end_time, duration=float(duration), ), metadata={ "is_final": is_final, "speech_final": speech_final, "channel_index": event.get("channel_index"), }, ) async def flush(self) -> None: """Force xAI to emit the current utterance-final transcript. We inject a short chunk of silence bytes — once the configured endpointing window elapses, xAI fires `transcript.partial` with speech_final=true, which we map to SpeechEventType.FINAL. """ if self._closed or not self._ws or self._ws.closed: return if not self._server_ready.is_set(): return try: silence_ms = max(self.endpointing + 50, 100) bytes_per_sample = 2 if self.encoding == "pcm" else 1 channel_count = self.channels if self.multichannel else 1 n_bytes = int( self.sample_rate * bytes_per_sample * channel_count * silence_ms / 1000 ) silence = _SILENCE_BYTE[self.encoding] * n_bytes await self._ws.send_bytes(silence) except Exception as e: logger.warning(f"xAI STT flush failed: {e}") async def _reset_connection(self) -> None: if self._ws: try: await self._ws.close() except Exception: pass self._ws = None if self._ws_task: self._ws_task.cancel() try: await self._ws_task except (asyncio.CancelledError, Exception): pass self._ws_task = None self._server_ready.clear() async def aclose(self) -> None: """Cleanup resources.""" self._closed = True if self._ws and not self._ws.closed: try: await self._ws.send_str(json.dumps({"type": "audio.done"})) except Exception: pass if self._ws_task: self._ws_task.cancel() try: await self._ws_task except (asyncio.CancelledError, Exception): pass self._ws_task = None if self._ws and not self._ws.closed: await self._ws.close() self._ws = None if self._session: await self._session.close() self._session = None await super().aclose()Base class for Speech-to-Text implementations
Initialize the xAI STT plugin.
Args
api_key- xAI API key. Falls back to XAI_API_KEY env var.
sample_rate- Audio sample rate in Hz. xAI accepts 8000/16000/22050/24000/44100/48000. Defaults to 48000 to match the framework's native input rate.
encoding- Raw audio encoding. One of "pcm" (signed 16-bit LE), "mulaw", "alaw".
interim_results- Emit partial transcripts (is_final=false) as they arrive.
endpointing- Silence duration (ms) before xAI fires utterance-final. Range 0–5000. Kept low (50ms default) because the framework's VAD is the primary turn detector; flush() injects synthetic silence to force utterance-final, and a low threshold means flush latency is short.
language- BCP-47 language code (e.g. "en", "fr"). Pass None to skip the param. xAI transcribes any supported language regardless of this — the value only enables Inverse Text Normalization (numbers, currencies in written form).
diarize- When true, each word in the response includes a
speakerfield. multichannel- When true, transcribes each input channel independently. Requires interleaved multi-channel audio. When false, input is downmixed to mono.
channels- Number of input channels (only relevant with multichannel=True).
base_url- WebSocket endpoint URL.
Ancestors
- videosdk.agents.stt.stt.STT
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Cleanup resources.""" self._closed = True if self._ws and not self._ws.closed: try: await self._ws.send_str(json.dumps({"type": "audio.done"})) except Exception: pass if self._ws_task: self._ws_task.cancel() try: await self._ws_task except (asyncio.CancelledError, Exception): pass self._ws_task = None if self._ws and not self._ws.closed: await self._ws.close() self._ws = None if self._session: await self._session.close() self._session = None await super().aclose()Cleanup resources.
async def flush(self) ‑> None-
Expand source code
async def flush(self) -> None: """Force xAI to emit the current utterance-final transcript. We inject a short chunk of silence bytes — once the configured endpointing window elapses, xAI fires `transcript.partial` with speech_final=true, which we map to SpeechEventType.FINAL. """ if self._closed or not self._ws or self._ws.closed: return if not self._server_ready.is_set(): return try: silence_ms = max(self.endpointing + 50, 100) bytes_per_sample = 2 if self.encoding == "pcm" else 1 channel_count = self.channels if self.multichannel else 1 n_bytes = int( self.sample_rate * bytes_per_sample * channel_count * silence_ms / 1000 ) silence = _SILENCE_BYTE[self.encoding] * n_bytes await self._ws.send_bytes(silence) except Exception as e: logger.warning(f"xAI STT flush failed: {e}")Force xAI to emit the current utterance-final transcript.
We inject a short chunk of silence bytes — once the configured endpointing window elapses, xAI fires
transcript.partialwith speech_final=true, which we map to SpeechEventType.FINAL. async def process_audio(self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any) ‑> None-
Expand source code
async def process_audio( self, audio_frames: bytes, language: Optional[str] = None, **kwargs: Any, ) -> None: """Process audio frames and stream them to xAI's WebSocket STT API.""" if self._closed: return if not self._ws: await self._connect_ws() self._ws_task = asyncio.create_task(self._listen_for_responses()) try: await asyncio.wait_for(self._server_ready.wait(), timeout=5.0) except asyncio.TimeoutError: self.emit("error", "Timed out waiting for xAI transcript.created") return try: audio_bytes = audio_frames if ( not self.multichannel and self.encoding == "pcm" and len(audio_frames) % 4 == 0 ): audio_data = np.frombuffer(audio_frames, dtype=np.int16) if audio_data.size > 0 and audio_data.size % 2 == 0: audio_data = ( audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16) ) audio_bytes = audio_data.tobytes() await self._ws.send_bytes(audio_bytes) except Exception as e: logger.error(f"Error sending audio to xAI STT: {e}") self.emit("error", str(e)) await self._reset_connection()Process audio frames and stream them to xAI's WebSocket STT API.