Cleaned and improved
This commit is contained in:
16
.env.example
Normal file
16
.env.example
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# DeepSeek LLM Configuration
|
||||||
|
DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
||||||
|
DEEPSEEK_BASE_URL=https://api.deepseek.com
|
||||||
|
DEEPSEEK_MODEL=deepseek-chat
|
||||||
|
TEMPERATURE=0.2
|
||||||
|
|
||||||
|
# TMDB API Configuration
|
||||||
|
TMDB_API_KEY=your_tmdb_api_key_here
|
||||||
|
TMDB_BASE_URL=https://api.themoviedb.org/3
|
||||||
|
|
||||||
|
# Storage Configuration
|
||||||
|
MEMORY_FILE=memory.json
|
||||||
|
|
||||||
|
# Security Configuration
|
||||||
|
MAX_TOOL_ITERATIONS=5
|
||||||
|
REQUEST_TIMEOUT=30
|
||||||
43
.gitignore
vendored
43
.gitignore
vendored
@@ -0,0 +1,43 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Qodo
|
||||||
|
./qodo
|
||||||
|
|
||||||
|
# Memory and state files
|
||||||
|
memory.json
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|||||||
178
agent/agent.py
178
agent/agent.py
@@ -4,39 +4,17 @@ import json
|
|||||||
|
|
||||||
from .llm import DeepSeekClient
|
from .llm import DeepSeekClient
|
||||||
from .memory import Memory
|
from .memory import Memory
|
||||||
from .tools import make_tools, Tool
|
from .registry import make_tools, Tool
|
||||||
|
from .prompts import PromptBuilder
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, llm: DeepSeekClient, memory: Memory):
|
def __init__(self, llm: DeepSeekClient, memory: Memory, max_tool_iterations: int = 5):
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.tools: Dict[str, Tool] = make_tools(memory)
|
self.tools: Dict[str, Tool] = make_tools(memory)
|
||||||
|
self.prompt_builder = PromptBuilder(self.tools)
|
||||||
|
self.max_tool_iterations = max_tool_iterations
|
||||||
|
|
||||||
def _build_system_prompt(self) -> str:
|
|
||||||
ctx = {"project_root": self.memory.get_project_root()}
|
|
||||||
tools_desc = "\n".join(
|
|
||||||
f"- {t.name}: {t.description}" for t in self.tools.values()
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
"Tu es un agent IA qui aide un développeur senior à gérer son projet local.\n"
|
|
||||||
"Tu peux demander des informations de base (comme le chemin du projet)\n"
|
|
||||||
"et tu peux utiliser des outils pour interagir avec le système de fichiers.\n\n"
|
|
||||||
"Contexte utilisateur (JSON):\n"
|
|
||||||
f"{json.dumps(ctx, ensure_ascii=False)}\n\n"
|
|
||||||
"Règles IMPORTANTES pour les outils:\n"
|
|
||||||
"1. Si tu ne connais pas la valeur d'un argument (par exemple project_root), "
|
|
||||||
"TU NE DOIS PAS mettre null ou une valeur inventée.\n"
|
|
||||||
" À la place, tu dois poser une question à l'utilisateur pour obtenir l'information.\n"
|
|
||||||
"2. Ne propose set_user_profile QUE lorsque l'utilisateur a donné un chemin de projet explicite.\n"
|
|
||||||
"3. Quand tu veux utiliser un outil, réponds STRICTEMENT avec un JSON de la forme :\n"
|
|
||||||
'{ "thought": "...", "action": { "name": "tool_name", "args": { ... } } }\n'
|
|
||||||
" - Pas de texte avant/après.\n"
|
|
||||||
" - Les args doivent être COMPLETS et NON nuls.\n"
|
|
||||||
"4. Quand JE (le système) te fournis un object JSON contenant 'tool_result' et 'intent', "
|
|
||||||
"tu dois ALORS répondre à l'utilisateur en TEXTE NATUREL, et NE PAS renvoyer de JSON d'action.\n\n"
|
|
||||||
"Tools disponibles:\n"
|
|
||||||
f"{tools_desc}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_intent(self, text: str) -> Dict[str, Any] | None:
|
def _parse_intent(self, text: str) -> Dict[str, Any] | None:
|
||||||
try:
|
try:
|
||||||
@@ -62,18 +40,6 @@ class Agent:
|
|||||||
name: str = action["name"]
|
name: str = action["name"]
|
||||||
args: Dict[str, Any] = action.get("args", {}) or {}
|
args: Dict[str, Any] = action.get("args", {}) or {}
|
||||||
|
|
||||||
if name == "set_user_profile":
|
|
||||||
project_root = args.get("project_root")
|
|
||||||
if not project_root:
|
|
||||||
return {
|
|
||||||
"error": "missing_project_root",
|
|
||||||
"message": (
|
|
||||||
"Le modèle a demandé set_user_profile sans project_root. "
|
|
||||||
"Tu dois d'abord demander à l'utilisateur de fournir un chemin "
|
|
||||||
"de projet valide (ex: /home/francois/mon_projet)."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
tool = self.tools.get(name)
|
tool = self.tools.get(name)
|
||||||
if not tool:
|
if not tool:
|
||||||
return {"error": "unknown_tool", "tool": name}
|
return {"error": "unknown_tool", "tool": name}
|
||||||
@@ -87,95 +53,77 @@ class Agent:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def step(self, user_input: str) -> str:
|
def step(self, user_input: str) -> str:
|
||||||
|
"""
|
||||||
|
Execute one agent step with iterative tool execution:
|
||||||
|
- Build system prompt
|
||||||
|
- Query LLM
|
||||||
|
- Loop: If JSON intent -> execute tool, add result to conversation, query LLM again
|
||||||
|
- Continue until LLM responds with text (no tool call) or max iterations reached
|
||||||
|
- Return final text response
|
||||||
|
"""
|
||||||
print("Starting a new step...")
|
print("Starting a new step...")
|
||||||
"""
|
|
||||||
Un 'tour' d'agent :
|
|
||||||
- construit le prompt system
|
|
||||||
- interroge DeepSeek
|
|
||||||
- si JSON d'intent -> exécute tool, refait un appel, renvoie réponse finale
|
|
||||||
- sinon -> renvoie texte brut
|
|
||||||
"""
|
|
||||||
print("User input:", user_input)
|
print("User input:", user_input)
|
||||||
root = self.memory.data.get("project_root")
|
|
||||||
print("Current project_root in memory:", root)
|
|
||||||
|
|
||||||
# Unified system prompt that always allows tools
|
print("Current memory state:", self.memory.data)
|
||||||
tools_desc = "\n".join(
|
|
||||||
f"- {t.name}: {t.description}\n Paramètres: {json.dumps(t.parameters, ensure_ascii=False)}"
|
|
||||||
for t in self.tools.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
if root is None:
|
# Build system prompt using PromptBuilder
|
||||||
print("No project_root set - asking user and allowing tool use")
|
system_prompt = self.prompt_builder.build_system_prompt(self.memory.data)
|
||||||
system_prompt = (
|
|
||||||
"Tu es un agent IA qui aide un développeur à gérer son projet local.\n\n"
|
|
||||||
"CONTEXTE ACTUEL:\n"
|
|
||||||
f"- project_root: {root} (NON DÉFINI)\n\n"
|
|
||||||
"RÈGLES IMPORTANTES:\n"
|
|
||||||
"1. Le project_root n'est pas encore défini. Tu DOIS d'abord demander à l'utilisateur "
|
|
||||||
"le chemin absolu de son projet (ex: /home/user/mon_projet).\n"
|
|
||||||
"2. Quand l'utilisateur te donne un chemin, tu DOIS immédiatement utiliser l'outil "
|
|
||||||
"'set_project_root' pour le sauvegarder.\n"
|
|
||||||
"3. Pour utiliser un outil, réponds STRICTEMENT avec ce format JSON:\n"
|
|
||||||
' { "thought": "explication", "action": { "name": "nom_outil", "args": { "arg": "valeur" } } }\n'
|
|
||||||
"4. Si tu réponds en texte (pas d'outil), réponds normalement en français.\n"
|
|
||||||
"5. Quand le système te renvoie un tool_result, réponds à l'utilisateur en TEXTE NATUREL.\n\n"
|
|
||||||
"OUTILS DISPONIBLES:\n"
|
|
||||||
f"{tools_desc}\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print("Project_root is set - normal operation mode")
|
|
||||||
system_prompt = (
|
|
||||||
"Tu es un agent IA qui aide un développeur à gérer son projet local.\n\n"
|
|
||||||
"CONTEXTE ACTUEL:\n"
|
|
||||||
f"- project_root: {root}\n\n"
|
|
||||||
"RÈGLES IMPORTANTES:\n"
|
|
||||||
"1. Le project_root est défini. Tu peux utiliser les outils disponibles.\n"
|
|
||||||
"2. Pour utiliser un outil, réponds STRICTEMENT avec ce format JSON:\n"
|
|
||||||
' { "thought": "explication", "action": { "name": "nom_outil", "args": { "param": "valeur" } } }\n'
|
|
||||||
" EXEMPLE pour lister le dossier 'src':\n"
|
|
||||||
' { "thought": "L\'utilisateur veut voir le contenu de src", "action": { "name": "list_directory", "args": { "path": "src" } } }\n'
|
|
||||||
" EXEMPLE pour lister la racine du projet:\n"
|
|
||||||
' { "thought": "L\'utilisateur veut voir la racine", "action": { "name": "list_directory", "args": { "path": "." } } }\n'
|
|
||||||
"3. Si tu réponds en texte (pas d'outil), réponds normalement en français.\n"
|
|
||||||
"4. Quand le système te renvoie un tool_result, réponds à l'utilisateur en TEXTE NATUREL.\n"
|
|
||||||
"5. IMPORTANT: Extrais le chemin demandé par l'utilisateur et passe-le comme argument 'path'.\n\n"
|
|
||||||
"OUTILS DISPONIBLES:\n"
|
|
||||||
f"{tools_desc}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Initialize conversation with user input
|
||||||
messages: List[Dict[str, Any]] = [
|
messages: List[Dict[str, Any]] = [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": user_input},
|
{"role": "user", "content": user_input},
|
||||||
]
|
]
|
||||||
|
|
||||||
raw_first = self.llm.complete(messages)
|
# Tool execution loop
|
||||||
intent = self._parse_intent(raw_first)
|
iteration = 0
|
||||||
print("raw_first:", raw_first)
|
while iteration < self.max_tool_iterations:
|
||||||
print("Intent:", intent)
|
print(f"\n--- Iteration {iteration + 1} ---")
|
||||||
if not intent:
|
|
||||||
# Réponse texte simple
|
|
||||||
#self.memory.append_history("user", user_input)
|
|
||||||
#self.memory.append_history("assistant", raw_first)
|
|
||||||
return raw_first
|
|
||||||
|
|
||||||
# Exécuter l'action
|
# Get LLM response
|
||||||
tool_result = self._execute_action(intent)
|
llm_response = self.llm.complete(messages)
|
||||||
|
print("LLM response:", llm_response)
|
||||||
|
|
||||||
followup_messages: List[Dict[str, Any]] = [
|
# Try to parse as tool intent
|
||||||
{"role": "system", "content": system_prompt},
|
intent = self._parse_intent(llm_response)
|
||||||
{"role": "user", "content": user_input},
|
|
||||||
{
|
if not intent:
|
||||||
|
# No tool call - this is the final text response
|
||||||
|
print("No tool intent detected, returning final response")
|
||||||
|
# Save to history
|
||||||
|
self.memory.append_history("user", user_input)
|
||||||
|
self.memory.append_history("assistant", llm_response)
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
# Tool call detected - execute it
|
||||||
|
print("Intent detected:", intent)
|
||||||
|
tool_result = self._execute_action(intent)
|
||||||
|
print("Tool result:", tool_result)
|
||||||
|
|
||||||
|
# Add assistant's tool call and result to conversation
|
||||||
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
"content": json.dumps(intent, ensure_ascii=False)
|
||||||
|
})
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
"content": json.dumps(
|
"content": json.dumps(
|
||||||
{"tool_result": tool_result, "intent": intent},
|
{"tool_result": tool_result},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False
|
||||||
),
|
)
|
||||||
},
|
})
|
||||||
]
|
|
||||||
|
|
||||||
raw_second = self.llm.complete(followup_messages)
|
iteration += 1
|
||||||
|
|
||||||
#self.memory.append_history("user", user_input)
|
# Max iterations reached - ask LLM for final response
|
||||||
#self.memory.append_history("assistant", raw_second)
|
print(f"\n--- Max iterations ({self.max_tool_iterations}) reached, requesting final response ---")
|
||||||
return raw_second
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": "Merci pour ces résultats. Peux-tu maintenant me donner une réponse finale en texte naturel ?"
|
||||||
|
})
|
||||||
|
|
||||||
|
final_response = self.llm.complete(messages)
|
||||||
|
# Save to history
|
||||||
|
self.memory.append_history("user", user_input)
|
||||||
|
self.memory.append_history("assistant", final_response)
|
||||||
|
return final_response
|
||||||
|
|||||||
20
agent/api/__init__.py
Normal file
20
agent/api/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""API clients module."""
|
||||||
|
from .themoviedb import (
|
||||||
|
TMDBClient,
|
||||||
|
tmdb_client,
|
||||||
|
TMDBError,
|
||||||
|
TMDBConfigurationError,
|
||||||
|
TMDBAPIError,
|
||||||
|
TMDBNotFoundError,
|
||||||
|
MediaResult
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'TMDBClient',
|
||||||
|
'tmdb_client',
|
||||||
|
'TMDBError',
|
||||||
|
'TMDBConfigurationError',
|
||||||
|
'TMDBAPIError',
|
||||||
|
'TMDBNotFoundError',
|
||||||
|
'MediaResult'
|
||||||
|
]
|
||||||
317
agent/api/themoviedb.py
Normal file
317
agent/api/themoviedb.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
"""TMDB (The Movie Database) API client."""
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import RequestException, Timeout, HTTPError
|
||||||
|
|
||||||
|
from ..config import Settings, settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TMDBError(Exception):
|
||||||
|
"""Base exception for TMDB-related errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TMDBConfigurationError(TMDBError):
|
||||||
|
"""Raised when TMDB API is not properly configured."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TMDBAPIError(TMDBError):
|
||||||
|
"""Raised when TMDB API returns an error."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TMDBNotFoundError(TMDBError):
|
||||||
|
"""Raised when media is not found."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MediaResult:
|
||||||
|
"""Represents a media search result from TMDB."""
|
||||||
|
tmdb_id: int
|
||||||
|
title: str
|
||||||
|
media_type: str # 'movie' or 'tv'
|
||||||
|
imdb_id: Optional[str] = None
|
||||||
|
overview: Optional[str] = None
|
||||||
|
release_date: Optional[str] = None
|
||||||
|
poster_path: Optional[str] = None
|
||||||
|
vote_average: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TMDBClient:
|
||||||
|
"""
|
||||||
|
Client for interacting with The Movie Database (TMDB) API.
|
||||||
|
|
||||||
|
This client provides methods to search for movies and TV shows,
|
||||||
|
retrieve their details, and get external IDs (like IMDb).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = TMDBClient()
|
||||||
|
>>> result = client.search_media("Inception")
|
||||||
|
>>> print(result.imdb_id)
|
||||||
|
'tt1375666'
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
config: Optional[Settings] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize TMDB client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: TMDB API key (defaults to settings)
|
||||||
|
base_url: TMDB API base URL (defaults to settings)
|
||||||
|
timeout: Request timeout in seconds (defaults to settings)
|
||||||
|
config: Optional Settings instance (for testing)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TMDBConfigurationError: If API key is missing
|
||||||
|
"""
|
||||||
|
cfg = config or settings
|
||||||
|
|
||||||
|
self.api_key = api_key or cfg.tmdb_api_key
|
||||||
|
self.base_url = base_url or cfg.tmdb_base_url
|
||||||
|
self.timeout = timeout or cfg.request_timeout
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise TMDBConfigurationError(
|
||||||
|
"TMDB API key is required. Set TMDB_API_KEY environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.base_url:
|
||||||
|
raise TMDBConfigurationError(
|
||||||
|
"TMDB base URL is required. Set TMDB_BASE_URL environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("TMDB client initialized")
|
||||||
|
|
||||||
|
def _make_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Make a request to TMDB API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: API endpoint (e.g., '/search/multi')
|
||||||
|
params: Query parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON response as dict
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TMDBAPIError: If request fails
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
|
||||||
|
# Add API key to params
|
||||||
|
request_params = params or {}
|
||||||
|
request_params['api_key'] = self.api_key
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"TMDB request: {endpoint}")
|
||||||
|
response = requests.get(url, params=request_params, timeout=self.timeout)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except Timeout as e:
|
||||||
|
logger.error(f"TMDB API timeout: {e}")
|
||||||
|
raise TMDBAPIError(f"Request timeout after {self.timeout} seconds") from e
|
||||||
|
|
||||||
|
except HTTPError as e:
|
||||||
|
logger.error(f"TMDB API HTTP error: {e}")
|
||||||
|
if e.response is not None:
|
||||||
|
status_code = e.response.status_code
|
||||||
|
if status_code == 401:
|
||||||
|
raise TMDBAPIError("Invalid TMDB API key") from e
|
||||||
|
elif status_code == 404:
|
||||||
|
raise TMDBNotFoundError("Resource not found") from e
|
||||||
|
else:
|
||||||
|
raise TMDBAPIError(f"HTTP {status_code}: {e}") from e
|
||||||
|
raise TMDBAPIError(f"HTTP error: {e}") from e
|
||||||
|
|
||||||
|
except RequestException as e:
|
||||||
|
logger.error(f"TMDB API request failed: {e}")
|
||||||
|
raise TMDBAPIError(f"Failed to connect to TMDB API: {e}") from e
|
||||||
|
|
||||||
|
def search_multi(self, query: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Search for movies and TV shows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query (movie or TV show title)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TMDBAPIError: If request fails
|
||||||
|
TMDBNotFoundError: If no results found
|
||||||
|
"""
|
||||||
|
if not query or not isinstance(query, str):
|
||||||
|
raise ValueError("Query must be a non-empty string")
|
||||||
|
|
||||||
|
if len(query) > 500:
|
||||||
|
raise ValueError("Query is too long (max 500 characters)")
|
||||||
|
|
||||||
|
data = self._make_request('/search/multi', {'query': query})
|
||||||
|
|
||||||
|
results = data.get('results', [])
|
||||||
|
if not results:
|
||||||
|
raise TMDBNotFoundError(f"No results found for '{query}'")
|
||||||
|
|
||||||
|
logger.info(f"Found {len(results)} results for '{query}'")
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_external_ids(self, media_type: str, tmdb_id: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get external IDs (IMDb, TVDB, etc.) for a media item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media_type: Type of media ('movie' or 'tv')
|
||||||
|
tmdb_id: TMDB ID of the media
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with external IDs
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TMDBAPIError: If request fails
|
||||||
|
"""
|
||||||
|
if media_type not in ('movie', 'tv'):
|
||||||
|
raise ValueError(f"Invalid media_type: {media_type}. Must be 'movie' or 'tv'")
|
||||||
|
|
||||||
|
endpoint = f"/{media_type}/{tmdb_id}/external_ids"
|
||||||
|
return self._make_request(endpoint)
|
||||||
|
|
||||||
|
def search_media(self, title: str) -> MediaResult:
|
||||||
|
"""
|
||||||
|
Search for a media item and return detailed information including IMDb ID.
|
||||||
|
|
||||||
|
This is a convenience method that combines search and external ID lookup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: Title of the movie or TV show
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MediaResult with all available information
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TMDBAPIError: If request fails
|
||||||
|
TMDBNotFoundError: If media not found
|
||||||
|
"""
|
||||||
|
# Search for media
|
||||||
|
results = self.search_multi(title)
|
||||||
|
|
||||||
|
# Get the first (most relevant) result
|
||||||
|
top_result = results[0]
|
||||||
|
|
||||||
|
# Validate result structure
|
||||||
|
if 'id' not in top_result or 'media_type' not in top_result:
|
||||||
|
raise TMDBAPIError("Invalid TMDB response structure")
|
||||||
|
|
||||||
|
tmdb_id = top_result['id']
|
||||||
|
media_type = top_result['media_type']
|
||||||
|
|
||||||
|
# Skip if not movie or TV show
|
||||||
|
if media_type not in ('movie', 'tv'):
|
||||||
|
logger.warning(f"Skipping result of type: {media_type}")
|
||||||
|
if len(results) > 1:
|
||||||
|
# Try next result
|
||||||
|
return self._parse_result(results[1])
|
||||||
|
raise TMDBNotFoundError(f"No movie or TV show found for '{title}'")
|
||||||
|
|
||||||
|
return self._parse_result(top_result)
|
||||||
|
|
||||||
|
def _parse_result(self, result: Dict[str, Any]) -> MediaResult:
|
||||||
|
"""
|
||||||
|
Parse a TMDB result into a MediaResult object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Raw TMDB result dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MediaResult object
|
||||||
|
"""
|
||||||
|
tmdb_id = result['id']
|
||||||
|
media_type = result['media_type']
|
||||||
|
title = result.get('title') or result.get('name', 'Unknown')
|
||||||
|
|
||||||
|
# Get external IDs (including IMDb)
|
||||||
|
try:
|
||||||
|
external_ids = self.get_external_ids(media_type, tmdb_id)
|
||||||
|
imdb_id = external_ids.get('imdb_id')
|
||||||
|
except TMDBAPIError as e:
|
||||||
|
logger.warning(f"Failed to get external IDs: {e}")
|
||||||
|
imdb_id = None
|
||||||
|
|
||||||
|
# Extract other useful information
|
||||||
|
overview = result.get('overview')
|
||||||
|
release_date = result.get('release_date') or result.get('first_air_date')
|
||||||
|
poster_path = result.get('poster_path')
|
||||||
|
vote_average = result.get('vote_average')
|
||||||
|
|
||||||
|
logger.info(f"Found: {title} (Type: {media_type}, TMDB ID: {tmdb_id}, IMDb: {imdb_id})")
|
||||||
|
|
||||||
|
return MediaResult(
|
||||||
|
tmdb_id=tmdb_id,
|
||||||
|
title=title,
|
||||||
|
media_type=media_type,
|
||||||
|
imdb_id=imdb_id,
|
||||||
|
overview=overview,
|
||||||
|
release_date=release_date,
|
||||||
|
poster_path=poster_path,
|
||||||
|
vote_average=vote_average
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_movie_details(self, movie_id: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get detailed information about a movie.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
movie_id: TMDB movie ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with movie details
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TMDBAPIError: If request fails
|
||||||
|
"""
|
||||||
|
return self._make_request(f'/movie/{movie_id}')
|
||||||
|
|
||||||
|
def get_tv_details(self, tv_id: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get detailed information about a TV show.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tv_id: TMDB TV show ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with TV show details
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TMDBAPIError: If request fails
|
||||||
|
"""
|
||||||
|
return self._make_request(f'/tv/{tv_id}')
|
||||||
|
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if TMDB client is properly configured.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if configured, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(self.api_key and self.base_url)
|
||||||
|
|
||||||
|
|
||||||
|
# Global TMDB client instance (singleton)
|
||||||
|
tmdb_client = TMDBClient()
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
# agent/commands.py
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, Dict, List
|
|
||||||
import os
|
|
||||||
|
|
||||||
from .memory import Memory
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Command:
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
needs_project_root: bool
|
|
||||||
handler: Callable[[List[str], Memory], str]
|
|
||||||
|
|
||||||
|
|
||||||
class CommandRegistry:
|
|
||||||
def __init__(self, memory: Memory):
|
|
||||||
self.memory = memory
|
|
||||||
self.commands: Dict[str, Command] = {}
|
|
||||||
self._register_defaults()
|
|
||||||
|
|
||||||
def _register_defaults(self):
|
|
||||||
# /setroot <path>
|
|
||||||
def cmd_setroot(args: List[str], mem: Memory) -> str:
|
|
||||||
if not args:
|
|
||||||
return "Usage: `/setroot <chemin_absolu_vers_projet>`"
|
|
||||||
|
|
||||||
path = args[0]
|
|
||||||
if not os.path.isdir(path):
|
|
||||||
return f"Le chemin `{path}` n'existe pas ou n'est pas un dossier."
|
|
||||||
|
|
||||||
mem.set_project_root(path)
|
|
||||||
return f"✅ Chemin du projet défini sur `{path}`."
|
|
||||||
|
|
||||||
self.register(
|
|
||||||
Command(
|
|
||||||
name="setroot",
|
|
||||||
description="Définit le chemin racine du projet.",
|
|
||||||
needs_project_root=False,
|
|
||||||
handler=cmd_setroot,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# /scan [rel_path]
|
|
||||||
def cmd_scan(args: List[str], mem: Memory) -> str:
|
|
||||||
root = mem.get_project_root()
|
|
||||||
rel = args[0] if args else "."
|
|
||||||
full = os.path.abspath(os.path.join(root, rel))
|
|
||||||
|
|
||||||
# sécurité basique
|
|
||||||
if not full.startswith(os.path.abspath(root)):
|
|
||||||
return "❌ Tu ne peux pas sortir du project_root."
|
|
||||||
|
|
||||||
if not os.path.isdir(full):
|
|
||||||
return f"❌ `{rel}` n'est pas un dossier dans le projet."
|
|
||||||
|
|
||||||
entries = sorted(os.listdir(full))
|
|
||||||
if not entries:
|
|
||||||
return f"📁 `{rel}` est vide."
|
|
||||||
|
|
||||||
lines = [f"📁 Scan de `{rel}` (dans `{root}`):"]
|
|
||||||
for e in entries:
|
|
||||||
p = os.path.join(full, e)
|
|
||||||
if os.path.isdir(p):
|
|
||||||
lines.append(f" 📂 {e}/")
|
|
||||||
else:
|
|
||||||
lines.append(f" 📄 {e}")
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
self.register(
|
|
||||||
Command(
|
|
||||||
name="scan",
|
|
||||||
description="Liste les fichiers/dossiers du projet.",
|
|
||||||
needs_project_root=True,
|
|
||||||
handler=cmd_scan,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def register(self, cmd: Command):
|
|
||||||
self.commands[cmd.name] = cmd
|
|
||||||
|
|
||||||
def handle(self, raw_input: str) -> str:
|
|
||||||
"""
|
|
||||||
Parse et exécute une commande du type `/scan src` ou `/setroot /chemin`.
|
|
||||||
"""
|
|
||||||
text = raw_input.strip()
|
|
||||||
if not text.startswith("/"):
|
|
||||||
return "Internal error: not a command."
|
|
||||||
|
|
||||||
parts = text[1:].split()
|
|
||||||
if not parts:
|
|
||||||
return "Commande vide."
|
|
||||||
|
|
||||||
name = parts[0]
|
|
||||||
args = parts[1:]
|
|
||||||
|
|
||||||
cmd = self.commands.get(name)
|
|
||||||
if not cmd:
|
|
||||||
return f"Commande inconnue: `/{name}`.\nCommandes disponibles: {', '.join('/'+n for n in self.commands.keys())}"
|
|
||||||
|
|
||||||
# dépendance project_root
|
|
||||||
if cmd.needs_project_root and not self.memory.get_project_root():
|
|
||||||
return (
|
|
||||||
"❗ Aucun `project_root` défini pour l'instant.\n"
|
|
||||||
"Commence par le définir avec:\n"
|
|
||||||
"`/setroot /chemin/vers/ton/projet`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cmd.handler(args, self.memory)
|
|
||||||
@@ -1,15 +1,78 @@
|
|||||||
from dataclasses import dataclass
|
"""Configuration management with validation."""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurationError(Exception):
|
||||||
|
"""Raised when configuration is invalid."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Settings:
|
class Settings:
|
||||||
deepseek_api_key: str = os.getenv("DEEPSEEK_API_KEY", "")
|
"""Application settings loaded from environment variables."""
|
||||||
deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
|
||||||
model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
|
|
||||||
temperature: float = float(os.getenv("TEMPERATURE", "0.2"))
|
|
||||||
memory_file: str = os.getenv("MEMORY_FILE", "memory.json")
|
|
||||||
|
|
||||||
|
# LLM Configuration
|
||||||
|
deepseek_api_key: str = field(default_factory=lambda: os.getenv("DEEPSEEK_API_KEY", ""))
|
||||||
|
deepseek_base_url: str = field(default_factory=lambda: os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com"))
|
||||||
|
model: str = field(default_factory=lambda: os.getenv("DEEPSEEK_MODEL", "deepseek-chat"))
|
||||||
|
temperature: float = field(default_factory=lambda: float(os.getenv("TEMPERATURE", "0.2")))
|
||||||
|
|
||||||
|
# TMDB Configuration
|
||||||
|
tmdb_api_key: str = field(default_factory=lambda: os.getenv("TMDB_API_KEY", ""))
|
||||||
|
tmdb_base_url: str = field(default_factory=lambda: os.getenv("TMDB_BASE_URL", "https://api.themoviedb.org/3"))
|
||||||
|
|
||||||
|
# Storage Configuration
|
||||||
|
memory_file: str = field(default_factory=lambda: os.getenv("MEMORY_FILE", "memory.json"))
|
||||||
|
|
||||||
|
# Security Configuration
|
||||||
|
max_tool_iterations: int = field(default_factory=lambda: int(os.getenv("MAX_TOOL_ITERATIONS", "5")))
|
||||||
|
request_timeout: int = field(default_factory=lambda: int(os.getenv("REQUEST_TIMEOUT", "30")))
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate settings after initialization."""
|
||||||
|
self._validate()
|
||||||
|
|
||||||
|
def _validate(self) -> None:
|
||||||
|
"""Validate configuration values."""
|
||||||
|
# Validate temperature
|
||||||
|
if not 0.0 <= self.temperature <= 2.0:
|
||||||
|
raise ConfigurationError(f"Temperature must be between 0.0 and 2.0, got {self.temperature}")
|
||||||
|
|
||||||
|
# Validate max_tool_iterations
|
||||||
|
if self.max_tool_iterations < 1 or self.max_tool_iterations > 20:
|
||||||
|
raise ConfigurationError(f"max_tool_iterations must be between 1 and 20, got {self.max_tool_iterations}")
|
||||||
|
|
||||||
|
# Validate request_timeout
|
||||||
|
if self.request_timeout < 1 or self.request_timeout > 300:
|
||||||
|
raise ConfigurationError(f"request_timeout must be between 1 and 300 seconds, got {self.request_timeout}")
|
||||||
|
|
||||||
|
# Validate URLs
|
||||||
|
if not self.deepseek_base_url.startswith(("http://", "https://")):
|
||||||
|
raise ConfigurationError(f"Invalid deepseek_base_url: {self.deepseek_base_url}")
|
||||||
|
|
||||||
|
if not self.tmdb_base_url.startswith(("http://", "https://")):
|
||||||
|
raise ConfigurationError(f"Invalid tmdb_base_url: {self.tmdb_base_url}")
|
||||||
|
|
||||||
|
# Validate memory file path
|
||||||
|
memory_path = Path(self.memory_file)
|
||||||
|
if memory_path.exists() and not memory_path.is_file():
|
||||||
|
raise ConfigurationError(f"memory_file exists but is not a file: {self.memory_file}")
|
||||||
|
|
||||||
|
def is_deepseek_configured(self) -> bool:
|
||||||
|
"""Check if DeepSeek API is properly configured."""
|
||||||
|
return bool(self.deepseek_api_key and self.deepseek_base_url)
|
||||||
|
|
||||||
|
def is_tmdb_configured(self) -> bool:
|
||||||
|
"""Check if TMDB API is properly configured."""
|
||||||
|
return bool(self.tmdb_api_key and self.tmdb_base_url)
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
41
agent/llm.py
41
agent/llm.py
@@ -1,41 +0,0 @@
|
|||||||
# agent/llm.py
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from .config import settings
|
|
||||||
|
|
||||||
class DeepSeekClient:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
):
|
|
||||||
self.api_key = api_key or settings.deepseek_api_key
|
|
||||||
self.base_url = base_url or settings.deepseek_base_url
|
|
||||||
self.model = model or settings.model
|
|
||||||
|
|
||||||
def complete(self, messages: List[Dict[str, Any]]) -> str:
|
|
||||||
"""
|
|
||||||
messages: liste de dicts {role: 'system'|'user'|'assistant', content: str}
|
|
||||||
Retourne content (str) du premier choix.
|
|
||||||
"""
|
|
||||||
if not self.api_key:
|
|
||||||
return "Erreur côté agent : DEEPSEEK_API_KEY manquant dans l'environnement backend."
|
|
||||||
|
|
||||||
url = f"{self.base_url}/v1/chat/completions"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
payload = {
|
|
||||||
"model": self.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": settings.temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = requests.post(url, headers=headers, json=payload, timeout=60)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
print("DeepSeek response:", data)
|
|
||||||
return data["choices"][0]["message"]["content"]
|
|
||||||
2
agent/llm/__init__.py
Normal file
2
agent/llm/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
"""LLM client module."""
|
||||||
|
from .deepseek import DeepSeekClient
|
||||||
150
agent/llm/deepseek.py
Normal file
150
agent/llm/deepseek.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""DeepSeek LLM client with robust error handling."""
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import RequestException, Timeout, HTTPError
|
||||||
|
|
||||||
|
from ..config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMError(Exception):
|
||||||
|
"""Base exception for LLM-related errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigurationError(LLMError):
|
||||||
|
"""Raised when LLM is not properly configured."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LLMAPIError(LLMError):
|
||||||
|
"""Raised when LLM API returns an error."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekClient:
|
||||||
|
"""Client for interacting with DeepSeek API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize DeepSeek client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication (defaults to settings)
|
||||||
|
base_url: Base URL for API (defaults to settings)
|
||||||
|
model: Model name to use (defaults to settings)
|
||||||
|
timeout: Request timeout in seconds (defaults to settings)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LLMConfigurationError: If API key is missing
|
||||||
|
"""
|
||||||
|
self.api_key = api_key or settings.deepseek_api_key
|
||||||
|
self.base_url = base_url or settings.deepseek_base_url
|
||||||
|
self.model = model or settings.model
|
||||||
|
self.timeout = timeout or settings.request_timeout
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise LLMConfigurationError(
|
||||||
|
"DeepSeek API key is required. Set DEEPSEEK_API_KEY environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.base_url:
|
||||||
|
raise LLMConfigurationError(
|
||||||
|
"DeepSeek base URL is required. Set DEEPSEEK_BASE_URL environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"DeepSeek client initialized with model: {self.model}")
|
||||||
|
|
||||||
|
def complete(self, messages: List[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
Generate a completion from the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content' keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated text response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LLMAPIError: If API request fails
|
||||||
|
ValueError: If messages format is invalid
|
||||||
|
"""
|
||||||
|
# Validate messages format
|
||||||
|
if not messages:
|
||||||
|
raise ValueError("Messages list cannot be empty")
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if not isinstance(msg, dict):
|
||||||
|
raise ValueError(f"Each message must be a dict, got {type(msg)}")
|
||||||
|
if "role" not in msg or "content" not in msg:
|
||||||
|
raise ValueError(f"Each message must have 'role' and 'content' keys, got {msg.keys()}")
|
||||||
|
if msg["role"] not in ("system", "user", "assistant"):
|
||||||
|
raise ValueError(f"Invalid role: {msg['role']}")
|
||||||
|
|
||||||
|
url = f"{self.base_url}/v1/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": settings.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"Sending request to {url} with {len(messages)} messages")
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Validate response structure
|
||||||
|
if "choices" not in data or not data["choices"]:
|
||||||
|
raise LLMAPIError("Invalid API response: missing 'choices'")
|
||||||
|
|
||||||
|
if "message" not in data["choices"][0]:
|
||||||
|
raise LLMAPIError("Invalid API response: missing 'message' in choice")
|
||||||
|
|
||||||
|
if "content" not in data["choices"][0]["message"]:
|
||||||
|
raise LLMAPIError("Invalid API response: missing 'content' in message")
|
||||||
|
|
||||||
|
content = data["choices"][0]["message"]["content"]
|
||||||
|
logger.debug(f"Received response with {len(content)} characters")
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
except Timeout as e:
|
||||||
|
logger.error(f"Request timeout after {self.timeout}s: {e}")
|
||||||
|
raise LLMAPIError(f"Request timeout after {self.timeout} seconds") from e
|
||||||
|
|
||||||
|
except HTTPError as e:
|
||||||
|
logger.error(f"HTTP error from DeepSeek API: {e}")
|
||||||
|
if e.response is not None:
|
||||||
|
try:
|
||||||
|
error_data = e.response.json()
|
||||||
|
error_msg = error_data.get("error", {}).get("message", str(e))
|
||||||
|
except Exception:
|
||||||
|
error_msg = str(e)
|
||||||
|
raise LLMAPIError(f"DeepSeek API error: {error_msg}") from e
|
||||||
|
raise LLMAPIError(f"HTTP error: {e}") from e
|
||||||
|
|
||||||
|
except RequestException as e:
|
||||||
|
logger.error(f"Request failed: {e}")
|
||||||
|
raise LLMAPIError(f"Failed to connect to DeepSeek API: {e}") from e
|
||||||
|
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
logger.error(f"Failed to parse API response: {e}")
|
||||||
|
raise LLMAPIError(f"Invalid API response format: {e}") from e
|
||||||
@@ -4,21 +4,37 @@ from typing import Any, Dict
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from .config import settings
|
from .config import settings
|
||||||
|
from .parameters import validate_parameter, get_parameter_schema
|
||||||
|
|
||||||
|
|
||||||
class Memory:
|
class Memory:
|
||||||
|
"""
|
||||||
|
Generic memory storage for agent state.
|
||||||
|
|
||||||
|
Provides a simple key-value store that persists to JSON.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, path: str = "memory.json"):
|
def __init__(self, path: str = "memory.json"):
|
||||||
print("init memory")
|
|
||||||
self.file = Path(path)
|
self.file = Path(path)
|
||||||
self.data: Dict[str, Any] = {}
|
self.data: Dict[str, Any] = {}
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
|
"""Load memory from file or initialize with defaults."""
|
||||||
if self.file.exists():
|
if self.file.exists():
|
||||||
self.data = json.loads(self.file.read_text(encoding="utf-8"))
|
try:
|
||||||
|
self.data = json.loads(self.file.read_text(encoding="utf-8"))
|
||||||
|
except (json.JSONDecodeError, IOError) as e:
|
||||||
|
print(f"Warning: Could not load memory file: {e}")
|
||||||
|
self.data = {
|
||||||
|
"config": {},
|
||||||
|
"tv_shows": [],
|
||||||
|
"history": [],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
self.data = {
|
self.data = {
|
||||||
"project_root": None,
|
"config": {},
|
||||||
|
"tv_shows": [],
|
||||||
"history": [],
|
"history": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,11 +44,43 @@ class Memory:
|
|||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_project_root(self) -> str | None:
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
"""Ce qu'on injecte dans le prompt pour le LLM."""
|
"""Get a value from memory by key."""
|
||||||
return self.data.get("project_root")
|
return self.data.get(key, default)
|
||||||
|
|
||||||
def set_project_root(self, path: str) -> None:
|
def set(self, key: str, value: Any) -> None:
|
||||||
print('Setting project root in memory to:', path)
|
"""
|
||||||
self.data["project_root"] = path
|
Set a value in memory and save.
|
||||||
|
|
||||||
|
Validates the value against the parameter schema if one exists.
|
||||||
|
"""
|
||||||
|
# Validate if schema exists
|
||||||
|
is_valid, error_msg = validate_parameter(key, value)
|
||||||
|
if not is_valid:
|
||||||
|
print(f'Validation failed for {key}: {error_msg}')
|
||||||
|
raise ValueError(f"Invalid value for {key}: {error_msg}")
|
||||||
|
|
||||||
|
print(f'Setting {key} in memory to: {value}')
|
||||||
|
self.data[key] = value
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def has(self, key: str) -> bool:
|
||||||
|
"""Check if a key exists and has a non-None value."""
|
||||||
|
return key in self.data and self.data[key] is not None
|
||||||
|
|
||||||
|
def append_history(self, role: str, content: str) -> None:
|
||||||
|
"""
|
||||||
|
Append a message to conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: Message role ('user' or 'assistant')
|
||||||
|
content: Message content
|
||||||
|
"""
|
||||||
|
if "history" not in self.data:
|
||||||
|
self.data["history"] = []
|
||||||
|
|
||||||
|
self.data["history"].append({
|
||||||
|
"role": role,
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
self.save()
|
self.save()
|
||||||
|
|||||||
2
agent/models/__init__.py
Normal file
2
agent/models/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
"""Models module."""
|
||||||
|
from .tv_show import TVShow, ShowStatus, validate_tv_shows_structure
|
||||||
58
agent/models/tv_show.py
Normal file
58
agent/models/tv_show.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""TV Show models and validation."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ShowStatus(Enum):
|
||||||
|
"""Status of a TV show - whether it's still airing or has ended."""
|
||||||
|
ONGOING = "ongoing"
|
||||||
|
ENDED = "ended"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TVShow:
|
||||||
|
"""Represents a TV show."""
|
||||||
|
imdb_id: str
|
||||||
|
title: str
|
||||||
|
seasons_count: int
|
||||||
|
status: ShowStatus # ongoing or ended
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tv_shows_structure(tv_shows: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the structure of the tv_shows parameter.
|
||||||
|
|
||||||
|
Expected structure: list of TV show objects
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"imdb_id": str,
|
||||||
|
"title": str,
|
||||||
|
"seasons_count": int,
|
||||||
|
"status": str # "ongoing" or "ended"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
if not isinstance(tv_shows, list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for show in tv_shows:
|
||||||
|
if not isinstance(show, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
required_fields = {"imdb_id", "title", "seasons_count", "status"}
|
||||||
|
if not all(field in show for field in required_fields):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate field types
|
||||||
|
if not isinstance(show["imdb_id"], str):
|
||||||
|
return False
|
||||||
|
if not isinstance(show["title"], str):
|
||||||
|
return False
|
||||||
|
if not isinstance(show["seasons_count"], int):
|
||||||
|
return False
|
||||||
|
if show["status"] not in ["ongoing", "ended"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
100
agent/parameters.py
Normal file
100
agent/parameters.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# agent/parameters.py
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Callable
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParameterSchema:
|
||||||
|
"""Describes a required parameter for the agent."""
|
||||||
|
key: str
|
||||||
|
description: str
|
||||||
|
why_needed: str # Explanation for the AI
|
||||||
|
type: str # "string", "number", "object", etc.
|
||||||
|
validator: Optional[Callable[[Any], bool]] = None
|
||||||
|
default: Any = None
|
||||||
|
required: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Define all required parameters
|
||||||
|
REQUIRED_PARAMETERS = [
|
||||||
|
ParameterSchema(
|
||||||
|
key="config",
|
||||||
|
description="Configuration object containing all folder paths",
|
||||||
|
why_needed=(
|
||||||
|
"This contains the paths to all important folders:\n"
|
||||||
|
"- download_folder: Where downloaded files arrive before being organized\n"
|
||||||
|
"- tvshow_folder: Where TV show files are organized and stored\n"
|
||||||
|
"- movie_folder: Where movie files are organized and stored\n"
|
||||||
|
"- torrent_folder: Where .torrent files are saved for the torrent client"
|
||||||
|
),
|
||||||
|
type="object",
|
||||||
|
validator=lambda x: isinstance(x, dict),
|
||||||
|
required=True,
|
||||||
|
default={}
|
||||||
|
),
|
||||||
|
ParameterSchema(
|
||||||
|
key="tv_shows",
|
||||||
|
description="List of TV shows the user is following",
|
||||||
|
why_needed=(
|
||||||
|
"This tracks which TV shows you're following. "
|
||||||
|
"Each show includes: IMDB ID, title, number of seasons, and status (ongoing or ended)."
|
||||||
|
),
|
||||||
|
type="array",
|
||||||
|
validator=lambda x: isinstance(x, list),
|
||||||
|
required=False,
|
||||||
|
default=[]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_schema(key: str) -> Optional[ParameterSchema]:
|
||||||
|
"""Get schema for a specific parameter."""
|
||||||
|
for param in REQUIRED_PARAMETERS:
|
||||||
|
if param.key == key:
|
||||||
|
return param
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_missing_required_parameters(memory_data: dict) -> list[ParameterSchema]:
|
||||||
|
"""Get list of required parameters that are missing or None."""
|
||||||
|
missing = []
|
||||||
|
for param in REQUIRED_PARAMETERS:
|
||||||
|
if param.required:
|
||||||
|
value = memory_data.get(param.key)
|
||||||
|
if value is None:
|
||||||
|
missing.append(param)
|
||||||
|
return missing
|
||||||
|
|
||||||
|
|
||||||
|
def format_parameters_for_prompt() -> str:
|
||||||
|
"""Format parameter descriptions for the AI system prompt."""
|
||||||
|
lines = ["REQUIRED PARAMETERS:"]
|
||||||
|
for param in REQUIRED_PARAMETERS:
|
||||||
|
status = "REQUIRED" if param.required else "OPTIONAL"
|
||||||
|
lines.append(f"\n- {param.key} ({status}):")
|
||||||
|
lines.append(f" Description: {param.description}")
|
||||||
|
lines.append(f" Why needed: {param.why_needed}")
|
||||||
|
lines.append(f" Type: {param.type}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_parameter(key: str, value: Any) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate a parameter value against its schema.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_valid, error_message)
|
||||||
|
"""
|
||||||
|
schema = get_parameter_schema(key)
|
||||||
|
if not schema:
|
||||||
|
return True, None # Unknown parameters are allowed
|
||||||
|
|
||||||
|
if schema.validator:
|
||||||
|
try:
|
||||||
|
if not schema.validator(value):
|
||||||
|
return False, f"Validation failed for {key}"
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"Validation error for {key}: {str(e)}"
|
||||||
|
|
||||||
|
return True, None
|
||||||
88
agent/prompts.py
Normal file
88
agent/prompts.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# agent/prompts.py
|
||||||
|
from typing import Dict, Any
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .registry import Tool
|
||||||
|
from .parameters import format_parameters_for_prompt, get_missing_required_parameters
|
||||||
|
|
||||||
|
|
||||||
|
class PromptBuilder:
|
||||||
|
"""Handles construction of system prompts for the agent."""
|
||||||
|
|
||||||
|
def __init__(self, tools: Dict[str, Tool]):
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
|
def _format_tools_description(self) -> str:
|
||||||
|
"""Format tools with their descriptions and parameters."""
|
||||||
|
return "\n".join(
|
||||||
|
f"- {tool.name}: {tool.description}\n"
|
||||||
|
f" Parameters: {json.dumps(tool.parameters, ensure_ascii=False)}"
|
||||||
|
for tool in self.tools.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_context(self, memory_data: dict) -> Dict[str, Any]:
|
||||||
|
"""Build the context object with current state from memory."""
|
||||||
|
return memory_data
|
||||||
|
|
||||||
|
def build_system_prompt(self, memory_data: dict) -> str:
|
||||||
|
"""
|
||||||
|
Build the system prompt with context provided as JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_data: The full memory data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The complete system prompt string
|
||||||
|
"""
|
||||||
|
context = self._build_context(memory_data)
|
||||||
|
tools_desc = self._format_tools_description()
|
||||||
|
params_desc = format_parameters_for_prompt()
|
||||||
|
|
||||||
|
# Check for missing required parameters
|
||||||
|
missing_params = get_missing_required_parameters(memory_data)
|
||||||
|
missing_info = ""
|
||||||
|
if missing_params:
|
||||||
|
missing_info = "\n\n⚠️ MISSING REQUIRED PARAMETERS:\n"
|
||||||
|
for param in missing_params:
|
||||||
|
missing_info += f"- {param.key}: {param.description}\n"
|
||||||
|
missing_info += f" Why needed: {param.why_needed}\n"
|
||||||
|
|
||||||
|
return (
|
||||||
|
"You are an AI agent helping a user manage their local media library.\n\n"
|
||||||
|
f"{params_desc}\n\n"
|
||||||
|
"CURRENT CONTEXT (JSON):\n"
|
||||||
|
f"{json.dumps(context, indent=2, ensure_ascii=False)}\n"
|
||||||
|
f"{missing_info}\n"
|
||||||
|
"IMPORTANT RULES:\n"
|
||||||
|
"1. Check the REQUIRED PARAMETERS section above to understand what information you need.\n"
|
||||||
|
"2. If any required parameter is missing (shown in MISSING REQUIRED PARAMETERS), "
|
||||||
|
"you MUST ask the user for it and explain WHY you need it based on the parameter description.\n"
|
||||||
|
"3. To use a tool, respond STRICTLY with this JSON format:\n"
|
||||||
|
' { "thought": "explanation", "action": { "name": "tool_name", "args": { "arg": "value" } } }\n'
|
||||||
|
" - No text before or after the JSON\n"
|
||||||
|
" - All args must be complete and non-null\n"
|
||||||
|
"4. You can use MULTIPLE TOOLS IN SEQUENCE:\n"
|
||||||
|
" - After executing a tool, you will receive its result\n"
|
||||||
|
" - You can then decide to use another tool based on the result\n"
|
||||||
|
" - Or provide a final text response to the user\n"
|
||||||
|
" - Continue using tools until you have all the information needed\n"
|
||||||
|
"5. If you respond with text (not using a tool), respond normally in French.\n"
|
||||||
|
"6. When you have all the information needed, provide a final response in NATURAL TEXT (not JSON).\n"
|
||||||
|
"7. Extract the relevant information from the user's request and pass it as tool arguments.\n"
|
||||||
|
"\n"
|
||||||
|
"EXAMPLES:\n"
|
||||||
|
" To set the download folder:\n"
|
||||||
|
' { "thought": "User provided download path", "action": { "name": "set_path", "args": { "path_type": "download_folder", "path_value": "/home/user/downloads" } } }\n'
|
||||||
|
"\n"
|
||||||
|
" To set the TV show folder:\n"
|
||||||
|
' { "thought": "User provided TV show path", "action": { "name": "set_path", "args": { "path_type": "tvshow_folder", "path_value": "/home/user/media/tvshows" } } }\n'
|
||||||
|
"\n"
|
||||||
|
" To list the download folder:\n"
|
||||||
|
' { "thought": "User wants to see downloads", "action": { "name": "list_folder", "args": { "folder_type": "download", "path": "." } } }\n'
|
||||||
|
"\n"
|
||||||
|
" To list a subfolder in TV shows:\n"
|
||||||
|
' { "thought": "User wants to see a specific show", "action": { "name": "list_folder", "args": { "folder_type": "tvshow", "path": "Game.of.Thrones" } } }\n'
|
||||||
|
"\n"
|
||||||
|
"AVAILABLE TOOLS:\n"
|
||||||
|
f"{tools_desc}\n"
|
||||||
|
)
|
||||||
93
agent/registry.py
Normal file
93
agent/registry.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Tool registry and definitions."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Any, Dict
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from .memory import Memory
|
||||||
|
from .tools.filesystem import set_path_for_folder, list_folder
|
||||||
|
from .tools.api import find_media_imdb_id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tool:
|
||||||
|
"""Represents a tool that can be used by the agent."""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
func: Callable[..., Dict[str, Any]]
|
||||||
|
parameters: Dict[str, Any] # JSON Schema des paramètres
|
||||||
|
|
||||||
|
|
||||||
|
def make_tools(memory: Memory) -> Dict[str, Tool]:
|
||||||
|
"""
|
||||||
|
Create all available tools with memory bound to them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory: Memory instance to be used by the tools
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping tool names to Tool instances
|
||||||
|
"""
|
||||||
|
# Create partial functions with memory pre-bound for filesystem tools
|
||||||
|
set_path_func = partial(set_path_for_folder, memory)
|
||||||
|
list_folder_func = partial(list_folder, memory)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
Tool(
|
||||||
|
name="set_path_for_folder",
|
||||||
|
description="Sets a path in the configuration (download_folder, tvshow_folder, movie_folder, or torrent_folder).",
|
||||||
|
func=set_path_func,
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"folder_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name of folder to set",
|
||||||
|
"enum": ["download", "tvshow", "movie", "torrent"]
|
||||||
|
},
|
||||||
|
"path_value": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Absolute path to the folder (e.g., /home/user/downloads)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["folder_name", "path_value"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="list_folder",
|
||||||
|
description="Lists the contents of a specified folder (download, tvshow, movie, or torrent).",
|
||||||
|
func=list_folder_func,
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"folder_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Type of folder to list: 'download', 'tvshow', 'movie', or 'torrent'",
|
||||||
|
"enum": ["download", "tvshow", "movie", "torrent"]
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Relative path within the folder (default: '.' for root)",
|
||||||
|
"default": "."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["folder_type"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="find_media_imdb_id",
|
||||||
|
description="Finds the IMDb ID for a given media title using TMDB API.",
|
||||||
|
func=find_media_imdb_id,
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"media_title": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Title of the media to find the IMDb ID for"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["media_title"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return {t.name: t for t in tools}
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
# agent/tools.py
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, Any, Dict
|
|
||||||
import os
|
|
||||||
|
|
||||||
from .memory import Memory
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Tool:
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
func: Callable[..., Dict[str, Any]]
|
|
||||||
parameters: Dict[str, Any] # JSON Schema des paramètres
|
|
||||||
|
|
||||||
def make_tools(memory: Memory) -> dict[str, Tool]:
|
|
||||||
def set_project_root(project_root: str) -> Dict[str, Any]:
|
|
||||||
if not os.path.isdir(project_root):
|
|
||||||
return {"error": "invalid_path", "message": f"Le chemin {project_root} n'est pas un dossier valide."}
|
|
||||||
|
|
||||||
print(f"Setting project root to: {project_root}")
|
|
||||||
print("Memory before:", memory.data)
|
|
||||||
memory.set_project_root(project_root)
|
|
||||||
print("Memory after:", memory.data)
|
|
||||||
return {"status": "ok", "project_root": project_root}
|
|
||||||
|
|
||||||
def list_directory(path: str) -> Dict[str, Any]:
|
|
||||||
print("Proper tool used")
|
|
||||||
if not memory.data.get("project_root"):
|
|
||||||
return {"error": "no_project_root", "message": "Project root not set."}
|
|
||||||
|
|
||||||
root = memory.data.get("project_root")
|
|
||||||
full_path = os.path.abspath(os.path.join(root, path))
|
|
||||||
root_abs = os.path.abspath(root)
|
|
||||||
if not full_path.startswith(root_abs):
|
|
||||||
return {"error": "forbidden", "message": "Path outside project_root."}
|
|
||||||
|
|
||||||
try:
|
|
||||||
entries = os.listdir(full_path)
|
|
||||||
return {"path": path, "entries": entries}
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": "os_error", "message": str(e)}
|
|
||||||
|
|
||||||
tools = [
|
|
||||||
Tool(
|
|
||||||
name="set_project_root",
|
|
||||||
description="Enregistre le path du dossier racine du projet.",
|
|
||||||
func=set_project_root,
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"project_root": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Chemin absolu du dossier racine du projet (ex: /home/user/mon_projet)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["project_root"]
|
|
||||||
}
|
|
||||||
),
|
|
||||||
Tool(
|
|
||||||
name="list_directory",
|
|
||||||
description="Liste le contenu d'un dossier relatif au project_root.",
|
|
||||||
func=list_directory,
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Chemin relatif du dossier à lister (ex: 'src' ou '.' pour la racine)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path"]
|
|
||||||
}
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
return {t.name: t for t in tools}
|
|
||||||
3
agent/tools/__init__.py
Normal file
3
agent/tools/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Tools module - filesystem and API tools."""
|
||||||
|
from .filesystem import FolderName, set_path_for_folder, list_folder
|
||||||
|
from .api import find_media_imdb_id
|
||||||
90
agent/tools/api.py
Normal file
90
agent/tools/api.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""API tools for interacting with external services."""
|
||||||
|
from typing import Dict, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from ..api import tmdb_client, TMDBError, TMDBNotFoundError, TMDBAPIError, TMDBConfigurationError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def find_media_imdb_id(media_title: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Find the IMDb ID for a given media title using TMDB API.
|
||||||
|
|
||||||
|
This is a wrapper around the TMDB client that returns a standardized
|
||||||
|
dict format for compatibility with the agent's tool system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media_title: Title of the media to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with IMDb ID or error information:
|
||||||
|
- Success: {"status": "ok", "imdb_id": str, "title": str, ...}
|
||||||
|
- Error: {"error": str, "message": str}
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> result = find_media_imdb_id("Inception")
|
||||||
|
>>> print(result)
|
||||||
|
{'status': 'ok', 'imdb_id': 'tt1375666', 'title': 'Inception', ...}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use the TMDB client to search for media
|
||||||
|
result = tmdb_client.search_media(media_title)
|
||||||
|
|
||||||
|
# Check if IMDb ID was found
|
||||||
|
if result.imdb_id:
|
||||||
|
logger.info(f"IMDb ID found for '{media_title}': {result.imdb_id}")
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"imdb_id": result.imdb_id,
|
||||||
|
"title": result.title,
|
||||||
|
"media_type": result.media_type,
|
||||||
|
"tmdb_id": result.tmdb_id,
|
||||||
|
"overview": result.overview,
|
||||||
|
"release_date": result.release_date,
|
||||||
|
"vote_average": result.vote_average
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(f"No IMDb ID available for '{media_title}'")
|
||||||
|
return {
|
||||||
|
"error": "no_imdb_id",
|
||||||
|
"message": f"No IMDb ID available for '{result.title}'",
|
||||||
|
"title": result.title,
|
||||||
|
"media_type": result.media_type,
|
||||||
|
"tmdb_id": result.tmdb_id
|
||||||
|
}
|
||||||
|
|
||||||
|
except TMDBNotFoundError as e:
|
||||||
|
logger.info(f"Media not found: {e}")
|
||||||
|
return {
|
||||||
|
"error": "not_found",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
except TMDBConfigurationError as e:
|
||||||
|
logger.error(f"TMDB configuration error: {e}")
|
||||||
|
return {
|
||||||
|
"error": "configuration_error",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
except TMDBAPIError as e:
|
||||||
|
logger.error(f"TMDB API error: {e}")
|
||||||
|
return {
|
||||||
|
"error": "api_error",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Validation error: {e}")
|
||||||
|
return {
|
||||||
|
"error": "validation_failed",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"error": "internal_error",
|
||||||
|
"message": "An unexpected error occurred"
|
||||||
|
}
|
||||||
257
agent/tools/filesystem.py
Normal file
257
agent/tools/filesystem.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
"""Filesystem tools for managing folders and files with security."""
|
||||||
|
from typing import Dict, Any
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from ..memory import Memory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FolderName(Enum):
|
||||||
|
"""Types of folders that can be managed."""
|
||||||
|
DOWNLOAD = "download"
|
||||||
|
TVSHOW = "tvshow"
|
||||||
|
MOVIE = "movie"
|
||||||
|
TORRENT = "torrent"
|
||||||
|
|
||||||
|
|
||||||
|
class FilesystemError(Exception):
|
||||||
|
"""Base exception for filesystem operations."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PathTraversalError(FilesystemError):
|
||||||
|
"""Raised when path traversal attack is detected."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_folder_name(folder_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate folder name against allowed values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_name: Name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if valid
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If folder name is invalid
|
||||||
|
"""
|
||||||
|
valid_names = [fn.value for fn in FolderName]
|
||||||
|
if folder_name not in valid_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid folder_name '{folder_name}'. Must be one of: {', '.join(valid_names)}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_path(path: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize path to prevent path traversal attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized path
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PathTraversalError: If path contains dangerous patterns
|
||||||
|
"""
|
||||||
|
# Normalize path
|
||||||
|
normalized = os.path.normpath(path)
|
||||||
|
|
||||||
|
# Check for absolute paths
|
||||||
|
if os.path.isabs(normalized):
|
||||||
|
raise PathTraversalError("Absolute paths are not allowed")
|
||||||
|
|
||||||
|
# Check for parent directory references
|
||||||
|
if normalized.startswith("..") or "/.." in normalized or "\\.." in normalized:
|
||||||
|
raise PathTraversalError("Parent directory references are not allowed")
|
||||||
|
|
||||||
|
# Check for null bytes
|
||||||
|
if "\x00" in normalized:
|
||||||
|
raise PathTraversalError("Null bytes in path are not allowed")
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _is_safe_path(base_path: Path, target_path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if target path is within base path (prevents path traversal).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_path: Base directory path
|
||||||
|
target_path: Target path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if safe, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Resolve both paths to absolute paths
|
||||||
|
base_resolved = base_path.resolve()
|
||||||
|
target_resolved = target_path.resolve()
|
||||||
|
|
||||||
|
# Check if target is relative to base
|
||||||
|
target_resolved.relative_to(base_resolved)
|
||||||
|
return True
|
||||||
|
except (ValueError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def set_path_for_folder(memory: Memory, folder_name: str, path_value: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Set a path in the config with validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory: Memory instance to store the configuration
|
||||||
|
folder_name: Name of folder to set (download, tvshow, movie, torrent)
|
||||||
|
path_value: Absolute path to the folder
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with status or error information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate folder name
|
||||||
|
_validate_folder_name(folder_name)
|
||||||
|
|
||||||
|
# Convert to Path object for better handling
|
||||||
|
path_obj = Path(path_value).resolve()
|
||||||
|
|
||||||
|
# Validate path exists and is a directory
|
||||||
|
if not path_obj.exists():
|
||||||
|
logger.warning(f"Path does not exist: {path_value}")
|
||||||
|
return {
|
||||||
|
"error": "invalid_path",
|
||||||
|
"message": f"Path does not exist: {path_value}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if not path_obj.is_dir():
|
||||||
|
logger.warning(f"Path is not a directory: {path_value}")
|
||||||
|
return {
|
||||||
|
"error": "invalid_path",
|
||||||
|
"message": f"Path is not a directory: {path_value}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if path is readable
|
||||||
|
if not os.access(path_obj, os.R_OK):
|
||||||
|
logger.warning(f"Path is not readable: {path_value}")
|
||||||
|
return {
|
||||||
|
"error": "permission_denied",
|
||||||
|
"message": f"Path is not readable: {path_value}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store in memory
|
||||||
|
config = memory.get("config", {})
|
||||||
|
config[f"{folder_name}_folder"] = str(path_obj)
|
||||||
|
memory.set("config", config)
|
||||||
|
|
||||||
|
logger.info(f"Set {folder_name}_folder to: {path_obj}")
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"folder_name": folder_name,
|
||||||
|
"path": str(path_obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Validation error: {e}")
|
||||||
|
return {"error": "validation_failed", "message": str(e)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error setting path: {e}", exc_info=True)
|
||||||
|
return {"error": "internal_error", "message": "Failed to set path"}
|
||||||
|
|
||||||
|
|
||||||
|
def list_folder(memory: Memory, folder_type: str, path: str = ".") -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
List contents of a folder with security checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory: Memory instance to retrieve the configuration
|
||||||
|
folder_type: Type of folder to list (download, tvshow, movie, torrent)
|
||||||
|
path: Relative path within the folder (default: ".")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with folder contents or error information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate folder type
|
||||||
|
_validate_folder_name(folder_type)
|
||||||
|
|
||||||
|
# Sanitize the path
|
||||||
|
safe_path = _sanitize_path(path)
|
||||||
|
|
||||||
|
# Get root folder from config
|
||||||
|
folder_key = f"{folder_type}_folder"
|
||||||
|
config = memory.get("config", {})
|
||||||
|
|
||||||
|
if folder_key not in config or not config[folder_key]:
|
||||||
|
logger.warning(f"Folder not configured: {folder_type}")
|
||||||
|
return {
|
||||||
|
"error": "folder_not_set",
|
||||||
|
"message": f"{folder_type.capitalize()} folder not set in config."
|
||||||
|
}
|
||||||
|
|
||||||
|
root = Path(config[folder_key])
|
||||||
|
target = root / safe_path
|
||||||
|
|
||||||
|
# Security check: ensure target is within root
|
||||||
|
if not _is_safe_path(root, target):
|
||||||
|
logger.warning(f"Path traversal attempt detected: {path}")
|
||||||
|
return {
|
||||||
|
"error": "forbidden",
|
||||||
|
"message": "Access denied: path outside allowed directory"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if target exists
|
||||||
|
if not target.exists():
|
||||||
|
logger.warning(f"Path does not exist: {target}")
|
||||||
|
return {
|
||||||
|
"error": "not_found",
|
||||||
|
"message": f"Path does not exist: {safe_path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if target is a directory
|
||||||
|
if not target.is_dir():
|
||||||
|
logger.warning(f"Path is not a directory: {target}")
|
||||||
|
return {
|
||||||
|
"error": "not_a_directory",
|
||||||
|
"message": f"Path is not a directory: {safe_path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# List directory contents
|
||||||
|
try:
|
||||||
|
entries = [entry.name for entry in target.iterdir()]
|
||||||
|
logger.debug(f"Listed {len(entries)} entries in {target}")
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"folder_type": folder_type,
|
||||||
|
"path": safe_path,
|
||||||
|
"entries": sorted(entries),
|
||||||
|
"count": len(entries)
|
||||||
|
}
|
||||||
|
except PermissionError:
|
||||||
|
logger.warning(f"Permission denied accessing: {target}")
|
||||||
|
return {
|
||||||
|
"error": "permission_denied",
|
||||||
|
"message": f"Permission denied accessing: {safe_path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
except PathTraversalError as e:
|
||||||
|
logger.warning(f"Path traversal attempt: {e}")
|
||||||
|
return {
|
||||||
|
"error": "forbidden",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Validation error: {e}")
|
||||||
|
return {"error": "validation_failed", "message": str(e)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error listing folder: {e}", exc_info=True)
|
||||||
|
return {"error": "internal_error", "message": "Failed to list folder"}
|
||||||
12
app.py
12
app.py
@@ -7,10 +7,9 @@ from typing import Any, Dict
|
|||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from agent.llm import DeepSeekClient
|
from agent.llm.deepseek import DeepSeekClient
|
||||||
from agent.memory import Memory
|
from agent.memory import Memory
|
||||||
from agent.agent import Agent
|
from agent.agent import Agent
|
||||||
from agent.commands import CommandRegistry
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="LibreChat Agent Backend",
|
title="LibreChat Agent Backend",
|
||||||
@@ -20,7 +19,6 @@ app = FastAPI(
|
|||||||
llm = DeepSeekClient()
|
llm = DeepSeekClient()
|
||||||
memory = Memory()
|
memory = Memory()
|
||||||
agent = Agent(llm=llm, memory=memory)
|
agent = Agent(llm=llm, memory=memory)
|
||||||
commands_registry = CommandRegistry(memory=memory)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_last_user_content(messages: list[Dict[str, Any]]) -> str:
|
def extract_last_user_content(messages: list[Dict[str, Any]]) -> str:
|
||||||
@@ -42,12 +40,8 @@ async def chat_completions(request: Request):
|
|||||||
user_input = extract_last_user_content(messages)
|
user_input = extract_last_user_content(messages)
|
||||||
print("Received chat completion request, stream =", stream, "input:", user_input)
|
print("Received chat completion request, stream =", stream, "input:", user_input)
|
||||||
|
|
||||||
# 🔹 1) Si c'est une commande, on ne fait PAS intervenir le LLM
|
# Process user input through the agent
|
||||||
if user_input.strip().startswith("/"):
|
answer = agent.step(user_input)
|
||||||
answer = commands_registry.handle(user_input)
|
|
||||||
else:
|
|
||||||
# 🔹 2) Sinon, logique agent + LLM comme avant
|
|
||||||
answer = agent.step(user_input)
|
|
||||||
|
|
||||||
# Ensuite = même logique de réponse (non-stream ou stream)
|
# Ensuite = même logique de réponse (non-stream ou stream)
|
||||||
created_ts = int(time.time())
|
created_ts = int(time.time())
|
||||||
|
|||||||
Reference in New Issue
Block a user