From 4eae1d6d582f63ba22f943f0313643dda068361a Mon Sep 17 00:00:00 2001 From: Francwa Date: Sun, 7 Dec 2025 03:33:51 +0100 Subject: [PATCH] Formatting --- agent/agent.py | 124 +++++----- agent/llm/deepseek.py | 14 +- agent/llm/ollama.py | 14 +- agent/prompts.py | 40 ++-- agent/registry.py | 49 ++-- agent/tools/language.py | 20 +- infrastructure/persistence/memory.py | 8 +- pyproject.toml | 2 +- tests/conftest.py | 87 ++++--- tests/test_agent.py | 98 ++++---- tests/test_agent_edge_cases.py | 171 +++++++------- tests/test_api.py | 94 +++++--- tests/test_api_edge_cases.py | 325 +++++++++++++++------------ tests/test_config_critical.py | 62 +++-- tests/test_config_edge_cases.py | 32 ++- tests/test_memory.py | 31 ++- tests/test_memory_edge_cases.py | 18 +- tests/test_prompts_critical.py | 128 ++++++----- tests/test_prompts_edge_cases.py | 44 ++-- tests/test_registry_critical.py | 140 ++++++------ tests/test_registry_edge_cases.py | 11 +- tests/test_repositories.py | 89 ++++---- tests/test_tools_api.py | 197 +++++++++------- tests/test_tools_edge_cases.py | 38 +++- 24 files changed, 1003 insertions(+), 833 deletions(-) diff --git a/agent/agent.py b/agent/agent.py index a03303e..995579e 100644 --- a/agent/agent.py +++ b/agent/agent.py @@ -1,7 +1,8 @@ """Main agent for media library management.""" + import json import logging -from typing import Any, Dict, List, Optional +from typing import Any from infrastructure.persistence import get_memory @@ -15,157 +16,156 @@ logger = logging.getLogger(__name__) class Agent: """ AI agent for media library management. - + Uses OpenAI-compatible tool calling API. """ def __init__(self, llm, max_tool_iterations: int = 5): """ Initialize the agent. - + Args: llm: LLM client with complete() method max_tool_iterations: Maximum number of tool execution iterations """ self.llm = llm - self.tools: Dict[str, Tool] = make_tools() + self.tools: dict[str, Tool] = make_tools() self.prompt_builder = PromptBuilder(self.tools) self.max_tool_iterations = max_tool_iterations def step(self, user_input: str) -> str: """ Execute one agent step with the user input. - + This method: 1. Adds user message to memory 2. Builds prompt with history and context 3. Calls LLM, executing tools as needed 4. Returns final response - + Args: user_input: User's message - + Returns: Agent's final response """ memory = get_memory() - + # Add user message to history memory.stm.add_message("user", user_input) memory.save() - + # Build initial messages system_prompt = self.prompt_builder.build_system_prompt() - messages: List[Dict[str, Any]] = [ - {"role": "system", "content": system_prompt} - ] - + messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] + # Add conversation history history = memory.stm.get_recent_history(settings.max_history_messages) messages.extend(history) - + # Add unread events if any unread_events = memory.episodic.get_unread_events() if unread_events: - events_text = "\n".join([ - f"- {e['type']}: {e['data']}" - for e in unread_events - ]) - messages.append({ - "role": "system", - "content": f"Background events:\n{events_text}" - }) - + events_text = "\n".join( + [f"- {e['type']}: {e['data']}" for e in unread_events] + ) + messages.append( + {"role": "system", "content": f"Background events:\n{events_text}"} + ) + # Get tools specification for OpenAI format tools_spec = self.prompt_builder.build_tools_spec() - + # Tool execution loop for iteration in range(self.max_tool_iterations): # Call LLM with tools llm_result = self.llm.complete(messages, tools=tools_spec) - + # Handle both tuple (response, usage) and dict response if isinstance(llm_result, tuple): response_message, usage = llm_result else: response_message = llm_result - + # Check if there are tool calls tool_calls = response_message.get("tool_calls") - + if not tool_calls: # No tool calls, this is the final response final_content = response_message.get("content", "") memory.stm.add_message("assistant", final_content) memory.save() return final_content - + # Add assistant message with tool calls to conversation messages.append(response_message) - + # Execute each tool call for tool_call in tool_calls: tool_result = self._execute_tool_call(tool_call) - + # Add tool result to messages - messages.append({ - "tool_call_id": tool_call.get("id"), - "role": "tool", - "name": tool_call.get("function", {}).get("name"), - "content": json.dumps(tool_result, ensure_ascii=False), - }) - + messages.append( + { + "tool_call_id": tool_call.get("id"), + "role": "tool", + "name": tool_call.get("function", {}).get("name"), + "content": json.dumps(tool_result, ensure_ascii=False), + } + ) + # Max iterations reached, force final response - messages.append({ - "role": "system", - "content": "Please provide a final response to the user without using any more tools." - }) - + messages.append( + { + "role": "system", + "content": "Please provide a final response to the user without using any more tools.", + } + ) + llm_result = self.llm.complete(messages) if isinstance(llm_result, tuple): final_message, usage = llm_result else: final_message = llm_result - - final_response = final_message.get("content", "I've completed the requested actions.") + + final_response = final_message.get( + "content", "I've completed the requested actions." + ) memory.stm.add_message("assistant", final_response) memory.save() return final_response - def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: + def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]: """ Execute a single tool call. - + Args: tool_call: OpenAI-format tool call dict - + Returns: Result dictionary """ function = tool_call.get("function", {}) tool_name = function.get("name", "") - + try: args_str = function.get("arguments", "{}") args = json.loads(args_str) except json.JSONDecodeError as e: logger.error(f"Failed to parse tool arguments: {e}") - return { - "error": "bad_args", - "message": f"Invalid JSON arguments: {e}" - } - + return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"} + # Validate tool exists if tool_name not in self.tools: available = list(self.tools.keys()) return { "error": "unknown_tool", "message": f"Tool '{tool_name}' not found", - "available_tools": available + "available_tools": available, } - + tool = self.tools[tool_name] - + # Execute tool try: result = tool.func(**args) @@ -177,17 +177,9 @@ class Agent: # Bad arguments memory = get_memory() memory.episodic.add_error(tool_name, f"bad_args: {e}") - return { - "error": "bad_args", - "message": str(e), - "tool": tool_name - } + return {"error": "bad_args", "message": str(e), "tool": tool_name} except Exception as e: # Other errors memory = get_memory() memory.episodic.add_error(tool_name, str(e)) - return { - "error": "execution_failed", - "message": str(e), - "tool": tool_name - } + return {"error": "execution_failed", "message": str(e), "tool": tool_name} diff --git a/agent/llm/deepseek.py b/agent/llm/deepseek.py index e6332b6..36b86f8 100644 --- a/agent/llm/deepseek.py +++ b/agent/llm/deepseek.py @@ -51,7 +51,9 @@ class DeepSeekClient: logger.info(f"DeepSeek client initialized with model: {self.model}") - def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]: + def complete( + self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None + ) -> dict[str, Any]: """ Generate a completion from the LLM. @@ -80,7 +82,9 @@ class DeepSeekClient: raise ValueError(f"Invalid role: {msg['role']}") # Content is optional for tool messages (they may have tool_call_id instead) if msg["role"] != "tool" and "content" not in msg: - raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}") + raise ValueError( + f"Non-tool message must have 'content' key, got {msg.keys()}" + ) url = f"{self.base_url}/v1/chat/completions" headers = { @@ -92,13 +96,15 @@ class DeepSeekClient: "messages": messages, "temperature": settings.temperature, } - + # Add tools if provided if tools: payload["tools"] = tools try: - logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools") + logger.debug( + f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools" + ) response = requests.post( url, headers=headers, json=payload, timeout=self.timeout ) diff --git a/agent/llm/ollama.py b/agent/llm/ollama.py index 5077bb8..afd6a93 100644 --- a/agent/llm/ollama.py +++ b/agent/llm/ollama.py @@ -66,7 +66,9 @@ class OllamaClient: logger.info(f"Ollama client initialized with model: {self.model}") - def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]: + def complete( + self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None + ) -> dict[str, Any]: """ Generate a completion from the LLM. @@ -95,7 +97,9 @@ class OllamaClient: raise ValueError(f"Invalid role: {msg['role']}") # Content is optional for tool messages (they may have tool_call_id instead) if msg["role"] != "tool" and "content" not in msg: - raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}") + raise ValueError( + f"Non-tool message must have 'content' key, got {msg.keys()}" + ) url = f"{self.base_url}/api/chat" payload = { @@ -106,13 +110,15 @@ class OllamaClient: "temperature": self.temperature, }, } - + # Add tools if provided if tools: payload["tools"] = tools try: - logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools") + logger.debug( + f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools" + ) response = requests.post(url, json=payload, timeout=self.timeout) response.raise_for_status() data = response.json() diff --git a/agent/prompts.py b/agent/prompts.py index 28efa38..3ae658e 100644 --- a/agent/prompts.py +++ b/agent/prompts.py @@ -1,18 +1,20 @@ """Prompt builder for the agent system.""" -from typing import Dict, List, Any + import json +from typing import Any + +from infrastructure.persistence import get_memory from .registry import Tool -from infrastructure.persistence import get_memory class PromptBuilder: """Builds system prompts for the agent with memory context.""" - def __init__(self, tools: Dict[str, Tool]): + def __init__(self, tools: dict[str, Tool]): self.tools = tools - def build_tools_spec(self) -> List[Dict[str, Any]]: + def build_tools_spec(self) -> list[dict[str, Any]]: """Build the tool specification for the LLM API.""" tool_specs = [] for tool in self.tools.values(): @@ -44,11 +46,13 @@ class PromptBuilder: if memory.episodic.last_search_results: results = memory.episodic.last_search_results - result_list = results.get('results', []) - lines.append(f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)") + result_list = results.get("results", []) + lines.append( + f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)" + ) # Show first 5 results for i, result in enumerate(result_list[:5]): - name = result.get('name', 'Unknown') + name = result.get("name", "Unknown") lines.append(f" {i+1}. {name}") if len(result_list) > 5: lines.append(f" ... and {len(result_list) - 5} more") @@ -57,7 +61,7 @@ class PromptBuilder: question = memory.episodic.pending_question lines.append(f"\nPENDING QUESTION: {question.get('question')}") lines.append(f" Type: {question.get('type')}") - if question.get('options'): + if question.get("options"): lines.append(f" Options: {len(question.get('options'))}") if memory.episodic.active_downloads: @@ -68,10 +72,12 @@ class PromptBuilder: if memory.episodic.recent_errors: lines.append("\nRECENT ERRORS (up to 3):") for error in memory.episodic.recent_errors[-3:]: - lines.append(f" - Action '{error.get('action')}' failed: {error.get('error')}") + lines.append( + f" - Action '{error.get('action')}' failed: {error.get('error')}" + ) # Unread events - unread = [e for e in memory.episodic.background_events if not e.get('read')] + unread = [e for e in memory.episodic.background_events if not e.get("read")] if unread: lines.append(f"\nUNREAD EVENTS: {len(unread)}") for event in unread[:3]: @@ -86,8 +92,10 @@ class PromptBuilder: if memory.stm.current_workflow: workflow = memory.stm.current_workflow - lines.append(f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})") - if workflow.get('target'): + lines.append( + f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})" + ) + if workflow.get("target"): lines.append(f" Target: {workflow.get('target')}") if memory.stm.current_topic: @@ -97,7 +105,7 @@ class PromptBuilder: lines.append("EXTRACTED ENTITIES:") for key, value in memory.stm.extracted_entities.items(): lines.append(f" - {key}: {value}") - + if memory.stm.language: lines.append(f"CONVERSATION LANGUAGE: {memory.stm.language}") @@ -106,7 +114,7 @@ class PromptBuilder: def _format_config_context(self) -> str: """Format configuration context.""" memory = get_memory() - + lines = ["CURRENT CONFIGURATION:"] if memory.ltm.config: for key, value in memory.ltm.config.items(): @@ -118,10 +126,10 @@ class PromptBuilder: def build_system_prompt(self) -> str: """Build the complete system prompt.""" memory = get_memory() - + # Base instruction base = "You are a helpful AI assistant for managing a media library." - + # Language instruction language_instruction = ( "Your first task is to determine the user's language from their message " diff --git a/agent/registry.py b/agent/registry.py index 21d65fb..4f24471 100644 --- a/agent/registry.py +++ b/agent/registry.py @@ -1,8 +1,10 @@ """Tool registry - defines and registers all available tools for the agent.""" -from dataclasses import dataclass -from typing import Callable, Any, Dict -import logging + import inspect +import logging +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any logger = logging.getLogger(__name__) @@ -10,36 +12,37 @@ logger = logging.getLogger(__name__) @dataclass class Tool: """Represents a tool that can be used by the agent.""" + name: str description: str - func: Callable[..., Dict[str, Any]] - parameters: Dict[str, Any] + func: Callable[..., dict[str, Any]] + parameters: dict[str, Any] def _create_tool_from_function(func: Callable) -> Tool: """ Create a Tool object from a function. - + Args: func: Function to convert to a tool - + Returns: Tool object with metadata extracted from function """ sig = inspect.signature(func) doc = inspect.getdoc(func) - + # Extract description from docstring (first line) - description = doc.strip().split('\n')[0] if doc else func.__name__ - + description = doc.strip().split("\n")[0] if doc else func.__name__ + # Build JSON schema from function signature properties = {} required = [] - + for param_name, param in sig.parameters.items(): if param_name == "self": continue - + # Map Python types to JSON schema types param_type = "string" # default if param.annotation != inspect.Parameter.empty: @@ -51,22 +54,22 @@ def _create_tool_from_function(func: Callable) -> Tool: param_type = "number" elif param.annotation == bool: param_type = "boolean" - + properties[param_name] = { "type": param_type, - "description": f"Parameter {param_name}" + "description": f"Parameter {param_name}", } - + # Add to required if no default value if param.default == inspect.Parameter.empty: required.append(param_name) - + parameters = { "type": "object", "properties": properties, "required": required, } - + return Tool( name=func.__name__, description=description, @@ -75,18 +78,18 @@ def _create_tool_from_function(func: Callable) -> Tool: ) -def make_tools() -> Dict[str, Tool]: +def make_tools() -> dict[str, Tool]: """ Create and register all available tools. - + Returns: Dictionary mapping tool names to Tool objects """ # Import tools here to avoid circular dependencies - from .tools import filesystem as fs_tools from .tools import api as api_tools + from .tools import filesystem as fs_tools from .tools import language as lang_tools - + # List of all tool functions tool_functions = [ fs_tools.set_path_for_folder, @@ -98,12 +101,12 @@ def make_tools() -> Dict[str, Tool]: api_tools.get_torrent_by_index, lang_tools.set_language, ] - + # Create Tool objects from functions tools = {} for func in tool_functions: tool = _create_tool_from_function(func) tools[tool.name] = tool - + logger.info(f"Registered {len(tools)} tools: {list(tools.keys())}") return tools diff --git a/agent/tools/language.py b/agent/tools/language.py index a0c1cae..e7ea471 100644 --- a/agent/tools/language.py +++ b/agent/tools/language.py @@ -1,19 +1,20 @@ """Language management tools for the agent.""" + import logging -from typing import Dict, Any +from typing import Any from infrastructure.persistence import get_memory logger = logging.getLogger(__name__) -def set_language(language: str) -> Dict[str, Any]: +def set_language(language: str) -> dict[str, Any]: """ Set the conversation language. - + Args: language: Language code (e.g., 'en', 'fr', 'es', 'de') - + Returns: Status dictionary """ @@ -21,17 +22,14 @@ def set_language(language: str) -> Dict[str, Any]: memory = get_memory() memory.stm.set_language(language) memory.save() - + logger.info(f"Language set to: {language}") - + return { "status": "ok", "message": f"Language set to {language}", - "language": language + "language": language, } except Exception as e: logger.error(f"Failed to set language: {e}") - return { - "status": "error", - "error": str(e) - } + return {"status": "error", "error": str(e)} diff --git a/infrastructure/persistence/memory.py b/infrastructure/persistence/memory.py index f571ef1..f6fb1ae 100644 --- a/infrastructure/persistence/memory.py +++ b/infrastructure/persistence/memory.py @@ -359,9 +359,7 @@ class EpisodicMemory: """Get active downloads.""" return self.active_downloads - def add_error( - self, action: str, error: str, context: dict | None = None - ) -> None: + def add_error(self, action: str, error: str, context: dict | None = None) -> None: """Record a recent error.""" self.recent_errors.append( { @@ -408,9 +406,7 @@ class EpisodicMemory: """Get the pending question.""" return self.pending_question - def resolve_pending_question( - self, answer_index: int | None = None - ) -> dict | None: + def resolve_pending_question(self, answer_index: int | None = None) -> dict | None: """ Resolve the pending question and return the chosen option. diff --git a/pyproject.toml b/pyproject.toml index 312c977..32df1d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,4 +110,4 @@ select = [ "PL", "UP", ] -ignore = ["W503", "PLR0913", "PLR2004"] +ignore = ["PLR0913", "PLR2004"] diff --git a/tests/conftest.py b/tests/conftest.py index c38b960..28d85d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,13 @@ """Pytest configuration and shared fixtures.""" -import pytest -import tempfile -import shutil -from pathlib import Path -from unittest.mock import Mock, MagicMock -from infrastructure.persistence import Memory, init_memory, set_memory, get_memory -from infrastructure.persistence.memory import ( - LongTermMemory, - ShortTermMemory, - EpisodicMemory, -) +import shutil +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, Mock + +import pytest + +from infrastructure.persistence import Memory, set_memory @pytest.fixture @@ -122,12 +119,11 @@ def memory_with_library(memory): def mock_llm(): """Create a mock LLM client that returns OpenAI-compatible format.""" llm = Mock() + # Return OpenAI-style message dict without tool calls def complete_func(messages, tools=None): - return { - "role": "assistant", - "content": "I found what you're looking for!" - } + return {"role": "assistant", "content": "I found what you're looking for!"} + llm.complete = Mock(side_effect=complete_func) return llm @@ -136,34 +132,33 @@ def mock_llm(): def mock_llm_with_tool_call(): """Create a mock LLM that returns a tool call then a response.""" llm = Mock() - + # First call returns a tool call, second returns final response def complete_side_effect(messages, tools=None): - if not hasattr(complete_side_effect, 'call_count'): + if not hasattr(complete_side_effect, "call_count"): complete_side_effect.call_count = 0 complete_side_effect.call_count += 1 - + if complete_side_effect.call_count == 1: # First call: return tool call return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": "call_123", - "type": "function", - "function": { - "name": "find_torrent", - "arguments": '{"media_title": "Inception"}' + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "find_torrent", + "arguments": '{"media_title": "Inception"}', + }, } - }] + ], } else: # Second call: return final response - return { - "role": "assistant", - "content": "I found 3 torrents for Inception!" - } - + return {"role": "assistant", "content": "I found 3 torrents for Inception!"} + llm.complete = Mock(side_effect=complete_side_effect) return llm @@ -248,36 +243,36 @@ def mock_deepseek(): """ Mock DeepSeekClient for individual tests that need it. This prevents real API calls in tests that use this fixture. - + Usage: def test_something(mock_deepseek): # Your test code here """ import sys - from unittest.mock import Mock, MagicMock - + from unittest.mock import Mock + # Save the original module if it exists - original_module = sys.modules.get('agent.llm.deepseek') - + original_module = sys.modules.get("agent.llm.deepseek") + # Create a mock module for deepseek mock_deepseek_module = MagicMock() - + class MockDeepSeekClient: def __init__(self, *args, **kwargs): self.complete = Mock(return_value="Mocked LLM response") - + mock_deepseek_module.DeepSeekClient = MockDeepSeekClient - + # Inject the mock - sys.modules['agent.llm.deepseek'] = mock_deepseek_module - + sys.modules["agent.llm.deepseek"] = mock_deepseek_module + yield mock_deepseek_module - + # Restore the original module if original_module is not None: - sys.modules['agent.llm.deepseek'] = original_module - elif 'agent.llm.deepseek' in sys.modules: - del sys.modules['agent.llm.deepseek'] + sys.modules["agent.llm.deepseek"] = original_module + elif "agent.llm.deepseek" in sys.modules: + del sys.modules["agent.llm.deepseek"] @pytest.fixture @@ -287,8 +282,8 @@ def mock_agent_step(): Returns a context manager that patches app.agent.step. """ from unittest.mock import patch - + def _mock_step(return_value="Mocked agent response"): return patch("app.agent.step", return_value=return_value) - + return _mock_step diff --git a/tests/test_agent.py b/tests/test_agent.py index e5a0e63..294944b 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,6 +1,6 @@ """Tests for the Agent.""" -from unittest.mock import Mock, patch +from unittest.mock import Mock from agent.agent import Agent from infrastructure.persistence import get_memory @@ -55,8 +55,8 @@ class TestExecuteToolCall: "id": "call_123", "function": { "name": "list_folder", - "arguments": '{"folder_type": "download"}' - } + "arguments": '{"folder_type": "download"}', + }, } result = agent._execute_tool_call(tool_call) @@ -68,10 +68,7 @@ class TestExecuteToolCall: tool_call = { "id": "call_123", - "function": { - "name": "unknown_tool", - "arguments": '{}' - } + "function": {"name": "unknown_tool", "arguments": "{}"}, } result = agent._execute_tool_call(tool_call) @@ -84,10 +81,7 @@ class TestExecuteToolCall: tool_call = { "id": "call_123", - "function": { - "name": "set_path_for_folder", - "arguments": '{}' - } + "function": {"name": "set_path_for_folder", "arguments": "{}"}, } result = agent._execute_tool_call(tool_call) @@ -102,8 +96,8 @@ class TestExecuteToolCall: "id": "call_123", "function": { "name": "set_path_for_folder", - "arguments": '{"folder_name": 123}' # Wrong type - } + "arguments": '{"folder_name": 123}', # Wrong type + }, } result = agent._execute_tool_call(tool_call) @@ -116,10 +110,7 @@ class TestExecuteToolCall: tool_call = { "id": "call_123", - "function": { - "name": "list_folder", - "arguments": '{invalid json}' - } + "function": {"name": "list_folder", "arguments": "{invalid json}"}, } result = agent._execute_tool_call(tool_call) @@ -160,40 +151,39 @@ class TestStep: assert "found" in response.lower() or "torrent" in response.lower() assert mock_llm_with_tool_call.complete.call_count == 2 - + # CRITICAL: Verify tools were passed to LLM first_call_args = mock_llm_with_tool_call.complete.call_args_list[0] - assert first_call_args[1]['tools'] is not None, "Tools not passed to LLM!" - assert len(first_call_args[1]['tools']) > 0, "Tools list is empty!" + assert first_call_args[1]["tools"] is not None, "Tools not passed to LLM!" + assert len(first_call_args[1]["tools"]) > 0, "Tools list is empty!" def test_step_max_iterations(self, memory, mock_llm): """Should stop after max iterations.""" call_count = [0] - + def mock_complete(messages, tools=None): call_count[0] += 1 # CRITICAL: Verify tools are passed (except on forced final call) if call_count[0] <= 3: assert tools is not None, f"Tools not passed on call {call_count[0]}!" - + if call_count[0] <= 3: return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": f"call_{call_count[0]}", - "function": { - "name": "list_folder", - "arguments": '{"folder_type": "download"}' + "tool_calls": [ + { + "id": f"call_{call_count[0]}", + "function": { + "name": "list_folder", + "arguments": '{"folder_type": "download"}', + }, } - }] + ], } else: - return { - "role": "assistant", - "content": "I couldn't complete the task." - } - + return {"role": "assistant", "content": "I couldn't complete the task."} + mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm, max_tool_iterations=3) @@ -241,49 +231,53 @@ class TestAgentIntegration: memory.ltm.set_config("movie_folder", str(real_folder["movies"])) call_count = [0] - + def mock_complete(messages, tools=None): call_count[0] += 1 # CRITICAL: Verify tools are passed on every call assert tools is not None, f"Tools not passed on call {call_count[0]}!" - + if call_count[0] == 1: return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": "call_1", - "function": { - "name": "list_folder", - "arguments": '{"folder_type": "download"}' + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "list_folder", + "arguments": '{"folder_type": "download"}', + }, } - }] + ], } elif call_count[0] == 2: # CRITICAL: Verify tool result was sent back - tool_messages = [m for m in messages if m.get('role') == 'tool'] + tool_messages = [m for m in messages if m.get("role") == "tool"] assert len(tool_messages) > 0, "Tool result not sent back to LLM!" - + return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": "call_2", - "function": { - "name": "list_folder", - "arguments": '{"folder_type": "movie"}' + "tool_calls": [ + { + "id": "call_2", + "function": { + "name": "list_folder", + "arguments": '{"folder_type": "movie"}', + }, } - }] + ], } else: return { "role": "assistant", - "content": "I listed both folders for you." + "content": "I listed both folders for you.", } - + mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm) - + response = agent.step("List my downloads and movies") assert call_count[0] == 3 diff --git a/tests/test_agent_edge_cases.py b/tests/test_agent_edge_cases.py index 93caba4..faca8e6 100644 --- a/tests/test_agent_edge_cases.py +++ b/tests/test_agent_edge_cases.py @@ -1,7 +1,9 @@ """Edge case tests for the Agent.""" -import pytest + from unittest.mock import Mock +import pytest + from agent.agent import Agent from infrastructure.persistence import get_memory @@ -15,19 +17,14 @@ class TestExecuteToolCallEdgeCases: # Mock a tool that returns None from agent.registry import Tool + agent.tools["test_tool"] = Tool( - name="test_tool", - description="Test", - func=lambda: None, - parameters={} + name="test_tool", description="Test", func=lambda: None, parameters={} ) tool_call = { "id": "call_123", - "function": { - "name": "test_tool", - "arguments": '{}' - } + "function": {"name": "test_tool", "arguments": "{}"}, } result = agent._execute_tool_call(tool_call) @@ -38,22 +35,17 @@ class TestExecuteToolCallEdgeCases: agent = Agent(llm=mock_llm) from agent.registry import Tool + def raise_interrupt(): raise KeyboardInterrupt() - + agent.tools["test_tool"] = Tool( - name="test_tool", - description="Test", - func=raise_interrupt, - parameters={} + name="test_tool", description="Test", func=raise_interrupt, parameters={} ) tool_call = { "id": "call_123", - "function": { - "name": "test_tool", - "arguments": '{}' - } + "function": {"name": "test_tool", "arguments": "{}"}, } with pytest.raises(KeyboardInterrupt): @@ -68,8 +60,8 @@ class TestExecuteToolCallEdgeCases: "id": "call_123", "function": { "name": "list_folder", - "arguments": '{"folder_type": "download", "extra_arg": "ignored"}' - } + "arguments": '{"folder_type": "download", "extra_arg": "ignored"}', + }, } result = agent._execute_tool_call(tool_call) @@ -84,8 +76,8 @@ class TestExecuteToolCallEdgeCases: "id": "call_123", "function": { "name": "get_torrent_by_index", - "arguments": '{"index": "not an int"}' - } + "arguments": '{"index": "not an int"}', + }, } result = agent._execute_tool_call(tool_call) @@ -115,12 +107,10 @@ class TestStepEdgeCases: def test_step_with_unicode_input(self, memory, mock_llm): """Should handle unicode input.""" + def mock_complete(messages, tools=None): - return { - "role": "assistant", - "content": "日本語の応答" - } - + return {"role": "assistant", "content": "日本語の応答"} + mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm) @@ -130,12 +120,10 @@ class TestStepEdgeCases: def test_step_llm_returns_empty(self, memory, mock_llm): """Should handle LLM returning empty string.""" + def mock_complete(messages, tools=None): - return { - "role": "assistant", - "content": "" - } - + return {"role": "assistant", "content": ""} + mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm) @@ -161,18 +149,17 @@ class TestStepEdgeCases: return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": f"call_{call_count[0]}", - "function": { - "name": "list_folder", - "arguments": '{"folder_type": "download"}' + "tool_calls": [ + { + "id": f"call_{call_count[0]}", + "function": { + "name": "list_folder", + "arguments": '{"folder_type": "download"}', + }, } - }] + ], } - return { - "role": "assistant", - "content": "Done looping" - } + return {"role": "assistant", "content": "Done looping"} mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm, max_tool_iterations=3) @@ -212,11 +199,13 @@ class TestStepEdgeCases: def test_step_with_active_downloads(self, memory, mock_llm): """Should include active downloads in context.""" - memory.episodic.add_active_download({ - "task_id": "123", - "name": "Movie.mkv", - "progress": 50, - }) + memory.episodic.add_active_download( + { + "task_id": "123", + "name": "Movie.mkv", + "progress": 50, + } + ) agent = Agent(llm=mock_llm) response = agent.step("Hello") @@ -257,29 +246,28 @@ class TestAgentConcurrencyEdgeCases: memory.ltm.set_config("download_folder", str(real_folder["downloads"])) call_count = [0] - + def mock_complete(messages, tools=None): call_count[0] += 1 if call_count[0] == 1: return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": "call_1", - "function": { - "name": "set_path_for_folder", - "arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}' + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "set_path_for_folder", + "arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}', + }, } - }] + ], } - return { - "role": "assistant", - "content": "Path set successfully." - } + return {"role": "assistant", "content": "Path set successfully."} mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm) - + response = agent.step("Set movie folder") mem = get_memory() @@ -292,29 +280,28 @@ class TestAgentErrorRecovery: def test_recovers_from_tool_error(self, memory, mock_llm): """Should recover from tool error and continue.""" call_count = [0] - + def mock_complete(messages, tools=None): call_count[0] += 1 if call_count[0] == 1: return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": "call_1", - "function": { - "name": "list_folder", - "arguments": '{"folder_type": "download"}' + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "list_folder", + "arguments": '{"folder_type": "download"}', + }, } - }] + ], } - return { - "role": "assistant", - "content": "The folder is not configured." - } + return {"role": "assistant", "content": "The folder is not configured."} mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm) - + response = agent.step("List downloads") assert "not configured" in response.lower() or len(response) > 0 @@ -322,29 +309,28 @@ class TestAgentErrorRecovery: def test_error_tracked_in_memory(self, memory, mock_llm): """Should track errors in episodic memory.""" call_count = [0] - + def mock_complete(messages, tools=None): call_count[0] += 1 if call_count[0] == 1: return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": "call_1", - "function": { - "name": "set_path_for_folder", - "arguments": '{}' # Missing required args + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "set_path_for_folder", + "arguments": "{}", # Missing required args + }, } - }] + ], } - return { - "role": "assistant", - "content": "Error occurred." - } + return {"role": "assistant", "content": "Error occurred."} mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm) - + agent.step("Set folder") mem = get_memory() @@ -360,18 +346,17 @@ class TestAgentErrorRecovery: return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": f"call_{call_count[0]}", - "function": { - "name": "set_path_for_folder", - "arguments": '{}' # Missing required args - will error + "tool_calls": [ + { + "id": f"call_{call_count[0]}", + "function": { + "name": "set_path_for_folder", + "arguments": "{}", # Missing required args - will error + }, } - }] + ], } - return { - "role": "assistant", - "content": "All attempts failed." - } + return {"role": "assistant", "content": "All attempts failed."} mock_llm.complete = Mock(side_effect=mock_complete) agent = Agent(llm=mock_llm, max_tool_iterations=3) diff --git a/tests/test_api.py b/tests/test_api.py index af2e1ac..85ea1c1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ """Tests for FastAPI endpoints.""" -import pytest -from unittest.mock import Mock, patch, MagicMock + +from unittest.mock import patch + from fastapi.testclient import TestClient @@ -10,6 +11,7 @@ class TestHealthEndpoint: def test_health_check(self, memory): """Should return healthy status.""" from app import app + client = TestClient(app) response = client.get("/health") @@ -24,6 +26,7 @@ class TestModelsEndpoint: def test_list_models(self, memory): """Should return model list.""" from app import app + client = TestClient(app) response = client.get("/v1/models") @@ -41,6 +44,7 @@ class TestMemoryEndpoints: def test_get_memory_state(self, memory): """Should return full memory state.""" from app import app + client = TestClient(app) response = client.get("/memory/state") @@ -54,6 +58,7 @@ class TestMemoryEndpoints: def test_get_search_results_empty(self, memory): """Should return empty when no search results.""" from app import app + client = TestClient(app) response = client.get("/memory/episodic/search-results") @@ -65,6 +70,7 @@ class TestMemoryEndpoints: def test_get_search_results_with_data(self, memory_with_search_results): """Should return search results when available.""" from app import app + client = TestClient(app) response = client.get("/memory/episodic/search-results") @@ -78,6 +84,7 @@ class TestMemoryEndpoints: def test_clear_session(self, memory_with_search_results): """Should clear session memories.""" from app import app + client = TestClient(app) response = client.post("/memory/clear-session") @@ -96,14 +103,18 @@ class TestChatCompletionsEndpoint: def test_chat_completion_success(self, memory): """Should return chat completion.""" from app import app + # Patch the agent's step method directly with patch("app.agent.step", return_value="Hello! How can I help?"): client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": "Hello"}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": "Hello"}], + }, + ) assert response.status_code == 200 data = response.json() @@ -113,12 +124,16 @@ class TestChatCompletionsEndpoint: def test_chat_completion_no_user_message(self, memory): """Should return error if no user message.""" from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "system", "content": "You are helpful"}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "system", "content": "You are helpful"}], + }, + ) assert response.status_code == 422 detail = response.json()["detail"] @@ -132,18 +147,23 @@ class TestChatCompletionsEndpoint: def test_chat_completion_empty_messages(self, memory): """Should return error for empty messages.""" from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [], + }, + ) assert response.status_code == 422 def test_chat_completion_invalid_json(self, memory): """Should return error for invalid JSON.""" from app import app + client = TestClient(app) response = client.post( @@ -157,14 +177,18 @@ class TestChatCompletionsEndpoint: def test_chat_completion_streaming(self, memory): """Should support streaming mode.""" from app import app + with patch("app.agent.step", return_value="Streaming response"): client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": "Hello"}], - "stream": True, - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] @@ -172,17 +196,21 @@ class TestChatCompletionsEndpoint: def test_chat_completion_extracts_last_user_message(self, memory): """Should use last user message.""" from app import app + with patch("app.agent.step", return_value="Response") as mock_step: client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [ - {"role": "user", "content": "First message"}, - {"role": "assistant", "content": "Response"}, - {"role": "user", "content": "Second message"}, - ], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [ + {"role": "user", "content": "First message"}, + {"role": "assistant", "content": "Response"}, + {"role": "user", "content": "Second message"}, + ], + }, + ) assert response.status_code == 200 # Verify the agent received the last user message @@ -191,13 +219,17 @@ class TestChatCompletionsEndpoint: def test_chat_completion_response_format(self, memory): """Should return OpenAI-compatible format.""" from app import app + with patch("app.agent.step", return_value="Test response"): client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": "Test"}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": "Test"}], + }, + ) data = response.json() assert "id" in data diff --git a/tests/test_api_edge_cases.py b/tests/test_api_edge_cases.py index 0fb3f47..562da55 100644 --- a/tests/test_api_edge_cases.py +++ b/tests/test_api_edge_cases.py @@ -1,7 +1,7 @@ """Edge case tests for FastAPI endpoints.""" -import pytest -import json -from unittest.mock import Mock, patch, MagicMock + +from unittest.mock import Mock, patch + from fastapi.testclient import TestClient @@ -10,43 +10,46 @@ class TestChatCompletionsEdgeCases: def test_very_long_message(self, memory): """Should handle very long user message.""" - from app import app, agent - + from app import agent, app + # Patch the agent's LLM directly mock_llm = Mock() - mock_llm.complete.return_value = { - "role": "assistant", - "content": "Response" - } + mock_llm.complete.return_value = {"role": "assistant", "content": "Response"} agent.llm = mock_llm - + client = TestClient(app) long_message = "x" * 100000 - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": long_message}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": long_message}], + }, + ) assert response.status_code == 200 def test_unicode_message(self, memory): """Should handle unicode in message.""" - from app import app, agent - + from app import agent, app + mock_llm = Mock() mock_llm.complete.return_value = { "role": "assistant", - "content": "日本語の応答" + "content": "日本語の応答", } agent.llm = mock_llm - + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}], + }, + ) assert response.status_code == 200 content = response.json()["choices"][0]["message"]["content"] @@ -54,22 +57,22 @@ class TestChatCompletionsEdgeCases: def test_special_characters_in_message(self, memory): """Should handle special characters.""" - from app import app, agent - + from app import agent, app + mock_llm = Mock() - mock_llm.complete.return_value = { - "role": "assistant", - "content": "Response" - } + mock_llm.complete.return_value = {"role": "assistant", "content": "Response"} agent.llm = mock_llm - + client = TestClient(app) special_message = 'Test with "quotes" and \\backslash and \n newline' - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": special_message}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": special_message}], + }, + ) assert response.status_code == 200 @@ -81,12 +84,16 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": ""}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": ""}], + }, + ) # Empty content should be rejected assert response.status_code == 422 @@ -98,12 +105,16 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": None}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": None}], + }, + ) assert response.status_code == 422 @@ -114,12 +125,16 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user"}], # No content - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user"}], # No content + }, + ) # May accept or reject depending on validation assert response.status_code in [200, 400, 422] @@ -131,12 +146,16 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"content": "Hello"}], # No role - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"content": "Hello"}], # No role + }, + ) # Should reject or accept depending on validation assert response.status_code in [200, 400, 422] @@ -149,27 +168,28 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "invalid_role", "content": "Hello"}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "invalid_role", "content": "Hello"}], + }, + ) # Should reject or ignore invalid role assert response.status_code in [200, 400, 422] def test_many_messages(self, memory): """Should handle many messages in conversation.""" - from app import app, agent - + from app import agent, app + mock_llm = Mock() - mock_llm.complete.return_value = { - "role": "assistant", - "content": "Response" - } + mock_llm.complete.return_value = {"role": "assistant", "content": "Response"} agent.llm = mock_llm - + client = TestClient(app) messages = [] @@ -178,10 +198,13 @@ class TestChatCompletionsEdgeCases: messages.append({"role": "assistant", "content": f"Response {i}"}) messages.append({"role": "user", "content": "Final message"}) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": messages, - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": messages, + }, + ) assert response.status_code == 200 @@ -192,15 +215,19 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [ - {"role": "system", "content": "You are helpful"}, - {"role": "system", "content": "Be concise"}, - ], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "system", "content": "Be concise"}, + ], + }, + ) assert response.status_code == 422 @@ -211,14 +238,18 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [ - {"role": "assistant", "content": "Hello"}, - ], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [ + {"role": "assistant", "content": "Hello"}, + ], + }, + ) assert response.status_code == 422 @@ -229,12 +260,16 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": "not an array", - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": "not an array", + }, + ) assert response.status_code == 422 # Pydantic validation error @@ -246,118 +281,128 @@ class TestChatCompletionsEdgeCases: mock_llm_class.return_value = mock_llm from app import app + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": ["not an object", 123, None], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": ["not an object", 123, None], + }, + ) assert response.status_code == 422 # Pydantic validation error def test_extra_fields_in_request(self, memory): """Should ignore extra fields in request.""" - from app import app, agent - + from app import agent, app + mock_llm = Mock() - mock_llm.complete.return_value = { - "role": "assistant", - "content": "Response" - } + mock_llm.complete.return_value = {"role": "assistant", "content": "Response"} agent.llm = mock_llm - + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": "Hello"}], - "extra_field": "should be ignored", - "temperature": 0.7, - "max_tokens": 100, - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": "Hello"}], + "extra_field": "should be ignored", + "temperature": 0.7, + "max_tokens": 100, + }, + ) assert response.status_code == 200 def test_streaming_with_tool_call(self, memory, real_folder): """Should handle streaming with tool execution.""" - from app import app, agent + from app import agent, app from infrastructure.persistence import get_memory - + mem = get_memory() mem.ltm.set_config("download_folder", str(real_folder["downloads"])) - + call_count = [0] + def mock_complete(messages, tools=None): call_count[0] += 1 if call_count[0] == 1: return { "role": "assistant", "content": None, - "tool_calls": [{ - "id": "call_1", - "function": { - "name": "list_folder", - "arguments": '{"folder_type": "download"}' + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "list_folder", + "arguments": '{"folder_type": "download"}', + }, } - }] + ], } - return { - "role": "assistant", - "content": "Listed the folder." - } - + return {"role": "assistant", "content": "Listed the folder."} + mock_llm = Mock() mock_llm.complete = Mock(side_effect=mock_complete) agent.llm = mock_llm - + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": "List downloads"}], - "stream": True, - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": "List downloads"}], + "stream": True, + }, + ) assert response.status_code == 200 def test_concurrent_requests_simulation(self, memory): """Should handle rapid sequential requests.""" - from app import app, agent - + from app import agent, app + mock_llm = Mock() - mock_llm.complete.return_value = { - "role": "assistant", - "content": "Response" - } + mock_llm.complete.return_value = {"role": "assistant", "content": "Response"} agent.llm = mock_llm - + client = TestClient(app) for i in range(10): - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": f"Request {i}"}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": f"Request {i}"}], + }, + ) assert response.status_code == 200 def test_llm_returns_json_in_response(self, memory): """Should handle LLM returning JSON in text response.""" - from app import app, agent - + from app import agent, app + mock_llm = Mock() mock_llm.complete.return_value = { "role": "assistant", - "content": '{"result": "some data", "count": 5}' + "content": '{"result": "some data", "count": 5}', } agent.llm = mock_llm - + client = TestClient(app) - response = client.post("/v1/chat/completions", json={ - "model": "agent-media", - "messages": [{"role": "user", "content": "Give me JSON"}], - }) + response = client.post( + "/v1/chat/completions", + json={ + "model": "agent-media", + "messages": [{"role": "user", "content": "Give me JSON"}], + }, + ) assert response.status_code == 200 content = response.json()["choices"][0]["message"]["content"] @@ -425,6 +470,7 @@ class TestMemoryEndpointsEdgeCases: with patch("app.DeepSeekClient") as mock_llm: mock_llm.return_value = Mock() from app import app + client = TestClient(app) # Clear multiple times @@ -459,6 +505,7 @@ class TestHealthEndpointEdgeCases: with patch("app.DeepSeekClient") as mock_llm: mock_llm.return_value = Mock() from app import app + client = TestClient(app) response = client.get("/health") @@ -471,6 +518,7 @@ class TestHealthEndpointEdgeCases: with patch("app.DeepSeekClient") as mock_llm: mock_llm.return_value = Mock() from app import app + client = TestClient(app) response = client.get("/health?extra=param&another=value") @@ -486,6 +534,7 @@ class TestModelsEndpointEdgeCases: with patch("app.DeepSeekClient") as mock_llm: mock_llm.return_value = Mock() from app import app + client = TestClient(app) response = client.get("/v1/models") diff --git a/tests/test_config_critical.py b/tests/test_config_critical.py index 72e1a51..2e75b1d 100644 --- a/tests/test_config_critical.py +++ b/tests/test_config_critical.py @@ -1,9 +1,9 @@ """Critical tests for configuration validation.""" -import pytest -import os -from agent.config import Settings, ConfigurationError +import pytest + +from agent.config import ConfigurationError, Settings class TestConfigValidation: @@ -13,7 +13,7 @@ class TestConfigValidation: """Verify invalid temperature is rejected.""" with pytest.raises(ConfigurationError, match="Temperature"): Settings(temperature=3.0) # > 2.0 - + with pytest.raises(ConfigurationError, match="Temperature"): Settings(temperature=-0.1) # < 0.0 @@ -28,7 +28,7 @@ class TestConfigValidation: """Verify invalid max_iterations is rejected.""" with pytest.raises(ConfigurationError, match="max_tool_iterations"): Settings(max_tool_iterations=0) # < 1 - + with pytest.raises(ConfigurationError, match="max_tool_iterations"): Settings(max_tool_iterations=100) # > 20 @@ -43,7 +43,7 @@ class TestConfigValidation: """Verify invalid timeout is rejected.""" with pytest.raises(ConfigurationError, match="request_timeout"): Settings(request_timeout=0) # < 1 - + with pytest.raises(ConfigurationError, match="request_timeout"): Settings(request_timeout=500) # > 300 @@ -58,7 +58,7 @@ class TestConfigValidation: """Verify invalid DeepSeek URL is rejected.""" with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"): Settings(deepseek_base_url="not-a-url") - + with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"): Settings(deepseek_base_url="ftp://invalid.com") @@ -86,19 +86,17 @@ class TestConfigChecks: def test_is_deepseek_configured_with_key(self): """Verify is_deepseek_configured returns True with API key.""" settings = Settings( - deepseek_api_key="test-key", - deepseek_base_url="https://api.test.com" + deepseek_api_key="test-key", deepseek_base_url="https://api.test.com" ) - + assert settings.is_deepseek_configured() is True def test_is_deepseek_configured_without_key(self): """Verify is_deepseek_configured returns False without API key.""" settings = Settings( - deepseek_api_key="", - deepseek_base_url="https://api.test.com" + deepseek_api_key="", deepseek_base_url="https://api.test.com" ) - + assert settings.is_deepseek_configured() is False def test_is_deepseek_configured_without_url(self): @@ -110,19 +108,15 @@ class TestConfigChecks: def test_is_tmdb_configured_with_key(self): """Verify is_tmdb_configured returns True with API key.""" settings = Settings( - tmdb_api_key="test-key", - tmdb_base_url="https://api.test.com" + tmdb_api_key="test-key", tmdb_base_url="https://api.test.com" ) - + assert settings.is_tmdb_configured() is True def test_is_tmdb_configured_without_key(self): """Verify is_tmdb_configured returns False without API key.""" - settings = Settings( - tmdb_api_key="", - tmdb_base_url="https://api.test.com" - ) - + settings = Settings(tmdb_api_key="", tmdb_base_url="https://api.test.com") + assert settings.is_tmdb_configured() is False @@ -132,25 +126,25 @@ class TestConfigDefaults: def test_default_temperature(self): """Verify default temperature is reasonable.""" settings = Settings() - + assert 0.0 <= settings.temperature <= 2.0 def test_default_max_iterations(self): """Verify default max_iterations is reasonable.""" settings = Settings() - + assert 1 <= settings.max_tool_iterations <= 20 def test_default_timeout(self): """Verify default timeout is reasonable.""" settings = Settings() - + assert 1 <= settings.request_timeout <= 300 def test_default_urls_are_valid(self): """Verify default URLs are valid.""" settings = Settings() - + assert settings.deepseek_base_url.startswith(("http://", "https://")) assert settings.tmdb_base_url.startswith(("http://", "https://")) @@ -161,38 +155,38 @@ class TestConfigEnvironmentVariables: def test_loads_temperature_from_env(self, monkeypatch): """Verify temperature is loaded from environment.""" monkeypatch.setenv("TEMPERATURE", "0.5") - + settings = Settings() - + assert settings.temperature == 0.5 def test_loads_max_iterations_from_env(self, monkeypatch): """Verify max_iterations is loaded from environment.""" monkeypatch.setenv("MAX_TOOL_ITERATIONS", "10") - + settings = Settings() - + assert settings.max_tool_iterations == 10 def test_loads_timeout_from_env(self, monkeypatch): """Verify timeout is loaded from environment.""" monkeypatch.setenv("REQUEST_TIMEOUT", "60") - + settings = Settings() - + assert settings.request_timeout == 60 def test_loads_deepseek_url_from_env(self, monkeypatch): """Verify DeepSeek URL is loaded from environment.""" monkeypatch.setenv("DEEPSEEK_BASE_URL", "https://custom.api.com") - + settings = Settings() - + assert settings.deepseek_base_url == "https://custom.api.com" def test_invalid_env_value_raises_error(self, monkeypatch): """Verify invalid environment value raises error.""" monkeypatch.setenv("TEMPERATURE", "invalid") - + with pytest.raises(ValueError): Settings() diff --git a/tests/test_config_edge_cases.py b/tests/test_config_edge_cases.py index 9be2873..3076e85 100644 --- a/tests/test_config_edge_cases.py +++ b/tests/test_config_edge_cases.py @@ -1,12 +1,14 @@ """Edge case tests for configuration and parameters.""" -import pytest + import os from unittest.mock import patch -from agent.config import Settings, ConfigurationError +import pytest + +from agent.config import ConfigurationError, Settings from agent.parameters import ( - ParameterSchema, REQUIRED_PARAMETERS, + ParameterSchema, format_parameters_for_prompt, get_missing_required_parameters, ) @@ -110,19 +112,27 @@ class TestSettingsEdgeCases: def test_http_url_accepted(self): """Should accept http:// URLs.""" - with patch.dict(os.environ, { - "DEEPSEEK_BASE_URL": "http://localhost:8080", - "TMDB_BASE_URL": "http://localhost:3000", - }, clear=True): + with patch.dict( + os.environ, + { + "DEEPSEEK_BASE_URL": "http://localhost:8080", + "TMDB_BASE_URL": "http://localhost:3000", + }, + clear=True, + ): settings = Settings() assert settings.deepseek_base_url == "http://localhost:8080" def test_https_url_accepted(self): """Should accept https:// URLs.""" - with patch.dict(os.environ, { - "DEEPSEEK_BASE_URL": "https://api.example.com", - "TMDB_BASE_URL": "https://api.example.com", - }, clear=True): + with patch.dict( + os.environ, + { + "DEEPSEEK_BASE_URL": "https://api.example.com", + "TMDB_BASE_URL": "https://api.example.com", + }, + clear=True, + ): settings = Settings() assert settings.deepseek_base_url == "https://api.example.com" diff --git a/tests/test_memory.py b/tests/test_memory.py index 28c0430..ad1ecfe 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -1,18 +1,17 @@ """Tests for the Memory system.""" -import pytest -import json + from datetime import datetime -from pathlib import Path + +import pytest from infrastructure.persistence import ( - Memory, - LongTermMemory, - ShortTermMemory, EpisodicMemory, - init_memory, + LongTermMemory, + Memory, + ShortTermMemory, get_memory, - set_memory, has_memory, + init_memory, ) from infrastructure.persistence.context import _memory_ctx @@ -23,11 +22,12 @@ def is_iso_format(s: str) -> bool: return False try: # Attempt to parse the string as an ISO 8601 timestamp - datetime.fromisoformat(s.replace('Z', '+00:00')) + datetime.fromisoformat(s.replace("Z", "+00:00")) return True except (ValueError, TypeError): return False + class TestLongTermMemory: """Tests for LongTermMemory.""" @@ -116,12 +116,18 @@ class TestLongTermMemory: assert data["config"]["key"] == "value" def test_from_dict(self): - data = {"config": {"download_folder": "/downloads"}, "preferences": {"preferred_quality": "4K"}, "library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]}, "following": []} + data = { + "config": {"download_folder": "/downloads"}, + "preferences": {"preferred_quality": "4K"}, + "library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]}, + "following": [], + } ltm = LongTermMemory.from_dict(data) assert ltm.get_config("download_folder") == "/downloads" assert ltm.preferences["preferred_quality"] == "4K" assert len(ltm.library["movies"]) == 1 + class TestShortTermMemory: """Tests for ShortTermMemory.""" @@ -162,6 +168,7 @@ class TestShortTermMemory: assert stm.conversation_history == [] assert stm.language == "en" + class TestEpisodicMemory: """Tests for EpisodicMemory.""" @@ -192,6 +199,7 @@ class TestEpisodicMemory: assert result is not None assert result["name"] == "Result 2" + class TestMemory: """Tests for the Memory manager.""" @@ -217,11 +225,10 @@ class TestMemory: assert memory.stm.conversation_history == [] assert memory.episodic.recent_errors == [] + class TestMemoryContext: """Tests for memory context functions.""" - - def test_get_memory_not_initialized(self): _memory_ctx.set(None) with pytest.raises(RuntimeError, match="Memory not initialized"): diff --git a/tests/test_memory_edge_cases.py b/tests/test_memory_edge_cases.py index 36f8955..9153333 100644 --- a/tests/test_memory_edge_cases.py +++ b/tests/test_memory_edge_cases.py @@ -1,18 +1,17 @@ """Edge case tests for the Memory system.""" -import pytest + import json import os -from pathlib import Path -from datetime import datetime -from unittest.mock import patch, mock_open + +import pytest from infrastructure.persistence import ( - Memory, - LongTermMemory, - ShortTermMemory, EpisodicMemory, - init_memory, + LongTermMemory, + Memory, + ShortTermMemory, get_memory, + init_memory, set_memory, ) from infrastructure.persistence.context import _memory_ctx @@ -390,7 +389,7 @@ class TestMemoryEdgeCases: def test_init_with_nonexistent_directory(self, temp_dir): """Should create directory if not exists.""" new_dir = temp_dir / "new" / "nested" / "dir" - + # Create parent directories first new_dir.mkdir(parents=True, exist_ok=True) memory = Memory(storage_dir=str(new_dir)) @@ -529,7 +528,6 @@ class TestMemoryContextEdgeCases: def test_context_isolation(self, temp_dir): """Context should be isolated per context.""" - import asyncio from contextvars import copy_context _memory_ctx.set(None) diff --git a/tests/test_prompts_critical.py b/tests/test_prompts_critical.py index 5ebba34..7b9a891 100644 --- a/tests/test_prompts_critical.py +++ b/tests/test_prompts_critical.py @@ -1,10 +1,8 @@ """Critical tests for prompt builder - Tests that would have caught bugs.""" -import pytest -from agent.registry import make_tools from agent.prompts import PromptBuilder -from infrastructure.persistence import get_memory +from agent.registry import make_tools class TestPromptBuilderToolsInjection: @@ -15,20 +13,22 @@ class TestPromptBuilderToolsInjection: tools = make_tools() builder = PromptBuilder(tools) prompt = builder.build_system_prompt() - + # Verify each tool is mentioned for tool_name in tools.keys(): - assert tool_name in prompt, f"Tool {tool_name} not mentioned in system prompt" + assert ( + tool_name in prompt + ), f"Tool {tool_name} not mentioned in system prompt" def test_tools_spec_contains_all_registered_tools(self): """CRITICAL: Verify build_tools_spec() returns all tools.""" tools = make_tools() builder = PromptBuilder(tools) specs = builder.build_tools_spec() - - spec_names = {spec['function']['name'] for spec in specs} + + spec_names = {spec["function"]["name"] for spec in specs} tool_names = set(tools.keys()) - + assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}" def test_tools_spec_is_not_empty(self): @@ -36,7 +36,7 @@ class TestPromptBuilderToolsInjection: tools = make_tools() builder = PromptBuilder(tools) specs = builder.build_tools_spec() - + assert len(specs) > 0, "Tools spec is empty!" def test_tools_spec_format_matches_openai(self): @@ -44,14 +44,14 @@ class TestPromptBuilderToolsInjection: tools = make_tools() builder = PromptBuilder(tools) specs = builder.build_tools_spec() - + for spec in specs: - assert 'type' in spec - assert spec['type'] == 'function' - assert 'function' in spec - assert 'name' in spec['function'] - assert 'description' in spec['function'] - assert 'parameters' in spec['function'] + assert "type" in spec + assert spec["type"] == "function" + assert "function" in spec + assert "name" in spec["function"] + assert "description" in spec["function"] + assert "parameters" in spec["function"] class TestPromptBuilderMemoryContext: @@ -61,29 +61,29 @@ class TestPromptBuilderMemoryContext: """Verify current topic is included in prompt.""" tools = make_tools() builder = PromptBuilder(tools) - + memory.stm.set_topic("test_topic") prompt = builder.build_system_prompt() - + assert "test_topic" in prompt def test_prompt_includes_extracted_entities(self, memory): """Verify extracted entities are included in prompt.""" tools = make_tools() builder = PromptBuilder(tools) - + memory.stm.set_entity("test_key", "test_value") prompt = builder.build_system_prompt() - + assert "test_key" in prompt def test_prompt_includes_search_results(self, memory_with_search_results): """Verify search results are included in prompt.""" tools = make_tools() builder = PromptBuilder(tools) - + prompt = builder.build_system_prompt() - + assert "Inception" in prompt assert "LAST SEARCH" in prompt @@ -91,15 +91,13 @@ class TestPromptBuilderMemoryContext: """Verify active downloads are included in prompt.""" tools = make_tools() builder = PromptBuilder(tools) - - memory.episodic.add_active_download({ - "task_id": "123", - "name": "Test Movie", - "progress": 50 - }) - + + memory.episodic.add_active_download( + {"task_id": "123", "name": "Test Movie", "progress": 50} + ) + prompt = builder.build_system_prompt() - + assert "ACTIVE DOWNLOADS" in prompt assert "Test Movie" in prompt @@ -107,33 +105,33 @@ class TestPromptBuilderMemoryContext: """Verify recent errors are included in prompt.""" tools = make_tools() builder = PromptBuilder(tools) - + memory.episodic.add_error("test_action", "test error message") - + prompt = builder.build_system_prompt() - + assert "RECENT ERRORS" in prompt or "error" in prompt.lower() def test_prompt_includes_configuration(self, memory): """Verify configuration is included in prompt.""" tools = make_tools() builder = PromptBuilder(tools) - + memory.ltm.set_config("download_folder", "/test/downloads") - + prompt = builder.build_system_prompt() - + assert "CONFIGURATION" in prompt or "download_folder" in prompt def test_prompt_includes_language(self, memory): """Verify language is included in prompt.""" tools = make_tools() builder = PromptBuilder(tools) - + memory.stm.set_language("fr") - + prompt = builder.build_system_prompt() - + assert "fr" in prompt or "LANGUAGE" in prompt @@ -145,7 +143,7 @@ class TestPromptBuilderStructure: tools = make_tools() builder = PromptBuilder(tools) prompt = builder.build_system_prompt() - + assert len(prompt) > 0 assert prompt.strip() != "" @@ -154,7 +152,7 @@ class TestPromptBuilderStructure: tools = make_tools() builder = PromptBuilder(tools) prompt = builder.build_system_prompt() - + assert "assistant" in prompt.lower() or "help" in prompt.lower() def test_system_prompt_includes_rules(self): @@ -162,7 +160,7 @@ class TestPromptBuilderStructure: tools = make_tools() builder = PromptBuilder(tools) prompt = builder.build_system_prompt() - + assert "RULES" in prompt or "IMPORTANT" in prompt def test_system_prompt_includes_examples(self): @@ -170,16 +168,16 @@ class TestPromptBuilderStructure: tools = make_tools() builder = PromptBuilder(tools) prompt = builder.build_system_prompt() - + assert "EXAMPLES" in prompt or "example" in prompt.lower() def test_tools_description_format(self): """Verify tools are properly formatted in description.""" tools = make_tools() builder = PromptBuilder(tools) - + description = builder._format_tools_description() - + # Should have tool names and descriptions for tool_name, tool in tools.items(): assert tool_name in description @@ -190,9 +188,9 @@ class TestPromptBuilderStructure: """Verify episodic context is properly formatted.""" tools = make_tools() builder = PromptBuilder(tools) - + context = builder._format_episodic_context() - + assert "LAST SEARCH" in context assert "Inception" in context @@ -200,12 +198,12 @@ class TestPromptBuilderStructure: """Verify STM context is properly formatted.""" tools = make_tools() builder = PromptBuilder(tools) - + memory.stm.set_topic("test_topic") memory.stm.set_entity("key", "value") - + context = builder._format_stm_context() - + assert "TOPIC" in context or "test_topic" in context assert "ENTITIES" in context or "key" in context @@ -213,11 +211,11 @@ class TestPromptBuilderStructure: """Verify config context is properly formatted.""" tools = make_tools() builder = PromptBuilder(tools) - + memory.ltm.set_config("test_key", "test_value") - + context = builder._format_config_context() - + assert "CONFIGURATION" in context assert "test_key" in context @@ -229,10 +227,10 @@ class TestPromptBuilderEdgeCases: """Verify prompt works with empty memory.""" tools = make_tools() builder = PromptBuilder(tools) - + # Memory is empty prompt = builder.build_system_prompt() - + # Should still have base content assert len(prompt) > 0 assert "assistant" in prompt.lower() @@ -240,18 +238,18 @@ class TestPromptBuilderEdgeCases: def test_prompt_with_empty_tools(self): """Verify prompt handles empty tools dict.""" builder = PromptBuilder({}) - + prompt = builder.build_system_prompt() - + # Should still generate a prompt assert len(prompt) > 0 def test_tools_spec_with_empty_tools(self): """Verify tools spec handles empty tools dict.""" builder = PromptBuilder({}) - + specs = builder.build_tools_spec() - + assert isinstance(specs, list) assert len(specs) == 0 @@ -259,11 +257,11 @@ class TestPromptBuilderEdgeCases: """Verify prompt handles unicode in memory.""" tools = make_tools() builder = PromptBuilder(tools) - + memory.stm.set_entity("movie", "Amélie 🎬") - + prompt = builder.build_system_prompt() - + assert "Amélie" in prompt assert "🎬" in prompt @@ -271,13 +269,13 @@ class TestPromptBuilderEdgeCases: """Verify prompt handles many search results.""" tools = make_tools() builder = PromptBuilder(tools) - + # Add many results results = [{"name": f"Movie {i}", "seeders": i} for i in range(20)] memory.episodic.store_search_results("test", results, "torrent") - + prompt = builder.build_system_prompt() - + # Should include some results but not all (to avoid huge prompts) assert "Movie 0" in prompt or "Movie 1" in prompt # Should indicate there are more diff --git a/tests/test_prompts_edge_cases.py b/tests/test_prompts_edge_cases.py index c1132bc..e57e4c1 100644 --- a/tests/test_prompts_edge_cases.py +++ b/tests/test_prompts_edge_cases.py @@ -1,10 +1,8 @@ """Edge case tests for PromptBuilder.""" -import pytest -import json + from agent.prompts import PromptBuilder from agent.registry import make_tools -from infrastructure.persistence import get_memory class TestPromptBuilderEdgeCases: @@ -93,11 +91,13 @@ class TestPromptBuilderEdgeCases: def test_prompt_with_many_active_downloads(self, memory): """Should limit displayed active downloads.""" for i in range(20): - memory.episodic.add_active_download({ - "task_id": str(i), - "name": f"Download {i}", - "progress": i * 5, - }) + memory.episodic.add_active_download( + { + "task_id": str(i), + "name": f"Download {i}", + "progress": i * 5, + } + ) tools = make_tools() builder = PromptBuilder(tools) @@ -136,12 +136,15 @@ class TestPromptBuilderEdgeCases: def test_prompt_with_complex_workflow(self, memory): """Should handle complex workflow state.""" - memory.stm.start_workflow("download", { - "title": "Test Movie", - "year": 2024, - "quality": "1080p", - "nested": {"deep": {"value": "test"}}, - }) + memory.stm.start_workflow( + "download", + { + "title": "Test Movie", + "year": 2024, + "quality": "1080p", + "nested": {"deep": {"value": "test"}}, + }, + ) memory.stm.update_workflow_stage("searching_torrents") tools = make_tools() @@ -313,11 +316,14 @@ class TestFormatEpisodicContextEdgeCases: def test_format_with_search_results_none_names(self, memory): """Should handle results with None names.""" - memory.episodic.store_search_results("test", [ - {"name": None}, - {"title": None}, - {}, - ]) + memory.episodic.store_search_results( + "test", + [ + {"name": None}, + {"title": None}, + {}, + ], + ) tools = make_tools() builder = PromptBuilder(tools) diff --git a/tests/test_registry_critical.py b/tests/test_registry_critical.py index 767e25f..5b4e176 100644 --- a/tests/test_registry_critical.py +++ b/tests/test_registry_critical.py @@ -1,10 +1,11 @@ """Critical tests for tool registry - Tests that would have caught bugs.""" -import pytest import inspect -from agent.registry import make_tools, _create_tool_from_function, Tool +import pytest + from agent.prompts import PromptBuilder +from agent.registry import Tool, _create_tool_from_function, make_tools class TestToolSpecFormat: @@ -15,54 +16,59 @@ class TestToolSpecFormat: tools = make_tools() builder = PromptBuilder(tools) specs = builder.build_tools_spec() - + # Verify structure assert isinstance(specs, list), "Tool specs must be a list" assert len(specs) > 0, "Tool specs list is empty" - + for spec in specs: # OpenAI format requires these fields - assert spec['type'] == 'function', f"Tool type must be 'function', got {spec.get('type')}" - assert 'function' in spec, "Tool spec missing 'function' key" - - func = spec['function'] - assert 'name' in func, "Function missing 'name'" - assert 'description' in func, "Function missing 'description'" - assert 'parameters' in func, "Function missing 'parameters'" - - params = func['parameters'] - assert params['type'] == 'object', "Parameters type must be 'object'" - assert 'properties' in params, "Parameters missing 'properties'" - assert 'required' in params, "Parameters missing 'required'" - assert isinstance(params['required'], list), "Required must be a list" + assert ( + spec["type"] == "function" + ), f"Tool type must be 'function', got {spec.get('type')}" + assert "function" in spec, "Tool spec missing 'function' key" + + func = spec["function"] + assert "name" in func, "Function missing 'name'" + assert "description" in func, "Function missing 'description'" + assert "parameters" in func, "Function missing 'parameters'" + + params = func["parameters"] + assert params["type"] == "object", "Parameters type must be 'object'" + assert "properties" in params, "Parameters missing 'properties'" + assert "required" in params, "Parameters missing 'required'" + assert isinstance(params["required"], list), "Required must be a list" def test_tool_parameters_match_function_signature(self): """CRITICAL: Verify generated parameters match function signature.""" + def test_func(name: str, age: int, active: bool = True): """Test function with typed parameters.""" return {"status": "ok"} - + tool = _create_tool_from_function(test_func) - + # Verify types are correctly mapped - assert tool.parameters['properties']['name']['type'] == 'string' - assert tool.parameters['properties']['age']['type'] == 'integer' - assert tool.parameters['properties']['active']['type'] == 'boolean' - + assert tool.parameters["properties"]["name"]["type"] == "string" + assert tool.parameters["properties"]["age"]["type"] == "integer" + assert tool.parameters["properties"]["active"]["type"] == "boolean" + # Verify required vs optional - assert 'name' in tool.parameters['required'], "name should be required" - assert 'age' in tool.parameters['required'], "age should be required" - assert 'active' not in tool.parameters['required'], "active has default, should not be required" + assert "name" in tool.parameters["required"], "name should be required" + assert "age" in tool.parameters["required"], "age should be required" + assert ( + "active" not in tool.parameters["required"] + ), "active has default, should not be required" def test_all_registered_tools_are_callable(self): """CRITICAL: Verify all registered tools are actually callable.""" tools = make_tools() - + assert len(tools) > 0, "No tools registered" - + for name, tool in tools.items(): assert callable(tool.func), f"Tool {name} is not callable" - + # Verify function has valid signature try: sig = inspect.signature(tool.func) @@ -75,38 +81,40 @@ class TestToolSpecFormat: tools = make_tools() builder = PromptBuilder(tools) specs = builder.build_tools_spec() - - spec_names = {spec['function']['name'] for spec in specs} + + spec_names = {spec["function"]["name"] for spec in specs} tool_names = set(tools.keys()) - + missing = tool_names - spec_names extra = spec_names - tool_names - + assert not missing, f"Tools missing from specs: {missing}" assert not extra, f"Extra tools in specs: {extra}" assert spec_names == tool_names, "Tool specs don't match registered tools" def test_tool_description_extracted_from_docstring(self): """Verify tool description is extracted from function docstring.""" + def test_func(param: str): """This is the description. - + More details here. """ return {} - + tool = _create_tool_from_function(test_func) - + assert tool.description == "This is the description." assert "More details" not in tool.description def test_tool_without_docstring_uses_function_name(self): """Verify tool without docstring uses function name as description.""" + def test_func_no_doc(param: str): return {} - + tool = _create_tool_from_function(test_func_no_doc) - + assert tool.description == "test_func_no_doc" def test_tool_parameters_have_descriptions(self): @@ -114,25 +122,27 @@ class TestToolSpecFormat: tools = make_tools() builder = PromptBuilder(tools) specs = builder.build_tools_spec() - + for spec in specs: - params = spec['function']['parameters'] - properties = params.get('properties', {}) - + params = spec["function"]["parameters"] + properties = params.get("properties", {}) + for param_name, param_spec in properties.items(): - assert 'description' in param_spec, \ - f"Parameter {param_name} in {spec['function']['name']} missing description" + assert ( + "description" in param_spec + ), f"Parameter {param_name} in {spec['function']['name']} missing description" def test_required_parameters_are_marked_correctly(self): """Verify required parameters are correctly identified.""" + def func_with_optional(required: str, optional: int = 5): return {} - + tool = _create_tool_from_function(func_with_optional) - - assert 'required' in tool.parameters['required'] - assert 'optional' not in tool.parameters['required'] - assert len(tool.parameters['required']) == 1 + + assert "required" in tool.parameters["required"] + assert "optional" not in tool.parameters["required"] + assert len(tool.parameters["required"]) == 1 class TestToolRegistry: @@ -141,28 +151,28 @@ class TestToolRegistry: def test_make_tools_returns_dict(self): """Verify make_tools returns a dictionary.""" tools = make_tools() - + assert isinstance(tools, dict) assert len(tools) > 0 def test_all_tools_have_unique_names(self): """Verify all tool names are unique.""" tools = make_tools() - + names = [tool.name for tool in tools.values()] assert len(names) == len(set(names)), "Duplicate tool names found" def test_tool_names_match_dict_keys(self): """Verify tool names match their dictionary keys.""" tools = make_tools() - + for key, tool in tools.items(): assert key == tool.name, f"Key {key} doesn't match tool name {tool.name}" def test_expected_tools_are_registered(self): """Verify all expected tools are registered.""" tools = make_tools() - + expected_tools = [ "set_path_for_folder", "list_folder", @@ -173,14 +183,14 @@ class TestToolRegistry: "get_torrent_by_index", "set_language", ] - + for expected in expected_tools: assert expected in tools, f"Expected tool {expected} not registered" def test_tool_functions_return_dict(self): """Verify all tool functions return dictionaries.""" tools = make_tools() - + # Test with minimal valid arguments # Note: This is a smoke test, not full integration for name, tool in tools.items(): @@ -195,16 +205,17 @@ class TestToolDataclass: def test_tool_creation(self): """Verify Tool can be created with all fields.""" + def dummy_func(): return {} - + tool = Tool( name="test_tool", description="Test description", func=dummy_func, - parameters={"type": "object", "properties": {}, "required": []} + parameters={"type": "object", "properties": {}, "required": []}, ) - + assert tool.name == "test_tool" assert tool.description == "Test description" assert tool.func == dummy_func @@ -212,12 +223,13 @@ class TestToolDataclass: def test_tool_parameters_structure(self): """Verify Tool parameters have correct structure.""" + def dummy_func(arg: str): return {} - + tool = _create_tool_from_function(dummy_func) - - assert 'type' in tool.parameters - assert 'properties' in tool.parameters - assert 'required' in tool.parameters - assert tool.parameters['type'] == 'object' + + assert "type" in tool.parameters + assert "properties" in tool.parameters + assert "required" in tool.parameters + assert tool.parameters["type"] == "object" diff --git a/tests/test_registry_edge_cases.py b/tests/test_registry_edge_cases.py index 7b4ce65..57a8365 100644 --- a/tests/test_registry_edge_cases.py +++ b/tests/test_registry_edge_cases.py @@ -1,6 +1,7 @@ """Edge case tests for tool registry.""" + + import pytest -from unittest.mock import Mock from agent.registry import Tool, make_tools @@ -182,7 +183,9 @@ class TestMakeToolsEdgeCases: params = tool.parameters if "required" in params and "properties" in params: for req in params["required"]: - assert req in params["properties"], f"Required param {req} not in properties for {tool.name}" + assert ( + req in params["properties"] + ), f"Required param {req} not in properties for {tool.name}" def test_make_tools_descriptions_not_empty(self, memory): """Should have non-empty descriptions.""" @@ -233,7 +236,9 @@ class TestMakeToolsEdgeCases: if "properties" in tool.parameters: for prop_name, prop_schema in tool.parameters["properties"].items(): if "type" in prop_schema: - assert prop_schema["type"] in valid_types, f"Invalid type for {tool.name}.{prop_name}" + assert ( + prop_schema["type"] in valid_types + ), f"Invalid type for {tool.name}.{prop_name}" def test_make_tools_enum_values(self, memory): """Should have valid enum values.""" diff --git a/tests/test_repositories.py b/tests/test_repositories.py index 988bb8f..fa5950f 100644 --- a/tests/test_repositories.py +++ b/tests/test_repositories.py @@ -1,19 +1,18 @@ """Tests for JSON repositories.""" -import pytest -from datetime import datetime -from infrastructure.persistence.json import ( - JsonMovieRepository, - JsonTVShowRepository, - JsonSubtitleRepository, -) + from domain.movies.entities import Movie -from domain.movies.value_objects import MovieTitle, ReleaseYear, Quality -from domain.tv_shows.entities import TVShow -from domain.tv_shows.value_objects import ShowStatus +from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear +from domain.shared.value_objects import FilePath, FileSize, ImdbId from domain.subtitles.entities import Subtitle from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset -from domain.shared.value_objects import ImdbId, FilePath, FileSize +from domain.tv_shows.entities import TVShow +from domain.tv_shows.value_objects import ShowStatus +from infrastructure.persistence.json import ( + JsonMovieRepository, + JsonSubtitleRepository, + JsonTVShowRepository, +) class TestJsonMovieRepository: @@ -224,7 +223,9 @@ class TestJsonTVShowRepository: """Should preserve show status.""" repo = JsonTVShowRepository() - for i, status in enumerate([ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]): + for i, status in enumerate( + [ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN] + ): show = TVShow( imdb_id=ImdbId(f"tt{i+1000000:07d}"), title=f"Show {status.value}", @@ -294,18 +295,22 @@ class TestJsonSubtitleRepository: def test_find_by_media_with_language_filter(self, memory): """Should filter by language.""" repo = JsonSubtitleRepository() - repo.save(Subtitle( - media_imdb_id=ImdbId("tt1375666"), - language=Language.ENGLISH, - format=SubtitleFormat.SRT, - file_path=FilePath("/subs/en.srt"), - )) - repo.save(Subtitle( - media_imdb_id=ImdbId("tt1375666"), - language=Language.FRENCH, - format=SubtitleFormat.SRT, - file_path=FilePath("/subs/fr.srt"), - )) + repo.save( + Subtitle( + media_imdb_id=ImdbId("tt1375666"), + language=Language.ENGLISH, + format=SubtitleFormat.SRT, + file_path=FilePath("/subs/en.srt"), + ) + ) + repo.save( + Subtitle( + media_imdb_id=ImdbId("tt1375666"), + language=Language.FRENCH, + format=SubtitleFormat.SRT, + file_path=FilePath("/subs/fr.srt"), + ) + ) results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH) @@ -315,22 +320,26 @@ class TestJsonSubtitleRepository: def test_find_by_media_with_episode_filter(self, memory): """Should filter by season/episode.""" repo = JsonSubtitleRepository() - repo.save(Subtitle( - media_imdb_id=ImdbId("tt0944947"), - language=Language.ENGLISH, - format=SubtitleFormat.SRT, - file_path=FilePath("/subs/s01e01.srt"), - season_number=1, - episode_number=1, - )) - repo.save(Subtitle( - media_imdb_id=ImdbId("tt0944947"), - language=Language.ENGLISH, - format=SubtitleFormat.SRT, - file_path=FilePath("/subs/s01e02.srt"), - season_number=1, - episode_number=2, - )) + repo.save( + Subtitle( + media_imdb_id=ImdbId("tt0944947"), + language=Language.ENGLISH, + format=SubtitleFormat.SRT, + file_path=FilePath("/subs/s01e01.srt"), + season_number=1, + episode_number=1, + ) + ) + repo.save( + Subtitle( + media_imdb_id=ImdbId("tt0944947"), + language=Language.ENGLISH, + format=SubtitleFormat.SRT, + file_path=FilePath("/subs/s01e02.srt"), + season_number=1, + episode_number=2, + ) + ) results = repo.find_by_media( ImdbId("tt0944947"), diff --git a/tests/test_tools_api.py b/tests/test_tools_api.py index 6b40287..137d2e1 100644 --- a/tests/test_tools_api.py +++ b/tests/test_tools_api.py @@ -21,24 +21,30 @@ def create_mock_response(status_code, json_data=None, text=None): class TestFindMediaImdbId: """Tests for find_media_imdb_id tool.""" - @patch('infrastructure.api.tmdb.client.requests.get') + @patch("infrastructure.api.tmdb.client.requests.get") def test_success(self, mock_get, memory): """Should return movie info on success.""" + # Mock HTTP responses def mock_get_side_effect(url, **kwargs): if "search" in url: - return create_mock_response(200, json_data={ - "results": [{ - "id": 27205, - "title": "Inception", - "release_date": "2010-07-16", - "overview": "A thief...", - "media_type": "movie" - }] - }) + return create_mock_response( + 200, + json_data={ + "results": [ + { + "id": 27205, + "title": "Inception", + "release_date": "2010-07-16", + "overview": "A thief...", + "media_type": "movie", + } + ] + }, + ) elif "external_ids" in url: return create_mock_response(200, json_data={"imdb_id": "tt1375666"}) - + mock_get.side_effect = mock_get_side_effect result = api_tools.find_media_imdb_id("Inception") @@ -46,26 +52,32 @@ class TestFindMediaImdbId: assert result["status"] == "ok" assert result["imdb_id"] == "tt1375666" assert result["title"] == "Inception" - + # Verify HTTP calls assert mock_get.call_count == 2 - @patch('infrastructure.api.tmdb.client.requests.get') + @patch("infrastructure.api.tmdb.client.requests.get") def test_stores_in_stm(self, mock_get, memory): """Should store result in STM on success.""" + def mock_get_side_effect(url, **kwargs): if "search" in url: - return create_mock_response(200, json_data={ - "results": [{ - "id": 27205, - "title": "Inception", - "release_date": "2010-07-16", - "media_type": "movie" - }] - }) + return create_mock_response( + 200, + json_data={ + "results": [ + { + "id": 27205, + "title": "Inception", + "release_date": "2010-07-16", + "media_type": "movie", + } + ] + }, + ) elif "external_ids" in url: return create_mock_response(200, json_data={"imdb_id": "tt1375666"}) - + mock_get.side_effect = mock_get_side_effect api_tools.find_media_imdb_id("Inception") @@ -76,7 +88,7 @@ class TestFindMediaImdbId: assert entity["title"] == "Inception" assert mem.stm.current_topic == "searching_media" - @patch('infrastructure.api.tmdb.client.requests.get') + @patch("infrastructure.api.tmdb.client.requests.get") def test_not_found(self, mock_get, memory): """Should return error when not found.""" mock_get.return_value = create_mock_response(200, json_data={"results": []}) @@ -86,7 +98,7 @@ class TestFindMediaImdbId: assert result["status"] == "error" assert result["error"] == "not_found" - @patch('infrastructure.api.tmdb.client.requests.get') + @patch("infrastructure.api.tmdb.client.requests.get") def test_does_not_store_on_error(self, mock_get, memory): """Should not store in STM on error.""" mock_get.return_value = create_mock_response(200, json_data={"results": []}) @@ -100,49 +112,57 @@ class TestFindMediaImdbId: class TestFindTorrent: """Tests for find_torrent tool.""" - @patch('infrastructure.api.knaben.client.requests.post') + @patch("infrastructure.api.knaben.client.requests.post") def test_success(self, mock_post, memory): """Should return torrents on success.""" - mock_post.return_value = create_mock_response(200, json_data={ - "hits": [ - { - "title": "Torrent 1", - "seeders": 100, - "leechers": 10, - "magnetUrl": "magnet:?xt=...", - "size": "2.5 GB" - }, - { - "title": "Torrent 2", - "seeders": 50, - "leechers": 5, - "magnetUrl": "magnet:?xt=...", - "size": "1.8 GB" - } - ] - }) + mock_post.return_value = create_mock_response( + 200, + json_data={ + "hits": [ + { + "title": "Torrent 1", + "seeders": 100, + "leechers": 10, + "magnetUrl": "magnet:?xt=...", + "size": "2.5 GB", + }, + { + "title": "Torrent 2", + "seeders": 50, + "leechers": 5, + "magnetUrl": "magnet:?xt=...", + "size": "1.8 GB", + }, + ] + }, + ) result = api_tools.find_torrent("Inception 1080p") assert result["status"] == "ok" assert len(result["torrents"]) == 2 - - # Verify HTTP payload - payload = mock_post.call_args[1]['json'] - assert payload['query'] == "Inception 1080p" - @patch('infrastructure.api.knaben.client.requests.post') + # Verify HTTP payload + payload = mock_post.call_args[1]["json"] + assert payload["query"] == "Inception 1080p" + + @patch("infrastructure.api.knaben.client.requests.post") def test_stores_in_episodic(self, mock_post, memory): """Should store results in episodic memory.""" - mock_post.return_value = create_mock_response(200, json_data={ - "hits": [{ - "title": "Torrent 1", - "seeders": 100, - "leechers": 10, - "magnetUrl": "magnet:?xt=...", - "size": "2.5 GB" - }] - }) + mock_post.return_value = create_mock_response( + 200, + json_data={ + "hits": [ + { + "title": "Torrent 1", + "seeders": 100, + "leechers": 10, + "magnetUrl": "magnet:?xt=...", + "size": "2.5 GB", + } + ] + }, + ) api_tools.find_torrent("Inception") @@ -151,16 +171,37 @@ class TestFindTorrent: assert mem.episodic.last_search_results["query"] == "Inception" assert mem.stm.current_topic == "selecting_torrent" - @patch('infrastructure.api.knaben.client.requests.post') + @patch("infrastructure.api.knaben.client.requests.post") def test_results_have_indexes(self, mock_post, memory): """Should add indexes to results.""" - mock_post.return_value = create_mock_response(200, json_data={ - "hits": [ - {"title": "Torrent 1", "seeders": 100, "leechers": 10, "magnetUrl": "magnet:?xt=1", "size": "1GB"}, - {"title": "Torrent 2", "seeders": 50, "leechers": 5, "magnetUrl": "magnet:?xt=2", "size": "2GB"}, - {"title": "Torrent 3", "seeders": 25, "leechers": 2, "magnetUrl": "magnet:?xt=3", "size": "3GB"} - ] - }) + mock_post.return_value = create_mock_response( + 200, + json_data={ + "hits": [ + { + "title": "Torrent 1", + "seeders": 100, + "leechers": 10, + "magnetUrl": "magnet:?xt=1", + "size": "1GB", + }, + { + "title": "Torrent 2", + "seeders": 50, + "leechers": 5, + "magnetUrl": "magnet:?xt=2", + "size": "2GB", + }, + { + "title": "Torrent 3", + "seeders": 25, + "leechers": 2, + "magnetUrl": "magnet:?xt=3", + "size": "3GB", + }, + ] + }, + ) api_tools.find_torrent("Test") @@ -170,7 +211,7 @@ class TestFindTorrent: assert results[1]["index"] == 2 assert results[2]["index"] == 3 - @patch('infrastructure.api.knaben.client.requests.post') + @patch("infrastructure.api.knaben.client.requests.post") def test_not_found(self, mock_post, memory): """Should return error when no torrents found.""" mock_post.return_value = create_mock_response(200, json_data={"hits": []}) @@ -236,16 +277,16 @@ class TestGetTorrentByIndex: class TestAddTorrentToQbittorrent: """Tests for add_torrent_to_qbittorrent tool. - + Note: These tests mock the qBittorrent client because: 1. The client requires authentication/session management 2. We want to test the tool's logic (memory updates, workflow management) 3. The client itself is tested separately in infrastructure tests - + This is acceptable mocking because we're testing the TOOL logic, not the client. """ - @patch('agent.tools.api.qbittorrent_client') + @patch("agent.tools.api.qbittorrent_client") def test_success(self, mock_client, memory): """Should add torrent successfully and update memory.""" mock_client.add_torrent.return_value = True @@ -257,7 +298,7 @@ class TestAddTorrentToQbittorrent: # Verify client was called correctly mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123") - @patch('agent.tools.api.qbittorrent_client') + @patch("agent.tools.api.qbittorrent_client") def test_adds_to_active_downloads(self, mock_client, memory_with_search_results): """Should add to active downloads on success.""" mock_client.add_torrent.return_value = True @@ -267,9 +308,12 @@ class TestAddTorrentToQbittorrent: # Test memory update logic mem = get_memory() assert len(mem.episodic.active_downloads) == 1 - assert mem.episodic.active_downloads[0]["name"] == "Inception.2010.1080p.BluRay.x264" + assert ( + mem.episodic.active_downloads[0]["name"] + == "Inception.2010.1080p.BluRay.x264" + ) - @patch('agent.tools.api.qbittorrent_client') + @patch("agent.tools.api.qbittorrent_client") def test_sets_topic_and_ends_workflow(self, mock_client, memory): """Should set topic and end workflow.""" mock_client.add_torrent.return_value = True @@ -282,10 +326,11 @@ class TestAddTorrentToQbittorrent: assert mem.stm.current_topic == "downloading" assert mem.stm.current_workflow is None - @patch('agent.tools.api.qbittorrent_client') + @patch("agent.tools.api.qbittorrent_client") def test_error_handling(self, mock_client, memory): """Should handle client errors correctly.""" from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError + mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed") result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...") @@ -296,7 +341,7 @@ class TestAddTorrentToQbittorrent: class TestAddTorrentByIndex: """Tests for add_torrent_by_index tool. - + These tests verify the tool's logic: - Getting torrent from memory by index - Extracting magnet link @@ -304,7 +349,7 @@ class TestAddTorrentByIndex: - Error handling for edge cases """ - @patch('agent.tools.api.qbittorrent_client') + @patch("agent.tools.api.qbittorrent_client") def test_success(self, mock_client, memory_with_search_results): """Should get torrent by index and add it.""" mock_client.add_torrent.return_value = True @@ -317,7 +362,7 @@ class TestAddTorrentByIndex: # Verify correct magnet was extracted and used mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123") - @patch('agent.tools.api.qbittorrent_client') + @patch("agent.tools.api.qbittorrent_client") def test_uses_correct_magnet(self, mock_client, memory_with_search_results): """Should extract correct magnet from index.""" mock_client.add_torrent.return_value = True diff --git a/tests/test_tools_edge_cases.py b/tests/test_tools_edge_cases.py index f682f03..23fb50e 100644 --- a/tests/test_tools_edge_cases.py +++ b/tests/test_tools_edge_cases.py @@ -1,7 +1,8 @@ """Edge case tests for tools.""" + +from unittest.mock import Mock, patch + import pytest -from unittest.mock import Mock, patch, MagicMock -from pathlib import Path from agent.tools import api as api_tools from agent.tools import filesystem as fs_tools @@ -15,7 +16,10 @@ class TestFindTorrentEdgeCases: def test_empty_query(self, mock_use_case_class, memory): """Should handle empty query.""" mock_response = Mock() - mock_response.to_dict.return_value = {"status": "error", "error": "invalid_query"} + mock_response.to_dict.return_value = { + "status": "error", + "error": "invalid_query", + } mock_use_case = Mock() mock_use_case.execute.return_value = mock_response mock_use_case_class.return_value = mock_use_case @@ -28,7 +32,11 @@ class TestFindTorrentEdgeCases: def test_very_long_query(self, mock_use_case_class, memory): """Should handle very long query.""" mock_response = Mock() - mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} + mock_response.to_dict.return_value = { + "status": "ok", + "torrents": [], + "count": 0, + } mock_use_case = Mock() mock_use_case.execute.return_value = mock_response mock_use_case_class.return_value = mock_use_case @@ -43,7 +51,11 @@ class TestFindTorrentEdgeCases: def test_special_characters_in_query(self, mock_use_case_class, memory): """Should handle special characters in query.""" mock_response = Mock() - mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} + mock_response.to_dict.return_value = { + "status": "ok", + "torrents": [], + "count": 0, + } mock_use_case = Mock() mock_use_case.execute.return_value = mock_response mock_use_case_class.return_value = mock_use_case @@ -57,7 +69,11 @@ class TestFindTorrentEdgeCases: def test_unicode_query(self, mock_use_case_class, memory): """Should handle unicode in query.""" mock_response = Mock() - mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} + mock_response.to_dict.return_value = { + "status": "ok", + "torrents": [], + "count": 0, + } mock_use_case = Mock() mock_use_case.execute.return_value = mock_response mock_use_case_class.return_value = mock_use_case @@ -161,7 +177,10 @@ class TestAddTorrentEdgeCases: def test_empty_magnet_link(self, mock_use_case_class, memory): """Should handle empty magnet link.""" mock_response = Mock() - mock_response.to_dict.return_value = {"status": "error", "error": "empty_magnet"} + mock_response.to_dict.return_value = { + "status": "error", + "error": "empty_magnet", + } mock_use_case = Mock() mock_use_case.execute.return_value = mock_response mock_use_case_class.return_value = mock_use_case @@ -326,7 +345,10 @@ class TestFilesystemEdgeCases: for attempt in attempts: result = fs_tools.list_folder("download", attempt) # Should either be forbidden or not found - assert result.get("error") in ["forbidden", "not_found", None] or result.get("status") == "ok" + assert ( + result.get("error") in ["forbidden", "not_found", None] + or result.get("status") == "ok" + ) def test_path_with_null_byte(self, memory, real_folder): """Should block null byte injection."""