Creating a Custom LLM Driver

LarAgent allows you to integrate with various LLM providers by creating custom drivers. This guide will walk you through the process of creating a custom LLM driver for a new provider, similar to the existing OpenAI driver but tailored to your specific LLM provider.

Understanding the LLM Driver Architecture

The LarAgent framework uses a driver-based architecture for LLM integrations:

  • LlmDriver Interface (LarAgent\Core\Contracts\LlmDriver): Defines the contract all LLM drivers must implement
  • Abstract LlmDriver (LarAgent\Core\Abstractions\LlmDriver): Provides common functionality for all drivers
  • Concrete Drivers: Implement provider-specific logic (e.g., OpenAiDriver)

Creating Your Custom Driver

The code bellow is a simplified example of a custom driver implementation. It is not a complete implementation and is intended for educational purposes only. Check the real drivers here.

Step 1: Create the Driver Class

First, create a new directory for your provider, then create your driver class:

<?php

namespace App\Drivers\YourProvider;

use LarAgent\Core\Abstractions\LlmDriver;
use LarAgent\Core\Contracts\LlmDriver as LlmDriverInterface;
use LarAgent\Core\Contracts\ToolCall as ToolCallInterface;
use LarAgent\Messages\AssistantMessage;
use LarAgent\Messages\ToolCallMessage;
use LarAgent\ToolCall;

class YourProviderDriver extends LlmDriver implements LlmDriverInterface
{
    protected mixed $client;
    
    public function __construct(array $settings = [])
    {
        parent::__construct($settings);
        
        // Initialize your provider's client/SDK
        // Example:
        $this->client = $this->initializeClient($settings);
    }
    
    /**
     * Initialize the client for your LLM provider
     */
    protected function initializeClient(array $settings): mixed
    {
        // Example implementation:
        $apiKey = $settings['api_key'] ?? null;
        if (!$apiKey) {
            return null;
        }
        
        // Return your initialized client
        // This will depend on your provider's SDK
        return new YourProviderClient($apiKey);
    }
    
    // Implement required methods...
}

Step 2: Implement Required Methods

2.1 Send Message Method

This is the core method for sending messages to the LLM and receiving responses:

/**
 * Send a message to the LLM and receive a response.
 *
 * @param  array  $messages  Array of messages to send
 * @param  array  $options  Configuration options
 * @return AssistantMessage The response from the LLM
 *
 * @throws \Exception
 */
public function sendMessage(array $messages, array $options = []): AssistantMessage
{
    if (empty($this->client)) {
        throw new \Exception('API key is required to use the YourProvider driver.');
    }

    // Prepare the payload with common settings
    $payload = $this->preparePayload($messages, $options);

    // Make an API call to your provider
    $response = $this->client->createCompletion($payload);
    $this->lastResponse = $response;

    // Handle the response based on your provider's response format
    // For example, if your provider supports tool calls:
    if ($this->isToolCallResponse($response)) {
        // Extract tool calls from the response
        $toolCalls = $this->extractToolCalls($response);
        
        // Build tool calls message
        $message = $this->toolCallsToMessage($toolCalls);
        
        return new ToolCallMessage($toolCalls, $message, $this->getResponseMetadata($response));
    }
    
    // For regular text responses:
    $content = $this->extractContent($response);
    
    return new AssistantMessage($content, $this->getResponseMetadata($response));
}

2.2 Tool Result to Message Method

This method formats tool results for the LLM:

/**
 * Convert a tool result to a message format for the LLM
 *
 * @param  ToolCallInterface  $toolCall  The tool call
 * @param  mixed  $result  The result from the tool
 * @return array The formatted message
 */
public function toolResultToMessage(ToolCallInterface $toolCall, mixed $result): array
{
    // Format depends on your provider's expected format
    // Example for OpenAI-compatible format:
    $content = json_decode($toolCall->getArguments(), true);
    $content[$toolCall->getToolName()] = $result;

    return [
        'role' => 'tool',
        'content' => json_encode($content),
        'tool_call_id' => $toolCall->getId(),
    ];
}

2.3 Tool Calls to Message Method

This method formats tool calls for the LLM:

/**
 * Convert tool calls to a message format for the LLM
 *
 * @param  array  $toolCalls  Array of tool calls
 * @return array The formatted message
 */
