Creating a Custom LLM Driver
LarAgent allows you to integrate with various LLM providers by creating custom drivers. This guide will walk you through the process of creating a custom LLM driver for a new provider, similar to the existing OpenAI driver but tailored to your specific LLM provider.Understanding the LLM Driver Architecture
The LarAgent framework uses a driver-based architecture for LLM integrations:- LlmDriver Interface (LarAgent\Core\Contracts\LlmDriver): Defines the contract all LLM drivers must implement
- Abstract LlmDriver (LarAgent\Core\Abstractions\LlmDriver): Provides common functionality for all drivers
- Concrete Drivers: Implement provider-specific logic (e.g., OpenAiDriver)
Creating Your Custom Driver
The code bellow is a simplified example of a custom driver implementation. It is not a complete implementation and is intended for educational purposes only.
Check the real drivers here.
Step 1: Create the Driver Class
First, create a new directory for your provider, then create your driver class:Copy
<?php
namespace App\Drivers\YourProvider;
use LarAgent\Core\Abstractions\LlmDriver;
use LarAgent\Core\Contracts\LlmDriver as LlmDriverInterface;
use LarAgent\Core\Contracts\ToolCall as ToolCallInterface;
use LarAgent\Messages\AssistantMessage;
use LarAgent\Messages\ToolCallMessage;
use LarAgent\ToolCall;
class YourProviderDriver extends LlmDriver implements LlmDriverInterface
{
    protected mixed $client;
    
    public function __construct(array $settings = [])
    {
        parent::__construct($settings);
        
        // Initialize your provider's client/SDK
        // Example:
        $this->client = $this->initializeClient($settings);
    }
    
    /**
     * Initialize the client for your LLM provider
     */
    protected function initializeClient(array $settings): mixed
    {
        // Example implementation:
        $apiKey = $settings['api_key'] ?? null;
        if (!$apiKey) {
            return null;
        }
        
        // Return your initialized client
        // This will depend on your provider's SDK
        return new YourProviderClient($apiKey);
    }
    
    // Implement required methods...
}
Step 2: Implement Required Methods
2.1 Send Message Method
This is the core method for sending messages to the LLM and receiving responses:Copy
/**
 * Send a message to the LLM and receive a response.
 *
 * @param  array  $messages  Array of messages to send
 * @param  array  $options  Configuration options
 * @return AssistantMessage The response from the LLM
 *
 * @throws \Exception
 */
public function sendMessage(array $messages, array $options = []): AssistantMessage
{
    if (empty($this->client)) {
        throw new \Exception('API key is required to use the YourProvider driver.');
    }
    // Prepare the payload with common settings
    $payload = $this->preparePayload($messages, $options);
    // Make an API call to your provider
    $response = $this->client->createCompletion($payload);
    $this->lastResponse = $response;
    // Handle the response based on your provider's response format
    // For example, if your provider supports tool calls:
    if ($this->isToolCallResponse($response)) {
        // Extract tool calls from the response
        $toolCalls = $this->extractToolCalls($response);
        
        // Build tool calls message
        $message = $this->toolCallsToMessage($toolCalls);
        
        return new ToolCallMessage($toolCalls, $message, $this->getResponseMetadata($response));
    }
    
    // For regular text responses:
    $content = $this->extractContent($response);
    
    return new AssistantMessage($content, $this->getResponseMetadata($response));
}
2.2 Tool Result to Message Method
This method formats tool results for the LLM:Copy
/**
 * Convert a tool result to a message format for the LLM
 *
 * @param  ToolCallInterface  $toolCall  The tool call
 * @param  mixed  $result  The result from the tool
 * @return array The formatted message
 */
public function toolResultToMessage(ToolCallInterface $toolCall, mixed $result): array
{
    // Format depends on your provider's expected format
    // Example for OpenAI-compatible format:
    $content = json_decode($toolCall->getArguments(), true);
    $content[$toolCall->getToolName()] = $result;
    return [
        'role' => 'tool',
        'content' => json_encode($content),
        'tool_call_id' => $toolCall->getId(),
    ];
}
2.3 Tool Calls to Message Method
This method formats tool calls for the LLM:Copy
/**
 * Convert tool calls to a message format for the LLM
 *
 * @param  array  $toolCalls  Array of tool calls
 * @return array The formatted message
 */
