Pular para conteúdo

Custom Providers

O arandu usa protocolos Python para injeção de dependência. Você pode usar qualquer backend de LLM ou embedding implementando duas interfaces simples - sem herança necessária.

Providers Incluídos

O SDK inclui dois providers built-in:

Provider Instalação LLM Embeddings
OpenAI pip install arandu[openai] ✅ GPT-4o, GPT-4o-mini, etc. ✅ text-embedding-3-small, etc.
Anthropic pip install arandu[anthropic] ✅ Claude Sonnet, Opus, Haiku ❌ Use OpenAI pra embeddings
# OpenAI (LLM + embeddings num provider só)
from arandu.providers.openai import OpenAIProvider
provider = OpenAIProvider(api_key="sk-...")
memory = MemoryClient(database_url="...", llm=provider, embeddings=provider)

# Anthropic (Claude pra LLM, OpenAI pra embeddings)
from arandu.providers.anthropic import AnthropicProvider
from arandu.providers.openai import OpenAIProvider
llm = AnthropicProvider(api_key="sk-ant-...")
embeddings = OpenAIProvider(api_key="sk-...")
memory = MemoryClient(database_url="...", llm=llm, embeddings=embeddings)

Providers Compatíveis com OpenAI

O OpenAIProvider funciona com qualquer API que siga o formato de chat completions da OpenAI. Basta definir base_url para apontar ao endpoint do provider:

from arandu.providers.openai import OpenAIProvider

# DeepSeek
llm = OpenAIProvider(api_key="sk-deepseek-...", model="deepseek-chat", base_url="https://api.deepseek.com/v1")

# Groq
llm = OpenAIProvider(api_key="gsk_...", model="llama-3.3-70b-versatile", base_url="https://api.groq.com/openai/v1")

# Together AI
llm = OpenAIProvider(api_key="tog_...", model="meta-llama/Llama-3.3-70B-Instruct-Turbo", base_url="https://api.together.xyz/v1")

# Fireworks AI
llm = OpenAIProvider(api_key="fw_...", model="accounts/fireworks/models/llama-v3p3-70b-instruct", base_url="https://api.fireworks.ai/inference/v1")

# Ollama (local)
llm = OpenAIProvider(api_key="ollama", model="llama3.1", base_url="http://localhost:11434/v1")

Isso cobre apenas chamadas de LLM. Embeddings ainda precisam de OpenAI ou de um EmbeddingProvider customizado, já que a maioria desses providers não oferece API de embeddings.

Se os providers incluídos atendem seu caso de uso, não precisa ler o resto desta página.


Os Protocolos

Se você precisa de um provider diferente (Ollama, LiteLLM, Groq, etc.), implemente os protocolos:

LLMProvider

from arandu.protocols import LLMResult, TokenUsage

class LLMProvider(Protocol):
    async def complete(
        self,
        messages: list[dict],
        temperature: float = 0,
        response_format: dict | None = None,
        max_tokens: int | None = None,
    ) -> LLMResult: ...
Parâmetro Descrição
messages Lista de dicts com chaves "role" e "content" (formato OpenAI)
temperature Temperatura de sampling (0 = determinístico)
response_format Especificação de formato opcional (ex: {"type": "json_object"})
max_tokens Máximo opcional de tokens para a resposta
Retorna LLMResult(text="...", usage=TokenUsage(...))

Suporte a modo JSON

O pipeline depende de respostas JSON (response_format={"type": "json_object"}). Se seu backend não suporta nativamente, appende uma instrução no system prompt.

EmbeddingProvider

class EmbeddingProvider(Protocol):
    async def embed(self, texts: list[str]) -> list[list[float]]: ...
    async def embed_one(self, text: str) -> list[float] | None: ...
Método Descrição
embed(texts) Gera embeddings para um batch de textos. Retorna um vetor por input.
embed_one(text) Gera embedding para um único texto. Retorna None se vazio/inválido.

Dimensões de embedding

