Module videosdk.plugins.sarvamai.stt
Classes
class SarvamAISTT (*,
api_key: str | None = None,
model: str = 'saarika:v2.5',
language: str = 'en-IN',
input_sample_rate: int = 48000,
output_sample_rate: int = 16000)-
Expand source code
class SarvamAISTT(STT): def __init__( self, *, api_key: str | None = None, model: str = DEFAULT_MODEL, language: str = "en-IN", input_sample_rate: int = 48000, output_sample_rate: int = 16000, ) -> None: """Initialize the SarvamAI STT plugin with WebSocket support. Args: api_key: SarvamAI API key model: The model to use (default: "saarika:v2.5") language: The language code (default: "en-IN") input_sample_rate: Input sample rate (default: 48000) output_sample_rate: Output sample rate (default: 16000) """ super().__init__() if not SCIPY_AVAILABLE: raise ImportError("scipy is not installed. Please install it with 'pip install scipy'") self.api_key = api_key or os.getenv("SARVAMAI_API_KEY") if not self.api_key: raise ValueError("Sarvam AI API key must be provided either through api_key parameter or SARVAMAI_API_KEY environment variable") self.model = model self.language = language self.input_sample_rate = input_sample_rate self.output_sample_rate = output_sample_rate # WebSocket related self._session: aiohttp.ClientSession | None = None self._ws: aiohttp.ClientWebSocketResponse | None = None self._ws_task: asyncio.Task | None = None self._is_speaking = False self._lock = asyncio.Lock() async def _ensure_websocket(self) -> aiohttp.ClientWebSocketResponse: """Ensure WebSocket connection is established.""" if self._ws is None or self._ws.closed: await self._connect_websocket() if self._ws is None: raise RuntimeError("Failed to establish WebSocket connection") return self._ws async def _connect_websocket(self) -> None: """Connect to Sarvam WebSocket API.""" if self._session is None: self._session = aiohttp.ClientSession() ws_url = f"{SARVAM_STT_STREAMING_URL}?language-code={self.language}&model={self.model}&vad_signals=true" headers = {"api-subscription-key": self.api_key} self._ws = await self._session.ws_connect(ws_url, headers=headers) self._ws_task = asyncio.create_task(self._process_messages()) async def process_audio(self, audio_frames: bytes, **kwargs: Any) -> None: """Process audio frames and send to WebSocket.""" try: resampled_audio = self._resample_audio(audio_frames) audio_array = np.frombuffer(resampled_audio, dtype=np.int16) base64_audio = base64.b64encode(audio_array.tobytes()).decode('utf-8') audio_message = { "audio": { "data": base64_audio, "encoding": "audio/wav", "sample_rate": self.output_sample_rate, } } ws = await self._ensure_websocket() await ws.send_str(json.dumps(audio_message)) except Exception as e: logger.error(f"[SarvamAISTT] Error processing audio: {e}") async def _process_messages(self) -> None: """Process incoming WebSocket messages.""" if self._ws is None: return try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: data = json.loads(msg.data) await self._handle_message(data) elif msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"[SarvamAISTT] WebSocket error: {self._ws.exception()}") break except Exception as e: logger.error(f"[SarvamAISTT] Error in message processing: {e}") async def _handle_message(self, data: dict) -> None: """Handle different message types from Sarvam API.""" msg_type = data.get("type") if msg_type == "data": transcript_data = data.get("data", {}) transcript_text = transcript_data.get("transcript", "") language = transcript_data.get("language_code", self.language) if transcript_text and self._transcript_callback: event = STTResponse( event_type=SpeechEventType.FINAL, data=SpeechData( text=transcript_text, language=language, confidence=1.0 ) ) await self._transcript_callback(event) elif msg_type == "events": event_data = data.get("data", {}) signal_type = event_data.get("signal_type") if signal_type == "START_SPEECH": if not self._is_speaking: self._is_speaking = True global_event_emitter.emit("speech_started") elif signal_type == "END_SPEECH": if self._is_speaking: flush_message = {"type": "flush"} await self._ws.send_str(json.dumps(flush_message)) self._is_speaking = False global_event_emitter.emit("speech_stopped") elif msg_type == "error": error_info = data.get("error", "Unknown error") logger.error(f"[SarvamAISTT] API error: {error_info}") def _resample_audio(self, audio_bytes: bytes) -> bytes: """Resample audio from input sample rate to output sample rate and convert to mono.""" try: if not audio_bytes: return b'' raw_audio = np.frombuffer(audio_bytes, dtype=np.int16) if raw_audio.size == 0: return b'' if raw_audio.size % 2 == 0: stereo_audio = raw_audio.reshape(-1, 2) mono_audio = stereo_audio.astype(np.float32).mean(axis=1) else: mono_audio = raw_audio.astype(np.float32) if self.input_sample_rate != self.output_sample_rate: output_length = int(len(mono_audio) * self.output_sample_rate / self.input_sample_rate) resampled_data = signal.resample(mono_audio, output_length) else: resampled_data = mono_audio resampled_data = np.clip(resampled_data, -32767, 32767) return resampled_data.astype(np.int16).tobytes() except Exception as e: logger.error(f"Error resampling audio: {e}") return b'' async def aclose(self) -> None: """Close WebSocket connection and cleanup.""" if self._ws_task and not self._ws_task.done(): self._ws_task.cancel() try: await self._ws_task except asyncio.CancelledError: pass if self._ws and not self._ws.closed: await self._ws.close() if self._session and not self._session.closed: await self._session.close() await super().aclose()Base class for Speech-to-Text implementations
Initialize the SarvamAI STT plugin with WebSocket support.
Args
api_key- SarvamAI API key
model- The model to use (default: "saarika:v2.5")
language- The language code (default: "en-IN")
input_sample_rate- Input sample rate (default: 48000)
output_sample_rate- Output sample rate (default: 16000)
Ancestors
- videosdk.agents.stt.stt.STT
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Close WebSocket connection and cleanup.""" if self._ws_task and not self._ws_task.done(): self._ws_task.cancel() try: await self._ws_task except asyncio.CancelledError: pass if self._ws and not self._ws.closed: await self._ws.close() if self._session and not self._session.closed: await self._session.close() await super().aclose()Close WebSocket connection and cleanup.
async def process_audio(self, audio_frames: bytes, **kwargs: Any) ‑> None-
Expand source code
async def process_audio(self, audio_frames: bytes, **kwargs: Any) -> None: """Process audio frames and send to WebSocket.""" try: resampled_audio = self._resample_audio(audio_frames) audio_array = np.frombuffer(resampled_audio, dtype=np.int16) base64_audio = base64.b64encode(audio_array.tobytes()).decode('utf-8') audio_message = { "audio": { "data": base64_audio, "encoding": "audio/wav", "sample_rate": self.output_sample_rate, } } ws = await self._ensure_websocket() await ws.send_str(json.dumps(audio_message)) except Exception as e: logger.error(f"[SarvamAISTT] Error processing audio: {e}")Process audio frames and send to WebSocket.