public function toolCallsToMessage(array $toolCalls): array
{
    $toolCallsArray = [];
    foreach ($toolCalls as $tc) {
        $toolCallsArray[] = $this->toolCallToContent($tc);
    }
    // Format depends on your provider's expected format
    // Example for OpenAI-compatible format:
    return [
        'role' => 'assistant',
        'tool_calls' => $toolCallsArray,
    ];
}
Step 2.4 Tool Call to Content Method
This method formats a tool call for the LLM:Copy
/**
 * Format a tool call for your provider's API payload
 */
public function formatToolForPayload(ToolInterface $tool): array
{
    // Override the default implementation if your provider has a different format
    // Example for a provider with a different tool format:
    return [
        'name' => $tool->getName(),
        'description' => $tool->getDescription(),
        'parameters' => $tool->getProperties(),
        'required_params' => $tool->getRequired(),
    ];
}
2.5 Streamed Message Method
For providers that support streaming:Copy
/**
 * Send a message to the LLM and receive a streamed response.
 *
 * @param  array  $messages  Array of messages to send
 * @param  array  $options  Configuration options
 * @param  callable|null  $callback  Optional callback function to process each chunk
 * @return \Generator A generator that yields chunks of the response
 *
 * @throws \Exception
 */
public function sendMessageStreamed(array $messages, array $options = [], ?callable $callback = null): \Generator
{
    if (empty($this->client)) {
        throw new \Exception('API key is required to use the YourProvider driver.');
    }
    // Prepare the payload with common settings
    $payload = $this->preparePayload($messages, $options);
    
    // Add stream-specific options
    $payload['stream'] = true;
    // Create a streamed response
    $stream = $this->client->createCompletionStream($payload);
    
    // Initialize a streamed message
    $streamedMessage = new StreamedAssistantMessage;
    
    // Process the stream according to your provider's format
    foreach ($stream as $chunk) {
        $this->lastResponse = $chunk;
        
        // Process the chunk and update the message
        // This will depend on your provider's streaming format
        $this->processStreamChunk($chunk, $streamedMessage);
        
        // Execute callback if provided
        if ($callback) {
            $callback($streamedMessage);
        }
        
        // Yield the updated message
        yield $streamedMessage;
    }
}
Step 3: Implement Helper Methods
These methods help with the core functionality:Copy
/**
 * Prepare the payload for API request with common settings
 */
protected function preparePayload(array $messages, array $options = []): array
{
    // Add model if from provider data if not provided via options
    if (empty($options['model'])) {
        $options['model'] = $this->getSettings()['model'] ?? 'default-model';
    }
    $this->setConfig($options);
    $payload = array_merge($this->getConfig(), [
        'messages' => $this->formatMessages($messages),
    ]);
    // Set the response format if structured output is enabled
    if ($this->structuredOutputEnabled()) {
        $payload['response_format'] = $this->formatResponseSchema($this->getResponseSchema());
    }
    // Add tools to payload if any are registered
    if (!empty($this->tools)) {
        $payload['tools'] = array_map(
            fn($tool) => $this->formatToolForPayload($tool),
            $this->getRegisteredTools()
        );
    }
    return $payload;
}
/**
 * Format messages for your provider's expected format
 */
protected function formatMessages(array $messages): array
{
    // Transform LarAgent message format to your provider's format if needed
    // Return the formatted messages
    return $messages;
}
/**
 * Format the response schema for your provider
 */
protected function formatResponseSchema(array $schema): array
{
    // Transform the schema to your provider's expected format
    return [
        'type' => 'json_schema',
        'schema' => $schema,
    ];
}
Step 4: Implement Provider-Specific Methods
These are methods specific to your provider’s API:Copy
/**
 * Check if a response contains tool calls
 */
protected function isToolCallResponse($response): bool
{
    // Implement based on your provider's response format
    // Example:
    return isset($response->tool_calls) && !empty($response->tool_calls);
}
/**
 * Extract tool calls from a response
 */
protected function extractToolCalls($response): array
{
    // Implement based on your provider's response format
    $toolCalls = [];
    
    foreach ($response->tool_calls as $tc) {
        $toolCalls[] = new ToolCall(
            $tc->id ?? 'tool_call_' . uniqid(),
            $tc->name ?? $tc->function->name ?? '',
            $tc->arguments ?? $tc->function->arguments ?? '{}'
        );
    }
    
    return $toolCalls;
}
/**
 * Extract content from a response
 */
