Package videosdk.plugins.sarvamai

Sub-modules

videosdk.plugins.sarvamai.llm
videosdk.plugins.sarvamai.stt
videosdk.plugins.sarvamai.tts

Classes

class SarvamAILLM (*,
api_key: str | None = None,
model: str = 'sarvam-m',
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None)
Expand source code
class SarvamAILLM(LLM):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = DEFAULT_MODEL,
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_completion_tokens: int | None = None,
    ) -> None:
        """Initialize the SarvamAI LLM plugin.

        Args:
            api_key (Optional[str], optional): SarvamAI API key. Defaults to None.
            model (str): The model to use for the LLM plugin. Defaults to "sarvam-m".
            temperature (float): The temperature to use for the LLM plugin. Defaults to 0.7.
            tool_choice (ToolChoice): The tool choice to use for the LLM plugin. Defaults to "auto".
            max_completion_tokens (Optional[int], optional): The maximum completion tokens to use for the LLM plugin. Defaults to None.
        """
        super().__init__()
        self.api_key = api_key or os.getenv("SARVAMAI_API_KEY")
        if not self.api_key:
            raise ValueError("Sarvam AI API key must be provided either through api_key parameter or SARVAMAI_API_KEY environment variable")
        
        self.model = model
        self.temperature = temperature
        self.tool_choice = tool_choice
        self.max_completion_tokens = max_completion_tokens
        self._cancelled = False
        
        self._client = httpx.AsyncClient(
            timeout=httpx.Timeout(connect=15.0, read=30.0, write=5.0, pool=5.0),
            follow_redirects=True,
        )

    async def chat(
        self,
        messages: ChatContext,
        tools: list[FunctionTool] | None = None,
        **kwargs: Any
    ) -> AsyncIterator[LLMResponse]:
        self._cancelled = False
        
        def _extract_text_content(content: Union[str, List[ChatContent]]) -> str:
            if isinstance(content, str):
                return content
            text_parts = [part for part in content if isinstance(part, str)]
            return "\n".join(text_parts)

        system_prompt = None
        message_items = list(messages.items)
        if (
            message_items
            and isinstance(message_items[0], ChatMessage)
            and message_items[0].role == ChatRole.SYSTEM
        ):
            system_prompt = {
                "role": "system",
                "content": _extract_text_content(message_items.pop(0).content),
            }

        cleaned_messages = []
        last_role = None
        for msg in message_items:
            if not isinstance(msg, ChatMessage):
                continue

            current_role_str = msg.role.value
            
            if not cleaned_messages and current_role_str == 'assistant':
                continue

            text_content = _extract_text_content(msg.content)
            if not text_content.strip():
                continue

            if last_role == 'user' and current_role_str == 'user':
                cleaned_messages[-1]['content'] += ' ' + text_content
                continue
            
            if last_role == current_role_str:
                cleaned_messages.pop()

            cleaned_messages.append({"role": current_role_str, "content": text_content})
            last_role = current_role_str

        final_messages = [system_prompt] + cleaned_messages if system_prompt else cleaned_messages
        
        try:
            payload = {
                "model": self.model,
                "messages": final_messages,
                "temperature": self.temperature,
                "stream": True,
            }

            if self.max_completion_tokens:
                payload['max_tokens'] = self.max_completion_tokens
            
            payload.update(kwargs)
            
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.api_key}"
            }

            async with self._client.stream("POST", SARVAM_CHAT_COMPLETION_URL, json=payload, headers=headers) as response:
                response.raise_for_status()
                
                current_content = ""
                async for line in response.aiter_lines():
                    if self._cancelled:
                        break
                        
                    if not line.startswith("data:"):
                        continue
                    data_str = line[len("data:"):].strip()
                    if not data_str:
                        continue
                    if data_str == "[DONE]":
                        break
                    
                    chunk = json.loads(data_str)
                    delta = chunk.get("choices", [{}])[0].get("delta", {})
                    if "content" in delta and delta["content"] is not None:
                        content_chunk = delta["content"]
                        current_content += content_chunk
                        yield LLMResponse(content=current_content, role=ChatRole.ASSISTANT)

        except httpx.HTTPStatusError as e:
            if not self._cancelled:
                error_message = f"Sarvam AI API error: {e.response.status_code}"
                try:
                    error_body = await e.response.aread()
                    error_text = error_body.decode()
                    error_message += f" - {error_text}"
                except Exception:
                    pass
                self.emit("error", Exception(error_message))
            raise
        except Exception as e:
            if not self._cancelled:
                traceback.print_exc()
                self.emit("error", e)
            raise

    async def cancel_current_generation(self) -> None:
        self._cancelled = True

    async def aclose(self) -> None:
        await self.cancel_current_generation()
        if self._client:
            await self._client.aclose()