public function toolCallsToMessage(array $toolCalls): array
{
    $toolCallsArray = [];
    foreach ($toolCalls as $tc) {
        $toolCallsArray[] = $this->toolCallToContent($tc);
    }

    // Format depends on your provider's expected format
    // Example for OpenAI-compatible format:
    return [
        'role' => 'assistant',
        'tool_calls' => $toolCallsArray,
    ];
}

Step 2.4 Tool Call to Content Method

This method formats a tool call for the LLM:

/**
 * Format a tool call for your provider's API payload
 */
public function formatToolForPayload(ToolInterface $tool): array
{
    // Override the default implementation if your provider has a different format
    // Example for a provider with a different tool format:
    return [
        'name' => $tool->getName(),
        'description' => $tool->getDescription(),
        'parameters' => $tool->getProperties(),
        'required_params' => $tool->getRequired(),
    ];
}

2.5 Streamed Message Method

For providers that support streaming:

/**
 * Send a message to the LLM and receive a streamed response.
 *
 * @param  array  $messages  Array of messages to send
 * @param  array  $options  Configuration options
 * @param  callable|null  $callback  Optional callback function to process each chunk
 * @return \Generator A generator that yields chunks of the response
 *
 * @throws \Exception
 */
public function sendMessageStreamed(array $messages, array $options = [], ?callable $callback = null): \Generator
{
    if (empty($this->client)) {
        throw new \Exception('API key is required to use the YourProvider driver.');
    }

    // Prepare the payload with common settings
    $payload = $this->preparePayload($messages, $options);
    
    // Add stream-specific options
    $payload['stream'] = true;

    // Create a streamed response
    $stream = $this->client->createCompletionStream($payload);
    
    // Initialize a streamed message
    $streamedMessage = new StreamedAssistantMessage;
    
    // Process the stream according to your provider's format
    foreach ($stream as $chunk) {
        $this->lastResponse = $chunk;
        
        // Process the chunk and update the message
        // This will depend on your provider's streaming format
        $this->processStreamChunk($chunk, $streamedMessage);
        
        // Execute callback if provided
        if ($callback) {
            $callback($streamedMessage);
        }
        
        // Yield the updated message
        yield $streamedMessage;
    }
}

Step 3: Implement Helper Methods

These methods help with the core functionality:


/**
 * Prepare the payload for API request with common settings
 */
protected function preparePayload(array $messages, array $options = []): array
{
    // Add model if from provider data if not provided via options
    if (empty($options['model'])) {
        $options['model'] = $this->getSettings()['model'] ?? 'default-model';
    }

    $this->setConfig($options);

    $payload = array_merge($this->getConfig(), [
        'messages' => $this->formatMessages($messages),
    ]);

    // Set the response format if structured output is enabled
    if ($this->structuredOutputEnabled()) {
        $payload['response_format'] = $this->formatResponseSchema($this->getResponseSchema());
    }

    // Add tools to payload if any are registered
    if (!empty($this->tools)) {
        $payload['tools'] = array_map(
            fn($tool) => $this->formatToolForPayload($tool),
            $this->getRegisteredTools()
        );
    }

    return $payload;
}

/**
 * Format messages for your provider's expected format
 */
protected function formatMessages(array $messages): array
{
    // Transform LarAgent message format to your provider's format if needed
    // Return the formatted messages
    return $messages;
}

/**
 * Format the response schema for your provider
 */
protected function formatResponseSchema(array $schema): array
{
    // Transform the schema to your provider's expected format
    return [
        'type' => 'json_schema',
        'schema' => $schema,
    ];
}

Step 4: Implement Provider-Specific Methods

These are methods specific to your provider’s API:

/**
 * Check if a response contains tool calls
 */
protected function isToolCallResponse($response): bool
{
    // Implement based on your provider's response format
    // Example:
    return isset($response->tool_calls) && !empty($response->tool_calls);
}

/**
 * Extract tool calls from a response
 */
protected function extractToolCalls($response): array
{
    // Implement based on your provider's response format
    $toolCalls = [];
    
    foreach ($response->tool_calls as $tc) {
        $toolCalls[] = new ToolCall(
            $tc->id ?? 'tool_call_' . uniqid(),
            $tc->name ?? $tc->function->name ?? '',
            $tc->arguments ?? $tc->function->arguments ?? '{}'
        );
    }
    
    return $toolCalls;
}

/**
 * Extract content from a response
 */
