Package videosdk.plugins.sarvamai

Sub-modules

videosdk.plugins.sarvamai.llm
videosdk.plugins.sarvamai.stt
videosdk.plugins.sarvamai.tts

Classes

class SarvamAILLM (*,
api_key: str | None = None,
model: str = 'sarvam-m',
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None)
Expand source code
class SarvamAILLM(LLM):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = DEFAULT_MODEL,
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_completion_tokens: int | None = None,
    ) -> None:
        """Initialize the SarvamAI LLM plugin.

        Args:
            api_key (Optional[str], optional): SarvamAI API key. Defaults to None.
            model (str): The model to use for the LLM plugin. Defaults to "sarvam-m".
            temperature (float): The temperature to use for the LLM plugin. Defaults to 0.7.
            tool_choice (ToolChoice): The tool choice to use for the LLM plugin. Defaults to "auto".
            max_completion_tokens (Optional[int], optional): The maximum completion tokens to use for the LLM plugin. Defaults to None.
        """
        super().__init__()
        self.api_key = api_key or os.getenv("SARVAMAI_API_KEY")
        if not self.api_key:
            raise ValueError("Sarvam AI API key must be provided either through api_key parameter or SARVAMAI_API_KEY environment variable")
        
        self.model = model
        self.temperature = temperature
        self.tool_choice = tool_choice
        self.max_completion_tokens = max_completion_tokens
        self._cancelled = False
        
        self._client = httpx.AsyncClient(
            timeout=httpx.Timeout(connect=15.0, read=30.0, write=5.0, pool=5.0),
            follow_redirects=True,
        )

    async def chat(
        self,
        messages: ChatContext,
        tools: list[FunctionTool] | None = None,
        **kwargs: Any
    ) -> AsyncIterator[LLMResponse]:
        self._cancelled = False
        
        def _extract_text_content(content: Union[str, List[ChatContent]]) -> str:
            if isinstance(content, str):
                return content
            text_parts = [part for part in content if isinstance(part, str)]
            return "\n".join(text_parts)

        system_prompt = None
        message_items = list(messages.items)
        if (
            message_items
            and isinstance(message_items[0], ChatMessage)
            and message_items[0].role == ChatRole.SYSTEM
        ):
            system_prompt = {
                "role": "system",
                "content": _extract_text_content(message_items.pop(0).content),
            }

        cleaned_messages = []
        last_role = None
        for msg in message_items:
            if not isinstance(msg, ChatMessage):
                continue

            current_role_str = msg.role.value
            
            if not cleaned_messages and current_role_str == 'assistant':
                continue

            text_content = _extract_text_content(msg.content)
            if not text_content.strip():
                continue

            if last_role == 'user' and current_role_str == 'user':
                cleaned_messages[-1]['content'] += ' ' + text_content
                continue
            
            if last_role == current_role_str:
                cleaned_messages.pop()

            cleaned_messages.append({"role": current_role_str, "content": text_content})
            last_role = current_role_str

        final_messages = [system_prompt] + cleaned_messages if system_prompt else cleaned_messages
        
        try:
            payload = {
                "model": self.model,
                "messages": final_messages,
                "temperature": self.temperature,
                "stream": True,
            }

            if self.max_completion_tokens:
                payload['max_tokens'] = self.max_completion_tokens
            
            payload.update(kwargs)
            
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.api_key}"
            }

            async with self._client.stream("POST", SARVAM_CHAT_COMPLETION_URL, json=payload, headers=headers) as response:
                response.raise_for_status()
                
                current_content = ""
                async for line in response.aiter_lines():
                    if self._cancelled:
                        break
                        
                    if not line.startswith("data:"):
                        continue
                    data_str = line[len("data:"):].strip()
                    if not data_str:
                        continue
                    if data_str == "[DONE]":
                        break
                    
                    chunk = json.loads(data_str)
                    delta = chunk.get("choices", [{}])[0].get("delta", {})
                    if "content" in delta and delta["content"] is not None:
                        content_chunk = delta["content"]
                        current_content += content_chunk
                        yield LLMResponse(content=current_content, role=ChatRole.ASSISTANT)

        except httpx.HTTPStatusError as e:
            if not self._cancelled:
                error_message = f"Sarvam AI API error: {e.response.status_code}"
                try:
                    error_body = await e.response.aread()
                    error_text = error_body.decode()
                    error_message += f" - {error_text}"
                except Exception:
                    pass
                self.emit("error", Exception(error_message))
            raise
        except Exception as e:
            if not self._cancelled:
                traceback.print_exc()
                self.emit("error", e)
            raise

    async def cancel_current_generation(self) -> None:
        self._cancelled = True

    async def aclose(self) -> None:
        await self.cancel_current_generation()
        if self._client:
            await self._client.aclose()
        await super().aclose()