Base class for LLM implementations.

Initialize the SarvamAI LLM plugin.

Args

api_key : Optional[str], optional
SarvamAI API key. Defaults to None.
model : str
The model to use for the LLM plugin. Defaults to "sarvam-m".
temperature : float
The temperature to use for the LLM plugin. Defaults to 0.7.
tool_choice : ToolChoice
The tool choice to use for the LLM plugin. Defaults to "auto".
max_completion_tokens : Optional[int], optional
The maximum completion tokens to use for the LLM plugin. Defaults to None.

Ancestors

  • videosdk.agents.llm.llm.LLM
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    await self.cancel_current_generation()
    if self._client:
        await self._client.aclose()

Cleanup resources.

async def cancel_current_generation(self) ‑> None
Expand source code
async def cancel_current_generation(self) -> None:
    self._cancelled = True

Cancel the current LLM generation if active.

Raises

NotImplementedError
This method must be implemented by subclasses.
async def chat(self,
messages: ChatContext,
tools: list[FunctionTool] | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]
Expand source code
async def chat(
    self,
    messages: ChatContext,
    tools: list[FunctionTool] | None = None,
    **kwargs: Any
) -> AsyncIterator[LLMResponse]:
    self._cancelled = False
    
    def _extract_text_content(content: Union[str, List[ChatContent]]) -> str:
        if isinstance(content, str):
            return content
        text_parts = [part for part in content if isinstance(part, str)]
        return "\n".join(text_parts)

    system_prompt = None
    message_items = list(messages.items)
    if (
        message_items
        and isinstance(message_items[0], ChatMessage)
        and message_items[0].role == ChatRole.SYSTEM
    ):
        system_prompt = {
            "role": "system",
            "content": _extract_text_content(message_items.pop(0).content),
        }

    cleaned_messages = []
    last_role = None
    for msg in message_items:
        if not isinstance(msg, ChatMessage):
            continue

        current_role_str = msg.role.value
        
        if not cleaned_messages and current_role_str == 'assistant':
            continue

        text_content = _extract_text_content(msg.content)
        if not text_content.strip():
            continue

        if last_role == 'user' and current_role_str == 'user':
            cleaned_messages[-1]['content'] += ' ' + text_content
            continue
        
        if last_role == current_role_str:
            cleaned_messages.pop()

        cleaned_messages.append({"role": current_role_str, "content": text_content})
        last_role = current_role_str

    final_messages = [system_prompt] + cleaned_messages if system_prompt else cleaned_messages
    
    try:
        payload = {
            "model": self.model,
            "messages": final_messages,
            "temperature": self.temperature,
            "stream": True,
        }

        if self.max_completion_tokens:
            payload['max_tokens'] = self.max_completion_tokens
        
        payload.update(kwargs)
        
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }

        async with self._client.stream("POST", SARVAM_CHAT_COMPLETION_URL, json=payload, headers=headers) as response:
            response.raise_for_status()
            
            current_content = ""
            async for line in response.aiter_lines():
                if self._cancelled:
                    break
                    
                if not line.startswith("data:"):
                    continue
                data_str = line[len("data:"):].strip()
                if not data_str:
                    continue
                if data_str == "[DONE]":
                    break
                
                chunk = json.loads(data_str)
                delta = chunk.get("choices", [{}])[0].get("delta", {})
                if "content" in delta and delta["content"] is not None:
                    content_chunk = delta["content"]
                    current_content += content_chunk
                    yield LLMResponse(content=current_content, role=ChatRole.ASSISTANT)

    except httpx.HTTPStatusError as e:
        if not self._cancelled:
            error_message = f"Sarvam AI API error: {e.response.status_code}"
            try:
                error_body = await e.response.aread()
                error_text = error_body.decode()
                error_message += f" - {error_text}"
            except Exception:
                pass
            self.emit("error", Exception(error_message))
        raise
    except Exception as e:
        if not self._cancelled:
            traceback.print_exc()
            self.emit("error", e)
        raise