O embedding_dimensions padrão é 1536 (OpenAI text-embedding-3-small). Se seu provider usa dimensões diferentes, defina MemoryConfig(embedding_dimensions=...).


Exemplo: Provider de Modelo Local

Para rodar com modelos locais (ex: via Ollama):

import httpx
from arandu.protocols import LLMResult, TokenUsage


class OllamaProvider:
    """LLM + Embedding provider usando um servidor Ollama local."""

    def __init__(
        self,
        base_url: str = "http://localhost:11434",
        model: str = "llama3.1",
        embedding_model: str = "nomic-embed-text",
    ) -> None:
        self._base_url = base_url
        self._model = model
        self._embedding_model = embedding_model
        self._client = httpx.AsyncClient(timeout=60.0)

    # -- LLMProvider --

    async def complete(
        self,
        messages: list[dict],
        temperature: float = 0,
        response_format: dict | None = None,
        max_tokens: int | None = None,
    ) -> LLMResult:
        payload: dict = {
            "model": self._model,
            "messages": messages,
            "stream": False,
            "options": {"temperature": temperature},
        }
        if response_format and response_format.get("type") == "json_object":
            payload["format"] = "json"

        response = await self._client.post(
            f"{self._base_url}/api/chat",
            json=payload,
        )
        response.raise_for_status()
        text = response.json()["message"]["content"]
        return LLMResult(text=text, usage=None)  # Ollama doesn't report usage

    # -- EmbeddingProvider --

    async def embed(self, texts: list[str]) -> list[list[float]]:
        results = []
        for text in texts:
            if not text.strip():
                continue
            response = await self._client.post(
                f"{self._base_url}/api/embed",
                json={"model": self._embedding_model, "input": text},
            )
            response.raise_for_status()
            results.append(response.json()["embeddings"][0])
        return results

    async def embed_one(self, text: str) -> list[float] | None:
        if not text or not text.strip():
            return None
        results = await self.embed([text])
        return results[0] if results else None

Dimensões de embedding

Quando usar modelos locais, configure as dimensões:

config = MemoryConfig(
    embedding_dimensions=768,  # nomic-embed-text usa 768 dims
)

Testando Seu Provider

Verifique se seu provider funciona antes de ir pra produção:

import asyncio
from arandu import MemoryClient, MemoryConfig


async def test_provider():
    provider = YourProvider(...)
    memory = MemoryClient(
        database_url="postgresql+psycopg://memory:memory@localhost/memory",
        llm=provider,
        embeddings=provider,
    )
    await memory.initialize()

    try:
        # Testar write
        result = await memory.write(
            agent_id="test",
            message="Testing the provider. My name is Alice and I work at Acme.",
        )
        assert len(result.facts_added) > 0, "No facts extracted — check LLM responses"
        assert len(result.entities_resolved) > 0, "No entities resolved"
        print(f"Write OK: {len(result.facts_added)} facts, {len(result.entities_resolved)} entities")

        # Testar retrieve
        context = await memory.retrieve(agent_id="test", query="who is Alice?")
        assert len(context.facts) > 0, "No facts retrieved — check embeddings"
        print(f"Retrieve OK: {len(context.facts)} facts found")
        print(f"Context: {context.context}")
    finally:
        await memory.close()


asyncio.run(test_provider())

Requisitos Importantes

  1. LLMResult - complete() retorna LLMResult(text=..., usage=...), não str. Se seu backend não reporta usage, passe usage=None.

  2. Modo JSON - O pipeline envia response_format={"type": "json_object"} frequentemente. Seu provider deve retornar JSON válido quando isso é definido.

  3. Async - Ambos os protocolos são async. Se o SDK do seu backend é síncrono, encapsule com asyncio.to_thread().

  4. Tratamento de vazio/erro - embed_one retorna None para input vazio. embed retorna [] para input vazio.

  5. Timeout - Adicione timeouts ao seu provider. O SDK define timeouts do lado dele, mas timeouts no provider adicionam segurança.

  6. Dimensões de embedding - Defina MemoryConfig(embedding_dimensions=N) pra corresponder ao seu provider.