Base class for LLM implementations.

Initialize the SarvamAI LLM plugin.

Args

api_key : Optional[str], optional
SarvamAI API key. Defaults to None.
model : str
The model to use for the LLM plugin. Defaults to "sarvam-m".
temperature : float
The temperature to use for the LLM plugin. Defaults to 0.7.
tool_choice : ToolChoice
The tool choice to use for the LLM plugin. Defaults to "auto".
max_completion_tokens : Optional[int], optional
The maximum completion tokens to use for the LLM plugin. Defaults to None.

Ancestors

  • videosdk.agents.llm.llm.LLM
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    await self.cancel_current_generation()
    if self._client:
        await self._client.aclose()
    await super().aclose()

Cleanup resources.

async def cancel_current_generation(self) ‑> None
Expand source code
async def cancel_current_generation(self) -> None:
    self._cancelled = True

Cancel the current LLM generation if active.

Raises

NotImplementedError
This method must be implemented by subclasses.
async def chat(self,
messages: ChatContext,
tools: list[FunctionTool] | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]
Expand source code
async def chat(
    self,
    messages: ChatContext,
    tools: list[FunctionTool] | None = None,
    **kwargs: Any
) -> AsyncIterator[LLMResponse]:
    self._cancelled = False
    
    def _extract_text_content(content: Union[str, List[ChatContent]]) -> str:
        if isinstance(content, str):
            return content
        text_parts = [part for part in content if isinstance(part, str)]
        return "\n".join(text_parts)

    system_prompt = None
    message_items = list(messages.items)
    if (
        message_items
        and isinstance(message_items[0], ChatMessage)
        and message_items[0].role == ChatRole.SYSTEM
    ):
        system_prompt = {
            "role": "system",
            "content": _extract_text_content(message_items.pop(0).content),
        }

    cleaned_messages = []
    last_role = None
    for msg in message_items:
        if not isinstance(msg, ChatMessage):
            continue

        current_role_str = msg.role.value
        
        if not cleaned_messages and current_role_str == 'assistant':
            continue

        text_content = _extract_text_content(msg.content)
        if not text_content.strip():
            continue

        if last_role == 'user' and current_role_str == 'user':
            cleaned_messages[-1]['content'] += ' ' + text_content
            continue
        
        if last_role == current_role_str:
            cleaned_messages.pop()

        cleaned_messages.append({"role": current_role_str, "content": text_content})
        last_role = current_role_str

    final_messages = [system_prompt] + cleaned_messages if system_prompt else cleaned_messages
    
    try:
        payload = {
            "model": self.model,
            "messages": final_messages,
            "temperature": self.temperature,
            "stream": True,
        }

        if self.max_completion_tokens:
            payload['max_tokens'] = self.max_completion_tokens
        
        payload.update(kwargs)
        
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }

        async with self._client.stream("POST", SARVAM_CHAT_COMPLETION_URL, json=payload, headers=headers) as response:
            response.raise_for_status()
            
            current_content = ""
            async for line in response.aiter_lines():
                if self._cancelled:
                    break
                    
                if not line.startswith("data:"):
                    continue
                data_str = line[len("data:"):].strip()
                if not data_str:
                    continue
                if data_str == "[DONE]":
                    break
                
                chunk = json.loads(data_str)
                delta = chunk.get("choices", [{}])[0].get("delta", {})
                if "content" in delta and delta["content"] is not None:
                    content_chunk = delta["content"]
                    current_content += content_chunk
                    yield LLMResponse(content=current_content, role=ChatRole.ASSISTANT)

    except httpx.HTTPStatusError as e:
        if not self._cancelled:
            error_message = f"Sarvam AI API error: {e.response.status_code}"
            try:
                error_body = await e.response.aread()
                error_text = error_body.decode()
                error_message += f" - {error_text}"
            except Exception:
                pass
            self.emit("error", Exception(error_message))
        raise
    except Exception as e:
        if not self._cancelled:
            traceback.print_exc()
            self.emit("error", e)
        raise

Main method to interact with the LLM.

Args