Main method to interact with the LLM.

Args

messages : ChatContext
The conversation context containing message history.
tools : list[FunctionTool] | None, optional
List of available function tools for the LLM to use.
**kwargs : Any
Additional arguments specific to the LLM provider implementation.

Returns

AsyncIterator[LLMResponse]
An async iterator yielding LLMResponse objects as they're generated.

Raises

NotImplementedError
This method must be implemented by subclasses.
class SarvamAISTT (*,
api_key: str | None = None,
model: str = 'saarika:v2',
language: str = 'en-IN',
input_sample_rate: int = 48000,
output_sample_rate: int = 16000,
silence_threshold: float = 0.01,
silence_duration: float = 0.8)
Expand source code
class SarvamAISTT(STT):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = DEFAULT_MODEL,
        language: str = "en-IN",
        input_sample_rate: int = 48000,
        output_sample_rate: int = 16000,
        silence_threshold: float = 0.01,
        silence_duration: float = 0.8,
    ) -> None:
        """Initialize the SarvamAI STT plugin.

        Args:
            api_key (Optional[str], optional): SarvamAI API key. Defaults to None.
            model (str): The model to use for the STT plugin. Defaults to "saarika:v2".
            language (str): The language to use for the STT plugin. Defaults to "en-IN".
            input_sample_rate (int): The input sample rate for the STT plugin. Defaults to 48000.
            output_sample_rate (int): The output sample rate for the STT plugin. Defaults to 16000.
            silence_threshold (float): The silence threshold for the STT plugin. Defaults to 0.01.
            silence_duration (float): The silence duration for the STT plugin. Defaults to 0.8.
        """
        super().__init__()
        if not SCIPY_AVAILABLE:
            raise ImportError("scipy is not installed. Please install it with 'pip install scipy'")

        self.api_key = api_key or os.getenv("SARVAMAI_API_KEY")
        if not self.api_key:
            raise ValueError("Sarvam AI API key must be provided either through api_key parameter or SARVAMAI_API_KEY environment variable")

        self.model = model
        self.language = language
        self.input_sample_rate = input_sample_rate
        self.output_sample_rate = output_sample_rate
        self.silence_threshold_bytes = int(silence_threshold * 32767)
        self.silence_duration_frames = int(silence_duration * self.input_sample_rate)

        self._http_client = httpx.AsyncClient(timeout=httpx.Timeout(connect=15.0, read=30.0, write=5.0, pool=5.0))
        self._audio_buffer = bytearray()
        self._is_speaking = False
        self._silence_frames = 0
        self._lock = asyncio.Lock()

    async def process_audio(self, audio_frames: bytes, **kwargs: Any) -> None:
        async with self._lock:
            is_silent_chunk = self._is_silent(audio_frames)
            
            if not is_silent_chunk:
                if not self._is_speaking:
                    self._is_speaking = True
                    global_event_emitter.emit("speech_started")
                self._audio_buffer.extend(audio_frames)
                self._silence_frames = 0
            else:
                if self._is_speaking:
                    self._silence_frames += len(audio_frames) // 4 
                    if self._silence_frames > self.silence_duration_frames:
                        global_event_emitter.emit("speech_stopped")
                        asyncio.create_task(self._transcribe_buffer())
                        self._is_speaking = False
                        self._silence_frames = 0

    def _is_silent(self, audio_chunk: bytes) -> bool:
        """Simple VAD: check if the max amplitude is below a threshold."""
        audio_data = np.frombuffer(audio_chunk, dtype=np.int16)
        return np.max(np.abs(audio_data)) < self.silence_threshold_bytes

    async def _transcribe_buffer(self):
        async with self._lock:
            if not self._audio_buffer:
                return
            audio_to_send = self._audio_buffer
            self._audio_buffer = bytearray()
        
        try:
            resampled_audio = self._resample_audio(audio_to_send)
            wav_audio = self._create_wav_in_memory(resampled_audio)

            headers = {"api-subscription-key": self.api_key}
            data = {"model": self.model, "language_code": self.language}
            files = {'file': ('audio.wav', wav_audio, 'audio/wav')}

            response = await self._http_client.post(SARVAM_STT_API_URL, headers=headers, data=data, files=files)
            response.raise_for_status()

            response_data = response.json()
            transcript = response_data.get("transcript", "")
            
            if transcript and self._transcript_callback:
                event = STTResponse(
                    event_type=SpeechEventType.FINAL,
                    data=SpeechData(text=transcript, language=self.language, confidence=1.0)
                )
                await self._transcript_callback(event)
        except httpx.HTTPStatusError as e:
            self.emit("error", f"Sarvam STT API error: {e.response.status_code} - {e.response.text}")
        except Exception as e:
            self.emit("error", f"Error during transcription: {e}")

    def _resample_audio(self, audio_bytes: bytes) -> bytes:
        audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
        resampled_data = signal.resample(audio_data, int(len(audio_data) * self.output_sample_rate / self.input_sample_rate))
        return resampled_data.astype(np.int16).tobytes()

    def _create_wav_in_memory(self, pcm_data: bytes) -> io.BytesIO:
        wav_buffer = io.BytesIO()
        with wave.open(wav_buffer, 'wb') as wf:
            wf.setnchannels(2)
            wf.setsampwidth(2)
            wf.setframerate(self.output_sample_rate) 
            wf.writeframes(pcm_data)
        wav_buffer.seek(0)
        return wav_buffer

    async def aclose(self) -> None:
        if self._is_speaking and self._audio_buffer:
            await self._transcribe_buffer()
            await asyncio.sleep(1)

        if self._http_client:
            await self._http_client.aclose()

