Module videosdk.plugins.xai.llm

Classes

class XAILLM (*,
api_key: str | None = None,
model: str = 'grok-4-1-fast-non-reasoning',
base_url: str = 'https://api.x.ai/v1',
temperature: float = 0.7,
tool_choice: ToolChoice = 'auto',
max_completion_tokens: int | None = None,
tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None)
Expand source code
class XAILLM(LLM):
    """
    LLM Plugin for xAI (Grok) API.
    Supports Grok-4, and reasoning models with standard client-side function calling.
    """
    
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model: str = "grok-4-1-fast-non-reasoning", 
        base_url: str = "https://api.x.ai/v1",
        temperature: float = 0.7,
        tool_choice: ToolChoice = "auto",
        max_completion_tokens: int | None = None,
        tools: List[Union[FunctionTool, Dict[str, Any]]] | None = None,
    ) -> None:
        """Initialize the xAI LLM plugin.

        Args:
            api_key (Optional[str], optional): xAI API key. Defaults to XAI_API_KEY env var.
            model (str): The model to use (e.g., "grok-4", "grok-4-1-fast").
            base_url (str): The base URL for the xAI API. Defaults to "https://api.x.ai/v1".
            temperature (float): The temperature to use. Defaults to 0.7.
            tool_choice (ToolChoice): The tool choice to use. Defaults to "auto".
            max_completion_tokens (Optional[int], optional): Max tokens.
            tools (Optional[List], optional): List of FunctionTools to be available to the LLM.
        """
        super().__init__()
        self.api_key = api_key or os.getenv("XAI_API_KEY")
        if not self.api_key:
            raise ValueError("xAI API key must be provided either through api_key parameter or XAI_API_KEY environment variable")
        
        self.model = model
        self.temperature = temperature
        self.tool_choice = tool_choice
        self.max_completion_tokens = max_completion_tokens
        self.tools = tools or []
        self._cancelled = False
        
        self._client = openai.AsyncOpenAI(
            api_key=self.api_key,
            base_url=base_url,
            max_retries=0,
            http_client=httpx.AsyncClient(
                timeout=httpx.Timeout(connect=15.0, read=60.0, write=5.0, pool=5.0),
                follow_redirects=True,
                limits=httpx.Limits(
                    max_connections=50,
                    max_keepalive_connections=50,
                    keepalive_expiry=120,
                ),
            ),
        )

    async def chat(
        self,
        messages: ChatContext,
        tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
        conversational_graph: Any | None = None,
        **kwargs: Any
    ) -> AsyncIterator[LLMResponse]:
        """
        Implement chat functionality using xAI's API via OpenAI SDK compatibility.
        """
        self._cancelled = False

        openai_messages = messages.to_openai_messages()

        completion_params = {
            "model": self.model,
            "messages": openai_messages,
            "temperature": self.temperature,
            "stream": True,
            "max_tokens": self.max_completion_tokens,
        }
        
        if conversational_graph:
            completion_params["response_format"] = {
                "type": "json_schema",
                "json_schema": {
                    "name": "conversational_graph_response",
                    "strict": True,
                    "schema": conversational_graph._get_graph_schema()
                }
            }

        combined_tools = (self.tools or []) + (tools or [])
        
        if combined_tools:
            formatted_tools = []
            for tool in combined_tools:
                if is_function_tool(tool):
                    try:
                        tool_schema = build_openai_schema(tool)
                        if "function" not in tool_schema:
                            inner_tool = {k: v for k, v in tool_schema.items() if k != "type"}
                            formatted_tools.append({
                                "type": "function",
                                "function": inner_tool
                            })
                        else:
                            formatted_tools.append(tool_schema)
                    except Exception as e:
                        self.emit("error", f"Failed to format tool {tool}: {e}")
                        continue
                elif isinstance(tool, dict):
                    formatted_tools.append(tool)
            
            if formatted_tools:
                completion_params["tools"] = formatted_tools
                completion_params["tool_choice"] = self.tool_choice

        completion_params.update(kwargs)

        try:
            response_stream = await self._client.chat.completions.create(**completion_params)
            
            current_content = ""
            current_tool_calls = {} 
            streaming_state = {
                "in_response": False,
                "response_start_index": -1,
                "yielded_content_length": 0
            }

            async for chunk in response_stream:
                if self._cancelled:
                    break
                
                if not chunk.choices:
                    continue
                    
                delta = chunk.choices[0].delta
                
                if delta.tool_calls:
                    for tool_call in delta.tool_calls:
                        idx = tool_call.index
                        if idx not in current_tool_calls:
                            current_tool_calls[idx] = {
                                "id": tool_call.id or "",
                                "name": tool_call.function.name or "",
                                "arguments": tool_call.function.arguments or ""
                            }
                        else:
                            if tool_call.function.name:
                                current_tool_calls[idx]["name"] += tool_call.function.name
                            if tool_call.function.arguments:
                                current_tool_calls[idx]["arguments"] += tool_call.function.arguments

                if delta.content is not None:
                    current_content += delta.content   
                    if conversational_graph:                     
                        for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state):
                            yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT)
                    else:
                        yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT)

            if current_tool_calls and not self._cancelled:
                for idx in sorted(current_tool_calls.keys()):
                    tool_data = current_tool_calls[idx]
                    try:
                        args_str = tool_data["arguments"]
                        parsed_args = json.loads(args_str) 
                        
                        yield LLMResponse(
                            content="",
                            role=ChatRole.ASSISTANT,
                            metadata={
                                "function_call": {
                                    "name": tool_data["name"],
                                    "arguments": parsed_args
                                },
                                "tool_call_id": tool_data["id"]
                            }
                        )
                    except json.JSONDecodeError:
                        self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}")

            if current_content and not self._cancelled and conversational_graph:
                try:
                    parsed_json = json.loads(current_content.strip())
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata=parsed_json
                    )
                except json.JSONDecodeError:
                     pass

        except Exception as e:
            if not self._cancelled:
                self.emit("error", e)
            raise

    async def cancel_current_generation(self) -> None:
        self._cancelled = True

    async def aclose(self) -> None:
        """Cleanup resources"""
        await self.cancel_current_generation()
        if self._client:
            await self._client.close()
        await super().aclose()