protected function extractContent($response): string
{
    // Implement based on your provider's response format
    // Example:
    return $response->choices[0]->message->content ?? '';
}
/**
 * Get metadata from a response
 */
protected function getResponseMetadata($response): array
{
    // Extract usage information or other metadata
    // Example:
    return [
        'usage' => [
            'prompt_tokens' => $response->usage->prompt_tokens ?? 0,
            'completion_tokens' => $response->usage->completion_tokens ?? 0,
            'total_tokens' => $response->usage->total_tokens ?? 0,
        ],
    ];
}
/**
 * Process a stream chunk
 */
protected function processStreamChunk($chunk, StreamedAssistantMessage $message): void
{
    // Implement based on your provider's streaming format
    // Example:
    if (isset($chunk->content)) {
        $message->appendContent($chunk->content);
    }
    
    if (isset($chunk->usage)) {
        $message->setUsage([
            'prompt_tokens' => $chunk->usage->prompt_tokens ?? 0,
            'completion_tokens' => $chunk->usage->completion_tokens ?? 0,
            'total_tokens' => $chunk->usage->total_tokens ?? 0,
        ]);
        $message->setComplete(true);
    }
}
Testing Your Driver
Create tests for your driver to ensure it works correctly:Copy
<?php
namespace Tests\Drivers\YourProvider;
use LarAgent\Drivers\YourProvider\YourProviderDriver;
use LarAgent\Messages\AssistantMessage;
use LarAgent\Messages\ToolCallMessage;
use PHPUnit\Framework\TestCase;
class YourProviderDriverTest extends TestCase
{
    protected YourProviderDriver $driver;
    
    protected function setUp(): void
    {
        $this->driver = new YourProviderDriver([
            'api_key' => 'test_key',
            'model' => 'test_model',
        ]);
    }
    
    public function testSendMessage()
    {
        // Mock your provider's client response
        $this->mockClientResponse();
        
        $messages = [
            ['role' => 'system', 'content' => 'You are a helpful assistant.'],
            ['role' => 'user', 'content' => 'Hello!'],
        ];
        
        $response = $this->driver->sendMessage($messages);
        
        $this->assertInstanceOf(AssistantMessage::class, $response);
        $this->assertEquals('Hello! How can I help you today?', $response->getContent());
    }
    
    public function testSendMessageWithToolCalls()
    {
        // Mock your provider's client response for tool calls
        $this->mockClientToolCallResponse();
        
        $messages = [
            ['role' => 'user', 'content' => 'What\'s the weather?'],
        ];
        
        $response = $this->driver->sendMessage($messages);
        
        $this->assertInstanceOf(ToolCallMessage::class, $response);
        $this->assertEquals('get_weather', $response->getToolCalls()[0]->getToolName());
    }
    
    // Add more tests for other methods
}
Registering Your Driver
To make your driver available in the LarAgent framework, you’ll need to register it:In Laravel
Add your driver to the configuration file:Copy
// config/laragent.php
return [
    // ...
    'providers' => [
        'your-provider' => [
            'label' => 'your-provider-name',
            'driver' => \App\Drivers\YourProvider\YourProviderDriver::class,
            'api_key' => env('YOUR_PROVIDER_API_KEY'),
            'model' => 'your-default-model',
            // Other provider-specific settings
        ],
    ],
];
In Agent Class
Copy
namespace App\AiAgents;
use LarAgent\Agent;
class YourAgent extends Agent
{
    protected $driver = \App\Drivers\YourProvider\YourProviderDriver::class;
    // ...
}
Best Practices
- Error Handling: Implement robust error handling for API calls
- Rate Limiting: Consider implementing rate limiting or retry logic
- Logging: Add logging for debugging purposes
- Configuration: Make your driver configurable with sensible defaults
- Documentation: Document your driver’s capabilities and limitations
Complete Example Implementation
Here’s a simplified example of a complete driver implementation:Copy
<?php
namespace LarAgent\Drivers\YourProvider;
use LarAgent\Core\Abstractions\LlmDriver;
use LarAgent\Core\Contracts\LlmDriver as LlmDriverInterface;
use LarAgent\Core\Contracts\Tool as ToolInterface;
use LarAgent\Core\Contracts\ToolCall as ToolCallInterface;
use LarAgent\Messages\AssistantMessage;
use LarAgent\Messages\StreamedAssistantMessage;
use LarAgent\Messages\ToolCallMessage;
use LarAgent\ToolCall;
use YourProvider\Client as YourProviderClient;
class YourProviderDriver extends LlmDriver implements LlmDriverInterface
{
    protected mixed $client;
    