Base class for Speech-to-Text implementations

Initialize the SarvamAI STT plugin.

Args

api_key : Optional[str], optional
SarvamAI API key. Defaults to None.
model : str
The model to use for the STT plugin. Defaults to "saarika:v2".
language : str
The language to use for the STT plugin. Defaults to "en-IN".
input_sample_rate : int
The input sample rate for the STT plugin. Defaults to 48000.
output_sample_rate : int
The output sample rate for the STT plugin. Defaults to 16000.
silence_threshold : float
The silence threshold for the STT plugin. Defaults to 0.01.
silence_duration : float
The silence duration for the STT plugin. Defaults to 0.8.

Ancestors

  • videosdk.agents.stt.stt.STT
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    if self._is_speaking and self._audio_buffer:
        await self._transcribe_buffer()
        await asyncio.sleep(1)

    if self._http_client:
        await self._http_client.aclose()

Cleanup resources

async def process_audio(self, audio_frames: bytes, **kwargs: Any) ‑> None
Expand source code
async def process_audio(self, audio_frames: bytes, **kwargs: Any) -> None:
    async with self._lock:
        is_silent_chunk = self._is_silent(audio_frames)
        
        if not is_silent_chunk:
            if not self._is_speaking:
                self._is_speaking = True
                global_event_emitter.emit("speech_started")
            self._audio_buffer.extend(audio_frames)
            self._silence_frames = 0
        else:
            if self._is_speaking:
                self._silence_frames += len(audio_frames) // 4 
                if self._silence_frames > self.silence_duration_frames:
                    global_event_emitter.emit("speech_stopped")
                    asyncio.create_task(self._transcribe_buffer())
                    self._is_speaking = False
                    self._silence_frames = 0

Process audio frames and convert to text

Args