messages : ChatContext
The conversation context containing message history.
tools : list[FunctionTool] | None, optional
List of available function tools for the LLM to use.
**kwargs : Any
Additional arguments specific to the LLM provider implementation.

Returns

AsyncIterator[LLMResponse]
An async iterator yielding LLMResponse objects as they're generated.

Raises

NotImplementedError
This method must be implemented by subclasses.
class SarvamAISTT (*,
api_key: str | None = None,
model: str = 'saarika:v2.5',
language: str = 'en-IN',
input_sample_rate: int = 48000,
output_sample_rate: int = 16000)
Expand source code
class SarvamAISTT(STT):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = DEFAULT_MODEL,
        language: str = "en-IN",
        input_sample_rate: int = 48000,
        output_sample_rate: int = 16000,
    ) -> None:
        """Initialize the SarvamAI STT plugin with WebSocket support.

        Args:
            api_key: SarvamAI API key
            model: The model to use (default: "saarika:v2.5")
            language: The language code (default: "en-IN")
            input_sample_rate: Input sample rate (default: 48000)
            output_sample_rate: Output sample rate (default: 16000)
        """
        super().__init__()
        if not SCIPY_AVAILABLE:
            raise ImportError("scipy is not installed. Please install it with 'pip install scipy'")

        self.api_key = api_key or os.getenv("SARVAMAI_API_KEY")
        if not self.api_key:
            raise ValueError("Sarvam AI API key must be provided either through api_key parameter or SARVAMAI_API_KEY environment variable")

        self.model = model
        self.language = language
        self.input_sample_rate = input_sample_rate
        self.output_sample_rate = output_sample_rate

        # WebSocket related
        self._session: aiohttp.ClientSession | None = None
        self._ws: aiohttp.ClientWebSocketResponse | None = None
        self._ws_task: asyncio.Task | None = None
        self._is_speaking = False
        self._lock = asyncio.Lock()

    async def _ensure_websocket(self) -> aiohttp.ClientWebSocketResponse:
        """Ensure WebSocket connection is established."""
        if self._ws is None or self._ws.closed:
            await self._connect_websocket()
        
        if self._ws is None:
            raise RuntimeError("Failed to establish WebSocket connection")
        
        return self._ws

    async def _connect_websocket(self) -> None:
        """Connect to Sarvam WebSocket API."""
        if self._session is None:
            self._session = aiohttp.ClientSession()

        ws_url = f"{SARVAM_STT_STREAMING_URL}?language-code={self.language}&model={self.model}&vad_signals=true"
        
        headers = {"api-subscription-key": self.api_key}
        
        self._ws = await self._session.ws_connect(ws_url, headers=headers)
        
        self._ws_task = asyncio.create_task(self._process_messages())


    async def process_audio(self, audio_frames: bytes, **kwargs: Any) -> None:
        """Process audio frames and send to WebSocket."""
        try:
            resampled_audio = self._resample_audio(audio_frames)
            
            audio_array = np.frombuffer(resampled_audio, dtype=np.int16)
            
            base64_audio = base64.b64encode(audio_array.tobytes()).decode('utf-8')
            
            audio_message = {
                "audio": {
                    "data": base64_audio,
                    "encoding": "audio/wav",
                    "sample_rate": self.output_sample_rate,
                }
            }
            
            ws = await self._ensure_websocket()
            await ws.send_str(json.dumps(audio_message))
            
        except Exception as e:
            logger.error(f"[SarvamAISTT] Error processing audio: {e}")

    async def _process_messages(self) -> None:
        """Process incoming WebSocket messages."""
        if self._ws is None:
            return
            
        try:
            async for msg in self._ws:
                if msg.type == aiohttp.WSMsgType.TEXT:
                    data = json.loads(msg.data)
                    await self._handle_message(data)
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    logger.error(f"[SarvamAISTT] WebSocket error: {self._ws.exception()}")
                    break
        except Exception as e:
            logger.error(f"[SarvamAISTT] Error in message processing: {e}")

    async def _handle_message(self, data: dict) -> None:
        """Handle different message types from Sarvam API."""
        msg_type = data.get("type")
        
        if msg_type == "data":
            transcript_data = data.get("data", {})
            transcript_text = transcript_data.get("transcript", "")
            language = transcript_data.get("language_code", self.language)
            
            if transcript_text and self._transcript_callback:
                event = STTResponse(
                    event_type=SpeechEventType.FINAL,
                    data=SpeechData(
                        text=transcript_text,
                        language=language,
                        confidence=1.0
                    )
                )
                await self._transcript_callback(event)
                
        elif msg_type == "events":
            event_data = data.get("data", {})
            signal_type = event_data.get("signal_type")
            
            if signal_type == "START_SPEECH":
                if not self._is_speaking:
                    self._is_speaking = True
                    global_event_emitter.emit("speech_started")
                    
                    
            elif signal_type == "END_SPEECH":
                if self._is_speaking:    
                    flush_message = {"type": "flush"}
                    await self._ws.send_str(json.dumps(flush_message))
                    self._is_speaking = False
                    global_event_emitter.emit("speech_stopped")
                    
        elif msg_type == "error":
            error_info = data.get("error", "Unknown error")
            logger.error(f"[SarvamAISTT] API error: {error_info}")

    def _resample_audio(self, audio_bytes: bytes) -> bytes:
        """Resample audio from input sample rate to output sample rate and convert to mono."""
        try:
            if not audio_bytes:
                return b''

            raw_audio = np.frombuffer(audio_bytes, dtype=np.int16)
            if raw_audio.size == 0:
                return b''

            if raw_audio.size % 2 == 0: 
                stereo_audio = raw_audio.reshape(-1, 2)
                mono_audio = stereo_audio.astype(np.float32).mean(axis=1)
            else:
                mono_audio = raw_audio.astype(np.float32)

            if self.input_sample_rate != self.output_sample_rate:
                output_length = int(len(mono_audio) * self.output_sample_rate / self.input_sample_rate)
                resampled_data = signal.resample(mono_audio, output_length)
            else:
                resampled_data = mono_audio

            resampled_data = np.clip(resampled_data, -32767, 32767)
            return resampled_data.astype(np.int16).tobytes()

        except Exception as e:
            logger.error(f"Error resampling audio: {e}")
            return b''

    async def aclose(self) -> None:
        """Close WebSocket connection and cleanup."""
        if self._ws_task and not self._ws_task.done():
            self._ws_task.cancel()
            try:
                await self._ws_task
            except asyncio.CancelledError:
                pass
        
        if self._ws and not self._ws.closed:
            await self._ws.close()
        
        if self._session and not self._session.closed:
            await self._session.close()
        
        await super().aclose()

