Package videosdk.plugins.sarvamai
Sub-modules
videosdk.plugins.sarvamai.llm
videosdk.plugins.sarvamai.stt
videosdk.plugins.sarvamai.tts
Classes
class SarvamAILLM (*,
api_key: str | None = None,
model: str = 'sarvam-m',
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None)-
Expand source code
class SarvamAILLM(LLM): def __init__( self, *, api_key: str | None = None, model: str = DEFAULT_MODEL, temperature: float = 0.7, tool_choice: ToolChoice = "auto", max_completion_tokens: int | None = None, ) -> None: """Initialize the SarvamAI LLM plugin. Args: api_key (Optional[str], optional): SarvamAI API key. Defaults to None. model (str): The model to use for the LLM plugin. Defaults to "sarvam-m". temperature (float): The temperature to use for the LLM plugin. Defaults to 0.7. tool_choice (ToolChoice): The tool choice to use for the LLM plugin. Defaults to "auto". max_completion_tokens (Optional[int], optional): The maximum completion tokens to use for the LLM plugin. Defaults to None. """ super().__init__() self.api_key = api_key or os.getenv("SARVAMAI_API_KEY") if not self.api_key: raise ValueError("Sarvam AI API key must be provided either through api_key parameter or SARVAMAI_API_KEY environment variable") self.model = model self.temperature = temperature self.tool_choice = tool_choice self.max_completion_tokens = max_completion_tokens self._cancelled = False self._client = httpx.AsyncClient( timeout=httpx.Timeout(connect=15.0, read=30.0, write=5.0, pool=5.0), follow_redirects=True, ) async def chat( self, messages: ChatContext, tools: list[FunctionTool] | None = None, **kwargs: Any ) -> AsyncIterator[LLMResponse]: self._cancelled = False def _extract_text_content(content: Union[str, List[ChatContent]]) -> str: if isinstance(content, str): return content text_parts = [part for part in content if isinstance(part, str)] return "\n".join(text_parts) system_prompt = None message_items = list(messages.items) if ( message_items and isinstance(message_items[0], ChatMessage) and message_items[0].role == ChatRole.SYSTEM ): system_prompt = { "role": "system", "content": _extract_text_content(message_items.pop(0).content), } cleaned_messages = [] last_role = None for msg in message_items: if not isinstance(msg, ChatMessage): continue current_role_str = msg.role.value if not cleaned_messages and current_role_str == 'assistant': continue text_content = _extract_text_content(msg.content) if not text_content.strip(): continue if last_role == 'user' and current_role_str == 'user': cleaned_messages[-1]['content'] += ' ' + text_content continue if last_role == current_role_str: cleaned_messages.pop() cleaned_messages.append({"role": current_role_str, "content": text_content}) last_role = current_role_str final_messages = [system_prompt] + cleaned_messages if system_prompt else cleaned_messages try: payload = { "model": self.model, "messages": final_messages, "temperature": self.temperature, "stream": True, } if self.max_completion_tokens: payload['max_tokens'] = self.max_completion_tokens payload.update(kwargs) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } async with self._client.stream("POST", SARVAM_CHAT_COMPLETION_URL, json=payload, headers=headers) as response: response.raise_for_status() current_content = "" async for line in response.aiter_lines(): if self._cancelled: break if not line.startswith("data:"): continue data_str = line[len("data:"):].strip() if not data_str: continue if data_str == "[DONE]": break chunk = json.loads(data_str) delta = chunk.get("choices", [{}])[0].get("delta", {}) if "content" in delta and delta["content"] is not None: content_chunk = delta["content"] current_content += content_chunk yield LLMResponse(content=current_content, role=ChatRole.ASSISTANT) except httpx.HTTPStatusError as e: if not self._cancelled: error_message = f"Sarvam AI API error: {e.response.status_code}" try: error_body = await e.response.aread() error_text = error_body.decode() error_message += f" - {error_text}" except Exception: pass self.emit("error", Exception(error_message)) raise except Exception as e: if not self._cancelled: traceback.print_exc() self.emit("error", e) raise async def cancel_current_generation(self) -> None: self._cancelled = True async def aclose(self) -> None: await self.cancel_current_generation() if self._client: await self._client.aclose()
Base class for LLM implementations.
Initialize the SarvamAI LLM plugin.
Args
api_key
:Optional[str]
, optional- SarvamAI API key. Defaults to None.
model
:str
- The model to use for the LLM plugin. Defaults to "sarvam-m".
temperature
:float
- The temperature to use for the LLM plugin. Defaults to 0.7.
tool_choice
:ToolChoice
- The tool choice to use for the LLM plugin. Defaults to "auto".
max_completion_tokens
:Optional[int]
, optional- The maximum completion tokens to use for the LLM plugin. Defaults to None.
Ancestors
- videosdk.agents.llm.llm.LLM
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None
-
Expand source code
async def aclose(self) -> None: await self.cancel_current_generation() if self._client: await self._client.aclose()
Cleanup resources.
async def cancel_current_generation(self) ‑> None
-
Expand source code
async def cancel_current_generation(self) -> None: self._cancelled = True
Cancel the current LLM generation if active.
Raises
NotImplementedError
- This method must be implemented by subclasses.
async def chat(self,
messages: ChatContext,
tools: list[FunctionTool] | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]-
Expand source code
async def chat( self, messages: ChatContext, tools: list[FunctionTool] | None = None, **kwargs: Any ) -> AsyncIterator[LLMResponse]: self._cancelled = False def _extract_text_content(content: Union[str, List[ChatContent]]) -> str: if isinstance(content, str): return content text_parts = [part for part in content if isinstance(part, str)] return "\n".join(text_parts) system_prompt = None message_items = list(messages.items) if ( message_items and isinstance(message_items[0], ChatMessage) and message_items[0].role == ChatRole.SYSTEM ): system_prompt = { "role": "system", "content": _extract_text_content(message_items.pop(0).content), } cleaned_messages = [] last_role = None for msg in message_items: if not isinstance(msg, ChatMessage): continue current_role_str = msg.role.value if not cleaned_messages and current_role_str == 'assistant': continue text_content = _extract_text_content(msg.content) if not text_content.strip(): continue if last_role == 'user' and current_role_str == 'user': cleaned_messages[-1]['content'] += ' ' + text_content continue if last_role == current_role_str: cleaned_messages.pop() cleaned_messages.append({"role": current_role_str, "content": text_content}) last_role = current_role_str final_messages = [system_prompt] + cleaned_messages if system_prompt else cleaned_messages try: payload = { "model": self.model, "messages": final_messages, "temperature": self.temperature, "stream": True, } if self.max_completion_tokens: payload['max_tokens'] = self.max_completion_tokens payload.update(kwargs) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } async with self._client.stream("POST", SARVAM_CHAT_COMPLETION_URL, json=payload, headers=headers) as response: response.raise_for_status() current_content = "" async for line in response.aiter_lines(): if self._cancelled: break if not line.startswith("data:"): continue data_str = line[len("data:"):].strip() if not data_str: continue if data_str == "[DONE]": break chunk = json.loads(data_str) delta = chunk.get("choices", [{}])[0].get("delta", {}) if "content" in delta and delta["content"] is not None: content_chunk = delta["content"] current_content += content_chunk yield LLMResponse(content=current_content, role=ChatRole.ASSISTANT) except httpx.HTTPStatusError as e: if not self._cancelled: error_message = f"Sarvam AI API error: {e.response.status_code}" try: error_body = await e.response.aread() error_text = error_body.decode() error_message += f" - {error_text}" except Exception: pass self.emit("error", Exception(error_message)) raise except Exception as e: if not self._cancelled: traceback.print_exc() self.emit("error", e) raise
Main method to interact with the LLM.
Args
messages
:ChatContext
- The conversation context containing message history.
tools
:list[FunctionTool] | None
, optional- List of available function tools for the LLM to use.
**kwargs
:Any
- Additional arguments specific to the LLM provider implementation.
Returns
AsyncIterator[LLMResponse]
- An async iterator yielding LLMResponse objects as they're generated.
Raises
NotImplementedError
- This method must be implemented by subclasses.
class SarvamAISTT (*,
api_key: str | None = None,
model: str = 'saarika:v2',
language: str = 'en-IN',
input_sample_rate: int = 48000,
output_sample_rate: int = 16000,
silence_threshold: float = 0.01,
silence_duration: float = 0.8)-
Expand source code
class SarvamAISTT(STT): def __init__( self, *, api_key: str | None = None, model: str = DEFAULT_MODEL, language: str = "en-IN", input_sample_rate: int = 48000, output_sample_rate: int = 16000, silence_threshold: float = 0.01, silence_duration: float = 0.8, ) -> None: """Initialize the SarvamAI STT plugin. Args: api_key (Optional[str], optional): SarvamAI API key. Defaults to None. model (str): The model to use for the STT plugin. Defaults to "saarika:v2". language (str): The language to use for the STT plugin. Defaults to "en-IN". input_sample_rate (int): The input sample rate for the STT plugin. Defaults to 48000. output_sample_rate (int): The output sample rate for the STT plugin. Defaults to 16000. silence_threshold (float): The silence threshold for the STT plugin. Defaults to 0.01. silence_duration (float): The silence duration for the STT plugin. Defaults to 0.8. """ super().__init__() if not SCIPY_AVAILABLE: raise ImportError("scipy is not installed. Please install it with 'pip install scipy'") self.api_key = api_key or os.getenv("SARVAMAI_API_KEY") if not self.api_key: raise ValueError("Sarvam AI API key must be provided either through api_key parameter or SARVAMAI_API_KEY environment variable") self.model = model self.language = language self.input_sample_rate = input_sample_rate self.output_sample_rate = output_sample_rate self.silence_threshold_bytes = int(silence_threshold * 32767) self.silence_duration_frames = int(silence_duration * self.input_sample_rate) self._http_client = httpx.AsyncClient(timeout=httpx.Timeout(connect=15.0, read=30.0, write=5.0, pool=5.0)) self._audio_buffer = bytearray() self._is_speaking = False self._silence_frames = 0 self._lock = asyncio.Lock() async def process_audio(self, audio_frames: bytes, **kwargs: Any) -> None: async with self._lock: is_silent_chunk = self._is_silent(audio_frames) if not is_silent_chunk: if not self._is_speaking: self._is_speaking = True global_event_emitter.emit("speech_started") self._audio_buffer.extend(audio_frames) self._silence_frames = 0 else: if self._is_speaking: self._silence_frames += len(audio_frames) // 4 if self._silence_frames > self.silence_duration_frames: global_event_emitter.emit("speech_stopped") asyncio.create_task(self._transcribe_buffer()) self._is_speaking = False self._silence_frames = 0 def _is_silent(self, audio_chunk: bytes) -> bool: """Simple VAD: check if the max amplitude is below a threshold.""" audio_data = np.frombuffer(audio_chunk, dtype=np.int16) return np.max(np.abs(audio_data)) < self.silence_threshold_bytes async def _transcribe_buffer(self): async with self._lock: if not self._audio_buffer: return audio_to_send = self._audio_buffer self._audio_buffer = bytearray() try: resampled_audio = self._resample_audio(audio_to_send) wav_audio = self._create_wav_in_memory(resampled_audio) headers = {"api-subscription-key": self.api_key} data = {"model": self.model, "language_code": self.language} files = {'file': ('audio.wav', wav_audio, 'audio/wav')} response = await self._http_client.post(SARVAM_STT_API_URL, headers=headers, data=data, files=files) response.raise_for_status() response_data = response.json() transcript = response_data.get("transcript", "") if transcript and self._transcript_callback: event = STTResponse( event_type=SpeechEventType.FINAL, data=SpeechData(text=transcript, language=self.language, confidence=1.0) ) await self._transcript_callback(event) except httpx.HTTPStatusError as e: self.emit("error", f"Sarvam STT API error: {e.response.status_code} - {e.response.text}") except Exception as e: self.emit("error", f"Error during transcription: {e}") def _resample_audio(self, audio_bytes: bytes) -> bytes: audio_data = np.frombuffer(audio_bytes, dtype=np.int16) resampled_data = signal.resample(audio_data, int(len(audio_data) * self.output_sample_rate / self.input_sample_rate)) return resampled_data.astype(np.int16).tobytes() def _create_wav_in_memory(self, pcm_data: bytes) -> io.BytesIO: wav_buffer = io.BytesIO() with wave.open(wav_buffer, 'wb') as wf: wf.setnchannels(2) wf.setsampwidth(2) wf.setframerate(self.output_sample_rate) wf.writeframes(pcm_data) wav_buffer.seek(0) return wav_buffer async def aclose(self) -> None: if self._is_speaking and self._audio_buffer: await self._transcribe_buffer() await asyncio.sleep(1) if self._http_client: await self._http_client.aclose()
Base class for Speech-to-Text implementations
Initialize the SarvamAI STT plugin.
Args
api_key
:Optional[str]
, optional- SarvamAI API key. Defaults to None.
model
:str
- The model to use for the STT plugin. Defaults to "saarika:v2".
language
:str
- The language to use for the STT plugin. Defaults to "en-IN".
input_sample_rate
:int
- The input sample rate for the STT plugin. Defaults to 48000.
output_sample_rate
:int
- The output sample rate for the STT plugin. Defaults to 16000.
silence_threshold
:float
- The silence threshold for the STT plugin. Defaults to 0.01.
silence_duration
:float
- The silence duration for the STT plugin. Defaults to 0.8.
Ancestors
- videosdk.agents.stt.stt.STT
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None
-
Expand source code
async def aclose(self) -> None: if self._is_speaking and self._audio_buffer: await self._transcribe_buffer() await asyncio.sleep(1) if self._http_client: await self._http_client.aclose()
Cleanup resources
async def process_audio(self, audio_frames: bytes, **kwargs: Any) ‑> None
-
Expand source code
async def process_audio(self, audio_frames: bytes, **kwargs: Any) -> None: async with self._lock: is_silent_chunk = self._is_silent(audio_frames) if not is_silent_chunk: if not self._is_speaking: self._is_speaking = True global_event_emitter.emit("speech_started") self._audio_buffer.extend(audio_frames) self._silence_frames = 0 else: if self._is_speaking: self._silence_frames += len(audio_frames) // 4 if self._silence_frames > self.silence_duration_frames: global_event_emitter.emit("speech_stopped") asyncio.create_task(self._transcribe_buffer()) self._is_speaking = False self._silence_frames = 0
Process audio frames and convert to text
Args
audio_frames
- Iterator of bytes to process
language
- Optional language code for recognition
**kwargs
- Additional provider-specific arguments
Returns
AsyncIterator yielding STTResponse objects
class SarvamAITTS (*,
api_key: str | None = None,
model: str = 'bulbul:v2',
speaker: str = 'anushka',
target_language_code: str = 'en-IN',
pitch: float = 0.0,
pace: float = 1.0,
loudness: float = 1.2,
enable_preprocessing: bool = True)-
Expand source code
class SarvamAITTS(TTS): def __init__( self, *, api_key: str | None = None, model: str = DEFAULT_MODEL, speaker: str = DEFAULT_SPEAKER, target_language_code: str = DEFAULT_TARGET_LANGUAGE, pitch: float = 0.0, pace: float = 1.0, loudness: float = 1.2, enable_preprocessing: bool = True, ) -> None: """Initialize the SarvamAI TTS plugin. Args: api_key (Optional[str], optional): SarvamAI API key. Defaults to None. model (str): The model to use for the TTS plugin. Defaults to "bulbul:v2". speaker (str): The speaker to use for the TTS plugin. Defaults to "anushka". target_language_code (str): The target language code to use for the TTS plugin. Defaults to "en-IN". pitch (float): The pitch to use for the TTS plugin. Defaults to 0.0. pace (float): The pace to use for the TTS plugin. Defaults to 1.0. loudness (float): The loudness to use for the TTS plugin. Defaults to 1.2. enable_preprocessing (bool): Whether to enable preprocessing for the TTS plugin. Defaults to True. """ super().__init__( sample_rate=SARVAMAI_SAMPLE_RATE, num_channels=SARVAMAI_CHANNELS ) self.model = model self.speaker = speaker self.target_language_code = target_language_code self.pitch = pitch self.pace = pace self.loudness = loudness self.enable_preprocessing = enable_preprocessing self.audio_track = None self.loop = None self._first_chunk_sent = False self.api_key = api_key or os.getenv("SARVAMAI_API_KEY") if not self.api_key: raise ValueError( "Sarvam AI API key required. Provide either:\n" "1. api_key parameter, OR\n" "2. SARVAMAI_API_KEY environment variable" ) self._http_client = httpx.AsyncClient( timeout=httpx.Timeout(connect=15.0, read=30.0, write=5.0, pool=5.0), follow_redirects=True, ) def reset_first_audio_tracking(self) -> None: """Reset the first audio tracking state for next TTS task""" self._first_chunk_sent = False async def synthesize( self, text: AsyncIterator[str] | str, **kwargs: Any, ) -> None: try: if not self.audio_track or not self.loop: self.emit("error", "Audio track or loop not initialized") return if isinstance(text, AsyncIterator): async for segment in segment_text(text): await self._synthesize_audio(segment) else: await self._synthesize_audio(text) except Exception as e: self.emit("error", f"Sarvam AI TTS synthesis failed: {str(e)}") async def _synthesize_audio(self, text: str) -> None: try: payload = { "inputs": [text], "target_language_code": self.target_language_code, "speaker": self.speaker, "pitch": self.pitch, "pace": self.pace, "loudness": self.loudness, "speech_sample_rate": SARVAMAI_SAMPLE_RATE, "enable_preprocessing": self.enable_preprocessing, "model": self.model, } headers = { "Accept": "application/json", "Content-Type": "application/json", "api-subscription-key": self.api_key, } response = await self._http_client.post( SARVAMAI_TTS_ENDPOINT, headers=headers, json=payload ) response.raise_for_status() response_data = response.json() if "audios" not in response_data or not response_data["audios"]: self.emit( "error", "No audio data found in response from Sarvam AI") return audio_content = response_data["audios"][0] if not audio_content: self.emit("error", "No audio content received from Sarvam AI") return audio_bytes = base64.b64decode(audio_content) if not audio_bytes: self.emit("error", "Decoded audio bytes are empty") return await self._stream_audio_chunks(audio_bytes) except httpx.HTTPStatusError as e: self.emit( "error", f"Sarvam AI TTS HTTP error: {e.response.status_code} - {e.response.text}", ) raise async def _stream_audio_chunks(self, audio_bytes: bytes) -> None: chunk_size = int(SARVAMAI_SAMPLE_RATE * SARVAMAI_CHANNELS * 2 * 20 / 1000) audio_data = self._remove_wav_header(audio_bytes) for i in range(0, len(audio_data), chunk_size): chunk = audio_data[i: i + chunk_size] if len(chunk) < chunk_size and len(chunk) > 0: padding_needed = chunk_size - len(chunk) chunk += b"\x00" * padding_needed if len(chunk) == chunk_size: if not self._first_chunk_sent and self._first_audio_callback: self._first_chunk_sent = True asyncio.create_task(self._first_audio_callback()) asyncio.create_task(self.audio_track.add_new_bytes(chunk)) await asyncio.sleep(0.001) def _remove_wav_header(self, audio_bytes: bytes) -> bytes: if audio_bytes.startswith(b"RIFF"): data_pos = audio_bytes.find(b"data") if data_pos != -1: return audio_bytes[data_pos + 8:] return audio_bytes async def aclose(self) -> None: if self._http_client: await self._http_client.aclose() await super().aclose() async def interrupt(self) -> None: if self.audio_track: self.audio_track.interrupt()
Base class for Text-to-Speech implementations
Initialize the SarvamAI TTS plugin.
Args
api_key
:Optional[str]
, optional- SarvamAI API key. Defaults to None.
model
:str
- The model to use for the TTS plugin. Defaults to "bulbul:v2".
speaker
:str
- The speaker to use for the TTS plugin. Defaults to "anushka".
target_language_code
:str
- The target language code to use for the TTS plugin. Defaults to "en-IN".
pitch
:float
- The pitch to use for the TTS plugin. Defaults to 0.0.
pace
:float
- The pace to use for the TTS plugin. Defaults to 1.0.
loudness
:float
- The loudness to use for the TTS plugin. Defaults to 1.2.
enable_preprocessing
:bool
- Whether to enable preprocessing for the TTS plugin. Defaults to True.
Ancestors
- videosdk.agents.tts.tts.TTS
- videosdk.agents.event_emitter.EventEmitter
- typing.Generic
Methods
async def aclose(self) ‑> None
-
Expand source code
async def aclose(self) -> None: if self._http_client: await self._http_client.aclose() await super().aclose()
Cleanup resources
async def interrupt(self) ‑> None
-
Expand source code
async def interrupt(self) -> None: if self.audio_track: self.audio_track.interrupt()
Interrupt the TTS process
def reset_first_audio_tracking(self) ‑> None
-
Expand source code
def reset_first_audio_tracking(self) -> None: """Reset the first audio tracking state for next TTS task""" self._first_chunk_sent = False
Reset the first audio tracking state for next TTS task
async def synthesize(self, text: AsyncIterator[str] | str, **kwargs: Any) ‑> None
-
Expand source code
async def synthesize( self, text: AsyncIterator[str] | str, **kwargs: Any, ) -> None: try: if not self.audio_track or not self.loop: self.emit("error", "Audio track or loop not initialized") return if isinstance(text, AsyncIterator): async for segment in segment_text(text): await self._synthesize_audio(segment) else: await self._synthesize_audio(text) except Exception as e: self.emit("error", f"Sarvam AI TTS synthesis failed: {str(e)}")
Convert text to speech
Args
text
- Text to convert to speech (either string or async iterator of strings)
voice_id
- Optional voice identifier
**kwargs
- Additional provider-specific arguments
Returns
None