254 lines
8.2 KiB
Python
254 lines
8.2 KiB
Python
"""FastAPI application for the media library agent."""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel, Field, validator
|
|
|
|
from agent.agent import Agent
|
|
from agent.config import settings
|
|
from agent.llm.deepseek import DeepSeekClient
|
|
from agent.llm.exceptions import LLMAPIError, LLMConfigurationError
|
|
from agent.llm.ollama import OllamaClient
|
|
from infrastructure.persistence import get_memory, init_memory
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(
|
|
title="Agent Media API",
|
|
description="AI agent for managing a local media library",
|
|
version="0.2.0",
|
|
)
|
|
|
|
# Initialize memory context at startup
|
|
init_memory(storage_dir="memory_data")
|
|
logger.info("Memory context initialized")
|
|
|
|
# Initialize LLM based on environment variable
|
|
llm_provider = os.getenv("LLM_PROVIDER", "deepseek").lower()
|
|
|
|
try:
|
|
if llm_provider == "ollama":
|
|
logger.info("Using Ollama LLM")
|
|
llm = OllamaClient()
|
|
else:
|
|
logger.info("Using DeepSeek LLM")
|
|
llm = DeepSeekClient()
|
|
except LLMConfigurationError as e:
|
|
logger.error(f"Failed to initialize LLM: {e}")
|
|
raise
|
|
|
|
# Initialize agent
|
|
agent = Agent(llm=llm, max_tool_iterations=settings.max_tool_iterations)
|
|
logger.info("Agent Media API initialized")
|
|
|
|
|
|
# Pydantic models for request validation
|
|
class ChatMessage(BaseModel):
|
|
"""A single message in the conversation."""
|
|
|
|
role: str = Field(..., description="Role of the message sender")
|
|
content: str | None = Field(None, description="Content of the message")
|
|
|
|
@validator("content")
|
|
def content_must_not_be_empty_for_user(cls, v, values):
|
|
"""Validate that user messages have non-empty content."""
|
|
if values.get("role") == "user" and not v:
|
|
raise ValueError("User messages must have non-empty content")
|
|
return v
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""Request body for chat completions."""
|
|
|
|
model: str = Field(default="agent-media", description="Model to use")
|
|
messages: list[ChatMessage] = Field(..., description="List of messages")
|
|
stream: bool = Field(default=False, description="Whether to stream the response")
|
|
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
|
max_tokens: int | None = Field(default=None, gt=0)
|
|
|
|
@validator("messages")
|
|
def messages_must_have_user_message(cls, v):
|
|
"""Validate that there is at least one user message."""
|
|
if not any(msg.role == "user" for msg in v):
|
|
raise ValueError("At least one user message is required")
|
|
return v
|
|
|
|
|
|
def extract_last_user_content(messages: list[dict[str, Any]]) -> str:
|
|
"""
|
|
Extract the last user message from the conversation.
|
|
|
|
Args:
|
|
messages: List of message dictionaries.
|
|
|
|
Returns:
|
|
Content of the last user message, or empty string.
|
|
"""
|
|
for m in reversed(messages):
|
|
if m.get("role") == "user":
|
|
return m.get("content") or ""
|
|
return ""
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {"status": "healthy", "version": "0.2.0"}
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models():
|
|
"""List available models (OpenAI-compatible endpoint)."""
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": "agent-media",
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "local",
|
|
}
|
|
],
|
|
}
|
|
|
|
|
|
@app.get("/memory/state")
|
|
async def get_memory_state():
|
|
"""Debug endpoint to view full memory state."""
|
|
memory = get_memory()
|
|
return memory.get_full_state()
|
|
|
|
|
|
@app.get("/memory/episodic/search-results")
|
|
async def get_search_results():
|
|
"""Debug endpoint to view last search results."""
|
|
memory = get_memory()
|
|
if memory.episodic.last_search_results:
|
|
return {
|
|
"status": "ok",
|
|
"query": memory.episodic.last_search_results.get("query"),
|
|
"type": memory.episodic.last_search_results.get("type"),
|
|
"timestamp": memory.episodic.last_search_results.get("timestamp"),
|
|
"result_count": len(memory.episodic.last_search_results.get("results", [])),
|
|
"results": memory.episodic.last_search_results.get("results", []),
|
|
}
|
|
return {"status": "empty", "message": "No search results in episodic memory"}
|
|
|
|
|
|
@app.post("/memory/clear-session")
|
|
async def clear_session():
|
|
"""Clear session memories (STM + Episodic)."""
|
|
memory = get_memory()
|
|
memory.clear_session()
|
|
return {"status": "ok", "message": "Session memories cleared"}
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(chat_request: ChatCompletionRequest):
|
|
"""
|
|
OpenAI-compatible chat completions endpoint.
|
|
|
|
Accepts messages and returns agent response.
|
|
Supports both streaming and non-streaming modes.
|
|
"""
|
|
# Convert Pydantic models to dicts for processing
|
|
messages_dict = [msg.dict() for msg in chat_request.messages]
|
|
|
|
user_input = extract_last_user_content(messages_dict)
|
|
|
|
logger.info(
|
|
f"Chat request - stream={chat_request.stream}, input_length={len(user_input)}"
|
|
)
|
|
|
|
created_ts = int(time.time())
|
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
|
|
|
if not chat_request.stream:
|
|
try:
|
|
answer = agent.step(user_input)
|
|
except LLMAPIError as e:
|
|
logger.error(f"LLM API error: {e}")
|
|
raise HTTPException(status_code=502, detail=f"LLM API error: {e}") from e
|
|
except Exception as e:
|
|
logger.error(f"Agent error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Internal agent error") from e
|
|
|
|
return JSONResponse(
|
|
{
|
|
"id": completion_id,
|
|
"object": "chat.completion",
|
|
"created": created_ts,
|
|
"model": chat_request.model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"finish_reason": "stop",
|
|
"message": {"role": "assistant", "content": answer or ""},
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0,
|
|
},
|
|
}
|
|
)
|
|
|
|
async def event_generator():
|
|
try:
|
|
# Stream the agent execution
|
|
async for chunk in agent.step_streaming(
|
|
user_input, completion_id, created_ts, chat_request.model
|
|
):
|
|
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
except LLMAPIError as e:
|
|
logger.error(f"LLM API error: {e}")
|
|
error_chunk = {
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created_ts,
|
|
"model": chat_request.model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": f"Error: {e}"},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
}
|
|
yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
except Exception as e:
|
|
logger.error(f"Agent error: {e}", exc_info=True)
|
|
error_chunk = {
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created_ts,
|
|
"model": chat_request.model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {
|
|
"role": "assistant",
|
|
"content": "Internal agent error",
|
|
},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
}
|
|
yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|