Base class for Speech-to-Text implementations

Initialize the SarvamAI STT plugin with WebSocket support.

Args

api_key
SarvamAI API key
model
The model to use (default: "saarika:v2.5")
language
The language code (default: "en-IN")
input_sample_rate
Input sample rate (default: 48000)
output_sample_rate
Output sample rate (default: 16000)

Ancestors

  • videosdk.agents.stt.stt.STT
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Close WebSocket connection and cleanup."""
    if self._ws_task and not self._ws_task.done():
        self._ws_task.cancel()
        try:
            await self._ws_task
        except asyncio.CancelledError:
            pass
    
    if self._ws and not self._ws.closed:
        await self._ws.close()
    
    if self._session and not self._session.closed:
        await self._session.close()
    
    await super().aclose()

Close WebSocket connection and cleanup.

async def process_audio(self, audio_frames: bytes, **kwargs: Any) ‑> None
Expand source code
async def process_audio(self, audio_frames: bytes, **kwargs: Any) -> None:
    """Process audio frames and send to WebSocket."""
    try:
        resampled_audio = self._resample_audio(audio_frames)
        
        audio_array = np.frombuffer(resampled_audio, dtype=np.int16)
        
        base64_audio = base64.b64encode(audio_array.tobytes()).decode('utf-8')
        
        audio_message = {
            "audio": {
                "data": base64_audio,
                "encoding": "audio/wav",
                "sample_rate": self.output_sample_rate,
            }
        }
        
        ws = await self._ensure_websocket()
        await ws.send_str(json.dumps(audio_message))
        
    except Exception as e:
        logger.error(f"[SarvamAISTT] Error processing audio: {e}")

Process audio frames and send to WebSocket.

class SarvamAITTS (*,
api_key: str | None = None,
model: str = 'bulbul:v2',
language: str = 'en-IN',
speaker: str = 'anushka',
enable_streaming: bool = True,
sample_rate: int = 24000,
output_audio_codec: str = 'linear16')
Expand source code
class SarvamAITTS(TTS):
    """
    A unified Sarvam.ai Text-to-Speech (TTS) plugin that supports both real-time
    streaming via WebSockets and batch synthesis via HTTP. This version is optimized
    for robust, long-running sessions and responsive non-streaming playback.
    """

    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = DEFAULT_MODEL,
        language: str = DEFAULT_LANGUAGE,
        speaker: str = DEFAULT_SPEAKER,
        enable_streaming: bool = True,
        sample_rate: int = SARVAM_SAMPLE_RATE,
        output_audio_codec: str = "linear16",
    ) -> None:
        """
        Initializes the SarvamAITTS plugin.

        Args:
            api_key (Optional[str]): The Sarvam.ai API key. If not provided, it will
                be read from the SARVAMAI_API_KEY environment variable.
            model (str): The TTS model to use.
            language (str): The target language code (e.g., "en-IN").
            speaker (str): The desired speaker for the voice.
            enable_streaming (bool): If True, uses WebSockets for low-latency streaming.
                If False, uses HTTP for batch synthesis.
            sample_rate (int): The audio sample rate.
            output_audio_codec (str): The desired output audio codec.
        """
        super().__init__(sample_rate=sample_rate, num_channels=SARVAM_CHANNELS)

        self.api_key = api_key or os.getenv("SARVAMAI_API_KEY")
        if not self.api_key:
            raise ValueError(
                "Sarvam AI API key required. Provide either:\n"
                "1. api_key parameter, OR\n"
                "2. SARVAMAI_API_KEY environment variable"
            )

        self.model = model
        self.language = language
        self.speaker = speaker
        self.enable_streaming = enable_streaming
        self.output_audio_codec = output_audio_codec
        self.base_url_ws = SARVAM_TTS_URL_STREAMING
        self.base_url_http = SARVAM_TTS_URL_HTTP

        self._ws_session: aiohttp.ClientSession | None = None
        self._ws_connection: aiohttp.ClientWebSocketResponse | None = None
        self._receive_task: asyncio.Task | None = None
        self._connection_lock = asyncio.Lock()

        self._http_client = httpx.AsyncClient(
            timeout=httpx.Timeout(connect=15.0, read=30.0, write=5.0, pool=5.0),
            follow_redirects=True,
        )

        self._interrupted = False
        self._first_chunk_sent = False
        self.ws_count = 0

    def reset_first_audio_tracking(self) -> None:
        """Resets tracking for the first audio chunk latency."""
        self._first_chunk_sent = False

    async def synthesize(
        self,
        text: AsyncIterator[str] | str,
        *,
        language: Optional[str] = None,
        speaker: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """
        Main entry point for synthesizing audio. Routes to either streaming
        or batch (HTTP) mode. The HTTP mode uses a smart buffering strategy
        to balance responsiveness and audio quality.
        """
        try:
            if not self.audio_track or not self.loop:
                logger.error("error", "Audio track or event loop not initialized")
                return

            self.language = language or self.language
            self.speaker = speaker or self.speaker
            self._interrupted = False
            self.reset_first_audio_tracking()

            if self.enable_streaming:
                await self._stream_synthesis(text)
            else:
                if isinstance(text, str):
                    if text.strip():
                        await self._http_synthesis(text)
                else:
                    chunk_buffer = []
                    HTTP_CHUNK_BUFFER_SIZE = 4
                    LLM_PAUSE_TIMEOUT = 1.0 

                    text_iterator = text.__aiter__()
                    while not self._interrupted:
                        try:
                            chunk = await asyncio.wait_for(text_iterator.__anext__(), timeout=LLM_PAUSE_TIMEOUT)
                            
                            if chunk and chunk.strip():
                                chunk_buffer.append(chunk)
                            
                            if len(chunk_buffer) >= HTTP_CHUNK_BUFFER_SIZE:
                                combined_text = "".join(chunk_buffer)
                                await self._http_synthesis(combined_text)
                                chunk_buffer.clear()

                        except asyncio.TimeoutError:
                            if chunk_buffer:
                                combined_text = "".join(chunk_buffer)
                                await self._http_synthesis(combined_text)
                                chunk_buffer.clear()
                        
                        except StopAsyncIteration:
                            if chunk_buffer:
                                combined_text = "".join(chunk_buffer)
                                await self._http_synthesis(combined_text)
                            break

        except Exception as e:
            logger.error("error", f"Sarvam TTS synthesis failed: {e}")

    async def _stream_synthesis(self, text: AsyncIterator[str] | str) -> None:
        """
        Manages the WebSocket synthesis workflow, ensuring a fresh connection
        for each synthesis task to guarantee reliability.
        """
        try:
            # await self._close_ws_resources()
            await self._ensure_ws_connection()
            

            if isinstance(text, str):
                async def _str_iter():
                    yield text
                text_iter = _str_iter()
            else:
                text_iter = text

            await self._send_text_chunks(text_iter)
        except Exception as e:
            logger.error("error", f"WebSocket streaming failed: {e}. Trying HTTP fallback.")
            try:
                full_text = ""
                if isinstance(text, str):
                    full_text = text
                else:
                    async for chunk in text:
                        full_text += chunk
                
                if full_text.strip():
                    await self._http_synthesis(full_text.strip())
            except Exception as http_e:
                logger.error("error", f"HTTP fallback also failed: {http_e}")

    async def _ensure_ws_connection(self) -> None:
        """Establishes and maintains a persistent WebSocket connection."""
        async with self._connection_lock:
            if self._ws_connection and not self._ws_connection.closed:
                return
            try:
                self._ws_session = aiohttp.ClientSession()
                headers = {"Api-Subscription-Key": self.api_key}
                self._ws_connection = await asyncio.wait_for(
                    self._ws_session.ws_connect(
                        self.base_url_ws, headers=headers, heartbeat=20
                    ),
                    timeout=10.0,
                )
                self._receive_task = asyncio.create_task(self._recv_loop())
                await self._send_initial_config()
                self.ws_count = self.ws_count + 1
                logger.info(f"WS connection numbers: {self.ws_count}")
            except Exception as e:
                logger.error("error", f"Failed to connect to WebSocket: {e}")
                raise

    async def _send_initial_config(self) -> None:
        """Sends the initial configuration message to the WebSocket server."""
        config_payload = {
            "type": "config",
            "data": {
                "target_language_code": self.language,
                "speaker": self.speaker,
                "speech_sample_rate": str(self.sample_rate),
                "output_audio_codec": self.output_audio_codec,
            },
        }
        if self._ws_connection:
            await self._ws_connection.send_str(json.dumps(config_payload))

    async def _send_text_chunks(self, text_iterator: AsyncIterator[str]):
        """Sends text to the WebSocket, chunking it by word count or time."""
        if not self._ws_connection:
            raise ConnectionError("WebSocket is not connected.")
        try:
            buffer = []
            MIN_WORDS, MAX_DELAY = 4, 1.0
            last_send_time = asyncio.get_event_loop().time()

            async for text_chunk in text_iterator:
                if self._interrupted:
                    break

                words = re.findall(r"\b[\w'-]+\b", text_chunk)
                if not words:
                    continue

                buffer.extend(words)
                now = asyncio.get_event_loop().time()

                if len(buffer) >= MIN_WORDS or (now - last_send_time > MAX_DELAY):
                    combined_text = " ".join(buffer).strip()
                    if combined_text:
                        payload = {"type": "text", "data": {"text": combined_text}}
                        await self._ws_connection.send_str(json.dumps(payload))
                    buffer.clear()
                    last_send_time = now

            if buffer and not self._interrupted:
                combined_text = " ".join(buffer).strip()
                if combined_text:
                    payload = {"type": "text", "data": {"text": combined_text}}
                    await self._ws_connection.send_str(json.dumps(payload))
                    if not self._first_chunk_sent and hasattr(self, '_first_audio_callback') and self._first_audio_callback:
                        self._first_chunk_sent = True
                        asyncio.create_task(self._first_audio_callback())

            if not self._interrupted:
                await self._ws_connection.send_str(json.dumps({"type": "flush"}))
        except Exception as e:
            logger.error("error", f"Failed to send text chunks via WebSocket: {e}")

    async def _recv_loop(self):
        """Continuously listens for and processes incoming WebSocket messages."""
        try:
            while self._ws_connection and not self._ws_connection.closed:
                msg = await self._ws_connection.receive()
                if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
                    break
                if msg.type != aiohttp.WSMsgType.TEXT:
                    continue

                data = json.loads(msg.data)
                msg_type = data.get("type")

                if msg_type == "audio":

                    if not self._first_chunk_sent and hasattr(self, '_first_audio_callback') and self._first_audio_callback:
                        self._first_chunk_sent = True
                        asyncio.create_task(self._first_audio_callback())
                    
                    await self._handle_audio_data(data.get("data"))
                
                elif msg_type == "event" and data.get("data", {}).get("event_type") == "final":
                    logger.error("done", "TTS completed")
                
                elif msg_type == "error":
                    error_msg = data.get("data", {}).get("message", "Unknown WS error")
                    logger.error("error", f"Sarvam WebSocket error: {error_msg}")
        except asyncio.CancelledError:
            pass
        except Exception as e:
            logger.error("error", f"WebSocket receive loop error: {e}")

    async def _handle_audio_data(self, audio_data: Optional[dict[str, Any]]):
        """Processes audio data received from the WebSocket."""
        if not audio_data or self._interrupted:
            return

        audio_b64 = audio_data.get("audio")
        if not audio_b64:
            return

        try:
            audio_bytes = base64.b64decode(audio_b64)
            if not self.audio_track:
                return


            await self.audio_track.add_new_bytes(audio_bytes)
        except Exception as e:
            logger.error("error", f"Failed to process WebSocket audio: {e}")


    async def _reinitialize_http_client(self):
        """Safely closes the current httpx client and creates a new one."""
        logger.info("Re-initializing HTTP client.")
        if self._http_client and not self._http_client.is_closed:
            await self._http_client.aclose()
        self._http_client = httpx.AsyncClient(
            timeout=httpx.Timeout(connect=15.0, read=30.0, write=5.0, pool=5.0),
            follow_redirects=True,
        )

    async def _http_synthesis(self, text: str) -> None:
        """Performs TTS synthesis using HTTP with a retry for connection errors."""
        payload = { "text": text, "target_language_code": self.language, "speaker": self.speaker, "speech_sample_rate": str(self.sample_rate), "model": self.model, "output_audio_codec": self.output_audio_codec }
        headers = { "Content-Type": "application/json", "api-subscription-key": self.api_key }
        max_attempts = 2
        for attempt in range(max_attempts):
            try:
                if self._http_client.is_closed:
                    await self._reinitialize_http_client()
                response = await self._http_client.post(self.base_url_http, headers=headers, json=payload)
                response.raise_for_status()
                data = response.json()
                if not data.get("audios"):
                    logger.error("error", f"No audio data in HTTP response: {data}")
                    return
                audio_b64 = data["audios"][0]
                audio_bytes = base64.b64decode(audio_b64)
                if not self._first_chunk_sent and self._first_audio_callback:
                    self._first_chunk_sent = True
                    await self._first_audio_callback()

                await self._stream_http_audio(audio_bytes)
                return
            except httpx.HTTPStatusError as e:
                logger.error("error", f"HTTP error: {e.response.status_code} - {e.response.text}")
                logger.info(response)
                raise e
            except (httpx.NetworkError, httpx.ConnectError, httpx.ReadTimeout) as e:
                logger.warning(f"HTTP connection error on attempt {attempt + 1}: {e}")
                if attempt < max_attempts - 1:
                    await self._reinitialize_http_client()
                    continue
                else:
                    logger.error("error", f"HTTP synthesis failed after {max_attempts} connection attempts.")
                    raise e
            except Exception as e:
                logger.error("error", f"An unexpected HTTP synthesis error occurred: {e}")
                raise e

    async def _stream_http_audio(self, audio_bytes: bytes) -> None:
        """
        Streams decoded HTTP audio bytes to the audio track by sending two
        20ms chunks at a time (a 40ms block) to ensure real-time playback.
        """
        single_chunk_size = int(self.sample_rate * self.num_channels * 2 * 20 / 1000)
        
        block_size = single_chunk_size * 2
        
        raw_audio = self._remove_wav_header(audio_bytes)

        for i in range(0, len(raw_audio), block_size):
            if self._interrupted:
                break
            
            block = raw_audio[i : i + block_size]

            if 0 < len(block) < block_size:
                block += b"\x00" * (block_size - len(block))

                
            if self.audio_track:
                asyncio.create_task(self.audio_track.add_new_bytes(block))

    def _remove_wav_header(self, audio_bytes: bytes) -> bytes:
        """Removes the WAV header if present."""
        if audio_bytes.startswith(b"RIFF"):
            data_pos = audio_bytes.find(b"data")
            if data_pos != -1:
                return audio_bytes[data_pos + 8:]
        return audio_bytes

    async def interrupt(self) -> None:
        """Interrupts any ongoing TTS synthesis."""
        self._interrupted = True
        if self.audio_track:
            self.audio_track.interrupt()
        
    async def _close_ws_resources(self) -> None:
        """Helper to clean up all WebSocket-related resources."""
        if self._receive_task and not self._receive_task.done():
            self._receive_task.cancel()
        if self._ws_connection and not self._ws_connection.closed:
            await self._ws_connection.close()
        if self._ws_session and not self._ws_session.closed:
            await self._ws_session.close()
        self._receive_task = self._ws_connection = self._ws_session = None

    async def aclose(self) -> None:
        """Gracefully closes all connections and cleans up resources."""
        self._interrupted = True
        await self._close_ws_resources()
        if self._http_client and not self._http_client.is_closed:
            await self._http_client.aclose()
        await super().aclose()

A unified Sarvam.ai Text-to-Speech (TTS) plugin that supports both real-time streaming via WebSockets and batch synthesis via HTTP. This version is optimized for robust, long-running sessions and responsive non-streaming playback.

Initializes the SarvamAITTS plugin.

Args

api_key : Optional[str]
The Sarvam.ai API key. If not provided, it will be read from the SARVAMAI_API_KEY environment variable.
model : str
The TTS model to use.
language : str
The target language code (e.g., "en-IN").
speaker : str
The desired speaker for the voice.
enable_streaming : bool
If True, uses WebSockets for low-latency streaming. If False, uses HTTP for batch synthesis.
sample_rate : int
The audio sample rate.
output_audio_codec : str
The desired output audio codec.

Ancestors

  • videosdk.agents.tts.tts.TTS
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Gracefully closes all connections and cleans up resources."""
    self._interrupted = True
    await self._close_ws_resources()
    if self._http_client and not self._http_client.is_closed:
        await self._http_client.aclose()
    await super().aclose()