audio_frames
Iterator of bytes to process
language
Optional language code for recognition
**kwargs
Additional provider-specific arguments

Returns

AsyncIterator yielding STTResponse objects

class SarvamAITTS (*,
api_key: str | None = None,
model: str = 'bulbul:v2',
speaker: str = 'anushka',
target_language_code: str = 'en-IN',
pitch: float = 0.0,
pace: float = 1.0,
loudness: float = 1.2,
enable_preprocessing: bool = True)
Expand source code
class SarvamAITTS(TTS):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = DEFAULT_MODEL,
        speaker: str = DEFAULT_SPEAKER,
        target_language_code: str = DEFAULT_TARGET_LANGUAGE,
        pitch: float = 0.0,
        pace: float = 1.0,
        loudness: float = 1.2,
        enable_preprocessing: bool = True,
    ) -> None:
        """Initialize the SarvamAI TTS plugin.

        Args:
            api_key (Optional[str], optional): SarvamAI API key. Defaults to None.
            model (str): The model to use for the TTS plugin. Defaults to "bulbul:v2".
            speaker (str): The speaker to use for the TTS plugin. Defaults to "anushka".
            target_language_code (str): The target language code to use for the TTS plugin. Defaults to "en-IN".
            pitch (float): The pitch to use for the TTS plugin. Defaults to 0.0.
            pace (float): The pace to use for the TTS plugin. Defaults to 1.0.
            loudness (float): The loudness to use for the TTS plugin. Defaults to 1.2.
            enable_preprocessing (bool): Whether to enable preprocessing for the TTS plugin. Defaults to True.
        """
        super().__init__(
            sample_rate=SARVAMAI_SAMPLE_RATE, num_channels=SARVAMAI_CHANNELS
        )

        self.model = model
        self.speaker = speaker
        self.target_language_code = target_language_code
        self.pitch = pitch
        self.pace = pace
        self.loudness = loudness
        self.enable_preprocessing = enable_preprocessing
        self.audio_track = None
        self.loop = None

        self._first_chunk_sent = False

        self.api_key = api_key or os.getenv("SARVAMAI_API_KEY")

        if not self.api_key:
            raise ValueError(
                "Sarvam AI API key required. Provide either:\n"
                "1. api_key parameter, OR\n"
                "2. SARVAMAI_API_KEY environment variable"
            )

        self._http_client = httpx.AsyncClient(
            timeout=httpx.Timeout(connect=15.0, read=30.0,
                                  write=5.0, pool=5.0),
            follow_redirects=True,
        )

    def reset_first_audio_tracking(self) -> None:
        """Reset the first audio tracking state for next TTS task"""
        self._first_chunk_sent = False

    async def synthesize(
        self,
        text: AsyncIterator[str] | str,
        **kwargs: Any,
    ) -> None:
        try:
            if not self.audio_track or not self.loop:
                self.emit("error", "Audio track or loop not initialized")
                return

            if isinstance(text, AsyncIterator):
                async for segment in segment_text(text):
                    await self._synthesize_audio(segment)
            else:
                await self._synthesize_audio(text)

        except Exception as e:
            self.emit("error", f"Sarvam AI TTS synthesis failed: {str(e)}")

    async def _synthesize_audio(self, text: str) -> None:
        try:
            payload = {
                "inputs": [text],
                "target_language_code": self.target_language_code,
                "speaker": self.speaker,
                "pitch": self.pitch,
                "pace": self.pace,
                "loudness": self.loudness,
                "speech_sample_rate": SARVAMAI_SAMPLE_RATE,
                "enable_preprocessing": self.enable_preprocessing,
                "model": self.model,
            }

            headers = {
                "Accept": "application/json",
                "Content-Type": "application/json",
                "api-subscription-key": self.api_key,
            }

            response = await self._http_client.post(
                SARVAMAI_TTS_ENDPOINT, headers=headers, json=payload
            )
            response.raise_for_status()

            response_data = response.json()
            if "audios" not in response_data or not response_data["audios"]:
                self.emit(
                    "error", "No audio data found in response from Sarvam AI")
                return

            audio_content = response_data["audios"][0]
            if not audio_content:
                self.emit("error", "No audio content received from Sarvam AI")
                return

            audio_bytes = base64.b64decode(audio_content)

            if not audio_bytes:
                self.emit("error", "Decoded audio bytes are empty")
                return

            await self._stream_audio_chunks(audio_bytes)

        except httpx.HTTPStatusError as e:
            self.emit(
                "error",
                f"Sarvam AI TTS HTTP error: {e.response.status_code} - {e.response.text}",
            )
            raise

    async def _stream_audio_chunks(self, audio_bytes: bytes) -> None:
        chunk_size = int(SARVAMAI_SAMPLE_RATE *
                         SARVAMAI_CHANNELS * 2 * 20 / 1000)

        audio_data = self._remove_wav_header(audio_bytes)

        for i in range(0, len(audio_data), chunk_size):
            chunk = audio_data[i: i + chunk_size]

            if len(chunk) < chunk_size and len(chunk) > 0:
                padding_needed = chunk_size - len(chunk)
                chunk += b"\x00" * padding_needed

            if len(chunk) == chunk_size:
                if not self._first_chunk_sent and self._first_audio_callback:
                    self._first_chunk_sent = True
                    asyncio.create_task(self._first_audio_callback())

                asyncio.create_task(self.audio_track.add_new_bytes(chunk))
                await asyncio.sleep(0.001)

    def _remove_wav_header(self, audio_bytes: bytes) -> bytes:
        if audio_bytes.startswith(b"RIFF"):
            data_pos = audio_bytes.find(b"data")
            if data_pos != -1:
                return audio_bytes[data_pos + 8:]

        return audio_bytes

    async def aclose(self) -> None:
        if self._http_client:
            await self._http_client.aclose()
        await super().aclose()

    async def interrupt(self) -> None:
        if self.audio_track:
            self.audio_track.interrupt()

