Files
alfred/alfred/agent/llm/ollama.py
2025-12-24 07:50:09 +01:00

194 lines
6.5 KiB
Python

"""Ollama LLM client with robust error handling."""
import logging
import os
from typing import Any
import requests
from requests.exceptions import HTTPError, RequestException, Timeout
from ..config import settings
from .exceptions import LLMAPIError, LLMConfigurationError
logger = logging.getLogger(__name__)
class OllamaClient:
"""
Client for interacting with Ollama API.
Ollama runs locally and provides an OpenAI-compatible API.
Example:
>>> client = OllamaClient(model="llama3.2")
>>> messages = [{"role": "user", "content": "Hello!"}]
>>> response = client.complete(messages)
>>> print(response)
"""
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
timeout: int | None = None,
temperature: float | None = None,
):
"""
Initialize Ollama client.
Args:
base_url: Ollama API base URL (defaults to http://localhost:11434)
model: Model name to use (e.g., "llama3.2", "mistral", "codellama")
timeout: Request timeout in seconds (defaults to settings)
temperature: Temperature for generation (defaults to settings)
Raises:
LLMConfigurationError: If configuration is invalid
"""
self.base_url = base_url or os.getenv(
"OLLAMA_BASE_URL", "http://localhost:11434"
)
self.model = model or os.getenv("OLLAMA_MODEL", "llama3.2")
self.timeout = timeout or settings.request_timeout
self.temperature = (
temperature if temperature is not None else settings.temperature
)
if not self.base_url:
raise LLMConfigurationError(
"Ollama base URL is required. Set OLLAMA_BASE_URL environment variable."
)
if not self.model:
raise LLMConfigurationError(
"Ollama model is required. Set OLLAMA_MODEL environment variable."
)
logger.info(f"Ollama client initialized with model: {self.model}")
def complete( # noqa: PLR0912
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""
Generate a completion from the LLM.
Args:
messages: List of message dicts with 'role' and 'content' keys
tools: Optional list of tool specifications (OpenAI format)
Returns:
OpenAI-compatible message dict with 'role', 'content', and optionally 'tool_calls'
Raises:
LLMAPIError: If API request fails
ValueError: If messages format is invalid
"""
# Validate messages format
if not messages:
raise ValueError("Messages list cannot be empty")
for msg in messages:
if not isinstance(msg, dict):
raise ValueError(f"Each message must be a dict, got {type(msg)}")
if "role" not in msg:
raise ValueError(f"Message must have 'role' key, got {msg.keys()}")
# Allow system, user, assistant, and tool roles
if msg["role"] not in ("system", "user", "assistant", "tool"):
raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg:
raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/api/chat"
payload = {
"model": self.model,
"messages": messages,
"stream": False,
"options": {
"temperature": self.temperature,
},
}
# Add tools if provided
if tools:
payload["tools"] = tools
try:
logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(url, json=payload, timeout=self.timeout)
response.raise_for_status()
data = response.json()
# Validate response structure
if "message" not in data:
raise LLMAPIError("Invalid API response: missing 'message'")
# Return the full message dict (OpenAI format)
message = data["message"]
logger.debug(f"Received response: {message.get('content', '')[:100]}...")
return message
except Timeout as e:
logger.error(f"Request timeout after {self.timeout}s: {e}")
raise LLMAPIError(f"Request timeout after {self.timeout} seconds") from e
except HTTPError as e:
logger.error(f"HTTP error from Ollama API: {e}")
if e.response is not None:
try:
error_data = e.response.json()
error_msg = error_data.get("error", str(e))
except Exception:
error_msg = str(e)
raise LLMAPIError(f"Ollama API error: {error_msg}") from e
raise LLMAPIError(f"HTTP error: {e}") from e
except RequestException as e:
logger.error(f"Request failed: {e}")
raise LLMAPIError(f"Failed to connect to Ollama API: {e}") from e
except (KeyError, IndexError, TypeError) as e:
logger.error(f"Failed to parse API response: {e}")
raise LLMAPIError(f"Invalid API response format: {e}") from e
def list_models(self) -> list[str]:
"""
List available models in Ollama.
Returns:
List of model names
"""
url = f"{self.base_url}/api/tags"
try:
response = requests.get(url, timeout=self.timeout)
response.raise_for_status()
data = response.json()
models = [model["name"] for model in data.get("models", [])]
logger.info(f"Found {len(models)} models: {models}")
return models
except Exception as e:
logger.error(f"Failed to list models: {e}")
return []
def is_available(self) -> bool:
"""
Check if Ollama is running and accessible.
Returns:
True if Ollama is available, False otherwise
"""
try:
url = f"{self.base_url}/api/tags"
response = requests.get(url, timeout=5)
return response.status_code == 200
except Exception:
return False