    public function __construct(array $settings = [])
    {
        parent::__construct($settings);
        
        $apiKey = $settings['api_key'] ?? null;
        $this->client = $apiKey ? new YourProviderClient($apiKey) : null;
    }
    
    public function sendMessage(array $messages, array $options = []): AssistantMessage
    {
        if (empty($this->client)) {
            throw new \Exception('API key is required to use the YourProvider driver.');
        }
        $payload = $this->preparePayload($messages, $options);
        $response = $this->client->createCompletion($payload);
        $this->lastResponse = $response;
        
        if (isset($response->tool_calls) && !empty($response->tool_calls)) {
            $toolCalls = [];
            foreach ($response->tool_calls as $tc) {
                $toolCalls[] = new ToolCall(
                    $tc->id ?? 'tool_call_' . uniqid(),
                    $tc->function->name ?? '',
                    $tc->function->arguments ?? '{}'
                );
            }
            
            $message = $this->toolCallsToMessage($toolCalls);
            return new ToolCallMessage($toolCalls, $message, ['usage' => $response->usage]);
        }
        
        $content = $response->choices[0]->message->content ?? '';
        return new AssistantMessage($content, ['usage' => $response->usage]);
    }
    
    public function sendMessageStreamed(array $messages, array $options = [], ?callable $callback = null): \Generator
    {
        if (empty($this->client)) {
            throw new \Exception('API key is required to use the YourProvider driver.');
        }
        $payload = $this->preparePayload($messages, $options);
        $payload['stream'] = true;
        
        $stream = $this->client->createCompletionStream($payload);
        $streamedMessage = new StreamedAssistantMessage;
        
        foreach ($stream as $chunk) {
            $this->lastResponse = $chunk;
            
            if (isset($chunk->content)) {
                $streamedMessage->appendContent($chunk->content);
            }
            
            if (isset($chunk->usage)) {
                $streamedMessage->setUsage([
                    'prompt_tokens' => $chunk->usage->prompt_tokens,
                    'completion_tokens' => $chunk->usage->completion_tokens,
                    'total_tokens' => $chunk->usage->total_tokens,
                ]);
                $streamedMessage->setComplete(true);
            }
            
            if ($callback) {
                $callback($streamedMessage);
            }
            
            yield $streamedMessage;
        }
    }
    
    public function toolResultToMessage(ToolCallInterface $toolCall, mixed $result): array
    {
        $content = json_decode($toolCall->getArguments(), true);
        $content[$toolCall->getToolName()] = $result;
        return [
            'role' => 'tool',
            'content' => json_encode($content),
            'tool_call_id' => $toolCall->getId(),
        ];
    }
    
    public function toolCallsToMessage(array $toolCalls): array
    {
        $toolCallsArray = [];
        foreach ($toolCalls as $tc) {
            $toolCallsArray[] = [
                'id' => $tc->getId(),
                'type' => 'function',
                'function' => [
                    'name' => $tc->getToolName(),
                    'arguments' => $tc->getArguments(),
                ],
            ];
        }
        return [
            'role' => 'assistant',
            'tool_calls' => $toolCallsArray,
        ];
    }
    
    protected function preparePayload(array $messages, array $options = []): array
    {
        if (empty($options['model'])) {
            $options['model'] = $this->getSettings()['model'] ?? 'default-model';
        }
        $this->setConfig($options);
        $payload = array_merge($this->getConfig(), [
            'messages' => $messages,
        ]);
        if ($this->structuredOutputEnabled()) {
            $payload['response_format'] = [
                'type' => 'json_schema',
                'schema' => $this->getResponseSchema(),
            ];
        }
        if (!empty($this->tools)) {
            foreach ($this->getRegisteredTools() as $tool) {
                $payload['tools'][] = $this->formatToolForPayload($tool);
            }
        }
        return $payload;
    }
}

