diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..231a69c --- /dev/null +++ b/.env.example @@ -0,0 +1,16 @@ +# DeepSeek LLM Configuration +DEEPSEEK_API_KEY=your_deepseek_api_key_here +DEEPSEEK_BASE_URL=https://api.deepseek.com +DEEPSEEK_MODEL=deepseek-chat +TEMPERATURE=0.2 + +# TMDB API Configuration +TMDB_API_KEY=your_tmdb_api_key_here +TMDB_BASE_URL=https://api.themoviedb.org/3 + +# Storage Configuration +MEMORY_FILE=memory.json + +# Security Configuration +MAX_TOOL_ITERATIONS=5 +REQUEST_TIMEOUT=30 diff --git a/.gitignore b/.gitignore index e69de29..6d3ac84 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1,43 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +ENV/ +env/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Qodo +./qodo + +# Memory and state files +memory.json + +# OS +.DS_Store +Thumbs.db diff --git a/agent/agent.py b/agent/agent.py index b627872..e529dfe 100644 --- a/agent/agent.py +++ b/agent/agent.py @@ -4,39 +4,17 @@ import json from .llm import DeepSeekClient from .memory import Memory -from .tools import make_tools, Tool +from .registry import make_tools, Tool +from .prompts import PromptBuilder class Agent: - def __init__(self, llm: DeepSeekClient, memory: Memory): + def __init__(self, llm: DeepSeekClient, memory: Memory, max_tool_iterations: int = 5): self.llm = llm self.memory = memory self.tools: Dict[str, Tool] = make_tools(memory) + self.prompt_builder = PromptBuilder(self.tools) + self.max_tool_iterations = max_tool_iterations - def _build_system_prompt(self) -> str: - ctx = {"project_root": self.memory.get_project_root()} - tools_desc = "\n".join( - f"- {t.name}: {t.description}" for t in self.tools.values() - ) - return ( - "Tu es un agent IA qui aide un développeur senior à gérer son projet local.\n" - "Tu peux demander des informations de base (comme le chemin du projet)\n" - "et tu peux utiliser des outils pour interagir avec le système de fichiers.\n\n" - "Contexte utilisateur (JSON):\n" - f"{json.dumps(ctx, ensure_ascii=False)}\n\n" - "Règles IMPORTANTES pour les outils:\n" - "1. Si tu ne connais pas la valeur d'un argument (par exemple project_root), " - "TU NE DOIS PAS mettre null ou une valeur inventée.\n" - " À la place, tu dois poser une question à l'utilisateur pour obtenir l'information.\n" - "2. Ne propose set_user_profile QUE lorsque l'utilisateur a donné un chemin de projet explicite.\n" - "3. Quand tu veux utiliser un outil, réponds STRICTEMENT avec un JSON de la forme :\n" - '{ "thought": "...", "action": { "name": "tool_name", "args": { ... } } }\n' - " - Pas de texte avant/après.\n" - " - Les args doivent être COMPLETS et NON nuls.\n" - "4. Quand JE (le système) te fournis un object JSON contenant 'tool_result' et 'intent', " - "tu dois ALORS répondre à l'utilisateur en TEXTE NATUREL, et NE PAS renvoyer de JSON d'action.\n\n" - "Tools disponibles:\n" - f"{tools_desc}\n" - ) def _parse_intent(self, text: str) -> Dict[str, Any] | None: try: @@ -62,18 +40,6 @@ class Agent: name: str = action["name"] args: Dict[str, Any] = action.get("args", {}) or {} - if name == "set_user_profile": - project_root = args.get("project_root") - if not project_root: - return { - "error": "missing_project_root", - "message": ( - "Le modèle a demandé set_user_profile sans project_root. " - "Tu dois d'abord demander à l'utilisateur de fournir un chemin " - "de projet valide (ex: /home/francois/mon_projet)." - ), - } - tool = self.tools.get(name) if not tool: return {"error": "unknown_tool", "tool": name} @@ -87,95 +53,77 @@ class Agent: return result def step(self, user_input: str) -> str: + """ + Execute one agent step with iterative tool execution: + - Build system prompt + - Query LLM + - Loop: If JSON intent -> execute tool, add result to conversation, query LLM again + - Continue until LLM responds with text (no tool call) or max iterations reached + - Return final text response + """ print("Starting a new step...") - """ - Un 'tour' d'agent : - - construit le prompt system - - interroge DeepSeek - - si JSON d'intent -> exécute tool, refait un appel, renvoie réponse finale - - sinon -> renvoie texte brut - """ print("User input:", user_input) - root = self.memory.data.get("project_root") - print("Current project_root in memory:", root) - # Unified system prompt that always allows tools - tools_desc = "\n".join( - f"- {t.name}: {t.description}\n Paramètres: {json.dumps(t.parameters, ensure_ascii=False)}" - for t in self.tools.values() - ) + print("Current memory state:", self.memory.data) - if root is None: - print("No project_root set - asking user and allowing tool use") - system_prompt = ( - "Tu es un agent IA qui aide un développeur à gérer son projet local.\n\n" - "CONTEXTE ACTUEL:\n" - f"- project_root: {root} (NON DÉFINI)\n\n" - "RÈGLES IMPORTANTES:\n" - "1. Le project_root n'est pas encore défini. Tu DOIS d'abord demander à l'utilisateur " - "le chemin absolu de son projet (ex: /home/user/mon_projet).\n" - "2. Quand l'utilisateur te donne un chemin, tu DOIS immédiatement utiliser l'outil " - "'set_project_root' pour le sauvegarder.\n" - "3. Pour utiliser un outil, réponds STRICTEMENT avec ce format JSON:\n" - ' { "thought": "explication", "action": { "name": "nom_outil", "args": { "arg": "valeur" } } }\n' - "4. Si tu réponds en texte (pas d'outil), réponds normalement en français.\n" - "5. Quand le système te renvoie un tool_result, réponds à l'utilisateur en TEXTE NATUREL.\n\n" - "OUTILS DISPONIBLES:\n" - f"{tools_desc}\n" - ) - else: - print("Project_root is set - normal operation mode") - system_prompt = ( - "Tu es un agent IA qui aide un développeur à gérer son projet local.\n\n" - "CONTEXTE ACTUEL:\n" - f"- project_root: {root}\n\n" - "RÈGLES IMPORTANTES:\n" - "1. Le project_root est défini. Tu peux utiliser les outils disponibles.\n" - "2. Pour utiliser un outil, réponds STRICTEMENT avec ce format JSON:\n" - ' { "thought": "explication", "action": { "name": "nom_outil", "args": { "param": "valeur" } } }\n' - " EXEMPLE pour lister le dossier 'src':\n" - ' { "thought": "L\'utilisateur veut voir le contenu de src", "action": { "name": "list_directory", "args": { "path": "src" } } }\n' - " EXEMPLE pour lister la racine du projet:\n" - ' { "thought": "L\'utilisateur veut voir la racine", "action": { "name": "list_directory", "args": { "path": "." } } }\n' - "3. Si tu réponds en texte (pas d'outil), réponds normalement en français.\n" - "4. Quand le système te renvoie un tool_result, réponds à l'utilisateur en TEXTE NATUREL.\n" - "5. IMPORTANT: Extrais le chemin demandé par l'utilisateur et passe-le comme argument 'path'.\n\n" - "OUTILS DISPONIBLES:\n" - f"{tools_desc}\n" - ) + # Build system prompt using PromptBuilder + system_prompt = self.prompt_builder.build_system_prompt(self.memory.data) + # Initialize conversation with user input messages: List[Dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_input}, ] - raw_first = self.llm.complete(messages) - intent = self._parse_intent(raw_first) - print("raw_first:", raw_first) - print("Intent:", intent) - if not intent: - # Réponse texte simple - #self.memory.append_history("user", user_input) - #self.memory.append_history("assistant", raw_first) - return raw_first + # Tool execution loop + iteration = 0 + while iteration < self.max_tool_iterations: + print(f"\n--- Iteration {iteration + 1} ---") - # Exécuter l'action - tool_result = self._execute_action(intent) + # Get LLM response + llm_response = self.llm.complete(messages) + print("LLM response:", llm_response) - followup_messages: List[Dict[str, Any]] = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_input}, - { + # Try to parse as tool intent + intent = self._parse_intent(llm_response) + + if not intent: + # No tool call - this is the final text response + print("No tool intent detected, returning final response") + # Save to history + self.memory.append_history("user", user_input) + self.memory.append_history("assistant", llm_response) + return llm_response + + # Tool call detected - execute it + print("Intent detected:", intent) + tool_result = self._execute_action(intent) + print("Tool result:", tool_result) + + # Add assistant's tool call and result to conversation + messages.append({ "role": "assistant", + "content": json.dumps(intent, ensure_ascii=False) + }) + messages.append({ + "role": "user", "content": json.dumps( - {"tool_result": tool_result, "intent": intent}, - ensure_ascii=False, - ), - }, - ] + {"tool_result": tool_result}, + ensure_ascii=False + ) + }) - raw_second = self.llm.complete(followup_messages) + iteration += 1 - #self.memory.append_history("user", user_input) - #self.memory.append_history("assistant", raw_second) - return raw_second + # Max iterations reached - ask LLM for final response + print(f"\n--- Max iterations ({self.max_tool_iterations}) reached, requesting final response ---") + messages.append({ + "role": "user", + "content": "Merci pour ces résultats. Peux-tu maintenant me donner une réponse finale en texte naturel ?" + }) + + final_response = self.llm.complete(messages) + # Save to history + self.memory.append_history("user", user_input) + self.memory.append_history("assistant", final_response) + return final_response diff --git a/agent/api/__init__.py b/agent/api/__init__.py new file mode 100644 index 0000000..a6f7e16 --- /dev/null +++ b/agent/api/__init__.py @@ -0,0 +1,20 @@ +"""API clients module.""" +from .themoviedb import ( + TMDBClient, + tmdb_client, + TMDBError, + TMDBConfigurationError, + TMDBAPIError, + TMDBNotFoundError, + MediaResult +) + +__all__ = [ + 'TMDBClient', + 'tmdb_client', + 'TMDBError', + 'TMDBConfigurationError', + 'TMDBAPIError', + 'TMDBNotFoundError', + 'MediaResult' +] diff --git a/agent/api/themoviedb.py b/agent/api/themoviedb.py new file mode 100644 index 0000000..afa3266 --- /dev/null +++ b/agent/api/themoviedb.py @@ -0,0 +1,317 @@ +"""TMDB (The Movie Database) API client.""" +from typing import Dict, Any, Optional, List +from dataclasses import dataclass +import logging +import requests +from requests.exceptions import RequestException, Timeout, HTTPError + +from ..config import Settings, settings + +logger = logging.getLogger(__name__) + + +class TMDBError(Exception): + """Base exception for TMDB-related errors.""" + pass + + +class TMDBConfigurationError(TMDBError): + """Raised when TMDB API is not properly configured.""" + pass + + +class TMDBAPIError(TMDBError): + """Raised when TMDB API returns an error.""" + pass + + +class TMDBNotFoundError(TMDBError): + """Raised when media is not found.""" + pass + + +@dataclass +class MediaResult: + """Represents a media search result from TMDB.""" + tmdb_id: int + title: str + media_type: str # 'movie' or 'tv' + imdb_id: Optional[str] = None + overview: Optional[str] = None + release_date: Optional[str] = None + poster_path: Optional[str] = None + vote_average: Optional[float] = None + + +class TMDBClient: + """ + Client for interacting with The Movie Database (TMDB) API. + + This client provides methods to search for movies and TV shows, + retrieve their details, and get external IDs (like IMDb). + + Example: + >>> client = TMDBClient() + >>> result = client.search_media("Inception") + >>> print(result.imdb_id) + 'tt1375666' + """ + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + timeout: Optional[int] = None, + config: Optional[Settings] = None + ): + """ + Initialize TMDB client. + + Args: + api_key: TMDB API key (defaults to settings) + base_url: TMDB API base URL (defaults to settings) + timeout: Request timeout in seconds (defaults to settings) + config: Optional Settings instance (for testing) + + Raises: + TMDBConfigurationError: If API key is missing + """ + cfg = config or settings + + self.api_key = api_key or cfg.tmdb_api_key + self.base_url = base_url or cfg.tmdb_base_url + self.timeout = timeout or cfg.request_timeout + + if not self.api_key: + raise TMDBConfigurationError( + "TMDB API key is required. Set TMDB_API_KEY environment variable." + ) + + if not self.base_url: + raise TMDBConfigurationError( + "TMDB base URL is required. Set TMDB_BASE_URL environment variable." + ) + + logger.info("TMDB client initialized") + + def _make_request( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Make a request to TMDB API. + + Args: + endpoint: API endpoint (e.g., '/search/multi') + params: Query parameters + + Returns: + JSON response as dict + + Raises: + TMDBAPIError: If request fails + """ + url = f"{self.base_url}{endpoint}" + + # Add API key to params + request_params = params or {} + request_params['api_key'] = self.api_key + + try: + logger.debug(f"TMDB request: {endpoint}") + response = requests.get(url, params=request_params, timeout=self.timeout) + response.raise_for_status() + return response.json() + + except Timeout as e: + logger.error(f"TMDB API timeout: {e}") + raise TMDBAPIError(f"Request timeout after {self.timeout} seconds") from e + + except HTTPError as e: + logger.error(f"TMDB API HTTP error: {e}") + if e.response is not None: + status_code = e.response.status_code + if status_code == 401: + raise TMDBAPIError("Invalid TMDB API key") from e + elif status_code == 404: + raise TMDBNotFoundError("Resource not found") from e + else: + raise TMDBAPIError(f"HTTP {status_code}: {e}") from e + raise TMDBAPIError(f"HTTP error: {e}") from e + + except RequestException as e: + logger.error(f"TMDB API request failed: {e}") + raise TMDBAPIError(f"Failed to connect to TMDB API: {e}") from e + + def search_multi(self, query: str) -> List[Dict[str, Any]]: + """ + Search for movies and TV shows. + + Args: + query: Search query (movie or TV show title) + + Returns: + List of search results + + Raises: + TMDBAPIError: If request fails + TMDBNotFoundError: If no results found + """ + if not query or not isinstance(query, str): + raise ValueError("Query must be a non-empty string") + + if len(query) > 500: + raise ValueError("Query is too long (max 500 characters)") + + data = self._make_request('/search/multi', {'query': query}) + + results = data.get('results', []) + if not results: + raise TMDBNotFoundError(f"No results found for '{query}'") + + logger.info(f"Found {len(results)} results for '{query}'") + return results + + def get_external_ids(self, media_type: str, tmdb_id: int) -> Dict[str, Any]: + """ + Get external IDs (IMDb, TVDB, etc.) for a media item. + + Args: + media_type: Type of media ('movie' or 'tv') + tmdb_id: TMDB ID of the media + + Returns: + Dict with external IDs + + Raises: + TMDBAPIError: If request fails + """ + if media_type not in ('movie', 'tv'): + raise ValueError(f"Invalid media_type: {media_type}. Must be 'movie' or 'tv'") + + endpoint = f"/{media_type}/{tmdb_id}/external_ids" + return self._make_request(endpoint) + + def search_media(self, title: str) -> MediaResult: + """ + Search for a media item and return detailed information including IMDb ID. + + This is a convenience method that combines search and external ID lookup. + + Args: + title: Title of the movie or TV show + + Returns: + MediaResult with all available information + + Raises: + TMDBAPIError: If request fails + TMDBNotFoundError: If media not found + """ + # Search for media + results = self.search_multi(title) + + # Get the first (most relevant) result + top_result = results[0] + + # Validate result structure + if 'id' not in top_result or 'media_type' not in top_result: + raise TMDBAPIError("Invalid TMDB response structure") + + tmdb_id = top_result['id'] + media_type = top_result['media_type'] + + # Skip if not movie or TV show + if media_type not in ('movie', 'tv'): + logger.warning(f"Skipping result of type: {media_type}") + if len(results) > 1: + # Try next result + return self._parse_result(results[1]) + raise TMDBNotFoundError(f"No movie or TV show found for '{title}'") + + return self._parse_result(top_result) + + def _parse_result(self, result: Dict[str, Any]) -> MediaResult: + """ + Parse a TMDB result into a MediaResult object. + + Args: + result: Raw TMDB result dict + + Returns: + MediaResult object + """ + tmdb_id = result['id'] + media_type = result['media_type'] + title = result.get('title') or result.get('name', 'Unknown') + + # Get external IDs (including IMDb) + try: + external_ids = self.get_external_ids(media_type, tmdb_id) + imdb_id = external_ids.get('imdb_id') + except TMDBAPIError as e: + logger.warning(f"Failed to get external IDs: {e}") + imdb_id = None + + # Extract other useful information + overview = result.get('overview') + release_date = result.get('release_date') or result.get('first_air_date') + poster_path = result.get('poster_path') + vote_average = result.get('vote_average') + + logger.info(f"Found: {title} (Type: {media_type}, TMDB ID: {tmdb_id}, IMDb: {imdb_id})") + + return MediaResult( + tmdb_id=tmdb_id, + title=title, + media_type=media_type, + imdb_id=imdb_id, + overview=overview, + release_date=release_date, + poster_path=poster_path, + vote_average=vote_average + ) + + def get_movie_details(self, movie_id: int) -> Dict[str, Any]: + """ + Get detailed information about a movie. + + Args: + movie_id: TMDB movie ID + + Returns: + Dict with movie details + + Raises: + TMDBAPIError: If request fails + """ + return self._make_request(f'/movie/{movie_id}') + + def get_tv_details(self, tv_id: int) -> Dict[str, Any]: + """ + Get detailed information about a TV show. + + Args: + tv_id: TMDB TV show ID + + Returns: + Dict with TV show details + + Raises: + TMDBAPIError: If request fails + """ + return self._make_request(f'/tv/{tv_id}') + + def is_configured(self) -> bool: + """ + Check if TMDB client is properly configured. + + Returns: + True if configured, False otherwise + """ + return bool(self.api_key and self.base_url) + + +# Global TMDB client instance (singleton) +tmdb_client = TMDBClient() diff --git a/agent/commands.py b/agent/commands.py deleted file mode 100644 index 633d82a..0000000 --- a/agent/commands.py +++ /dev/null @@ -1,109 +0,0 @@ -# agent/commands.py -from dataclasses import dataclass -from typing import Callable, Dict, List -import os - -from .memory import Memory - -@dataclass -class Command: - name: str - description: str - needs_project_root: bool - handler: Callable[[List[str], Memory], str] - - -class CommandRegistry: - def __init__(self, memory: Memory): - self.memory = memory - self.commands: Dict[str, Command] = {} - self._register_defaults() - - def _register_defaults(self): - # /setroot - def cmd_setroot(args: List[str], mem: Memory) -> str: - if not args: - return "Usage: `/setroot `" - - path = args[0] - if not os.path.isdir(path): - return f"Le chemin `{path}` n'existe pas ou n'est pas un dossier." - - mem.set_project_root(path) - return f"✅ Chemin du projet défini sur `{path}`." - - self.register( - Command( - name="setroot", - description="Définit le chemin racine du projet.", - needs_project_root=False, - handler=cmd_setroot, - ) - ) - - # /scan [rel_path] - def cmd_scan(args: List[str], mem: Memory) -> str: - root = mem.get_project_root() - rel = args[0] if args else "." - full = os.path.abspath(os.path.join(root, rel)) - - # sécurité basique - if not full.startswith(os.path.abspath(root)): - return "❌ Tu ne peux pas sortir du project_root." - - if not os.path.isdir(full): - return f"❌ `{rel}` n'est pas un dossier dans le projet." - - entries = sorted(os.listdir(full)) - if not entries: - return f"📁 `{rel}` est vide." - - lines = [f"📁 Scan de `{rel}` (dans `{root}`):"] - for e in entries: - p = os.path.join(full, e) - if os.path.isdir(p): - lines.append(f" 📂 {e}/") - else: - lines.append(f" 📄 {e}") - return "\n".join(lines) - - self.register( - Command( - name="scan", - description="Liste les fichiers/dossiers du projet.", - needs_project_root=True, - handler=cmd_scan, - ) - ) - - def register(self, cmd: Command): - self.commands[cmd.name] = cmd - - def handle(self, raw_input: str) -> str: - """ - Parse et exécute une commande du type `/scan src` ou `/setroot /chemin`. - """ - text = raw_input.strip() - if not text.startswith("/"): - return "Internal error: not a command." - - parts = text[1:].split() - if not parts: - return "Commande vide." - - name = parts[0] - args = parts[1:] - - cmd = self.commands.get(name) - if not cmd: - return f"Commande inconnue: `/{name}`.\nCommandes disponibles: {', '.join('/'+n for n in self.commands.keys())}" - - # dépendance project_root - if cmd.needs_project_root and not self.memory.get_project_root(): - return ( - "❗ Aucun `project_root` défini pour l'instant.\n" - "Commence par le définir avec:\n" - "`/setroot /chemin/vers/ton/projet`" - ) - - return cmd.handler(args, self.memory) diff --git a/agent/config.py b/agent/config.py index 3c4804f..69c137d 100644 --- a/agent/config.py +++ b/agent/config.py @@ -1,15 +1,78 @@ -from dataclasses import dataclass +"""Configuration management with validation.""" +from dataclasses import dataclass, field import os +from pathlib import Path +from typing import Optional from dotenv import load_dotenv +# Load environment variables from .env file load_dotenv() + +class ConfigurationError(Exception): + """Raised when configuration is invalid.""" + pass + + @dataclass class Settings: - deepseek_api_key: str = os.getenv("DEEPSEEK_API_KEY", "") - deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com") - model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat") - temperature: float = float(os.getenv("TEMPERATURE", "0.2")) - memory_file: str = os.getenv("MEMORY_FILE", "memory.json") + """Application settings loaded from environment variables.""" + # LLM Configuration + deepseek_api_key: str = field(default_factory=lambda: os.getenv("DEEPSEEK_API_KEY", "")) + deepseek_base_url: str = field(default_factory=lambda: os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")) + model: str = field(default_factory=lambda: os.getenv("DEEPSEEK_MODEL", "deepseek-chat")) + temperature: float = field(default_factory=lambda: float(os.getenv("TEMPERATURE", "0.2"))) + + # TMDB Configuration + tmdb_api_key: str = field(default_factory=lambda: os.getenv("TMDB_API_KEY", "")) + tmdb_base_url: str = field(default_factory=lambda: os.getenv("TMDB_BASE_URL", "https://api.themoviedb.org/3")) + + # Storage Configuration + memory_file: str = field(default_factory=lambda: os.getenv("MEMORY_FILE", "memory.json")) + + # Security Configuration + max_tool_iterations: int = field(default_factory=lambda: int(os.getenv("MAX_TOOL_ITERATIONS", "5"))) + request_timeout: int = field(default_factory=lambda: int(os.getenv("REQUEST_TIMEOUT", "30"))) + + def __post_init__(self): + """Validate settings after initialization.""" + self._validate() + + def _validate(self) -> None: + """Validate configuration values.""" + # Validate temperature + if not 0.0 <= self.temperature <= 2.0: + raise ConfigurationError(f"Temperature must be between 0.0 and 2.0, got {self.temperature}") + + # Validate max_tool_iterations + if self.max_tool_iterations < 1 or self.max_tool_iterations > 20: + raise ConfigurationError(f"max_tool_iterations must be between 1 and 20, got {self.max_tool_iterations}") + + # Validate request_timeout + if self.request_timeout < 1 or self.request_timeout > 300: + raise ConfigurationError(f"request_timeout must be between 1 and 300 seconds, got {self.request_timeout}") + + # Validate URLs + if not self.deepseek_base_url.startswith(("http://", "https://")): + raise ConfigurationError(f"Invalid deepseek_base_url: {self.deepseek_base_url}") + + if not self.tmdb_base_url.startswith(("http://", "https://")): + raise ConfigurationError(f"Invalid tmdb_base_url: {self.tmdb_base_url}") + + # Validate memory file path + memory_path = Path(self.memory_file) + if memory_path.exists() and not memory_path.is_file(): + raise ConfigurationError(f"memory_file exists but is not a file: {self.memory_file}") + + def is_deepseek_configured(self) -> bool: + """Check if DeepSeek API is properly configured.""" + return bool(self.deepseek_api_key and self.deepseek_base_url) + + def is_tmdb_configured(self) -> bool: + """Check if TMDB API is properly configured.""" + return bool(self.tmdb_api_key and self.tmdb_base_url) + + +# Global settings instance settings = Settings() diff --git a/agent/llm.py b/agent/llm.py deleted file mode 100644 index 7767740..0000000 --- a/agent/llm.py +++ /dev/null @@ -1,41 +0,0 @@ -# agent/llm.py -from typing import List, Dict, Any -import requests - -from .config import settings - -class DeepSeekClient: - def __init__( - self, - api_key: str | None = None, - base_url: str | None = None, - model: str | None = None, - ): - self.api_key = api_key or settings.deepseek_api_key - self.base_url = base_url or settings.deepseek_base_url - self.model = model or settings.model - - def complete(self, messages: List[Dict[str, Any]]) -> str: - """ - messages: liste de dicts {role: 'system'|'user'|'assistant', content: str} - Retourne content (str) du premier choix. - """ - if not self.api_key: - return "Erreur côté agent : DEEPSEEK_API_KEY manquant dans l'environnement backend." - - url = f"{self.base_url}/v1/chat/completions" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - payload = { - "model": self.model, - "messages": messages, - "temperature": settings.temperature, - } - - resp = requests.post(url, headers=headers, json=payload, timeout=60) - resp.raise_for_status() - data = resp.json() - print("DeepSeek response:", data) - return data["choices"][0]["message"]["content"] diff --git a/agent/llm/__init__.py b/agent/llm/__init__.py new file mode 100644 index 0000000..ada85e1 --- /dev/null +++ b/agent/llm/__init__.py @@ -0,0 +1,2 @@ +"""LLM client module.""" +from .deepseek import DeepSeekClient diff --git a/agent/llm/deepseek.py b/agent/llm/deepseek.py new file mode 100644 index 0000000..bc0c375 --- /dev/null +++ b/agent/llm/deepseek.py @@ -0,0 +1,150 @@ +"""DeepSeek LLM client with robust error handling.""" +from typing import List, Dict, Any, Optional +import logging +import requests +from requests.exceptions import RequestException, Timeout, HTTPError + +from ..config import settings + +logger = logging.getLogger(__name__) + + +class LLMError(Exception): + """Base exception for LLM-related errors.""" + pass + + +class LLMConfigurationError(LLMError): + """Raised when LLM is not properly configured.""" + pass + + +class LLMAPIError(LLMError): + """Raised when LLM API returns an error.""" + pass + + +class DeepSeekClient: + """Client for interacting with DeepSeek API.""" + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: Optional[str] = None, + timeout: Optional[int] = None, + ): + """ + Initialize DeepSeek client. + + Args: + api_key: API key for authentication (defaults to settings) + base_url: Base URL for API (defaults to settings) + model: Model name to use (defaults to settings) + timeout: Request timeout in seconds (defaults to settings) + + Raises: + LLMConfigurationError: If API key is missing + """ + self.api_key = api_key or settings.deepseek_api_key + self.base_url = base_url or settings.deepseek_base_url + self.model = model or settings.model + self.timeout = timeout or settings.request_timeout + + if not self.api_key: + raise LLMConfigurationError( + "DeepSeek API key is required. Set DEEPSEEK_API_KEY environment variable." + ) + + if not self.base_url: + raise LLMConfigurationError( + "DeepSeek base URL is required. Set DEEPSEEK_BASE_URL environment variable." + ) + + logger.info(f"DeepSeek client initialized with model: {self.model}") + + def complete(self, messages: List[Dict[str, Any]]) -> str: + """ + Generate a completion from the LLM. + + Args: + messages: List of message dicts with 'role' and 'content' keys + + Returns: + Generated text response + + Raises: + LLMAPIError: If API request fails + ValueError: If messages format is invalid + """ + # Validate messages format + if not messages: + raise ValueError("Messages list cannot be empty") + + for msg in messages: + if not isinstance(msg, dict): + raise ValueError(f"Each message must be a dict, got {type(msg)}") + if "role" not in msg or "content" not in msg: + raise ValueError(f"Each message must have 'role' and 'content' keys, got {msg.keys()}") + if msg["role"] not in ("system", "user", "assistant"): + raise ValueError(f"Invalid role: {msg['role']}") + + url = f"{self.base_url}/v1/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + payload = { + "model": self.model, + "messages": messages, + "temperature": settings.temperature, + } + + try: + logger.debug(f"Sending request to {url} with {len(messages)} messages") + response = requests.post( + url, + headers=headers, + json=payload, + timeout=self.timeout + ) + response.raise_for_status() + data = response.json() + + # Validate response structure + if "choices" not in data or not data["choices"]: + raise LLMAPIError("Invalid API response: missing 'choices'") + + if "message" not in data["choices"][0]: + raise LLMAPIError("Invalid API response: missing 'message' in choice") + + if "content" not in data["choices"][0]["message"]: + raise LLMAPIError("Invalid API response: missing 'content' in message") + + content = data["choices"][0]["message"]["content"] + logger.debug(f"Received response with {len(content)} characters") + + return content + + except Timeout as e: + logger.error(f"Request timeout after {self.timeout}s: {e}") + raise LLMAPIError(f"Request timeout after {self.timeout} seconds") from e + + except HTTPError as e: + logger.error(f"HTTP error from DeepSeek API: {e}") + if e.response is not None: + try: + error_data = e.response.json() + error_msg = error_data.get("error", {}).get("message", str(e)) + except Exception: + error_msg = str(e) + raise LLMAPIError(f"DeepSeek API error: {error_msg}") from e + raise LLMAPIError(f"HTTP error: {e}") from e + + except RequestException as e: + logger.error(f"Request failed: {e}") + raise LLMAPIError(f"Failed to connect to DeepSeek API: {e}") from e + + except (KeyError, IndexError, TypeError) as e: + logger.error(f"Failed to parse API response: {e}") + raise LLMAPIError(f"Invalid API response format: {e}") from e diff --git a/agent/memory.py b/agent/memory.py index 9fdc810..4b0d1e8 100644 --- a/agent/memory.py +++ b/agent/memory.py @@ -4,21 +4,37 @@ from typing import Any, Dict import json from .config import settings +from .parameters import validate_parameter, get_parameter_schema class Memory: + """ + Generic memory storage for agent state. + + Provides a simple key-value store that persists to JSON. + """ + def __init__(self, path: str = "memory.json"): - print("init memory") self.file = Path(path) self.data: Dict[str, Any] = {} self.load() def load(self) -> None: + """Load memory from file or initialize with defaults.""" if self.file.exists(): - self.data = json.loads(self.file.read_text(encoding="utf-8")) + try: + self.data = json.loads(self.file.read_text(encoding="utf-8")) + except (json.JSONDecodeError, IOError) as e: + print(f"Warning: Could not load memory file: {e}") + self.data = { + "config": {}, + "tv_shows": [], + "history": [], + } else: self.data = { - "project_root": None, + "config": {}, + "tv_shows": [], "history": [], } @@ -28,11 +44,43 @@ class Memory: encoding="utf-8", ) - def get_project_root(self) -> str | None: - """Ce qu'on injecte dans le prompt pour le LLM.""" - return self.data.get("project_root") + def get(self, key: str, default: Any = None) -> Any: + """Get a value from memory by key.""" + return self.data.get(key, default) - def set_project_root(self, path: str) -> None: - print('Setting project root in memory to:', path) - self.data["project_root"] = path + def set(self, key: str, value: Any) -> None: + """ + Set a value in memory and save. + + Validates the value against the parameter schema if one exists. + """ + # Validate if schema exists + is_valid, error_msg = validate_parameter(key, value) + if not is_valid: + print(f'Validation failed for {key}: {error_msg}') + raise ValueError(f"Invalid value for {key}: {error_msg}") + + print(f'Setting {key} in memory to: {value}') + self.data[key] = value + self.save() + + def has(self, key: str) -> bool: + """Check if a key exists and has a non-None value.""" + return key in self.data and self.data[key] is not None + + def append_history(self, role: str, content: str) -> None: + """ + Append a message to conversation history. + + Args: + role: Message role ('user' or 'assistant') + content: Message content + """ + if "history" not in self.data: + self.data["history"] = [] + + self.data["history"].append({ + "role": role, + "content": content + }) self.save() diff --git a/agent/models/__init__.py b/agent/models/__init__.py new file mode 100644 index 0000000..8bc9199 --- /dev/null +++ b/agent/models/__init__.py @@ -0,0 +1,2 @@ +"""Models module.""" +from .tv_show import TVShow, ShowStatus, validate_tv_shows_structure diff --git a/agent/models/tv_show.py b/agent/models/tv_show.py new file mode 100644 index 0000000..1e198f0 --- /dev/null +++ b/agent/models/tv_show.py @@ -0,0 +1,58 @@ +"""TV Show models and validation.""" +from dataclasses import dataclass +from enum import Enum +from typing import Any + + +class ShowStatus(Enum): + """Status of a TV show - whether it's still airing or has ended.""" + ONGOING = "ongoing" + ENDED = "ended" + + +@dataclass +class TVShow: + """Represents a TV show.""" + imdb_id: str + title: str + seasons_count: int + status: ShowStatus # ongoing or ended + + +def validate_tv_shows_structure(tv_shows: Any) -> bool: + """ + Validate the structure of the tv_shows parameter. + + Expected structure: list of TV show objects + [ + { + "imdb_id": str, + "title": str, + "seasons_count": int, + "status": str # "ongoing" or "ended" + } + ] + """ + if not isinstance(tv_shows, list): + return False + + for show in tv_shows: + if not isinstance(show, dict): + return False + + # Check required fields + required_fields = {"imdb_id", "title", "seasons_count", "status"} + if not all(field in show for field in required_fields): + return False + + # Validate field types + if not isinstance(show["imdb_id"], str): + return False + if not isinstance(show["title"], str): + return False + if not isinstance(show["seasons_count"], int): + return False + if show["status"] not in ["ongoing", "ended"]: + return False + + return True diff --git a/agent/parameters.py b/agent/parameters.py new file mode 100644 index 0000000..08d26f9 --- /dev/null +++ b/agent/parameters.py @@ -0,0 +1,100 @@ +# agent/parameters.py +from dataclasses import dataclass +from typing import Any, Optional, Callable +import os + + +@dataclass +class ParameterSchema: + """Describes a required parameter for the agent.""" + key: str + description: str + why_needed: str # Explanation for the AI + type: str # "string", "number", "object", etc. + validator: Optional[Callable[[Any], bool]] = None + default: Any = None + required: bool = True + + +# Define all required parameters +REQUIRED_PARAMETERS = [ + ParameterSchema( + key="config", + description="Configuration object containing all folder paths", + why_needed=( + "This contains the paths to all important folders:\n" + "- download_folder: Where downloaded files arrive before being organized\n" + "- tvshow_folder: Where TV show files are organized and stored\n" + "- movie_folder: Where movie files are organized and stored\n" + "- torrent_folder: Where .torrent files are saved for the torrent client" + ), + type="object", + validator=lambda x: isinstance(x, dict), + required=True, + default={} + ), + ParameterSchema( + key="tv_shows", + description="List of TV shows the user is following", + why_needed=( + "This tracks which TV shows you're following. " + "Each show includes: IMDB ID, title, number of seasons, and status (ongoing or ended)." + ), + type="array", + validator=lambda x: isinstance(x, list), + required=False, + default=[] + ), +] + + +def get_parameter_schema(key: str) -> Optional[ParameterSchema]: + """Get schema for a specific parameter.""" + for param in REQUIRED_PARAMETERS: + if param.key == key: + return param + return None + + +def get_missing_required_parameters(memory_data: dict) -> list[ParameterSchema]: + """Get list of required parameters that are missing or None.""" + missing = [] + for param in REQUIRED_PARAMETERS: + if param.required: + value = memory_data.get(param.key) + if value is None: + missing.append(param) + return missing + + +def format_parameters_for_prompt() -> str: + """Format parameter descriptions for the AI system prompt.""" + lines = ["REQUIRED PARAMETERS:"] + for param in REQUIRED_PARAMETERS: + status = "REQUIRED" if param.required else "OPTIONAL" + lines.append(f"\n- {param.key} ({status}):") + lines.append(f" Description: {param.description}") + lines.append(f" Why needed: {param.why_needed}") + lines.append(f" Type: {param.type}") + return "\n".join(lines) + + +def validate_parameter(key: str, value: Any) -> tuple[bool, Optional[str]]: + """ + Validate a parameter value against its schema. + + Returns: + (is_valid, error_message) + """ + schema = get_parameter_schema(key) + if not schema: + return True, None # Unknown parameters are allowed + + if schema.validator: + try: + if not schema.validator(value): + return False, f"Validation failed for {key}" + except Exception as e: + return False, f"Validation error for {key}: {str(e)}" + + return True, None diff --git a/agent/prompts.py b/agent/prompts.py new file mode 100644 index 0000000..ab223f9 --- /dev/null +++ b/agent/prompts.py @@ -0,0 +1,88 @@ +# agent/prompts.py +from typing import Dict, Any +import json + +from .registry import Tool +from .parameters import format_parameters_for_prompt, get_missing_required_parameters + + +class PromptBuilder: + """Handles construction of system prompts for the agent.""" + + def __init__(self, tools: Dict[str, Tool]): + self.tools = tools + + def _format_tools_description(self) -> str: + """Format tools with their descriptions and parameters.""" + return "\n".join( + f"- {tool.name}: {tool.description}\n" + f" Parameters: {json.dumps(tool.parameters, ensure_ascii=False)}" + for tool in self.tools.values() + ) + + def _build_context(self, memory_data: dict) -> Dict[str, Any]: + """Build the context object with current state from memory.""" + return memory_data + + def build_system_prompt(self, memory_data: dict) -> str: + """ + Build the system prompt with context provided as JSON. + + Args: + memory_data: The full memory data dictionary + + Returns: + The complete system prompt string + """ + context = self._build_context(memory_data) + tools_desc = self._format_tools_description() + params_desc = format_parameters_for_prompt() + + # Check for missing required parameters + missing_params = get_missing_required_parameters(memory_data) + missing_info = "" + if missing_params: + missing_info = "\n\n⚠️ MISSING REQUIRED PARAMETERS:\n" + for param in missing_params: + missing_info += f"- {param.key}: {param.description}\n" + missing_info += f" Why needed: {param.why_needed}\n" + + return ( + "You are an AI agent helping a user manage their local media library.\n\n" + f"{params_desc}\n\n" + "CURRENT CONTEXT (JSON):\n" + f"{json.dumps(context, indent=2, ensure_ascii=False)}\n" + f"{missing_info}\n" + "IMPORTANT RULES:\n" + "1. Check the REQUIRED PARAMETERS section above to understand what information you need.\n" + "2. If any required parameter is missing (shown in MISSING REQUIRED PARAMETERS), " + "you MUST ask the user for it and explain WHY you need it based on the parameter description.\n" + "3. To use a tool, respond STRICTLY with this JSON format:\n" + ' { "thought": "explanation", "action": { "name": "tool_name", "args": { "arg": "value" } } }\n' + " - No text before or after the JSON\n" + " - All args must be complete and non-null\n" + "4. You can use MULTIPLE TOOLS IN SEQUENCE:\n" + " - After executing a tool, you will receive its result\n" + " - You can then decide to use another tool based on the result\n" + " - Or provide a final text response to the user\n" + " - Continue using tools until you have all the information needed\n" + "5. If you respond with text (not using a tool), respond normally in French.\n" + "6. When you have all the information needed, provide a final response in NATURAL TEXT (not JSON).\n" + "7. Extract the relevant information from the user's request and pass it as tool arguments.\n" + "\n" + "EXAMPLES:\n" + " To set the download folder:\n" + ' { "thought": "User provided download path", "action": { "name": "set_path", "args": { "path_type": "download_folder", "path_value": "/home/user/downloads" } } }\n' + "\n" + " To set the TV show folder:\n" + ' { "thought": "User provided TV show path", "action": { "name": "set_path", "args": { "path_type": "tvshow_folder", "path_value": "/home/user/media/tvshows" } } }\n' + "\n" + " To list the download folder:\n" + ' { "thought": "User wants to see downloads", "action": { "name": "list_folder", "args": { "folder_type": "download", "path": "." } } }\n' + "\n" + " To list a subfolder in TV shows:\n" + ' { "thought": "User wants to see a specific show", "action": { "name": "list_folder", "args": { "folder_type": "tvshow", "path": "Game.of.Thrones" } } }\n' + "\n" + "AVAILABLE TOOLS:\n" + f"{tools_desc}\n" + ) diff --git a/agent/registry.py b/agent/registry.py new file mode 100644 index 0000000..8667b59 --- /dev/null +++ b/agent/registry.py @@ -0,0 +1,93 @@ +"""Tool registry and definitions.""" +from dataclasses import dataclass +from typing import Callable, Any, Dict +from functools import partial + +from .memory import Memory +from .tools.filesystem import set_path_for_folder, list_folder +from .tools.api import find_media_imdb_id + + +@dataclass +class Tool: + """Represents a tool that can be used by the agent.""" + name: str + description: str + func: Callable[..., Dict[str, Any]] + parameters: Dict[str, Any] # JSON Schema des paramètres + + +def make_tools(memory: Memory) -> Dict[str, Tool]: + """ + Create all available tools with memory bound to them. + + Args: + memory: Memory instance to be used by the tools + + Returns: + Dictionary mapping tool names to Tool instances + """ + # Create partial functions with memory pre-bound for filesystem tools + set_path_func = partial(set_path_for_folder, memory) + list_folder_func = partial(list_folder, memory) + + tools = [ + Tool( + name="set_path_for_folder", + description="Sets a path in the configuration (download_folder, tvshow_folder, movie_folder, or torrent_folder).", + func=set_path_func, + parameters={ + "type": "object", + "properties": { + "folder_name": { + "type": "string", + "description": "Name of folder to set", + "enum": ["download", "tvshow", "movie", "torrent"] + }, + "path_value": { + "type": "string", + "description": "Absolute path to the folder (e.g., /home/user/downloads)" + } + }, + "required": ["folder_name", "path_value"] + } + ), + Tool( + name="list_folder", + description="Lists the contents of a specified folder (download, tvshow, movie, or torrent).", + func=list_folder_func, + parameters={ + "type": "object", + "properties": { + "folder_type": { + "type": "string", + "description": "Type of folder to list: 'download', 'tvshow', 'movie', or 'torrent'", + "enum": ["download", "tvshow", "movie", "torrent"] + }, + "path": { + "type": "string", + "description": "Relative path within the folder (default: '.' for root)", + "default": "." + } + }, + "required": ["folder_type"] + } + ), + Tool( + name="find_media_imdb_id", + description="Finds the IMDb ID for a given media title using TMDB API.", + func=find_media_imdb_id, + parameters={ + "type": "object", + "properties": { + "media_title": { + "type": "string", + "description": "Title of the media to find the IMDb ID for" + }, + }, + "required": ["media_title"] + } + ), + ] + + return {t.name: t for t in tools} diff --git a/agent/tools.py b/agent/tools.py deleted file mode 100644 index b3ddd11..0000000 --- a/agent/tools.py +++ /dev/null @@ -1,76 +0,0 @@ -# agent/tools.py -from dataclasses import dataclass -from typing import Callable, Any, Dict -import os - -from .memory import Memory - -@dataclass -class Tool: - name: str - description: str - func: Callable[..., Dict[str, Any]] - parameters: Dict[str, Any] # JSON Schema des paramètres - -def make_tools(memory: Memory) -> dict[str, Tool]: - def set_project_root(project_root: str) -> Dict[str, Any]: - if not os.path.isdir(project_root): - return {"error": "invalid_path", "message": f"Le chemin {project_root} n'est pas un dossier valide."} - - print(f"Setting project root to: {project_root}") - print("Memory before:", memory.data) - memory.set_project_root(project_root) - print("Memory after:", memory.data) - return {"status": "ok", "project_root": project_root} - - def list_directory(path: str) -> Dict[str, Any]: - print("Proper tool used") - if not memory.data.get("project_root"): - return {"error": "no_project_root", "message": "Project root not set."} - - root = memory.data.get("project_root") - full_path = os.path.abspath(os.path.join(root, path)) - root_abs = os.path.abspath(root) - if not full_path.startswith(root_abs): - return {"error": "forbidden", "message": "Path outside project_root."} - - try: - entries = os.listdir(full_path) - return {"path": path, "entries": entries} - except Exception as e: - return {"error": "os_error", "message": str(e)} - - tools = [ - Tool( - name="set_project_root", - description="Enregistre le path du dossier racine du projet.", - func=set_project_root, - parameters={ - "type": "object", - "properties": { - "project_root": { - "type": "string", - "description": "Chemin absolu du dossier racine du projet (ex: /home/user/mon_projet)" - } - }, - "required": ["project_root"] - } - ), - Tool( - name="list_directory", - description="Liste le contenu d'un dossier relatif au project_root.", - func=list_directory, - parameters={ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Chemin relatif du dossier à lister (ex: 'src' ou '.' pour la racine)" - } - }, - "required": ["path"] - } - ), - ] - - return {t.name: t for t in tools} diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py new file mode 100644 index 0000000..7c757c4 --- /dev/null +++ b/agent/tools/__init__.py @@ -0,0 +1,3 @@ +"""Tools module - filesystem and API tools.""" +from .filesystem import FolderName, set_path_for_folder, list_folder +from .api import find_media_imdb_id diff --git a/agent/tools/api.py b/agent/tools/api.py new file mode 100644 index 0000000..e08473e --- /dev/null +++ b/agent/tools/api.py @@ -0,0 +1,90 @@ +"""API tools for interacting with external services.""" +from typing import Dict, Any +import logging + +from ..api import tmdb_client, TMDBError, TMDBNotFoundError, TMDBAPIError, TMDBConfigurationError + +logger = logging.getLogger(__name__) + + +def find_media_imdb_id(media_title: str) -> Dict[str, Any]: + """ + Find the IMDb ID for a given media title using TMDB API. + + This is a wrapper around the TMDB client that returns a standardized + dict format for compatibility with the agent's tool system. + + Args: + media_title: Title of the media to search for + + Returns: + Dict with IMDb ID or error information: + - Success: {"status": "ok", "imdb_id": str, "title": str, ...} + - Error: {"error": str, "message": str} + + Example: + >>> result = find_media_imdb_id("Inception") + >>> print(result) + {'status': 'ok', 'imdb_id': 'tt1375666', 'title': 'Inception', ...} + """ + try: + # Use the TMDB client to search for media + result = tmdb_client.search_media(media_title) + + # Check if IMDb ID was found + if result.imdb_id: + logger.info(f"IMDb ID found for '{media_title}': {result.imdb_id}") + return { + "status": "ok", + "imdb_id": result.imdb_id, + "title": result.title, + "media_type": result.media_type, + "tmdb_id": result.tmdb_id, + "overview": result.overview, + "release_date": result.release_date, + "vote_average": result.vote_average + } + else: + logger.warning(f"No IMDb ID available for '{media_title}'") + return { + "error": "no_imdb_id", + "message": f"No IMDb ID available for '{result.title}'", + "title": result.title, + "media_type": result.media_type, + "tmdb_id": result.tmdb_id + } + + except TMDBNotFoundError as e: + logger.info(f"Media not found: {e}") + return { + "error": "not_found", + "message": str(e) + } + + except TMDBConfigurationError as e: + logger.error(f"TMDB configuration error: {e}") + return { + "error": "configuration_error", + "message": str(e) + } + + except TMDBAPIError as e: + logger.error(f"TMDB API error: {e}") + return { + "error": "api_error", + "message": str(e) + } + + except ValueError as e: + logger.error(f"Validation error: {e}") + return { + "error": "validation_failed", + "message": str(e) + } + + except Exception as e: + logger.error(f"Unexpected error: {e}", exc_info=True) + return { + "error": "internal_error", + "message": "An unexpected error occurred" + } diff --git a/agent/tools/filesystem.py b/agent/tools/filesystem.py new file mode 100644 index 0000000..5cb4f69 --- /dev/null +++ b/agent/tools/filesystem.py @@ -0,0 +1,257 @@ +"""Filesystem tools for managing folders and files with security.""" +from typing import Dict, Any +from enum import Enum +from pathlib import Path +import logging +import os + +from ..memory import Memory + +logger = logging.getLogger(__name__) + + +class FolderName(Enum): + """Types of folders that can be managed.""" + DOWNLOAD = "download" + TVSHOW = "tvshow" + MOVIE = "movie" + TORRENT = "torrent" + + +class FilesystemError(Exception): + """Base exception for filesystem operations.""" + pass + + +class PathTraversalError(FilesystemError): + """Raised when path traversal attack is detected.""" + pass + + +def _validate_folder_name(folder_name: str) -> bool: + """ + Validate folder name against allowed values. + + Args: + folder_name: Name to validate + + Returns: + True if valid + + Raises: + ValueError: If folder name is invalid + """ + valid_names = [fn.value for fn in FolderName] + if folder_name not in valid_names: + raise ValueError( + f"Invalid folder_name '{folder_name}'. Must be one of: {', '.join(valid_names)}" + ) + return True + + +def _sanitize_path(path: str) -> str: + """ + Sanitize path to prevent path traversal attacks. + + Args: + path: Path to sanitize + + Returns: + Sanitized path + + Raises: + PathTraversalError: If path contains dangerous patterns + """ + # Normalize path + normalized = os.path.normpath(path) + + # Check for absolute paths + if os.path.isabs(normalized): + raise PathTraversalError("Absolute paths are not allowed") + + # Check for parent directory references + if normalized.startswith("..") or "/.." in normalized or "\\.." in normalized: + raise PathTraversalError("Parent directory references are not allowed") + + # Check for null bytes + if "\x00" in normalized: + raise PathTraversalError("Null bytes in path are not allowed") + + return normalized + + +def _is_safe_path(base_path: Path, target_path: Path) -> bool: + """ + Check if target path is within base path (prevents path traversal). + + Args: + base_path: Base directory path + target_path: Target path to check + + Returns: + True if safe, False otherwise + """ + try: + # Resolve both paths to absolute paths + base_resolved = base_path.resolve() + target_resolved = target_path.resolve() + + # Check if target is relative to base + target_resolved.relative_to(base_resolved) + return True + except (ValueError, OSError): + return False + + +def set_path_for_folder(memory: Memory, folder_name: str, path_value: str) -> Dict[str, Any]: + """ + Set a path in the config with validation. + + Args: + memory: Memory instance to store the configuration + folder_name: Name of folder to set (download, tvshow, movie, torrent) + path_value: Absolute path to the folder + + Returns: + Dict with status or error information + """ + try: + # Validate folder name + _validate_folder_name(folder_name) + + # Convert to Path object for better handling + path_obj = Path(path_value).resolve() + + # Validate path exists and is a directory + if not path_obj.exists(): + logger.warning(f"Path does not exist: {path_value}") + return { + "error": "invalid_path", + "message": f"Path does not exist: {path_value}" + } + + if not path_obj.is_dir(): + logger.warning(f"Path is not a directory: {path_value}") + return { + "error": "invalid_path", + "message": f"Path is not a directory: {path_value}" + } + + # Check if path is readable + if not os.access(path_obj, os.R_OK): + logger.warning(f"Path is not readable: {path_value}") + return { + "error": "permission_denied", + "message": f"Path is not readable: {path_value}" + } + + # Store in memory + config = memory.get("config", {}) + config[f"{folder_name}_folder"] = str(path_obj) + memory.set("config", config) + + logger.info(f"Set {folder_name}_folder to: {path_obj}") + return { + "status": "ok", + "folder_name": folder_name, + "path": str(path_obj) + } + + except ValueError as e: + logger.error(f"Validation error: {e}") + return {"error": "validation_failed", "message": str(e)} + + except Exception as e: + logger.error(f"Unexpected error setting path: {e}", exc_info=True) + return {"error": "internal_error", "message": "Failed to set path"} + + +def list_folder(memory: Memory, folder_type: str, path: str = ".") -> Dict[str, Any]: + """ + List contents of a folder with security checks. + + Args: + memory: Memory instance to retrieve the configuration + folder_type: Type of folder to list (download, tvshow, movie, torrent) + path: Relative path within the folder (default: ".") + + Returns: + Dict with folder contents or error information + """ + try: + # Validate folder type + _validate_folder_name(folder_type) + + # Sanitize the path + safe_path = _sanitize_path(path) + + # Get root folder from config + folder_key = f"{folder_type}_folder" + config = memory.get("config", {}) + + if folder_key not in config or not config[folder_key]: + logger.warning(f"Folder not configured: {folder_type}") + return { + "error": "folder_not_set", + "message": f"{folder_type.capitalize()} folder not set in config." + } + + root = Path(config[folder_key]) + target = root / safe_path + + # Security check: ensure target is within root + if not _is_safe_path(root, target): + logger.warning(f"Path traversal attempt detected: {path}") + return { + "error": "forbidden", + "message": "Access denied: path outside allowed directory" + } + + # Check if target exists + if not target.exists(): + logger.warning(f"Path does not exist: {target}") + return { + "error": "not_found", + "message": f"Path does not exist: {safe_path}" + } + + # Check if target is a directory + if not target.is_dir(): + logger.warning(f"Path is not a directory: {target}") + return { + "error": "not_a_directory", + "message": f"Path is not a directory: {safe_path}" + } + + # List directory contents + try: + entries = [entry.name for entry in target.iterdir()] + logger.debug(f"Listed {len(entries)} entries in {target}") + return { + "status": "ok", + "folder_type": folder_type, + "path": safe_path, + "entries": sorted(entries), + "count": len(entries) + } + except PermissionError: + logger.warning(f"Permission denied accessing: {target}") + return { + "error": "permission_denied", + "message": f"Permission denied accessing: {safe_path}" + } + + except PathTraversalError as e: + logger.warning(f"Path traversal attempt: {e}") + return { + "error": "forbidden", + "message": str(e) + } + + except ValueError as e: + logger.error(f"Validation error: {e}") + return {"error": "validation_failed", "message": str(e)} + + except Exception as e: + logger.error(f"Unexpected error listing folder: {e}", exc_info=True) + return {"error": "internal_error", "message": "Failed to list folder"} diff --git a/app.py b/app.py index 3193c9a..eed94e3 100644 --- a/app.py +++ b/app.py @@ -7,10 +7,9 @@ from typing import Any, Dict from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse -from agent.llm import DeepSeekClient +from agent.llm.deepseek import DeepSeekClient from agent.memory import Memory from agent.agent import Agent -from agent.commands import CommandRegistry app = FastAPI( title="LibreChat Agent Backend", @@ -20,7 +19,6 @@ app = FastAPI( llm = DeepSeekClient() memory = Memory() agent = Agent(llm=llm, memory=memory) -commands_registry = CommandRegistry(memory=memory) def extract_last_user_content(messages: list[Dict[str, Any]]) -> str: @@ -42,12 +40,8 @@ async def chat_completions(request: Request): user_input = extract_last_user_content(messages) print("Received chat completion request, stream =", stream, "input:", user_input) - # 🔹 1) Si c'est une commande, on ne fait PAS intervenir le LLM - if user_input.strip().startswith("/"): - answer = commands_registry.handle(user_input) - else: - # 🔹 2) Sinon, logique agent + LLM comme avant - answer = agent.step(user_input) + # Process user input through the agent + answer = agent.step(user_input) # Ensuite = même logique de réponse (non-stream ou stream) created_ts = int(time.time())