protected function extractContent($response): string
{
    // Implement based on your provider's response format
    // Example:
    return $response->choices[0]->message->content ?? '';
}

/**
 * Get metadata from a response
 */
protected function getResponseMetadata($response): array
{
    // Extract usage information or other metadata
    // Example:
    return [
        'usage' => [
            'prompt_tokens' => $response->usage->prompt_tokens ?? 0,
            'completion_tokens' => $response->usage->completion_tokens ?? 0,
            'total_tokens' => $response->usage->total_tokens ?? 0,
        ],
    ];
}

/**
 * Process a stream chunk
 */
protected function processStreamChunk($chunk, StreamedAssistantMessage $message): void
{
    // Implement based on your provider's streaming format
    // Example:
    if (isset($chunk->content)) {
        $message->appendContent($chunk->content);
    }
    
    if (isset($chunk->usage)) {
        $message->setUsage([
            'prompt_tokens' => $chunk->usage->prompt_tokens ?? 0,
            'completion_tokens' => $chunk->usage->completion_tokens ?? 0,
            'total_tokens' => $chunk->usage->total_tokens ?? 0,
        ]);
        $message->setComplete(true);
    }
}

Testing Your Driver

Create tests for your driver to ensure it works correctly:

<?php

namespace Tests\Drivers\YourProvider;

use LarAgent\Drivers\YourProvider\YourProviderDriver;
use LarAgent\Messages\AssistantMessage;
use LarAgent\Messages\ToolCallMessage;
use PHPUnit\Framework\TestCase;

class YourProviderDriverTest extends TestCase
{
    protected YourProviderDriver $driver;
    
    protected function setUp(): void
    {
        $this->driver = new YourProviderDriver([
            'api_key' => 'test_key',
            'model' => 'test_model',
        ]);
    }
    
    public function testSendMessage()
    {
        // Mock your provider's client response
        $this->mockClientResponse();
        
        $messages = [
            ['role' => 'system', 'content' => 'You are a helpful assistant.'],
            ['role' => 'user', 'content' => 'Hello!'],
        ];
        
        $response = $this->driver->sendMessage($messages);
        
        $this->assertInstanceOf(AssistantMessage::class, $response);
        $this->assertEquals('Hello! How can I help you today?', $response->getContent());
    }
    
    public function testSendMessageWithToolCalls()
    {
        // Mock your provider's client response for tool calls
        $this->mockClientToolCallResponse();
        
        $messages = [
            ['role' => 'user', 'content' => 'What\'s the weather?'],
        ];
        
        $response = $this->driver->sendMessage($messages);
        
        $this->assertInstanceOf(ToolCallMessage::class, $response);
        $this->assertEquals('get_weather', $response->getToolCalls()[0]->getToolName());
    }
    
    // Add more tests for other methods
}

Registering Your Driver

To make your driver available in the LarAgent framework, you’ll need to register it:

In Laravel

Add your driver to the configuration file:

// config/laragent.php
return [
    // ...
    'providers' => [
        'your-provider' => [
            'label' => 'your-provider-name',
            'driver' => \App\Drivers\YourProvider\YourProviderDriver::class,
            'api_key' => env('YOUR_PROVIDER_API_KEY'),
            'model' => 'your-default-model',
            // Other provider-specific settings
        ],
    ],
];

In Agent Class

namespace App\AiAgents;

use LarAgent\Agent;

class YourAgent extends Agent
{
    protected $driver = \App\Drivers\YourProvider\YourProviderDriver::class;
    // ...
}

Best Practices

  1. Error Handling: Implement robust error handling for API calls
  2. Rate Limiting: Consider implementing rate limiting or retry logic
  3. Logging: Add logging for debugging purposes
  4. Configuration: Make your driver configurable with sensible defaults
  5. Documentation: Document your driver’s capabilities and limitations

Complete Example Implementation

Here’s a simplified example of a complete driver implementation:

<?php

namespace LarAgent\Drivers\YourProvider;

use LarAgent\Core\Abstractions\LlmDriver;
use LarAgent\Core\Contracts\LlmDriver as LlmDriverInterface;
use LarAgent\Core\Contracts\Tool as ToolInterface;
use LarAgent\Core\Contracts\ToolCall as ToolCallInterface;
use LarAgent\Messages\AssistantMessage;
use LarAgent\Messages\StreamedAssistantMessage;
use LarAgent\Messages\ToolCallMessage;
use LarAgent\ToolCall;
use YourProvider\Client as YourProviderClient;