Base class for Text-to-Speech implementations

Initialize the SarvamAI TTS plugin.

Args

api_key : Optional[str], optional
SarvamAI API key. Defaults to None.
model : str
The model to use for the TTS plugin. Defaults to "bulbul:v2".
speaker : str
The speaker to use for the TTS plugin. Defaults to "anushka".
target_language_code : str
The target language code to use for the TTS plugin. Defaults to "en-IN".
pitch : float
The pitch to use for the TTS plugin. Defaults to 0.0.
pace : float
The pace to use for the TTS plugin. Defaults to 1.0.
loudness : float
The loudness to use for the TTS plugin. Defaults to 1.2.
enable_preprocessing : bool
Whether to enable preprocessing for the TTS plugin. Defaults to True.

Ancestors

  • videosdk.agents.tts.tts.TTS
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    if self._http_client:
        await self._http_client.aclose()
    await super().aclose()

Cleanup resources

async def interrupt(self) ‑> None
Expand source code
async def interrupt(self) -> None:
    if self.audio_track:
        self.audio_track.interrupt()

Interrupt the TTS process

def reset_first_audio_tracking(self) ‑> None
Expand source code
def reset_first_audio_tracking(self) -> None:
    """Reset the first audio tracking state for next TTS task"""
    self._first_chunk_sent = False

Reset the first audio tracking state for next TTS task

async def synthesize(self, text: AsyncIterator[str] | str, **kwargs: Any) ‑> None
Expand source code
async def synthesize(
    self,
    text: AsyncIterator[str] | str,
    **kwargs: Any,
) -> None:
    try:
        if not self.audio_track or not self.loop:
            self.emit("error", "Audio track or loop not initialized")
            return

        if isinstance(text, AsyncIterator):
            async for segment in segment_text(text):
                await self._synthesize_audio(segment)
        else:
            await self._synthesize_audio(text)

    except Exception as e:
        self.emit("error", f"Sarvam AI TTS synthesis failed: {str(e)}")

Convert text to speech

Args

text
Text to convert to speech (either string or async iterator of strings)
voice_id
Optional voice identifier
**kwargs
Additional provider-specific arguments

Returns

None