Gracefully closes all connections and cleans up resources.

async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    """Interrupts any ongoing TTS synthesis."""
    self._interrupted = True
    if self.audio_track:
        self.audio_track.interrupt()

Interrupts any ongoing TTS synthesis.

def reset_first_audio_tracking(self) ‑> None
Expand source code
def reset_first_audio_tracking(self) -> None:
    """Resets tracking for the first audio chunk latency."""
    self._first_chunk_sent = False

Resets tracking for the first audio chunk latency.

async def synthesize(self,
text: AsyncIterator[str] | str,
*,
language: Optional[str] = None,
speaker: Optional[str] = None,
**kwargs: Any) ‑> None
Expand source code
async def synthesize(
    self,
    text: AsyncIterator[str] | str,
    *,
    language: Optional[str] = None,
    speaker: Optional[str] = None,
    **kwargs: Any,
) -> None:
    """
    Main entry point for synthesizing audio. Routes to either streaming
    or batch (HTTP) mode. The HTTP mode uses a smart buffering strategy
    to balance responsiveness and audio quality.
    """
    try:
        if not self.audio_track or not self.loop:
            logger.error("error", "Audio track or event loop not initialized")
            return

        self.language = language or self.language
        self.speaker = speaker or self.speaker
        self._interrupted = False
        self.reset_first_audio_tracking()

        if self.enable_streaming:
            await self._stream_synthesis(text)
        else:
            if isinstance(text, str):
                if text.strip():
                    await self._http_synthesis(text)
            else:
                chunk_buffer = []
                HTTP_CHUNK_BUFFER_SIZE = 4
                LLM_PAUSE_TIMEOUT = 1.0 

                text_iterator = text.__aiter__()
                while not self._interrupted:
                    try:
                        chunk = await asyncio.wait_for(text_iterator.__anext__(), timeout=LLM_PAUSE_TIMEOUT)
                        
                        if chunk and chunk.strip():
                            chunk_buffer.append(chunk)
                        
                        if len(chunk_buffer) >= HTTP_CHUNK_BUFFER_SIZE:
                            combined_text = "".join(chunk_buffer)
                            await self._http_synthesis(combined_text)
                            chunk_buffer.clear()

                    except asyncio.TimeoutError:
                        if chunk_buffer:
                            combined_text = "".join(chunk_buffer)
                            await self._http_synthesis(combined_text)
                            chunk_buffer.clear()
                    
                    except StopAsyncIteration:
                        if chunk_buffer:
                            combined_text = "".join(chunk_buffer)
                            await self._http_synthesis(combined_text)
                        break

    except Exception as e:
        logger.error("error", f"Sarvam TTS synthesis failed: {e}")

Main entry point for synthesizing audio. Routes to either streaming or batch (HTTP) mode. The HTTP mode uses a smart buffering strategy to balance responsiveness and audio quality.