LLM Plugin for xAI (Grok) API. Supports Grok-4, and reasoning models with standard client-side function calling.

Initialize the xAI LLM plugin.

Args

api_key : Optional[str], optional
xAI API key. Defaults to XAI_API_KEY env var.
model : str
The model to use (e.g., "grok-4", "grok-4-1-fast").
base_url : str
The base URL for the xAI API. Defaults to "https://api.x.ai/v1".
temperature : float
The temperature to use. Defaults to 0.7.
tool_choice : ToolChoice
The tool choice to use. Defaults to "auto".
max_completion_tokens : Optional[int], optional
Max tokens.
tools : Optional[List], optional
List of FunctionTools to be available to the LLM.

Ancestors

  • videosdk.agents.llm.llm.LLM
  • videosdk.agents.event_emitter.EventEmitter
  • typing.Generic

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Cleanup resources"""
    await self.cancel_current_generation()
    if self._client:
        await self._client.close()
    await super().aclose()

Cleanup resources

async def cancel_current_generation(self) ‑> None
Expand source code
async def cancel_current_generation(self) -> None:
    self._cancelled = True

Cancel the current LLM generation if active.

Raises

NotImplementedError
This method must be implemented by subclasses.
async def chat(self,
messages: ChatContext,
tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
conversational_graph: Any | None = None,
**kwargs: Any) ‑> AsyncIterator[videosdk.agents.llm.llm.LLMResponse]
Expand source code
async def chat(
    self,
    messages: ChatContext,
    tools: list[Union[FunctionTool, Dict[str, Any]]] | None = None,
    conversational_graph: Any | None = None,
    **kwargs: Any
) -> AsyncIterator[LLMResponse]:
    """
    Implement chat functionality using xAI's API via OpenAI SDK compatibility.
    """
    self._cancelled = False

    openai_messages = messages.to_openai_messages()

    completion_params = {
        "model": self.model,
        "messages": openai_messages,
        "temperature": self.temperature,
        "stream": True,
        "max_tokens": self.max_completion_tokens,
    }
    
    if conversational_graph:
        completion_params["response_format"] = {
            "type": "json_schema",
            "json_schema": {
                "name": "conversational_graph_response",
                "strict": True,
                "schema": conversational_graph._get_graph_schema()
            }
        }

    combined_tools = (self.tools or []) + (tools or [])
    
    if combined_tools:
        formatted_tools = []
        for tool in combined_tools:
            if is_function_tool(tool):
                try:
                    tool_schema = build_openai_schema(tool)
                    if "function" not in tool_schema:
                        inner_tool = {k: v for k, v in tool_schema.items() if k != "type"}
                        formatted_tools.append({
                            "type": "function",
                            "function": inner_tool
                        })
                    else:
                        formatted_tools.append(tool_schema)
                except Exception as e:
                    self.emit("error", f"Failed to format tool {tool}: {e}")
                    continue
            elif isinstance(tool, dict):
                formatted_tools.append(tool)
        
        if formatted_tools:
            completion_params["tools"] = formatted_tools
            completion_params["tool_choice"] = self.tool_choice

    completion_params.update(kwargs)

    try:
        response_stream = await self._client.chat.completions.create(**completion_params)
        
        current_content = ""
        current_tool_calls = {} 
        streaming_state = {
            "in_response": False,
            "response_start_index": -1,
            "yielded_content_length": 0
        }

        async for chunk in response_stream:
            if self._cancelled:
                break
            
            if not chunk.choices:
                continue
                
            delta = chunk.choices[0].delta
            
            if delta.tool_calls:
                for tool_call in delta.tool_calls:
                    idx = tool_call.index
                    if idx not in current_tool_calls:
                        current_tool_calls[idx] = {
                            "id": tool_call.id or "",
                            "name": tool_call.function.name or "",
                            "arguments": tool_call.function.arguments or ""
                        }
                    else:
                        if tool_call.function.name:
                            current_tool_calls[idx]["name"] += tool_call.function.name
                        if tool_call.function.arguments:
                            current_tool_calls[idx]["arguments"] += tool_call.function.arguments

            if delta.content is not None:
                current_content += delta.content   
                if conversational_graph:                     
                    for content_chunk in conversational_graph.stream_conversational_graph_response(current_content, streaming_state):
                        yield LLMResponse(content=content_chunk, role=ChatRole.ASSISTANT)
                else:
                    yield LLMResponse(content=delta.content, role=ChatRole.ASSISTANT)

        if current_tool_calls and not self._cancelled:
            for idx in sorted(current_tool_calls.keys()):
                tool_data = current_tool_calls[idx]
                try:
                    args_str = tool_data["arguments"]
                    parsed_args = json.loads(args_str) 
                    
                    yield LLMResponse(
                        content="",
                        role=ChatRole.ASSISTANT,
                        metadata={
                            "function_call": {
                                "name": tool_data["name"],
                                "arguments": parsed_args
                            },
                            "tool_call_id": tool_data["id"]
                        }
                    )
                except json.JSONDecodeError:
                    self.emit("error", f"Failed to parse function arguments for tool {tool_data['name']}")

        if current_content and not self._cancelled and conversational_graph:
            try:
                parsed_json = json.loads(current_content.strip())
                yield LLMResponse(
                    content="",
                    role=ChatRole.ASSISTANT,
                    metadata=parsed_json
                )
            except json.JSONDecodeError:
                 pass

    except Exception as e:
        if not self._cancelled:
            self.emit("error", e)
        raise

Implement chat functionality using xAI's API via OpenAI SDK compatibility.