class YourProviderDriver extends LlmDriver implements LlmDriverInterface
{
    protected mixed $client;
    
    public function __construct(array $settings = [])
    {
        parent::__construct($settings);
        
        $apiKey = $settings['api_key'] ?? null;
        $this->client = $apiKey ? new YourProviderClient($apiKey) : null;
    }
    
    public function sendMessage(array $messages, array $options = []): AssistantMessage
    {
        if (empty($this->client)) {
            throw new \Exception('API key is required to use the YourProvider driver.');
        }

        $payload = $this->preparePayload($messages, $options);
        $response = $this->client->createCompletion($payload);
        $this->lastResponse = $response;
        
        if (isset($response->tool_calls) && !empty($response->tool_calls)) {
            $toolCalls = [];
            foreach ($response->tool_calls as $tc) {
                $toolCalls[] = new ToolCall(
                    $tc->id ?? 'tool_call_' . uniqid(),
                    $tc->function->name ?? '',
                    $tc->function->arguments ?? '{}'
                );
            }
            
            $message = $this->toolCallsToMessage($toolCalls);
            return new ToolCallMessage($toolCalls, $message, ['usage' => $response->usage]);
        }
        
        $content = $response->choices[0]->message->content ?? '';
        return new AssistantMessage($content, ['usage' => $response->usage]);
    }
    
    public function sendMessageStreamed(array $messages, array $options = [], ?callable $callback = null): \Generator
    {
        if (empty($this->client)) {
            throw new \Exception('API key is required to use the YourProvider driver.');
        }

        $payload = $this->preparePayload($messages, $options);
        $payload['stream'] = true;
        
        $stream = $this->client->createCompletionStream($payload);
        $streamedMessage = new StreamedAssistantMessage;
        
        foreach ($stream as $chunk) {
            $this->lastResponse = $chunk;
            
            if (isset($chunk->content)) {
                $streamedMessage->appendContent($chunk->content);
            }
            
            if (isset($chunk->usage)) {
                $streamedMessage->setUsage([
                    'prompt_tokens' => $chunk->usage->prompt_tokens,
                    'completion_tokens' => $chunk->usage->completion_tokens,
                    'total_tokens' => $chunk->usage->total_tokens,
                ]);
                $streamedMessage->setComplete(true);
            }
            
            if ($callback) {
                $callback($streamedMessage);
            }
            
            yield $streamedMessage;
        }
    }
    
    public function toolResultToMessage(ToolCallInterface $toolCall, mixed $result): array
    {
        $content = json_decode($toolCall->getArguments(), true);
        $content[$toolCall->getToolName()] = $result;

        return [
            'role' => 'tool',
            'content' => json_encode($content),
            'tool_call_id' => $toolCall->getId(),
        ];
    }
    
    public function toolCallsToMessage(array $toolCalls): array
    {
        $toolCallsArray = [];
        foreach ($toolCalls as $tc) {
            $toolCallsArray[] = [
                'id' => $tc->getId(),
                'type' => 'function',
                'function' => [
                    'name' => $tc->getToolName(),
                    'arguments' => $tc->getArguments(),
                ],
            ];
        }

        return [
            'role' => 'assistant',
            'tool_calls' => $toolCallsArray,
        ];
    }
    
    protected function preparePayload(array $messages, array $options = []): array
    {
        if (empty($options['model'])) {
            $options['model'] = $this->getSettings()['model'] ?? 'default-model';
        }

        $this->setConfig($options);

        $payload = array_merge($this->getConfig(), [
            'messages' => $messages,
        ]);

        if ($this->structuredOutputEnabled()) {
            $payload['response_format'] = [
                'type' => 'json_schema',
                'schema' => $this->getResponseSchema(),
            ];
        }

        if (!empty($this->tools)) {
            foreach ($this->getRegisteredTools() as $tool) {
                $payload['tools'][] = $this->formatToolForPayload($tool);
            }
        }

        return $payload;
    }
}

Conclusion

By following this guide, you can create a custom LLM driver for any provider and integrate it with the LarAgent framework. This allows you to leverage the full power of LarAgent with your preferred LLM provider while maintaining compatibility with the existing architecture. For more details, see the real drivers here.