Formatting

This commit is contained in:
2025-12-07 03:33:51 +01:00
parent a923a760ef
commit 4eae1d6d58
24 changed files with 1003 additions and 833 deletions

View File

@@ -1,7 +1,8 @@
"""Main agent for media library management.""" """Main agent for media library management."""
import json import json
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
@@ -15,157 +16,156 @@ logger = logging.getLogger(__name__)
class Agent: class Agent:
""" """
AI agent for media library management. AI agent for media library management.
Uses OpenAI-compatible tool calling API. Uses OpenAI-compatible tool calling API.
""" """
def __init__(self, llm, max_tool_iterations: int = 5): def __init__(self, llm, max_tool_iterations: int = 5):
""" """
Initialize the agent. Initialize the agent.
Args: Args:
llm: LLM client with complete() method llm: LLM client with complete() method
max_tool_iterations: Maximum number of tool execution iterations max_tool_iterations: Maximum number of tool execution iterations
""" """
self.llm = llm self.llm = llm
self.tools: Dict[str, Tool] = make_tools() self.tools: dict[str, Tool] = make_tools()
self.prompt_builder = PromptBuilder(self.tools) self.prompt_builder = PromptBuilder(self.tools)
self.max_tool_iterations = max_tool_iterations self.max_tool_iterations = max_tool_iterations
def step(self, user_input: str) -> str: def step(self, user_input: str) -> str:
""" """
Execute one agent step with the user input. Execute one agent step with the user input.
This method: This method:
1. Adds user message to memory 1. Adds user message to memory
2. Builds prompt with history and context 2. Builds prompt with history and context
3. Calls LLM, executing tools as needed 3. Calls LLM, executing tools as needed
4. Returns final response 4. Returns final response
Args: Args:
user_input: User's message user_input: User's message
Returns: Returns:
Agent's final response Agent's final response
""" """
memory = get_memory() memory = get_memory()
# Add user message to history # Add user message to history
memory.stm.add_message("user", user_input) memory.stm.add_message("user", user_input)
memory.save() memory.save()
# Build initial messages # Build initial messages
system_prompt = self.prompt_builder.build_system_prompt() system_prompt = self.prompt_builder.build_system_prompt()
messages: List[Dict[str, Any]] = [ messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
{"role": "system", "content": system_prompt}
]
# Add conversation history # Add conversation history
history = memory.stm.get_recent_history(settings.max_history_messages) history = memory.stm.get_recent_history(settings.max_history_messages)
messages.extend(history) messages.extend(history)
# Add unread events if any # Add unread events if any
unread_events = memory.episodic.get_unread_events() unread_events = memory.episodic.get_unread_events()
if unread_events: if unread_events:
events_text = "\n".join([ events_text = "\n".join(
f"- {e['type']}: {e['data']}" [f"- {e['type']}: {e['data']}" for e in unread_events]
for e in unread_events )
]) messages.append(
messages.append({ {"role": "system", "content": f"Background events:\n{events_text}"}
"role": "system", )
"content": f"Background events:\n{events_text}"
})
# Get tools specification for OpenAI format # Get tools specification for OpenAI format
tools_spec = self.prompt_builder.build_tools_spec() tools_spec = self.prompt_builder.build_tools_spec()
# Tool execution loop # Tool execution loop
for iteration in range(self.max_tool_iterations): for iteration in range(self.max_tool_iterations):
# Call LLM with tools # Call LLM with tools
llm_result = self.llm.complete(messages, tools=tools_spec) llm_result = self.llm.complete(messages, tools=tools_spec)
# Handle both tuple (response, usage) and dict response # Handle both tuple (response, usage) and dict response
if isinstance(llm_result, tuple): if isinstance(llm_result, tuple):
response_message, usage = llm_result response_message, usage = llm_result
else: else:
response_message = llm_result response_message = llm_result
# Check if there are tool calls # Check if there are tool calls
tool_calls = response_message.get("tool_calls") tool_calls = response_message.get("tool_calls")
if not tool_calls: if not tool_calls:
# No tool calls, this is the final response # No tool calls, this is the final response
final_content = response_message.get("content", "") final_content = response_message.get("content", "")
memory.stm.add_message("assistant", final_content) memory.stm.add_message("assistant", final_content)
memory.save() memory.save()
return final_content return final_content
# Add assistant message with tool calls to conversation # Add assistant message with tool calls to conversation
messages.append(response_message) messages.append(response_message)
# Execute each tool call # Execute each tool call
for tool_call in tool_calls: for tool_call in tool_calls:
tool_result = self._execute_tool_call(tool_call) tool_result = self._execute_tool_call(tool_call)
# Add tool result to messages # Add tool result to messages
messages.append({ messages.append(
"tool_call_id": tool_call.get("id"), {
"role": "tool", "tool_call_id": tool_call.get("id"),
"name": tool_call.get("function", {}).get("name"), "role": "tool",
"content": json.dumps(tool_result, ensure_ascii=False), "name": tool_call.get("function", {}).get("name"),
}) "content": json.dumps(tool_result, ensure_ascii=False),
}
)
# Max iterations reached, force final response # Max iterations reached, force final response
messages.append({ messages.append(
"role": "system", {
"content": "Please provide a final response to the user without using any more tools." "role": "system",
}) "content": "Please provide a final response to the user without using any more tools.",
}
)
llm_result = self.llm.complete(messages) llm_result = self.llm.complete(messages)
if isinstance(llm_result, tuple): if isinstance(llm_result, tuple):
final_message, usage = llm_result final_message, usage = llm_result
else: else:
final_message = llm_result final_message = llm_result
final_response = final_message.get("content", "I've completed the requested actions.") final_response = final_message.get(
"content", "I've completed the requested actions."
)
memory.stm.add_message("assistant", final_response) memory.stm.add_message("assistant", final_response)
memory.save() memory.save()
return final_response return final_response
def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
""" """
Execute a single tool call. Execute a single tool call.
Args: Args:
tool_call: OpenAI-format tool call dict tool_call: OpenAI-format tool call dict
Returns: Returns:
Result dictionary Result dictionary
""" """
function = tool_call.get("function", {}) function = tool_call.get("function", {})
tool_name = function.get("name", "") tool_name = function.get("name", "")
try: try:
args_str = function.get("arguments", "{}") args_str = function.get("arguments", "{}")
args = json.loads(args_str) args = json.loads(args_str)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {e}") logger.error(f"Failed to parse tool arguments: {e}")
return { return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"}
"error": "bad_args",
"message": f"Invalid JSON arguments: {e}"
}
# Validate tool exists # Validate tool exists
if tool_name not in self.tools: if tool_name not in self.tools:
available = list(self.tools.keys()) available = list(self.tools.keys())
return { return {
"error": "unknown_tool", "error": "unknown_tool",
"message": f"Tool '{tool_name}' not found", "message": f"Tool '{tool_name}' not found",
"available_tools": available "available_tools": available,
} }
tool = self.tools[tool_name] tool = self.tools[tool_name]
# Execute tool # Execute tool
try: try:
result = tool.func(**args) result = tool.func(**args)
@@ -177,17 +177,9 @@ class Agent:
# Bad arguments # Bad arguments
memory = get_memory() memory = get_memory()
memory.episodic.add_error(tool_name, f"bad_args: {e}") memory.episodic.add_error(tool_name, f"bad_args: {e}")
return { return {"error": "bad_args", "message": str(e), "tool": tool_name}
"error": "bad_args",
"message": str(e),
"tool": tool_name
}
except Exception as e: except Exception as e:
# Other errors # Other errors
memory = get_memory() memory = get_memory()
memory.episodic.add_error(tool_name, str(e)) memory.episodic.add_error(tool_name, str(e))
return { return {"error": "execution_failed", "message": str(e), "tool": tool_name}
"error": "execution_failed",
"message": str(e),
"tool": tool_name
}

View File

@@ -51,7 +51,9 @@ class DeepSeekClient:
logger.info(f"DeepSeek client initialized with model: {self.model}") logger.info(f"DeepSeek client initialized with model: {self.model}")
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]: def complete(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
""" """
Generate a completion from the LLM. Generate a completion from the LLM.
@@ -80,7 +82,9 @@ class DeepSeekClient:
raise ValueError(f"Invalid role: {msg['role']}") raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead) # Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg: if msg["role"] != "tool" and "content" not in msg:
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}") raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/v1/chat/completions" url = f"{self.base_url}/v1/chat/completions"
headers = { headers = {
@@ -92,13 +96,15 @@ class DeepSeekClient:
"messages": messages, "messages": messages,
"temperature": settings.temperature, "temperature": settings.temperature,
} }
# Add tools if provided # Add tools if provided
if tools: if tools:
payload["tools"] = tools payload["tools"] = tools
try: try:
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools") logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post( response = requests.post(
url, headers=headers, json=payload, timeout=self.timeout url, headers=headers, json=payload, timeout=self.timeout
) )

View File

@@ -66,7 +66,9 @@ class OllamaClient:
logger.info(f"Ollama client initialized with model: {self.model}") logger.info(f"Ollama client initialized with model: {self.model}")
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]: def complete(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
""" """
Generate a completion from the LLM. Generate a completion from the LLM.
@@ -95,7 +97,9 @@ class OllamaClient:
raise ValueError(f"Invalid role: {msg['role']}") raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead) # Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg: if msg["role"] != "tool" and "content" not in msg:
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}") raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/api/chat" url = f"{self.base_url}/api/chat"
payload = { payload = {
@@ -106,13 +110,15 @@ class OllamaClient:
"temperature": self.temperature, "temperature": self.temperature,
}, },
} }
# Add tools if provided # Add tools if provided
if tools: if tools:
payload["tools"] = tools payload["tools"] = tools
try: try:
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools") logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(url, json=payload, timeout=self.timeout) response = requests.post(url, json=payload, timeout=self.timeout)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()

View File

@@ -1,18 +1,20 @@
"""Prompt builder for the agent system.""" """Prompt builder for the agent system."""
from typing import Dict, List, Any
import json import json
from typing import Any
from infrastructure.persistence import get_memory
from .registry import Tool from .registry import Tool
from infrastructure.persistence import get_memory
class PromptBuilder: class PromptBuilder:
"""Builds system prompts for the agent with memory context.""" """Builds system prompts for the agent with memory context."""
def __init__(self, tools: Dict[str, Tool]): def __init__(self, tools: dict[str, Tool]):
self.tools = tools self.tools = tools
def build_tools_spec(self) -> List[Dict[str, Any]]: def build_tools_spec(self) -> list[dict[str, Any]]:
"""Build the tool specification for the LLM API.""" """Build the tool specification for the LLM API."""
tool_specs = [] tool_specs = []
for tool in self.tools.values(): for tool in self.tools.values():
@@ -44,11 +46,13 @@ class PromptBuilder:
if memory.episodic.last_search_results: if memory.episodic.last_search_results:
results = memory.episodic.last_search_results results = memory.episodic.last_search_results
result_list = results.get('results', []) result_list = results.get("results", [])
lines.append(f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)") lines.append(
f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)"
)
# Show first 5 results # Show first 5 results
for i, result in enumerate(result_list[:5]): for i, result in enumerate(result_list[:5]):
name = result.get('name', 'Unknown') name = result.get("name", "Unknown")
lines.append(f" {i+1}. {name}") lines.append(f" {i+1}. {name}")
if len(result_list) > 5: if len(result_list) > 5:
lines.append(f" ... and {len(result_list) - 5} more") lines.append(f" ... and {len(result_list) - 5} more")
@@ -57,7 +61,7 @@ class PromptBuilder:
question = memory.episodic.pending_question question = memory.episodic.pending_question
lines.append(f"\nPENDING QUESTION: {question.get('question')}") lines.append(f"\nPENDING QUESTION: {question.get('question')}")
lines.append(f" Type: {question.get('type')}") lines.append(f" Type: {question.get('type')}")
if question.get('options'): if question.get("options"):
lines.append(f" Options: {len(question.get('options'))}") lines.append(f" Options: {len(question.get('options'))}")
if memory.episodic.active_downloads: if memory.episodic.active_downloads:
@@ -68,10 +72,12 @@ class PromptBuilder:
if memory.episodic.recent_errors: if memory.episodic.recent_errors:
lines.append("\nRECENT ERRORS (up to 3):") lines.append("\nRECENT ERRORS (up to 3):")
for error in memory.episodic.recent_errors[-3:]: for error in memory.episodic.recent_errors[-3:]:
lines.append(f" - Action '{error.get('action')}' failed: {error.get('error')}") lines.append(
f" - Action '{error.get('action')}' failed: {error.get('error')}"
)
# Unread events # Unread events
unread = [e for e in memory.episodic.background_events if not e.get('read')] unread = [e for e in memory.episodic.background_events if not e.get("read")]
if unread: if unread:
lines.append(f"\nUNREAD EVENTS: {len(unread)}") lines.append(f"\nUNREAD EVENTS: {len(unread)}")
for event in unread[:3]: for event in unread[:3]:
@@ -86,8 +92,10 @@ class PromptBuilder:
if memory.stm.current_workflow: if memory.stm.current_workflow:
workflow = memory.stm.current_workflow workflow = memory.stm.current_workflow
lines.append(f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})") lines.append(
if workflow.get('target'): f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})"
)
if workflow.get("target"):
lines.append(f" Target: {workflow.get('target')}") lines.append(f" Target: {workflow.get('target')}")
if memory.stm.current_topic: if memory.stm.current_topic:
@@ -97,7 +105,7 @@ class PromptBuilder:
lines.append("EXTRACTED ENTITIES:") lines.append("EXTRACTED ENTITIES:")
for key, value in memory.stm.extracted_entities.items(): for key, value in memory.stm.extracted_entities.items():
lines.append(f" - {key}: {value}") lines.append(f" - {key}: {value}")
if memory.stm.language: if memory.stm.language:
lines.append(f"CONVERSATION LANGUAGE: {memory.stm.language}") lines.append(f"CONVERSATION LANGUAGE: {memory.stm.language}")
@@ -106,7 +114,7 @@ class PromptBuilder:
def _format_config_context(self) -> str: def _format_config_context(self) -> str:
"""Format configuration context.""" """Format configuration context."""
memory = get_memory() memory = get_memory()
lines = ["CURRENT CONFIGURATION:"] lines = ["CURRENT CONFIGURATION:"]
if memory.ltm.config: if memory.ltm.config:
for key, value in memory.ltm.config.items(): for key, value in memory.ltm.config.items():
@@ -118,10 +126,10 @@ class PromptBuilder:
def build_system_prompt(self) -> str: def build_system_prompt(self) -> str:
"""Build the complete system prompt.""" """Build the complete system prompt."""
memory = get_memory() memory = get_memory()
# Base instruction # Base instruction
base = "You are a helpful AI assistant for managing a media library." base = "You are a helpful AI assistant for managing a media library."
# Language instruction # Language instruction
language_instruction = ( language_instruction = (
"Your first task is to determine the user's language from their message " "Your first task is to determine the user's language from their message "

View File

@@ -1,8 +1,10 @@
"""Tool registry - defines and registers all available tools for the agent.""" """Tool registry - defines and registers all available tools for the agent."""
from dataclasses import dataclass
from typing import Callable, Any, Dict
import logging
import inspect import inspect
import logging
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -10,36 +12,37 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class Tool: class Tool:
"""Represents a tool that can be used by the agent.""" """Represents a tool that can be used by the agent."""
name: str name: str
description: str description: str
func: Callable[..., Dict[str, Any]] func: Callable[..., dict[str, Any]]
parameters: Dict[str, Any] parameters: dict[str, Any]
def _create_tool_from_function(func: Callable) -> Tool: def _create_tool_from_function(func: Callable) -> Tool:
""" """
Create a Tool object from a function. Create a Tool object from a function.
Args: Args:
func: Function to convert to a tool func: Function to convert to a tool
Returns: Returns:
Tool object with metadata extracted from function Tool object with metadata extracted from function
""" """
sig = inspect.signature(func) sig = inspect.signature(func)
doc = inspect.getdoc(func) doc = inspect.getdoc(func)
# Extract description from docstring (first line) # Extract description from docstring (first line)
description = doc.strip().split('\n')[0] if doc else func.__name__ description = doc.strip().split("\n")[0] if doc else func.__name__
# Build JSON schema from function signature # Build JSON schema from function signature
properties = {} properties = {}
required = [] required = []
for param_name, param in sig.parameters.items(): for param_name, param in sig.parameters.items():
if param_name == "self": if param_name == "self":
continue continue
# Map Python types to JSON schema types # Map Python types to JSON schema types
param_type = "string" # default param_type = "string" # default
if param.annotation != inspect.Parameter.empty: if param.annotation != inspect.Parameter.empty:
@@ -51,22 +54,22 @@ def _create_tool_from_function(func: Callable) -> Tool:
param_type = "number" param_type = "number"
elif param.annotation == bool: elif param.annotation == bool:
param_type = "boolean" param_type = "boolean"
properties[param_name] = { properties[param_name] = {
"type": param_type, "type": param_type,
"description": f"Parameter {param_name}" "description": f"Parameter {param_name}",
} }
# Add to required if no default value # Add to required if no default value
if param.default == inspect.Parameter.empty: if param.default == inspect.Parameter.empty:
required.append(param_name) required.append(param_name)
parameters = { parameters = {
"type": "object", "type": "object",
"properties": properties, "properties": properties,
"required": required, "required": required,
} }
return Tool( return Tool(
name=func.__name__, name=func.__name__,
description=description, description=description,
@@ -75,18 +78,18 @@ def _create_tool_from_function(func: Callable) -> Tool:
) )
def make_tools() -> Dict[str, Tool]: def make_tools() -> dict[str, Tool]:
""" """
Create and register all available tools. Create and register all available tools.
Returns: Returns:
Dictionary mapping tool names to Tool objects Dictionary mapping tool names to Tool objects
""" """
# Import tools here to avoid circular dependencies # Import tools here to avoid circular dependencies
from .tools import filesystem as fs_tools
from .tools import api as api_tools from .tools import api as api_tools
from .tools import filesystem as fs_tools
from .tools import language as lang_tools from .tools import language as lang_tools
# List of all tool functions # List of all tool functions
tool_functions = [ tool_functions = [
fs_tools.set_path_for_folder, fs_tools.set_path_for_folder,
@@ -98,12 +101,12 @@ def make_tools() -> Dict[str, Tool]:
api_tools.get_torrent_by_index, api_tools.get_torrent_by_index,
lang_tools.set_language, lang_tools.set_language,
] ]
# Create Tool objects from functions # Create Tool objects from functions
tools = {} tools = {}
for func in tool_functions: for func in tool_functions:
tool = _create_tool_from_function(func) tool = _create_tool_from_function(func)
tools[tool.name] = tool tools[tool.name] = tool
logger.info(f"Registered {len(tools)} tools: {list(tools.keys())}") logger.info(f"Registered {len(tools)} tools: {list(tools.keys())}")
return tools return tools

View File

@@ -1,19 +1,20 @@
"""Language management tools for the agent.""" """Language management tools for the agent."""
import logging import logging
from typing import Dict, Any from typing import Any
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def set_language(language: str) -> Dict[str, Any]: def set_language(language: str) -> dict[str, Any]:
""" """
Set the conversation language. Set the conversation language.
Args: Args:
language: Language code (e.g., 'en', 'fr', 'es', 'de') language: Language code (e.g., 'en', 'fr', 'es', 'de')
Returns: Returns:
Status dictionary Status dictionary
""" """
@@ -21,17 +22,14 @@ def set_language(language: str) -> Dict[str, Any]:
memory = get_memory() memory = get_memory()
memory.stm.set_language(language) memory.stm.set_language(language)
memory.save() memory.save()
logger.info(f"Language set to: {language}") logger.info(f"Language set to: {language}")
return { return {
"status": "ok", "status": "ok",
"message": f"Language set to {language}", "message": f"Language set to {language}",
"language": language "language": language,
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to set language: {e}") logger.error(f"Failed to set language: {e}")
return { return {"status": "error", "error": str(e)}
"status": "error",
"error": str(e)
}

View File

@@ -359,9 +359,7 @@ class EpisodicMemory:
"""Get active downloads.""" """Get active downloads."""
return self.active_downloads return self.active_downloads
def add_error( def add_error(self, action: str, error: str, context: dict | None = None) -> None:
self, action: str, error: str, context: dict | None = None
) -> None:
"""Record a recent error.""" """Record a recent error."""
self.recent_errors.append( self.recent_errors.append(
{ {
@@ -408,9 +406,7 @@ class EpisodicMemory:
"""Get the pending question.""" """Get the pending question."""
return self.pending_question return self.pending_question
def resolve_pending_question( def resolve_pending_question(self, answer_index: int | None = None) -> dict | None:
self, answer_index: int | None = None
) -> dict | None:
""" """
Resolve the pending question and return the chosen option. Resolve the pending question and return the chosen option.

View File

@@ -110,4 +110,4 @@ select = [
"PL", "PL",
"UP", "UP",
] ]
ignore = ["W503", "PLR0913", "PLR2004"] ignore = ["PLR0913", "PLR2004"]

View File

@@ -1,16 +1,13 @@
"""Pytest configuration and shared fixtures.""" """Pytest configuration and shared fixtures."""
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, MagicMock
from infrastructure.persistence import Memory, init_memory, set_memory, get_memory import shutil
from infrastructure.persistence.memory import ( import tempfile
LongTermMemory, from pathlib import Path
ShortTermMemory, from unittest.mock import MagicMock, Mock
EpisodicMemory,
) import pytest
from infrastructure.persistence import Memory, set_memory
@pytest.fixture @pytest.fixture
@@ -122,12 +119,11 @@ def memory_with_library(memory):
def mock_llm(): def mock_llm():
"""Create a mock LLM client that returns OpenAI-compatible format.""" """Create a mock LLM client that returns OpenAI-compatible format."""
llm = Mock() llm = Mock()
# Return OpenAI-style message dict without tool calls # Return OpenAI-style message dict without tool calls
def complete_func(messages, tools=None): def complete_func(messages, tools=None):
return { return {"role": "assistant", "content": "I found what you're looking for!"}
"role": "assistant",
"content": "I found what you're looking for!"
}
llm.complete = Mock(side_effect=complete_func) llm.complete = Mock(side_effect=complete_func)
return llm return llm
@@ -136,34 +132,33 @@ def mock_llm():
def mock_llm_with_tool_call(): def mock_llm_with_tool_call():
"""Create a mock LLM that returns a tool call then a response.""" """Create a mock LLM that returns a tool call then a response."""
llm = Mock() llm = Mock()
# First call returns a tool call, second returns final response # First call returns a tool call, second returns final response
def complete_side_effect(messages, tools=None): def complete_side_effect(messages, tools=None):
if not hasattr(complete_side_effect, 'call_count'): if not hasattr(complete_side_effect, "call_count"):
complete_side_effect.call_count = 0 complete_side_effect.call_count = 0
complete_side_effect.call_count += 1 complete_side_effect.call_count += 1
if complete_side_effect.call_count == 1: if complete_side_effect.call_count == 1:
# First call: return tool call # First call: return tool call
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_123", {
"type": "function", "id": "call_123",
"function": { "type": "function",
"name": "find_torrent", "function": {
"arguments": '{"media_title": "Inception"}' "name": "find_torrent",
"arguments": '{"media_title": "Inception"}',
},
} }
}] ],
} }
else: else:
# Second call: return final response # Second call: return final response
return { return {"role": "assistant", "content": "I found 3 torrents for Inception!"}
"role": "assistant",
"content": "I found 3 torrents for Inception!"
}
llm.complete = Mock(side_effect=complete_side_effect) llm.complete = Mock(side_effect=complete_side_effect)
return llm return llm
@@ -248,36 +243,36 @@ def mock_deepseek():
""" """
Mock DeepSeekClient for individual tests that need it. Mock DeepSeekClient for individual tests that need it.
This prevents real API calls in tests that use this fixture. This prevents real API calls in tests that use this fixture.
Usage: Usage:
def test_something(mock_deepseek): def test_something(mock_deepseek):
# Your test code here # Your test code here
""" """
import sys import sys
from unittest.mock import Mock, MagicMock from unittest.mock import Mock
# Save the original module if it exists # Save the original module if it exists
original_module = sys.modules.get('agent.llm.deepseek') original_module = sys.modules.get("agent.llm.deepseek")
# Create a mock module for deepseek # Create a mock module for deepseek
mock_deepseek_module = MagicMock() mock_deepseek_module = MagicMock()
class MockDeepSeekClient: class MockDeepSeekClient:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.complete = Mock(return_value="Mocked LLM response") self.complete = Mock(return_value="Mocked LLM response")
mock_deepseek_module.DeepSeekClient = MockDeepSeekClient mock_deepseek_module.DeepSeekClient = MockDeepSeekClient
# Inject the mock # Inject the mock
sys.modules['agent.llm.deepseek'] = mock_deepseek_module sys.modules["agent.llm.deepseek"] = mock_deepseek_module
yield mock_deepseek_module yield mock_deepseek_module
# Restore the original module # Restore the original module
if original_module is not None: if original_module is not None:
sys.modules['agent.llm.deepseek'] = original_module sys.modules["agent.llm.deepseek"] = original_module
elif 'agent.llm.deepseek' in sys.modules: elif "agent.llm.deepseek" in sys.modules:
del sys.modules['agent.llm.deepseek'] del sys.modules["agent.llm.deepseek"]
@pytest.fixture @pytest.fixture
@@ -287,8 +282,8 @@ def mock_agent_step():
Returns a context manager that patches app.agent.step. Returns a context manager that patches app.agent.step.
""" """
from unittest.mock import patch from unittest.mock import patch
def _mock_step(return_value="Mocked agent response"): def _mock_step(return_value="Mocked agent response"):
return patch("app.agent.step", return_value=return_value) return patch("app.agent.step", return_value=return_value)
return _mock_step return _mock_step

View File

@@ -1,6 +1,6 @@
"""Tests for the Agent.""" """Tests for the Agent."""
from unittest.mock import Mock, patch from unittest.mock import Mock
from agent.agent import Agent from agent.agent import Agent
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
@@ -55,8 +55,8 @@ class TestExecuteToolCall:
"id": "call_123", "id": "call_123",
"function": { "function": {
"name": "list_folder", "name": "list_folder",
"arguments": '{"folder_type": "download"}' "arguments": '{"folder_type": "download"}',
} },
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -68,10 +68,7 @@ class TestExecuteToolCall:
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "unknown_tool", "arguments": "{}"},
"name": "unknown_tool",
"arguments": '{}'
}
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -84,10 +81,7 @@ class TestExecuteToolCall:
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "set_path_for_folder", "arguments": "{}"},
"name": "set_path_for_folder",
"arguments": '{}'
}
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -102,8 +96,8 @@ class TestExecuteToolCall:
"id": "call_123", "id": "call_123",
"function": { "function": {
"name": "set_path_for_folder", "name": "set_path_for_folder",
"arguments": '{"folder_name": 123}' # Wrong type "arguments": '{"folder_name": 123}', # Wrong type
} },
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -116,10 +110,7 @@ class TestExecuteToolCall:
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "list_folder", "arguments": "{invalid json}"},
"name": "list_folder",
"arguments": '{invalid json}'
}
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -160,40 +151,39 @@ class TestStep:
assert "found" in response.lower() or "torrent" in response.lower() assert "found" in response.lower() or "torrent" in response.lower()
assert mock_llm_with_tool_call.complete.call_count == 2 assert mock_llm_with_tool_call.complete.call_count == 2
# CRITICAL: Verify tools were passed to LLM # CRITICAL: Verify tools were passed to LLM
first_call_args = mock_llm_with_tool_call.complete.call_args_list[0] first_call_args = mock_llm_with_tool_call.complete.call_args_list[0]
assert first_call_args[1]['tools'] is not None, "Tools not passed to LLM!" assert first_call_args[1]["tools"] is not None, "Tools not passed to LLM!"
assert len(first_call_args[1]['tools']) > 0, "Tools list is empty!" assert len(first_call_args[1]["tools"]) > 0, "Tools list is empty!"
def test_step_max_iterations(self, memory, mock_llm): def test_step_max_iterations(self, memory, mock_llm):
"""Should stop after max iterations.""" """Should stop after max iterations."""
call_count = [0] call_count = [0]
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
call_count[0] += 1 call_count[0] += 1
# CRITICAL: Verify tools are passed (except on forced final call) # CRITICAL: Verify tools are passed (except on forced final call)
if call_count[0] <= 3: if call_count[0] <= 3:
assert tools is not None, f"Tools not passed on call {call_count[0]}!" assert tools is not None, f"Tools not passed on call {call_count[0]}!"
if call_count[0] <= 3: if call_count[0] <= 3:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": f"call_{call_count[0]}", {
"function": { "id": f"call_{call_count[0]}",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
else: else:
return { return {"role": "assistant", "content": "I couldn't complete the task."}
"role": "assistant",
"content": "I couldn't complete the task."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3) agent = Agent(llm=mock_llm, max_tool_iterations=3)
@@ -241,49 +231,53 @@ class TestAgentIntegration:
memory.ltm.set_config("movie_folder", str(real_folder["movies"])) memory.ltm.set_config("movie_folder", str(real_folder["movies"]))
call_count = [0] call_count = [0]
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
call_count[0] += 1 call_count[0] += 1
# CRITICAL: Verify tools are passed on every call # CRITICAL: Verify tools are passed on every call
assert tools is not None, f"Tools not passed on call {call_count[0]}!" assert tools is not None, f"Tools not passed on call {call_count[0]}!"
if call_count[0] == 1: if call_count[0] == 1:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
elif call_count[0] == 2: elif call_count[0] == 2:
# CRITICAL: Verify tool result was sent back # CRITICAL: Verify tool result was sent back
tool_messages = [m for m in messages if m.get('role') == 'tool'] tool_messages = [m for m in messages if m.get("role") == "tool"]
assert len(tool_messages) > 0, "Tool result not sent back to LLM!" assert len(tool_messages) > 0, "Tool result not sent back to LLM!"
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_2", {
"function": { "id": "call_2",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "movie"}' "name": "list_folder",
"arguments": '{"folder_type": "movie"}',
},
} }
}] ],
} }
else: else:
return { return {
"role": "assistant", "role": "assistant",
"content": "I listed both folders for you." "content": "I listed both folders for you.",
} }
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
response = agent.step("List my downloads and movies") response = agent.step("List my downloads and movies")
assert call_count[0] == 3 assert call_count[0] == 3

View File

@@ -1,7 +1,9 @@
"""Edge case tests for the Agent.""" """Edge case tests for the Agent."""
import pytest
from unittest.mock import Mock from unittest.mock import Mock
import pytest
from agent.agent import Agent from agent.agent import Agent
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
@@ -15,19 +17,14 @@ class TestExecuteToolCallEdgeCases:
# Mock a tool that returns None # Mock a tool that returns None
from agent.registry import Tool from agent.registry import Tool
agent.tools["test_tool"] = Tool( agent.tools["test_tool"] = Tool(
name="test_tool", name="test_tool", description="Test", func=lambda: None, parameters={}
description="Test",
func=lambda: None,
parameters={}
) )
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "test_tool", "arguments": "{}"},
"name": "test_tool",
"arguments": '{}'
}
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -38,22 +35,17 @@ class TestExecuteToolCallEdgeCases:
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
from agent.registry import Tool from agent.registry import Tool
def raise_interrupt(): def raise_interrupt():
raise KeyboardInterrupt() raise KeyboardInterrupt()
agent.tools["test_tool"] = Tool( agent.tools["test_tool"] = Tool(
name="test_tool", name="test_tool", description="Test", func=raise_interrupt, parameters={}
description="Test",
func=raise_interrupt,
parameters={}
) )
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "test_tool", "arguments": "{}"},
"name": "test_tool",
"arguments": '{}'
}
} }
with pytest.raises(KeyboardInterrupt): with pytest.raises(KeyboardInterrupt):
@@ -68,8 +60,8 @@ class TestExecuteToolCallEdgeCases:
"id": "call_123", "id": "call_123",
"function": { "function": {
"name": "list_folder", "name": "list_folder",
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}' "arguments": '{"folder_type": "download", "extra_arg": "ignored"}',
} },
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -84,8 +76,8 @@ class TestExecuteToolCallEdgeCases:
"id": "call_123", "id": "call_123",
"function": { "function": {
"name": "get_torrent_by_index", "name": "get_torrent_by_index",
"arguments": '{"index": "not an int"}' "arguments": '{"index": "not an int"}',
} },
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -115,12 +107,10 @@ class TestStepEdgeCases:
def test_step_with_unicode_input(self, memory, mock_llm): def test_step_with_unicode_input(self, memory, mock_llm):
"""Should handle unicode input.""" """Should handle unicode input."""
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
return { return {"role": "assistant", "content": "日本語の応答"}
"role": "assistant",
"content": "日本語の応答"
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
@@ -130,12 +120,10 @@ class TestStepEdgeCases:
def test_step_llm_returns_empty(self, memory, mock_llm): def test_step_llm_returns_empty(self, memory, mock_llm):
"""Should handle LLM returning empty string.""" """Should handle LLM returning empty string."""
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
return { return {"role": "assistant", "content": ""}
"role": "assistant",
"content": ""
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
@@ -161,18 +149,17 @@ class TestStepEdgeCases:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": f"call_{call_count[0]}", {
"function": { "id": f"call_{call_count[0]}",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "Done looping"}
"role": "assistant",
"content": "Done looping"
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3) agent = Agent(llm=mock_llm, max_tool_iterations=3)
@@ -212,11 +199,13 @@ class TestStepEdgeCases:
def test_step_with_active_downloads(self, memory, mock_llm): def test_step_with_active_downloads(self, memory, mock_llm):
"""Should include active downloads in context.""" """Should include active downloads in context."""
memory.episodic.add_active_download({ memory.episodic.add_active_download(
"task_id": "123", {
"name": "Movie.mkv", "task_id": "123",
"progress": 50, "name": "Movie.mkv",
}) "progress": 50,
}
)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
response = agent.step("Hello") response = agent.step("Hello")
@@ -257,29 +246,28 @@ class TestAgentConcurrencyEdgeCases:
memory.ltm.set_config("download_folder", str(real_folder["downloads"])) memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
call_count = [0] call_count = [0]
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
call_count[0] += 1 call_count[0] += 1
if call_count[0] == 1: if call_count[0] == 1:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "set_path_for_folder", "function": {
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}' "name": "set_path_for_folder",
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}',
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "Path set successfully."}
"role": "assistant",
"content": "Path set successfully."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
response = agent.step("Set movie folder") response = agent.step("Set movie folder")
mem = get_memory() mem = get_memory()
@@ -292,29 +280,28 @@ class TestAgentErrorRecovery:
def test_recovers_from_tool_error(self, memory, mock_llm): def test_recovers_from_tool_error(self, memory, mock_llm):
"""Should recover from tool error and continue.""" """Should recover from tool error and continue."""
call_count = [0] call_count = [0]
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
call_count[0] += 1 call_count[0] += 1
if call_count[0] == 1: if call_count[0] == 1:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "The folder is not configured."}
"role": "assistant",
"content": "The folder is not configured."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
response = agent.step("List downloads") response = agent.step("List downloads")
assert "not configured" in response.lower() or len(response) > 0 assert "not configured" in response.lower() or len(response) > 0
@@ -322,29 +309,28 @@ class TestAgentErrorRecovery:
def test_error_tracked_in_memory(self, memory, mock_llm): def test_error_tracked_in_memory(self, memory, mock_llm):
"""Should track errors in episodic memory.""" """Should track errors in episodic memory."""
call_count = [0] call_count = [0]
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
call_count[0] += 1 call_count[0] += 1
if call_count[0] == 1: if call_count[0] == 1:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "set_path_for_folder", "function": {
"arguments": '{}' # Missing required args "name": "set_path_for_folder",
"arguments": "{}", # Missing required args
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "Error occurred."}
"role": "assistant",
"content": "Error occurred."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
agent.step("Set folder") agent.step("Set folder")
mem = get_memory() mem = get_memory()
@@ -360,18 +346,17 @@ class TestAgentErrorRecovery:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": f"call_{call_count[0]}", {
"function": { "id": f"call_{call_count[0]}",
"name": "set_path_for_folder", "function": {
"arguments": '{}' # Missing required args - will error "name": "set_path_for_folder",
"arguments": "{}", # Missing required args - will error
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "All attempts failed."}
"role": "assistant",
"content": "All attempts failed."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3) agent = Agent(llm=mock_llm, max_tool_iterations=3)

View File

@@ -1,6 +1,7 @@
"""Tests for FastAPI endpoints.""" """Tests for FastAPI endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock from unittest.mock import patch
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -10,6 +11,7 @@ class TestHealthEndpoint:
def test_health_check(self, memory): def test_health_check(self, memory):
"""Should return healthy status.""" """Should return healthy status."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/health") response = client.get("/health")
@@ -24,6 +26,7 @@ class TestModelsEndpoint:
def test_list_models(self, memory): def test_list_models(self, memory):
"""Should return model list.""" """Should return model list."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/v1/models") response = client.get("/v1/models")
@@ -41,6 +44,7 @@ class TestMemoryEndpoints:
def test_get_memory_state(self, memory): def test_get_memory_state(self, memory):
"""Should return full memory state.""" """Should return full memory state."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/memory/state") response = client.get("/memory/state")
@@ -54,6 +58,7 @@ class TestMemoryEndpoints:
def test_get_search_results_empty(self, memory): def test_get_search_results_empty(self, memory):
"""Should return empty when no search results.""" """Should return empty when no search results."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/memory/episodic/search-results") response = client.get("/memory/episodic/search-results")
@@ -65,6 +70,7 @@ class TestMemoryEndpoints:
def test_get_search_results_with_data(self, memory_with_search_results): def test_get_search_results_with_data(self, memory_with_search_results):
"""Should return search results when available.""" """Should return search results when available."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/memory/episodic/search-results") response = client.get("/memory/episodic/search-results")
@@ -78,6 +84,7 @@ class TestMemoryEndpoints:
def test_clear_session(self, memory_with_search_results): def test_clear_session(self, memory_with_search_results):
"""Should clear session memories.""" """Should clear session memories."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/memory/clear-session") response = client.post("/memory/clear-session")
@@ -96,14 +103,18 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_success(self, memory): def test_chat_completion_success(self, memory):
"""Should return chat completion.""" """Should return chat completion."""
from app import app from app import app
# Patch the agent's step method directly # Patch the agent's step method directly
with patch("app.agent.step", return_value="Hello! How can I help?"): with patch("app.agent.step", return_value="Hello! How can I help?"):
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Hello"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
},
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -113,12 +124,16 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_no_user_message(self, memory): def test_chat_completion_no_user_message(self, memory):
"""Should return error if no user message.""" """Should return error if no user message."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "system", "content": "You are helpful"}], json={
}) "model": "agent-media",
"messages": [{"role": "system", "content": "You are helpful"}],
},
)
assert response.status_code == 422 assert response.status_code == 422
detail = response.json()["detail"] detail = response.json()["detail"]
@@ -132,18 +147,23 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_empty_messages(self, memory): def test_chat_completion_empty_messages(self, memory):
"""Should return error for empty messages.""" """Should return error for empty messages."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [], json={
}) "model": "agent-media",
"messages": [],
},
)
assert response.status_code == 422 assert response.status_code == 422
def test_chat_completion_invalid_json(self, memory): def test_chat_completion_invalid_json(self, memory):
"""Should return error for invalid JSON.""" """Should return error for invalid JSON."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post( response = client.post(
@@ -157,14 +177,18 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_streaming(self, memory): def test_chat_completion_streaming(self, memory):
"""Should support streaming mode.""" """Should support streaming mode."""
from app import app from app import app
with patch("app.agent.step", return_value="Streaming response"): with patch("app.agent.step", return_value="Streaming response"):
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Hello"}], json={
"stream": True, "model": "agent-media",
}) "messages": [{"role": "user", "content": "Hello"}],
"stream": True,
},
)
assert response.status_code == 200 assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"] assert "text/event-stream" in response.headers["content-type"]
@@ -172,17 +196,21 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_extracts_last_user_message(self, memory): def test_chat_completion_extracts_last_user_message(self, memory):
"""Should use last user message.""" """Should use last user message."""
from app import app from app import app
with patch("app.agent.step", return_value="Response") as mock_step: with patch("app.agent.step", return_value="Response") as mock_step:
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [ json={
{"role": "user", "content": "First message"}, "model": "agent-media",
{"role": "assistant", "content": "Response"}, "messages": [
{"role": "user", "content": "Second message"}, {"role": "user", "content": "First message"},
], {"role": "assistant", "content": "Response"},
}) {"role": "user", "content": "Second message"},
],
},
)
assert response.status_code == 200 assert response.status_code == 200
# Verify the agent received the last user message # Verify the agent received the last user message
@@ -191,13 +219,17 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_response_format(self, memory): def test_chat_completion_response_format(self, memory):
"""Should return OpenAI-compatible format.""" """Should return OpenAI-compatible format."""
from app import app from app import app
with patch("app.agent.step", return_value="Test response"): with patch("app.agent.step", return_value="Test response"):
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Test"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": "Test"}],
},
)
data = response.json() data = response.json()
assert "id" in data assert "id" in data

View File

@@ -1,7 +1,7 @@
"""Edge case tests for FastAPI endpoints.""" """Edge case tests for FastAPI endpoints."""
import pytest
import json from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -10,43 +10,46 @@ class TestChatCompletionsEdgeCases:
def test_very_long_message(self, memory): def test_very_long_message(self, memory):
"""Should handle very long user message.""" """Should handle very long user message."""
from app import app, agent from app import agent, app
# Patch the agent's LLM directly # Patch the agent's LLM directly
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
long_message = "x" * 100000 long_message = "x" * 100000
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": long_message}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": long_message}],
},
)
assert response.status_code == 200 assert response.status_code == 200
def test_unicode_message(self, memory): def test_unicode_message(self, memory):
"""Should handle unicode in message.""" """Should handle unicode in message."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {
"role": "assistant", "role": "assistant",
"content": "日本語の応答" "content": "日本語の応答",
} }
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}],
},
)
assert response.status_code == 200 assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"] content = response.json()["choices"][0]["message"]["content"]
@@ -54,22 +57,22 @@ class TestChatCompletionsEdgeCases:
def test_special_characters_in_message(self, memory): def test_special_characters_in_message(self, memory):
"""Should handle special characters.""" """Should handle special characters."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
special_message = 'Test with "quotes" and \\backslash and \n newline' special_message = 'Test with "quotes" and \\backslash and \n newline'
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": special_message}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": special_message}],
},
)
assert response.status_code == 200 assert response.status_code == 200
@@ -81,12 +84,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": ""}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": ""}],
},
)
# Empty content should be rejected # Empty content should be rejected
assert response.status_code == 422 assert response.status_code == 422
@@ -98,12 +105,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": None}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": None}],
},
)
assert response.status_code == 422 assert response.status_code == 422
@@ -114,12 +125,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user"}], # No content json={
}) "model": "agent-media",
"messages": [{"role": "user"}], # No content
},
)
# May accept or reject depending on validation # May accept or reject depending on validation
assert response.status_code in [200, 400, 422] assert response.status_code in [200, 400, 422]
@@ -131,12 +146,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"content": "Hello"}], # No role json={
}) "model": "agent-media",
"messages": [{"content": "Hello"}], # No role
},
)
# Should reject or accept depending on validation # Should reject or accept depending on validation
assert response.status_code in [200, 400, 422] assert response.status_code in [200, 400, 422]
@@ -149,27 +168,28 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "invalid_role", "content": "Hello"}], json={
}) "model": "agent-media",
"messages": [{"role": "invalid_role", "content": "Hello"}],
},
)
# Should reject or ignore invalid role # Should reject or ignore invalid role
assert response.status_code in [200, 400, 422] assert response.status_code in [200, 400, 422]
def test_many_messages(self, memory): def test_many_messages(self, memory):
"""Should handle many messages in conversation.""" """Should handle many messages in conversation."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
messages = [] messages = []
@@ -178,10 +198,13 @@ class TestChatCompletionsEdgeCases:
messages.append({"role": "assistant", "content": f"Response {i}"}) messages.append({"role": "assistant", "content": f"Response {i}"})
messages.append({"role": "user", "content": "Final message"}) messages.append({"role": "user", "content": "Final message"})
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": messages, json={
}) "model": "agent-media",
"messages": messages,
},
)
assert response.status_code == 200 assert response.status_code == 200
@@ -192,15 +215,19 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [ json={
{"role": "system", "content": "You are helpful"}, "model": "agent-media",
{"role": "system", "content": "Be concise"}, "messages": [
], {"role": "system", "content": "You are helpful"},
}) {"role": "system", "content": "Be concise"},
],
},
)
assert response.status_code == 422 assert response.status_code == 422
@@ -211,14 +238,18 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [ json={
{"role": "assistant", "content": "Hello"}, "model": "agent-media",
], "messages": [
}) {"role": "assistant", "content": "Hello"},
],
},
)
assert response.status_code == 422 assert response.status_code == 422
@@ -229,12 +260,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": "not an array", json={
}) "model": "agent-media",
"messages": "not an array",
},
)
assert response.status_code == 422 assert response.status_code == 422
# Pydantic validation error # Pydantic validation error
@@ -246,118 +281,128 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": ["not an object", 123, None], json={
}) "model": "agent-media",
"messages": ["not an object", 123, None],
},
)
assert response.status_code == 422 assert response.status_code == 422
# Pydantic validation error # Pydantic validation error
def test_extra_fields_in_request(self, memory): def test_extra_fields_in_request(self, memory):
"""Should ignore extra fields in request.""" """Should ignore extra fields in request."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Hello"}], json={
"extra_field": "should be ignored", "model": "agent-media",
"temperature": 0.7, "messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100, "extra_field": "should be ignored",
}) "temperature": 0.7,
"max_tokens": 100,
},
)
assert response.status_code == 200 assert response.status_code == 200
def test_streaming_with_tool_call(self, memory, real_folder): def test_streaming_with_tool_call(self, memory, real_folder):
"""Should handle streaming with tool execution.""" """Should handle streaming with tool execution."""
from app import app, agent from app import agent, app
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
mem = get_memory() mem = get_memory()
mem.ltm.set_config("download_folder", str(real_folder["downloads"])) mem.ltm.set_config("download_folder", str(real_folder["downloads"]))
call_count = [0] call_count = [0]
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
call_count[0] += 1 call_count[0] += 1
if call_count[0] == 1: if call_count[0] == 1:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "Listed the folder."}
"role": "assistant",
"content": "Listed the folder."
}
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "List downloads"}], json={
"stream": True, "model": "agent-media",
}) "messages": [{"role": "user", "content": "List downloads"}],
"stream": True,
},
)
assert response.status_code == 200 assert response.status_code == 200
def test_concurrent_requests_simulation(self, memory): def test_concurrent_requests_simulation(self, memory):
"""Should handle rapid sequential requests.""" """Should handle rapid sequential requests."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
for i in range(10): for i in range(10):
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": f"Request {i}"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": f"Request {i}"}],
},
)
assert response.status_code == 200 assert response.status_code == 200
def test_llm_returns_json_in_response(self, memory): def test_llm_returns_json_in_response(self, memory):
"""Should handle LLM returning JSON in text response.""" """Should handle LLM returning JSON in text response."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {
"role": "assistant", "role": "assistant",
"content": '{"result": "some data", "count": 5}' "content": '{"result": "some data", "count": 5}',
} }
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Give me JSON"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": "Give me JSON"}],
},
)
assert response.status_code == 200 assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"] content = response.json()["choices"][0]["message"]["content"]
@@ -425,6 +470,7 @@ class TestMemoryEndpointsEdgeCases:
with patch("app.DeepSeekClient") as mock_llm: with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock() mock_llm.return_value = Mock()
from app import app from app import app
client = TestClient(app) client = TestClient(app)
# Clear multiple times # Clear multiple times
@@ -459,6 +505,7 @@ class TestHealthEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm: with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock() mock_llm.return_value = Mock()
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/health") response = client.get("/health")
@@ -471,6 +518,7 @@ class TestHealthEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm: with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock() mock_llm.return_value = Mock()
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/health?extra=param&another=value") response = client.get("/health?extra=param&another=value")
@@ -486,6 +534,7 @@ class TestModelsEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm: with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock() mock_llm.return_value = Mock()
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/v1/models") response = client.get("/v1/models")

View File

@@ -1,9 +1,9 @@
"""Critical tests for configuration validation.""" """Critical tests for configuration validation."""
import pytest
import os
from agent.config import Settings, ConfigurationError import pytest
from agent.config import ConfigurationError, Settings
class TestConfigValidation: class TestConfigValidation:
@@ -13,7 +13,7 @@ class TestConfigValidation:
"""Verify invalid temperature is rejected.""" """Verify invalid temperature is rejected."""
with pytest.raises(ConfigurationError, match="Temperature"): with pytest.raises(ConfigurationError, match="Temperature"):
Settings(temperature=3.0) # > 2.0 Settings(temperature=3.0) # > 2.0
with pytest.raises(ConfigurationError, match="Temperature"): with pytest.raises(ConfigurationError, match="Temperature"):
Settings(temperature=-0.1) # < 0.0 Settings(temperature=-0.1) # < 0.0
@@ -28,7 +28,7 @@ class TestConfigValidation:
"""Verify invalid max_iterations is rejected.""" """Verify invalid max_iterations is rejected."""
with pytest.raises(ConfigurationError, match="max_tool_iterations"): with pytest.raises(ConfigurationError, match="max_tool_iterations"):
Settings(max_tool_iterations=0) # < 1 Settings(max_tool_iterations=0) # < 1
with pytest.raises(ConfigurationError, match="max_tool_iterations"): with pytest.raises(ConfigurationError, match="max_tool_iterations"):
Settings(max_tool_iterations=100) # > 20 Settings(max_tool_iterations=100) # > 20
@@ -43,7 +43,7 @@ class TestConfigValidation:
"""Verify invalid timeout is rejected.""" """Verify invalid timeout is rejected."""
with pytest.raises(ConfigurationError, match="request_timeout"): with pytest.raises(ConfigurationError, match="request_timeout"):
Settings(request_timeout=0) # < 1 Settings(request_timeout=0) # < 1
with pytest.raises(ConfigurationError, match="request_timeout"): with pytest.raises(ConfigurationError, match="request_timeout"):
Settings(request_timeout=500) # > 300 Settings(request_timeout=500) # > 300
@@ -58,7 +58,7 @@ class TestConfigValidation:
"""Verify invalid DeepSeek URL is rejected.""" """Verify invalid DeepSeek URL is rejected."""
with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"): with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"):
Settings(deepseek_base_url="not-a-url") Settings(deepseek_base_url="not-a-url")
with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"): with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"):
Settings(deepseek_base_url="ftp://invalid.com") Settings(deepseek_base_url="ftp://invalid.com")
@@ -86,19 +86,17 @@ class TestConfigChecks:
def test_is_deepseek_configured_with_key(self): def test_is_deepseek_configured_with_key(self):
"""Verify is_deepseek_configured returns True with API key.""" """Verify is_deepseek_configured returns True with API key."""
settings = Settings( settings = Settings(
deepseek_api_key="test-key", deepseek_api_key="test-key", deepseek_base_url="https://api.test.com"
deepseek_base_url="https://api.test.com"
) )
assert settings.is_deepseek_configured() is True assert settings.is_deepseek_configured() is True
def test_is_deepseek_configured_without_key(self): def test_is_deepseek_configured_without_key(self):
"""Verify is_deepseek_configured returns False without API key.""" """Verify is_deepseek_configured returns False without API key."""
settings = Settings( settings = Settings(
deepseek_api_key="", deepseek_api_key="", deepseek_base_url="https://api.test.com"
deepseek_base_url="https://api.test.com"
) )
assert settings.is_deepseek_configured() is False assert settings.is_deepseek_configured() is False
def test_is_deepseek_configured_without_url(self): def test_is_deepseek_configured_without_url(self):
@@ -110,19 +108,15 @@ class TestConfigChecks:
def test_is_tmdb_configured_with_key(self): def test_is_tmdb_configured_with_key(self):
"""Verify is_tmdb_configured returns True with API key.""" """Verify is_tmdb_configured returns True with API key."""
settings = Settings( settings = Settings(
tmdb_api_key="test-key", tmdb_api_key="test-key", tmdb_base_url="https://api.test.com"
tmdb_base_url="https://api.test.com"
) )
assert settings.is_tmdb_configured() is True assert settings.is_tmdb_configured() is True
def test_is_tmdb_configured_without_key(self): def test_is_tmdb_configured_without_key(self):
"""Verify is_tmdb_configured returns False without API key.""" """Verify is_tmdb_configured returns False without API key."""
settings = Settings( settings = Settings(tmdb_api_key="", tmdb_base_url="https://api.test.com")
tmdb_api_key="",
tmdb_base_url="https://api.test.com"
)
assert settings.is_tmdb_configured() is False assert settings.is_tmdb_configured() is False
@@ -132,25 +126,25 @@ class TestConfigDefaults:
def test_default_temperature(self): def test_default_temperature(self):
"""Verify default temperature is reasonable.""" """Verify default temperature is reasonable."""
settings = Settings() settings = Settings()
assert 0.0 <= settings.temperature <= 2.0 assert 0.0 <= settings.temperature <= 2.0
def test_default_max_iterations(self): def test_default_max_iterations(self):
"""Verify default max_iterations is reasonable.""" """Verify default max_iterations is reasonable."""
settings = Settings() settings = Settings()
assert 1 <= settings.max_tool_iterations <= 20 assert 1 <= settings.max_tool_iterations <= 20
def test_default_timeout(self): def test_default_timeout(self):
"""Verify default timeout is reasonable.""" """Verify default timeout is reasonable."""
settings = Settings() settings = Settings()
assert 1 <= settings.request_timeout <= 300 assert 1 <= settings.request_timeout <= 300
def test_default_urls_are_valid(self): def test_default_urls_are_valid(self):
"""Verify default URLs are valid.""" """Verify default URLs are valid."""
settings = Settings() settings = Settings()
assert settings.deepseek_base_url.startswith(("http://", "https://")) assert settings.deepseek_base_url.startswith(("http://", "https://"))
assert settings.tmdb_base_url.startswith(("http://", "https://")) assert settings.tmdb_base_url.startswith(("http://", "https://"))
@@ -161,38 +155,38 @@ class TestConfigEnvironmentVariables:
def test_loads_temperature_from_env(self, monkeypatch): def test_loads_temperature_from_env(self, monkeypatch):
"""Verify temperature is loaded from environment.""" """Verify temperature is loaded from environment."""
monkeypatch.setenv("TEMPERATURE", "0.5") monkeypatch.setenv("TEMPERATURE", "0.5")
settings = Settings() settings = Settings()
assert settings.temperature == 0.5 assert settings.temperature == 0.5
def test_loads_max_iterations_from_env(self, monkeypatch): def test_loads_max_iterations_from_env(self, monkeypatch):
"""Verify max_iterations is loaded from environment.""" """Verify max_iterations is loaded from environment."""
monkeypatch.setenv("MAX_TOOL_ITERATIONS", "10") monkeypatch.setenv("MAX_TOOL_ITERATIONS", "10")
settings = Settings() settings = Settings()
assert settings.max_tool_iterations == 10 assert settings.max_tool_iterations == 10
def test_loads_timeout_from_env(self, monkeypatch): def test_loads_timeout_from_env(self, monkeypatch):
"""Verify timeout is loaded from environment.""" """Verify timeout is loaded from environment."""
monkeypatch.setenv("REQUEST_TIMEOUT", "60") monkeypatch.setenv("REQUEST_TIMEOUT", "60")
settings = Settings() settings = Settings()
assert settings.request_timeout == 60 assert settings.request_timeout == 60
def test_loads_deepseek_url_from_env(self, monkeypatch): def test_loads_deepseek_url_from_env(self, monkeypatch):
"""Verify DeepSeek URL is loaded from environment.""" """Verify DeepSeek URL is loaded from environment."""
monkeypatch.setenv("DEEPSEEK_BASE_URL", "https://custom.api.com") monkeypatch.setenv("DEEPSEEK_BASE_URL", "https://custom.api.com")
settings = Settings() settings = Settings()
assert settings.deepseek_base_url == "https://custom.api.com" assert settings.deepseek_base_url == "https://custom.api.com"
def test_invalid_env_value_raises_error(self, monkeypatch): def test_invalid_env_value_raises_error(self, monkeypatch):
"""Verify invalid environment value raises error.""" """Verify invalid environment value raises error."""
monkeypatch.setenv("TEMPERATURE", "invalid") monkeypatch.setenv("TEMPERATURE", "invalid")
with pytest.raises(ValueError): with pytest.raises(ValueError):
Settings() Settings()

View File

@@ -1,12 +1,14 @@
"""Edge case tests for configuration and parameters.""" """Edge case tests for configuration and parameters."""
import pytest
import os import os
from unittest.mock import patch from unittest.mock import patch
from agent.config import Settings, ConfigurationError import pytest
from agent.config import ConfigurationError, Settings
from agent.parameters import ( from agent.parameters import (
ParameterSchema,
REQUIRED_PARAMETERS, REQUIRED_PARAMETERS,
ParameterSchema,
format_parameters_for_prompt, format_parameters_for_prompt,
get_missing_required_parameters, get_missing_required_parameters,
) )
@@ -110,19 +112,27 @@ class TestSettingsEdgeCases:
def test_http_url_accepted(self): def test_http_url_accepted(self):
"""Should accept http:// URLs.""" """Should accept http:// URLs."""
with patch.dict(os.environ, { with patch.dict(
"DEEPSEEK_BASE_URL": "http://localhost:8080", os.environ,
"TMDB_BASE_URL": "http://localhost:3000", {
}, clear=True): "DEEPSEEK_BASE_URL": "http://localhost:8080",
"TMDB_BASE_URL": "http://localhost:3000",
},
clear=True,
):
settings = Settings() settings = Settings()
assert settings.deepseek_base_url == "http://localhost:8080" assert settings.deepseek_base_url == "http://localhost:8080"
def test_https_url_accepted(self): def test_https_url_accepted(self):
"""Should accept https:// URLs.""" """Should accept https:// URLs."""
with patch.dict(os.environ, { with patch.dict(
"DEEPSEEK_BASE_URL": "https://api.example.com", os.environ,
"TMDB_BASE_URL": "https://api.example.com", {
}, clear=True): "DEEPSEEK_BASE_URL": "https://api.example.com",
"TMDB_BASE_URL": "https://api.example.com",
},
clear=True,
):
settings = Settings() settings = Settings()
assert settings.deepseek_base_url == "https://api.example.com" assert settings.deepseek_base_url == "https://api.example.com"

View File

@@ -1,18 +1,17 @@
"""Tests for the Memory system.""" """Tests for the Memory system."""
import pytest
import json
from datetime import datetime from datetime import datetime
from pathlib import Path
import pytest
from infrastructure.persistence import ( from infrastructure.persistence import (
Memory,
LongTermMemory,
ShortTermMemory,
EpisodicMemory, EpisodicMemory,
init_memory, LongTermMemory,
Memory,
ShortTermMemory,
get_memory, get_memory,
set_memory,
has_memory, has_memory,
init_memory,
) )
from infrastructure.persistence.context import _memory_ctx from infrastructure.persistence.context import _memory_ctx
@@ -23,11 +22,12 @@ def is_iso_format(s: str) -> bool:
return False return False
try: try:
# Attempt to parse the string as an ISO 8601 timestamp # Attempt to parse the string as an ISO 8601 timestamp
datetime.fromisoformat(s.replace('Z', '+00:00')) datetime.fromisoformat(s.replace("Z", "+00:00"))
return True return True
except (ValueError, TypeError): except (ValueError, TypeError):
return False return False
class TestLongTermMemory: class TestLongTermMemory:
"""Tests for LongTermMemory.""" """Tests for LongTermMemory."""
@@ -116,12 +116,18 @@ class TestLongTermMemory:
assert data["config"]["key"] == "value" assert data["config"]["key"] == "value"
def test_from_dict(self): def test_from_dict(self):
data = {"config": {"download_folder": "/downloads"}, "preferences": {"preferred_quality": "4K"}, "library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]}, "following": []} data = {
"config": {"download_folder": "/downloads"},
"preferences": {"preferred_quality": "4K"},
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
"following": [],
}
ltm = LongTermMemory.from_dict(data) ltm = LongTermMemory.from_dict(data)
assert ltm.get_config("download_folder") == "/downloads" assert ltm.get_config("download_folder") == "/downloads"
assert ltm.preferences["preferred_quality"] == "4K" assert ltm.preferences["preferred_quality"] == "4K"
assert len(ltm.library["movies"]) == 1 assert len(ltm.library["movies"]) == 1
class TestShortTermMemory: class TestShortTermMemory:
"""Tests for ShortTermMemory.""" """Tests for ShortTermMemory."""
@@ -162,6 +168,7 @@ class TestShortTermMemory:
assert stm.conversation_history == [] assert stm.conversation_history == []
assert stm.language == "en" assert stm.language == "en"
class TestEpisodicMemory: class TestEpisodicMemory:
"""Tests for EpisodicMemory.""" """Tests for EpisodicMemory."""
@@ -192,6 +199,7 @@ class TestEpisodicMemory:
assert result is not None assert result is not None
assert result["name"] == "Result 2" assert result["name"] == "Result 2"
class TestMemory: class TestMemory:
"""Tests for the Memory manager.""" """Tests for the Memory manager."""
@@ -217,11 +225,10 @@ class TestMemory:
assert memory.stm.conversation_history == [] assert memory.stm.conversation_history == []
assert memory.episodic.recent_errors == [] assert memory.episodic.recent_errors == []
class TestMemoryContext: class TestMemoryContext:
"""Tests for memory context functions.""" """Tests for memory context functions."""
def test_get_memory_not_initialized(self): def test_get_memory_not_initialized(self):
_memory_ctx.set(None) _memory_ctx.set(None)
with pytest.raises(RuntimeError, match="Memory not initialized"): with pytest.raises(RuntimeError, match="Memory not initialized"):

View File

@@ -1,18 +1,17 @@
"""Edge case tests for the Memory system.""" """Edge case tests for the Memory system."""
import pytest
import json import json
import os import os
from pathlib import Path
from datetime import datetime import pytest
from unittest.mock import patch, mock_open
from infrastructure.persistence import ( from infrastructure.persistence import (
Memory,
LongTermMemory,
ShortTermMemory,
EpisodicMemory, EpisodicMemory,
init_memory, LongTermMemory,
Memory,
ShortTermMemory,
get_memory, get_memory,
init_memory,
set_memory, set_memory,
) )
from infrastructure.persistence.context import _memory_ctx from infrastructure.persistence.context import _memory_ctx
@@ -390,7 +389,7 @@ class TestMemoryEdgeCases:
def test_init_with_nonexistent_directory(self, temp_dir): def test_init_with_nonexistent_directory(self, temp_dir):
"""Should create directory if not exists.""" """Should create directory if not exists."""
new_dir = temp_dir / "new" / "nested" / "dir" new_dir = temp_dir / "new" / "nested" / "dir"
# Create parent directories first # Create parent directories first
new_dir.mkdir(parents=True, exist_ok=True) new_dir.mkdir(parents=True, exist_ok=True)
memory = Memory(storage_dir=str(new_dir)) memory = Memory(storage_dir=str(new_dir))
@@ -529,7 +528,6 @@ class TestMemoryContextEdgeCases:
def test_context_isolation(self, temp_dir): def test_context_isolation(self, temp_dir):
"""Context should be isolated per context.""" """Context should be isolated per context."""
import asyncio
from contextvars import copy_context from contextvars import copy_context
_memory_ctx.set(None) _memory_ctx.set(None)

View File

@@ -1,10 +1,8 @@
"""Critical tests for prompt builder - Tests that would have caught bugs.""" """Critical tests for prompt builder - Tests that would have caught bugs."""
import pytest
from agent.registry import make_tools
from agent.prompts import PromptBuilder from agent.prompts import PromptBuilder
from infrastructure.persistence import get_memory from agent.registry import make_tools
class TestPromptBuilderToolsInjection: class TestPromptBuilderToolsInjection:
@@ -15,20 +13,22 @@ class TestPromptBuilderToolsInjection:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
# Verify each tool is mentioned # Verify each tool is mentioned
for tool_name in tools.keys(): for tool_name in tools.keys():
assert tool_name in prompt, f"Tool {tool_name} not mentioned in system prompt" assert (
tool_name in prompt
), f"Tool {tool_name} not mentioned in system prompt"
def test_tools_spec_contains_all_registered_tools(self): def test_tools_spec_contains_all_registered_tools(self):
"""CRITICAL: Verify build_tools_spec() returns all tools.""" """CRITICAL: Verify build_tools_spec() returns all tools."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
spec_names = {spec['function']['name'] for spec in specs} spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys()) tool_names = set(tools.keys())
assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}" assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}"
def test_tools_spec_is_not_empty(self): def test_tools_spec_is_not_empty(self):
@@ -36,7 +36,7 @@ class TestPromptBuilderToolsInjection:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
assert len(specs) > 0, "Tools spec is empty!" assert len(specs) > 0, "Tools spec is empty!"
def test_tools_spec_format_matches_openai(self): def test_tools_spec_format_matches_openai(self):
@@ -44,14 +44,14 @@ class TestPromptBuilderToolsInjection:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
for spec in specs: for spec in specs:
assert 'type' in spec assert "type" in spec
assert spec['type'] == 'function' assert spec["type"] == "function"
assert 'function' in spec assert "function" in spec
assert 'name' in spec['function'] assert "name" in spec["function"]
assert 'description' in spec['function'] assert "description" in spec["function"]
assert 'parameters' in spec['function'] assert "parameters" in spec["function"]
class TestPromptBuilderMemoryContext: class TestPromptBuilderMemoryContext:
@@ -61,29 +61,29 @@ class TestPromptBuilderMemoryContext:
"""Verify current topic is included in prompt.""" """Verify current topic is included in prompt."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.stm.set_topic("test_topic") memory.stm.set_topic("test_topic")
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "test_topic" in prompt assert "test_topic" in prompt
def test_prompt_includes_extracted_entities(self, memory): def test_prompt_includes_extracted_entities(self, memory):
"""Verify extracted entities are included in prompt.""" """Verify extracted entities are included in prompt."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.stm.set_entity("test_key", "test_value") memory.stm.set_entity("test_key", "test_value")
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "test_key" in prompt assert "test_key" in prompt
def test_prompt_includes_search_results(self, memory_with_search_results): def test_prompt_includes_search_results(self, memory_with_search_results):
"""Verify search results are included in prompt.""" """Verify search results are included in prompt."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "Inception" in prompt assert "Inception" in prompt
assert "LAST SEARCH" in prompt assert "LAST SEARCH" in prompt
@@ -91,15 +91,13 @@ class TestPromptBuilderMemoryContext:
"""Verify active downloads are included in prompt.""" """Verify active downloads are included in prompt."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.episodic.add_active_download({ memory.episodic.add_active_download(
"task_id": "123", {"task_id": "123", "name": "Test Movie", "progress": 50}
"name": "Test Movie", )
"progress": 50
})
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "ACTIVE DOWNLOADS" in prompt assert "ACTIVE DOWNLOADS" in prompt
assert "Test Movie" in prompt assert "Test Movie" in prompt
@@ -107,33 +105,33 @@ class TestPromptBuilderMemoryContext:
"""Verify recent errors are included in prompt.""" """Verify recent errors are included in prompt."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.episodic.add_error("test_action", "test error message") memory.episodic.add_error("test_action", "test error message")
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "RECENT ERRORS" in prompt or "error" in prompt.lower() assert "RECENT ERRORS" in prompt or "error" in prompt.lower()
def test_prompt_includes_configuration(self, memory): def test_prompt_includes_configuration(self, memory):
"""Verify configuration is included in prompt.""" """Verify configuration is included in prompt."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.ltm.set_config("download_folder", "/test/downloads") memory.ltm.set_config("download_folder", "/test/downloads")
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "CONFIGURATION" in prompt or "download_folder" in prompt assert "CONFIGURATION" in prompt or "download_folder" in prompt
def test_prompt_includes_language(self, memory): def test_prompt_includes_language(self, memory):
"""Verify language is included in prompt.""" """Verify language is included in prompt."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.stm.set_language("fr") memory.stm.set_language("fr")
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "fr" in prompt or "LANGUAGE" in prompt assert "fr" in prompt or "LANGUAGE" in prompt
@@ -145,7 +143,7 @@ class TestPromptBuilderStructure:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert len(prompt) > 0 assert len(prompt) > 0
assert prompt.strip() != "" assert prompt.strip() != ""
@@ -154,7 +152,7 @@ class TestPromptBuilderStructure:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "assistant" in prompt.lower() or "help" in prompt.lower() assert "assistant" in prompt.lower() or "help" in prompt.lower()
def test_system_prompt_includes_rules(self): def test_system_prompt_includes_rules(self):
@@ -162,7 +160,7 @@ class TestPromptBuilderStructure:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "RULES" in prompt or "IMPORTANT" in prompt assert "RULES" in prompt or "IMPORTANT" in prompt
def test_system_prompt_includes_examples(self): def test_system_prompt_includes_examples(self):
@@ -170,16 +168,16 @@ class TestPromptBuilderStructure:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "EXAMPLES" in prompt or "example" in prompt.lower() assert "EXAMPLES" in prompt or "example" in prompt.lower()
def test_tools_description_format(self): def test_tools_description_format(self):
"""Verify tools are properly formatted in description.""" """Verify tools are properly formatted in description."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
description = builder._format_tools_description() description = builder._format_tools_description()
# Should have tool names and descriptions # Should have tool names and descriptions
for tool_name, tool in tools.items(): for tool_name, tool in tools.items():
assert tool_name in description assert tool_name in description
@@ -190,9 +188,9 @@ class TestPromptBuilderStructure:
"""Verify episodic context is properly formatted.""" """Verify episodic context is properly formatted."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
context = builder._format_episodic_context() context = builder._format_episodic_context()
assert "LAST SEARCH" in context assert "LAST SEARCH" in context
assert "Inception" in context assert "Inception" in context
@@ -200,12 +198,12 @@ class TestPromptBuilderStructure:
"""Verify STM context is properly formatted.""" """Verify STM context is properly formatted."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.stm.set_topic("test_topic") memory.stm.set_topic("test_topic")
memory.stm.set_entity("key", "value") memory.stm.set_entity("key", "value")
context = builder._format_stm_context() context = builder._format_stm_context()
assert "TOPIC" in context or "test_topic" in context assert "TOPIC" in context or "test_topic" in context
assert "ENTITIES" in context or "key" in context assert "ENTITIES" in context or "key" in context
@@ -213,11 +211,11 @@ class TestPromptBuilderStructure:
"""Verify config context is properly formatted.""" """Verify config context is properly formatted."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.ltm.set_config("test_key", "test_value") memory.ltm.set_config("test_key", "test_value")
context = builder._format_config_context() context = builder._format_config_context()
assert "CONFIGURATION" in context assert "CONFIGURATION" in context
assert "test_key" in context assert "test_key" in context
@@ -229,10 +227,10 @@ class TestPromptBuilderEdgeCases:
"""Verify prompt works with empty memory.""" """Verify prompt works with empty memory."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
# Memory is empty # Memory is empty
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
# Should still have base content # Should still have base content
assert len(prompt) > 0 assert len(prompt) > 0
assert "assistant" in prompt.lower() assert "assistant" in prompt.lower()
@@ -240,18 +238,18 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_empty_tools(self): def test_prompt_with_empty_tools(self):
"""Verify prompt handles empty tools dict.""" """Verify prompt handles empty tools dict."""
builder = PromptBuilder({}) builder = PromptBuilder({})
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
# Should still generate a prompt # Should still generate a prompt
assert len(prompt) > 0 assert len(prompt) > 0
def test_tools_spec_with_empty_tools(self): def test_tools_spec_with_empty_tools(self):
"""Verify tools spec handles empty tools dict.""" """Verify tools spec handles empty tools dict."""
builder = PromptBuilder({}) builder = PromptBuilder({})
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
assert isinstance(specs, list) assert isinstance(specs, list)
assert len(specs) == 0 assert len(specs) == 0
@@ -259,11 +257,11 @@ class TestPromptBuilderEdgeCases:
"""Verify prompt handles unicode in memory.""" """Verify prompt handles unicode in memory."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.stm.set_entity("movie", "Amélie 🎬") memory.stm.set_entity("movie", "Amélie 🎬")
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
assert "Amélie" in prompt assert "Amélie" in prompt
assert "🎬" in prompt assert "🎬" in prompt
@@ -271,13 +269,13 @@ class TestPromptBuilderEdgeCases:
"""Verify prompt handles many search results.""" """Verify prompt handles many search results."""
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
# Add many results # Add many results
results = [{"name": f"Movie {i}", "seeders": i} for i in range(20)] results = [{"name": f"Movie {i}", "seeders": i} for i in range(20)]
memory.episodic.store_search_results("test", results, "torrent") memory.episodic.store_search_results("test", results, "torrent")
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
# Should include some results but not all (to avoid huge prompts) # Should include some results but not all (to avoid huge prompts)
assert "Movie 0" in prompt or "Movie 1" in prompt assert "Movie 0" in prompt or "Movie 1" in prompt
# Should indicate there are more # Should indicate there are more

View File

@@ -1,10 +1,8 @@
"""Edge case tests for PromptBuilder.""" """Edge case tests for PromptBuilder."""
import pytest
import json
from agent.prompts import PromptBuilder from agent.prompts import PromptBuilder
from agent.registry import make_tools from agent.registry import make_tools
from infrastructure.persistence import get_memory
class TestPromptBuilderEdgeCases: class TestPromptBuilderEdgeCases:
@@ -93,11 +91,13 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_many_active_downloads(self, memory): def test_prompt_with_many_active_downloads(self, memory):
"""Should limit displayed active downloads.""" """Should limit displayed active downloads."""
for i in range(20): for i in range(20):
memory.episodic.add_active_download({ memory.episodic.add_active_download(
"task_id": str(i), {
"name": f"Download {i}", "task_id": str(i),
"progress": i * 5, "name": f"Download {i}",
}) "progress": i * 5,
}
)
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
@@ -136,12 +136,15 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_complex_workflow(self, memory): def test_prompt_with_complex_workflow(self, memory):
"""Should handle complex workflow state.""" """Should handle complex workflow state."""
memory.stm.start_workflow("download", { memory.stm.start_workflow(
"title": "Test Movie", "download",
"year": 2024, {
"quality": "1080p", "title": "Test Movie",
"nested": {"deep": {"value": "test"}}, "year": 2024,
}) "quality": "1080p",
"nested": {"deep": {"value": "test"}},
},
)
memory.stm.update_workflow_stage("searching_torrents") memory.stm.update_workflow_stage("searching_torrents")
tools = make_tools() tools = make_tools()
@@ -313,11 +316,14 @@ class TestFormatEpisodicContextEdgeCases:
def test_format_with_search_results_none_names(self, memory): def test_format_with_search_results_none_names(self, memory):
"""Should handle results with None names.""" """Should handle results with None names."""
memory.episodic.store_search_results("test", [ memory.episodic.store_search_results(
{"name": None}, "test",
{"title": None}, [
{}, {"name": None},
]) {"title": None},
{},
],
)
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)

View File

@@ -1,10 +1,11 @@
"""Critical tests for tool registry - Tests that would have caught bugs.""" """Critical tests for tool registry - Tests that would have caught bugs."""
import pytest
import inspect import inspect
from agent.registry import make_tools, _create_tool_from_function, Tool import pytest
from agent.prompts import PromptBuilder from agent.prompts import PromptBuilder
from agent.registry import Tool, _create_tool_from_function, make_tools
class TestToolSpecFormat: class TestToolSpecFormat:
@@ -15,54 +16,59 @@ class TestToolSpecFormat:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
# Verify structure # Verify structure
assert isinstance(specs, list), "Tool specs must be a list" assert isinstance(specs, list), "Tool specs must be a list"
assert len(specs) > 0, "Tool specs list is empty" assert len(specs) > 0, "Tool specs list is empty"
for spec in specs: for spec in specs:
# OpenAI format requires these fields # OpenAI format requires these fields
assert spec['type'] == 'function', f"Tool type must be 'function', got {spec.get('type')}" assert (
assert 'function' in spec, "Tool spec missing 'function' key" spec["type"] == "function"
), f"Tool type must be 'function', got {spec.get('type')}"
func = spec['function'] assert "function" in spec, "Tool spec missing 'function' key"
assert 'name' in func, "Function missing 'name'"
assert 'description' in func, "Function missing 'description'" func = spec["function"]
assert 'parameters' in func, "Function missing 'parameters'" assert "name" in func, "Function missing 'name'"
assert "description" in func, "Function missing 'description'"
params = func['parameters'] assert "parameters" in func, "Function missing 'parameters'"
assert params['type'] == 'object', "Parameters type must be 'object'"
assert 'properties' in params, "Parameters missing 'properties'" params = func["parameters"]
assert 'required' in params, "Parameters missing 'required'" assert params["type"] == "object", "Parameters type must be 'object'"
assert isinstance(params['required'], list), "Required must be a list" assert "properties" in params, "Parameters missing 'properties'"
assert "required" in params, "Parameters missing 'required'"
assert isinstance(params["required"], list), "Required must be a list"
def test_tool_parameters_match_function_signature(self): def test_tool_parameters_match_function_signature(self):
"""CRITICAL: Verify generated parameters match function signature.""" """CRITICAL: Verify generated parameters match function signature."""
def test_func(name: str, age: int, active: bool = True): def test_func(name: str, age: int, active: bool = True):
"""Test function with typed parameters.""" """Test function with typed parameters."""
return {"status": "ok"} return {"status": "ok"}
tool = _create_tool_from_function(test_func) tool = _create_tool_from_function(test_func)
# Verify types are correctly mapped # Verify types are correctly mapped
assert tool.parameters['properties']['name']['type'] == 'string' assert tool.parameters["properties"]["name"]["type"] == "string"
assert tool.parameters['properties']['age']['type'] == 'integer' assert tool.parameters["properties"]["age"]["type"] == "integer"
assert tool.parameters['properties']['active']['type'] == 'boolean' assert tool.parameters["properties"]["active"]["type"] == "boolean"
# Verify required vs optional # Verify required vs optional
assert 'name' in tool.parameters['required'], "name should be required" assert "name" in tool.parameters["required"], "name should be required"
assert 'age' in tool.parameters['required'], "age should be required" assert "age" in tool.parameters["required"], "age should be required"
assert 'active' not in tool.parameters['required'], "active has default, should not be required" assert (
"active" not in tool.parameters["required"]
), "active has default, should not be required"
def test_all_registered_tools_are_callable(self): def test_all_registered_tools_are_callable(self):
"""CRITICAL: Verify all registered tools are actually callable.""" """CRITICAL: Verify all registered tools are actually callable."""
tools = make_tools() tools = make_tools()
assert len(tools) > 0, "No tools registered" assert len(tools) > 0, "No tools registered"
for name, tool in tools.items(): for name, tool in tools.items():
assert callable(tool.func), f"Tool {name} is not callable" assert callable(tool.func), f"Tool {name} is not callable"
# Verify function has valid signature # Verify function has valid signature
try: try:
sig = inspect.signature(tool.func) sig = inspect.signature(tool.func)
@@ -75,38 +81,40 @@ class TestToolSpecFormat:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
spec_names = {spec['function']['name'] for spec in specs} spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys()) tool_names = set(tools.keys())
missing = tool_names - spec_names missing = tool_names - spec_names
extra = spec_names - tool_names extra = spec_names - tool_names
assert not missing, f"Tools missing from specs: {missing}" assert not missing, f"Tools missing from specs: {missing}"
assert not extra, f"Extra tools in specs: {extra}" assert not extra, f"Extra tools in specs: {extra}"
assert spec_names == tool_names, "Tool specs don't match registered tools" assert spec_names == tool_names, "Tool specs don't match registered tools"
def test_tool_description_extracted_from_docstring(self): def test_tool_description_extracted_from_docstring(self):
"""Verify tool description is extracted from function docstring.""" """Verify tool description is extracted from function docstring."""
def test_func(param: str): def test_func(param: str):
"""This is the description. """This is the description.
More details here. More details here.
""" """
return {} return {}
tool = _create_tool_from_function(test_func) tool = _create_tool_from_function(test_func)
assert tool.description == "This is the description." assert tool.description == "This is the description."
assert "More details" not in tool.description assert "More details" not in tool.description
def test_tool_without_docstring_uses_function_name(self): def test_tool_without_docstring_uses_function_name(self):
"""Verify tool without docstring uses function name as description.""" """Verify tool without docstring uses function name as description."""
def test_func_no_doc(param: str): def test_func_no_doc(param: str):
return {} return {}
tool = _create_tool_from_function(test_func_no_doc) tool = _create_tool_from_function(test_func_no_doc)
assert tool.description == "test_func_no_doc" assert tool.description == "test_func_no_doc"
def test_tool_parameters_have_descriptions(self): def test_tool_parameters_have_descriptions(self):
@@ -114,25 +122,27 @@ class TestToolSpecFormat:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
for spec in specs: for spec in specs:
params = spec['function']['parameters'] params = spec["function"]["parameters"]
properties = params.get('properties', {}) properties = params.get("properties", {})
for param_name, param_spec in properties.items(): for param_name, param_spec in properties.items():
assert 'description' in param_spec, \ assert (
f"Parameter {param_name} in {spec['function']['name']} missing description" "description" in param_spec
), f"Parameter {param_name} in {spec['function']['name']} missing description"
def test_required_parameters_are_marked_correctly(self): def test_required_parameters_are_marked_correctly(self):
"""Verify required parameters are correctly identified.""" """Verify required parameters are correctly identified."""
def func_with_optional(required: str, optional: int = 5): def func_with_optional(required: str, optional: int = 5):
return {} return {}
tool = _create_tool_from_function(func_with_optional) tool = _create_tool_from_function(func_with_optional)
assert 'required' in tool.parameters['required'] assert "required" in tool.parameters["required"]
assert 'optional' not in tool.parameters['required'] assert "optional" not in tool.parameters["required"]
assert len(tool.parameters['required']) == 1 assert len(tool.parameters["required"]) == 1
class TestToolRegistry: class TestToolRegistry:
@@ -141,28 +151,28 @@ class TestToolRegistry:
def test_make_tools_returns_dict(self): def test_make_tools_returns_dict(self):
"""Verify make_tools returns a dictionary.""" """Verify make_tools returns a dictionary."""
tools = make_tools() tools = make_tools()
assert isinstance(tools, dict) assert isinstance(tools, dict)
assert len(tools) > 0 assert len(tools) > 0
def test_all_tools_have_unique_names(self): def test_all_tools_have_unique_names(self):
"""Verify all tool names are unique.""" """Verify all tool names are unique."""
tools = make_tools() tools = make_tools()
names = [tool.name for tool in tools.values()] names = [tool.name for tool in tools.values()]
assert len(names) == len(set(names)), "Duplicate tool names found" assert len(names) == len(set(names)), "Duplicate tool names found"
def test_tool_names_match_dict_keys(self): def test_tool_names_match_dict_keys(self):
"""Verify tool names match their dictionary keys.""" """Verify tool names match their dictionary keys."""
tools = make_tools() tools = make_tools()
for key, tool in tools.items(): for key, tool in tools.items():
assert key == tool.name, f"Key {key} doesn't match tool name {tool.name}" assert key == tool.name, f"Key {key} doesn't match tool name {tool.name}"
def test_expected_tools_are_registered(self): def test_expected_tools_are_registered(self):
"""Verify all expected tools are registered.""" """Verify all expected tools are registered."""
tools = make_tools() tools = make_tools()
expected_tools = [ expected_tools = [
"set_path_for_folder", "set_path_for_folder",
"list_folder", "list_folder",
@@ -173,14 +183,14 @@ class TestToolRegistry:
"get_torrent_by_index", "get_torrent_by_index",
"set_language", "set_language",
] ]
for expected in expected_tools: for expected in expected_tools:
assert expected in tools, f"Expected tool {expected} not registered" assert expected in tools, f"Expected tool {expected} not registered"
def test_tool_functions_return_dict(self): def test_tool_functions_return_dict(self):
"""Verify all tool functions return dictionaries.""" """Verify all tool functions return dictionaries."""
tools = make_tools() tools = make_tools()
# Test with minimal valid arguments # Test with minimal valid arguments
# Note: This is a smoke test, not full integration # Note: This is a smoke test, not full integration
for name, tool in tools.items(): for name, tool in tools.items():
@@ -195,16 +205,17 @@ class TestToolDataclass:
def test_tool_creation(self): def test_tool_creation(self):
"""Verify Tool can be created with all fields.""" """Verify Tool can be created with all fields."""
def dummy_func(): def dummy_func():
return {} return {}
tool = Tool( tool = Tool(
name="test_tool", name="test_tool",
description="Test description", description="Test description",
func=dummy_func, func=dummy_func,
parameters={"type": "object", "properties": {}, "required": []} parameters={"type": "object", "properties": {}, "required": []},
) )
assert tool.name == "test_tool" assert tool.name == "test_tool"
assert tool.description == "Test description" assert tool.description == "Test description"
assert tool.func == dummy_func assert tool.func == dummy_func
@@ -212,12 +223,13 @@ class TestToolDataclass:
def test_tool_parameters_structure(self): def test_tool_parameters_structure(self):
"""Verify Tool parameters have correct structure.""" """Verify Tool parameters have correct structure."""
def dummy_func(arg: str): def dummy_func(arg: str):
return {} return {}
tool = _create_tool_from_function(dummy_func) tool = _create_tool_from_function(dummy_func)
assert 'type' in tool.parameters assert "type" in tool.parameters
assert 'properties' in tool.parameters assert "properties" in tool.parameters
assert 'required' in tool.parameters assert "required" in tool.parameters
assert tool.parameters['type'] == 'object' assert tool.parameters["type"] == "object"

View File

@@ -1,6 +1,7 @@
"""Edge case tests for tool registry.""" """Edge case tests for tool registry."""
import pytest import pytest
from unittest.mock import Mock
from agent.registry import Tool, make_tools from agent.registry import Tool, make_tools
@@ -182,7 +183,9 @@ class TestMakeToolsEdgeCases:
params = tool.parameters params = tool.parameters
if "required" in params and "properties" in params: if "required" in params and "properties" in params:
for req in params["required"]: for req in params["required"]:
assert req in params["properties"], f"Required param {req} not in properties for {tool.name}" assert (
req in params["properties"]
), f"Required param {req} not in properties for {tool.name}"
def test_make_tools_descriptions_not_empty(self, memory): def test_make_tools_descriptions_not_empty(self, memory):
"""Should have non-empty descriptions.""" """Should have non-empty descriptions."""
@@ -233,7 +236,9 @@ class TestMakeToolsEdgeCases:
if "properties" in tool.parameters: if "properties" in tool.parameters:
for prop_name, prop_schema in tool.parameters["properties"].items(): for prop_name, prop_schema in tool.parameters["properties"].items():
if "type" in prop_schema: if "type" in prop_schema:
assert prop_schema["type"] in valid_types, f"Invalid type for {tool.name}.{prop_name}" assert (
prop_schema["type"] in valid_types
), f"Invalid type for {tool.name}.{prop_name}"
def test_make_tools_enum_values(self, memory): def test_make_tools_enum_values(self, memory):
"""Should have valid enum values.""" """Should have valid enum values."""

View File

@@ -1,19 +1,18 @@
"""Tests for JSON repositories.""" """Tests for JSON repositories."""
import pytest
from datetime import datetime
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonTVShowRepository,
JsonSubtitleRepository,
)
from domain.movies.entities import Movie from domain.movies.entities import Movie
from domain.movies.value_objects import MovieTitle, ReleaseYear, Quality from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear
from domain.tv_shows.entities import TVShow from domain.shared.value_objects import FilePath, FileSize, ImdbId
from domain.tv_shows.value_objects import ShowStatus
from domain.subtitles.entities import Subtitle from domain.subtitles.entities import Subtitle
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
from domain.shared.value_objects import ImdbId, FilePath, FileSize from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonSubtitleRepository,
JsonTVShowRepository,
)
class TestJsonMovieRepository: class TestJsonMovieRepository:
@@ -224,7 +223,9 @@ class TestJsonTVShowRepository:
"""Should preserve show status.""" """Should preserve show status."""
repo = JsonTVShowRepository() repo = JsonTVShowRepository()
for i, status in enumerate([ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]): for i, status in enumerate(
[ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]
):
show = TVShow( show = TVShow(
imdb_id=ImdbId(f"tt{i+1000000:07d}"), imdb_id=ImdbId(f"tt{i+1000000:07d}"),
title=f"Show {status.value}", title=f"Show {status.value}",
@@ -294,18 +295,22 @@ class TestJsonSubtitleRepository:
def test_find_by_media_with_language_filter(self, memory): def test_find_by_media_with_language_filter(self, memory):
"""Should filter by language.""" """Should filter by language."""
repo = JsonSubtitleRepository() repo = JsonSubtitleRepository()
repo.save(Subtitle( repo.save(
media_imdb_id=ImdbId("tt1375666"), Subtitle(
language=Language.ENGLISH, media_imdb_id=ImdbId("tt1375666"),
format=SubtitleFormat.SRT, language=Language.ENGLISH,
file_path=FilePath("/subs/en.srt"), format=SubtitleFormat.SRT,
)) file_path=FilePath("/subs/en.srt"),
repo.save(Subtitle( )
media_imdb_id=ImdbId("tt1375666"), )
language=Language.FRENCH, repo.save(
format=SubtitleFormat.SRT, Subtitle(
file_path=FilePath("/subs/fr.srt"), media_imdb_id=ImdbId("tt1375666"),
)) language=Language.FRENCH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/fr.srt"),
)
)
results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH) results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH)
@@ -315,22 +320,26 @@ class TestJsonSubtitleRepository:
def test_find_by_media_with_episode_filter(self, memory): def test_find_by_media_with_episode_filter(self, memory):
"""Should filter by season/episode.""" """Should filter by season/episode."""
repo = JsonSubtitleRepository() repo = JsonSubtitleRepository()
repo.save(Subtitle( repo.save(
media_imdb_id=ImdbId("tt0944947"), Subtitle(
language=Language.ENGLISH, media_imdb_id=ImdbId("tt0944947"),
format=SubtitleFormat.SRT, language=Language.ENGLISH,
file_path=FilePath("/subs/s01e01.srt"), format=SubtitleFormat.SRT,
season_number=1, file_path=FilePath("/subs/s01e01.srt"),
episode_number=1, season_number=1,
)) episode_number=1,
repo.save(Subtitle( )
media_imdb_id=ImdbId("tt0944947"), )
language=Language.ENGLISH, repo.save(
format=SubtitleFormat.SRT, Subtitle(
file_path=FilePath("/subs/s01e02.srt"), media_imdb_id=ImdbId("tt0944947"),
season_number=1, language=Language.ENGLISH,
episode_number=2, format=SubtitleFormat.SRT,
)) file_path=FilePath("/subs/s01e02.srt"),
season_number=1,
episode_number=2,
)
)
results = repo.find_by_media( results = repo.find_by_media(
ImdbId("tt0944947"), ImdbId("tt0944947"),

View File

@@ -21,24 +21,30 @@ def create_mock_response(status_code, json_data=None, text=None):
class TestFindMediaImdbId: class TestFindMediaImdbId:
"""Tests for find_media_imdb_id tool.""" """Tests for find_media_imdb_id tool."""
@patch('infrastructure.api.tmdb.client.requests.get') @patch("infrastructure.api.tmdb.client.requests.get")
def test_success(self, mock_get, memory): def test_success(self, mock_get, memory):
"""Should return movie info on success.""" """Should return movie info on success."""
# Mock HTTP responses # Mock HTTP responses
def mock_get_side_effect(url, **kwargs): def mock_get_side_effect(url, **kwargs):
if "search" in url: if "search" in url:
return create_mock_response(200, json_data={ return create_mock_response(
"results": [{ 200,
"id": 27205, json_data={
"title": "Inception", "results": [
"release_date": "2010-07-16", {
"overview": "A thief...", "id": 27205,
"media_type": "movie" "title": "Inception",
}] "release_date": "2010-07-16",
}) "overview": "A thief...",
"media_type": "movie",
}
]
},
)
elif "external_ids" in url: elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"}) return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
mock_get.side_effect = mock_get_side_effect mock_get.side_effect = mock_get_side_effect
result = api_tools.find_media_imdb_id("Inception") result = api_tools.find_media_imdb_id("Inception")
@@ -46,26 +52,32 @@ class TestFindMediaImdbId:
assert result["status"] == "ok" assert result["status"] == "ok"
assert result["imdb_id"] == "tt1375666" assert result["imdb_id"] == "tt1375666"
assert result["title"] == "Inception" assert result["title"] == "Inception"
# Verify HTTP calls # Verify HTTP calls
assert mock_get.call_count == 2 assert mock_get.call_count == 2
@patch('infrastructure.api.tmdb.client.requests.get') @patch("infrastructure.api.tmdb.client.requests.get")
def test_stores_in_stm(self, mock_get, memory): def test_stores_in_stm(self, mock_get, memory):
"""Should store result in STM on success.""" """Should store result in STM on success."""
def mock_get_side_effect(url, **kwargs): def mock_get_side_effect(url, **kwargs):
if "search" in url: if "search" in url:
return create_mock_response(200, json_data={ return create_mock_response(
"results": [{ 200,
"id": 27205, json_data={
"title": "Inception", "results": [
"release_date": "2010-07-16", {
"media_type": "movie" "id": 27205,
}] "title": "Inception",
}) "release_date": "2010-07-16",
"media_type": "movie",
}
]
},
)
elif "external_ids" in url: elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"}) return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
mock_get.side_effect = mock_get_side_effect mock_get.side_effect = mock_get_side_effect
api_tools.find_media_imdb_id("Inception") api_tools.find_media_imdb_id("Inception")
@@ -76,7 +88,7 @@ class TestFindMediaImdbId:
assert entity["title"] == "Inception" assert entity["title"] == "Inception"
assert mem.stm.current_topic == "searching_media" assert mem.stm.current_topic == "searching_media"
@patch('infrastructure.api.tmdb.client.requests.get') @patch("infrastructure.api.tmdb.client.requests.get")
def test_not_found(self, mock_get, memory): def test_not_found(self, mock_get, memory):
"""Should return error when not found.""" """Should return error when not found."""
mock_get.return_value = create_mock_response(200, json_data={"results": []}) mock_get.return_value = create_mock_response(200, json_data={"results": []})
@@ -86,7 +98,7 @@ class TestFindMediaImdbId:
assert result["status"] == "error" assert result["status"] == "error"
assert result["error"] == "not_found" assert result["error"] == "not_found"
@patch('infrastructure.api.tmdb.client.requests.get') @patch("infrastructure.api.tmdb.client.requests.get")
def test_does_not_store_on_error(self, mock_get, memory): def test_does_not_store_on_error(self, mock_get, memory):
"""Should not store in STM on error.""" """Should not store in STM on error."""
mock_get.return_value = create_mock_response(200, json_data={"results": []}) mock_get.return_value = create_mock_response(200, json_data={"results": []})
@@ -100,49 +112,57 @@ class TestFindMediaImdbId:
class TestFindTorrent: class TestFindTorrent:
"""Tests for find_torrent tool.""" """Tests for find_torrent tool."""
@patch('infrastructure.api.knaben.client.requests.post') @patch("infrastructure.api.knaben.client.requests.post")
def test_success(self, mock_post, memory): def test_success(self, mock_post, memory):
"""Should return torrents on success.""" """Should return torrents on success."""
mock_post.return_value = create_mock_response(200, json_data={ mock_post.return_value = create_mock_response(
"hits": [ 200,
{ json_data={
"title": "Torrent 1", "hits": [
"seeders": 100, {
"leechers": 10, "title": "Torrent 1",
"magnetUrl": "magnet:?xt=...", "seeders": 100,
"size": "2.5 GB" "leechers": 10,
}, "magnetUrl": "magnet:?xt=...",
{ "size": "2.5 GB",
"title": "Torrent 2", },
"seeders": 50, {
"leechers": 5, "title": "Torrent 2",
"magnetUrl": "magnet:?xt=...", "seeders": 50,
"size": "1.8 GB" "leechers": 5,
} "magnetUrl": "magnet:?xt=...",
] "size": "1.8 GB",
}) },
]
},
)
result = api_tools.find_torrent("Inception 1080p") result = api_tools.find_torrent("Inception 1080p")
assert result["status"] == "ok" assert result["status"] == "ok"
assert len(result["torrents"]) == 2 assert len(result["torrents"]) == 2
# Verify HTTP payload
payload = mock_post.call_args[1]['json']
assert payload['query'] == "Inception 1080p"
@patch('infrastructure.api.knaben.client.requests.post') # Verify HTTP payload
payload = mock_post.call_args[1]["json"]
assert payload["query"] == "Inception 1080p"
@patch("infrastructure.api.knaben.client.requests.post")
def test_stores_in_episodic(self, mock_post, memory): def test_stores_in_episodic(self, mock_post, memory):
"""Should store results in episodic memory.""" """Should store results in episodic memory."""
mock_post.return_value = create_mock_response(200, json_data={ mock_post.return_value = create_mock_response(
"hits": [{ 200,
"title": "Torrent 1", json_data={
"seeders": 100, "hits": [
"leechers": 10, {
"magnetUrl": "magnet:?xt=...", "title": "Torrent 1",
"size": "2.5 GB" "seeders": 100,
}] "leechers": 10,
}) "magnetUrl": "magnet:?xt=...",
"size": "2.5 GB",
}
]
},
)
api_tools.find_torrent("Inception") api_tools.find_torrent("Inception")
@@ -151,16 +171,37 @@ class TestFindTorrent:
assert mem.episodic.last_search_results["query"] == "Inception" assert mem.episodic.last_search_results["query"] == "Inception"
assert mem.stm.current_topic == "selecting_torrent" assert mem.stm.current_topic == "selecting_torrent"
@patch('infrastructure.api.knaben.client.requests.post') @patch("infrastructure.api.knaben.client.requests.post")
def test_results_have_indexes(self, mock_post, memory): def test_results_have_indexes(self, mock_post, memory):
"""Should add indexes to results.""" """Should add indexes to results."""
mock_post.return_value = create_mock_response(200, json_data={ mock_post.return_value = create_mock_response(
"hits": [ 200,
{"title": "Torrent 1", "seeders": 100, "leechers": 10, "magnetUrl": "magnet:?xt=1", "size": "1GB"}, json_data={
{"title": "Torrent 2", "seeders": 50, "leechers": 5, "magnetUrl": "magnet:?xt=2", "size": "2GB"}, "hits": [
{"title": "Torrent 3", "seeders": 25, "leechers": 2, "magnetUrl": "magnet:?xt=3", "size": "3GB"} {
] "title": "Torrent 1",
}) "seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=1",
"size": "1GB",
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=2",
"size": "2GB",
},
{
"title": "Torrent 3",
"seeders": 25,
"leechers": 2,
"magnetUrl": "magnet:?xt=3",
"size": "3GB",
},
]
},
)
api_tools.find_torrent("Test") api_tools.find_torrent("Test")
@@ -170,7 +211,7 @@ class TestFindTorrent:
assert results[1]["index"] == 2 assert results[1]["index"] == 2
assert results[2]["index"] == 3 assert results[2]["index"] == 3
@patch('infrastructure.api.knaben.client.requests.post') @patch("infrastructure.api.knaben.client.requests.post")
def test_not_found(self, mock_post, memory): def test_not_found(self, mock_post, memory):
"""Should return error when no torrents found.""" """Should return error when no torrents found."""
mock_post.return_value = create_mock_response(200, json_data={"hits": []}) mock_post.return_value = create_mock_response(200, json_data={"hits": []})
@@ -236,16 +277,16 @@ class TestGetTorrentByIndex:
class TestAddTorrentToQbittorrent: class TestAddTorrentToQbittorrent:
"""Tests for add_torrent_to_qbittorrent tool. """Tests for add_torrent_to_qbittorrent tool.
Note: These tests mock the qBittorrent client because: Note: These tests mock the qBittorrent client because:
1. The client requires authentication/session management 1. The client requires authentication/session management
2. We want to test the tool's logic (memory updates, workflow management) 2. We want to test the tool's logic (memory updates, workflow management)
3. The client itself is tested separately in infrastructure tests 3. The client itself is tested separately in infrastructure tests
This is acceptable mocking because we're testing the TOOL logic, not the client. This is acceptable mocking because we're testing the TOOL logic, not the client.
""" """
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory): def test_success(self, mock_client, memory):
"""Should add torrent successfully and update memory.""" """Should add torrent successfully and update memory."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
@@ -257,7 +298,7 @@ class TestAddTorrentToQbittorrent:
# Verify client was called correctly # Verify client was called correctly
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123") mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_adds_to_active_downloads(self, mock_client, memory_with_search_results): def test_adds_to_active_downloads(self, mock_client, memory_with_search_results):
"""Should add to active downloads on success.""" """Should add to active downloads on success."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
@@ -267,9 +308,12 @@ class TestAddTorrentToQbittorrent:
# Test memory update logic # Test memory update logic
mem = get_memory() mem = get_memory()
assert len(mem.episodic.active_downloads) == 1 assert len(mem.episodic.active_downloads) == 1
assert mem.episodic.active_downloads[0]["name"] == "Inception.2010.1080p.BluRay.x264" assert (
mem.episodic.active_downloads[0]["name"]
== "Inception.2010.1080p.BluRay.x264"
)
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_sets_topic_and_ends_workflow(self, mock_client, memory): def test_sets_topic_and_ends_workflow(self, mock_client, memory):
"""Should set topic and end workflow.""" """Should set topic and end workflow."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
@@ -282,10 +326,11 @@ class TestAddTorrentToQbittorrent:
assert mem.stm.current_topic == "downloading" assert mem.stm.current_topic == "downloading"
assert mem.stm.current_workflow is None assert mem.stm.current_workflow is None
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_error_handling(self, mock_client, memory): def test_error_handling(self, mock_client, memory):
"""Should handle client errors correctly.""" """Should handle client errors correctly."""
from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError
mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed") mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed")
result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...") result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...")
@@ -296,7 +341,7 @@ class TestAddTorrentToQbittorrent:
class TestAddTorrentByIndex: class TestAddTorrentByIndex:
"""Tests for add_torrent_by_index tool. """Tests for add_torrent_by_index tool.
These tests verify the tool's logic: These tests verify the tool's logic:
- Getting torrent from memory by index - Getting torrent from memory by index
- Extracting magnet link - Extracting magnet link
@@ -304,7 +349,7 @@ class TestAddTorrentByIndex:
- Error handling for edge cases - Error handling for edge cases
""" """
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory_with_search_results): def test_success(self, mock_client, memory_with_search_results):
"""Should get torrent by index and add it.""" """Should get torrent by index and add it."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
@@ -317,7 +362,7 @@ class TestAddTorrentByIndex:
# Verify correct magnet was extracted and used # Verify correct magnet was extracted and used
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123") mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_uses_correct_magnet(self, mock_client, memory_with_search_results): def test_uses_correct_magnet(self, mock_client, memory_with_search_results):
"""Should extract correct magnet from index.""" """Should extract correct magnet from index."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True

View File

@@ -1,7 +1,8 @@
"""Edge case tests for tools.""" """Edge case tests for tools."""
from unittest.mock import Mock, patch
import pytest import pytest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
from agent.tools import api as api_tools from agent.tools import api as api_tools
from agent.tools import filesystem as fs_tools from agent.tools import filesystem as fs_tools
@@ -15,7 +16,10 @@ class TestFindTorrentEdgeCases:
def test_empty_query(self, mock_use_case_class, memory): def test_empty_query(self, mock_use_case_class, memory):
"""Should handle empty query.""" """Should handle empty query."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "error", "error": "invalid_query"} mock_response.to_dict.return_value = {
"status": "error",
"error": "invalid_query",
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -28,7 +32,11 @@ class TestFindTorrentEdgeCases:
def test_very_long_query(self, mock_use_case_class, memory): def test_very_long_query(self, mock_use_case_class, memory):
"""Should handle very long query.""" """Should handle very long query."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -43,7 +51,11 @@ class TestFindTorrentEdgeCases:
def test_special_characters_in_query(self, mock_use_case_class, memory): def test_special_characters_in_query(self, mock_use_case_class, memory):
"""Should handle special characters in query.""" """Should handle special characters in query."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -57,7 +69,11 @@ class TestFindTorrentEdgeCases:
def test_unicode_query(self, mock_use_case_class, memory): def test_unicode_query(self, mock_use_case_class, memory):
"""Should handle unicode in query.""" """Should handle unicode in query."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -161,7 +177,10 @@ class TestAddTorrentEdgeCases:
def test_empty_magnet_link(self, mock_use_case_class, memory): def test_empty_magnet_link(self, mock_use_case_class, memory):
"""Should handle empty magnet link.""" """Should handle empty magnet link."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "error", "error": "empty_magnet"} mock_response.to_dict.return_value = {
"status": "error",
"error": "empty_magnet",
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -326,7 +345,10 @@ class TestFilesystemEdgeCases:
for attempt in attempts: for attempt in attempts:
result = fs_tools.list_folder("download", attempt) result = fs_tools.list_folder("download", attempt)
# Should either be forbidden or not found # Should either be forbidden or not found
assert result.get("error") in ["forbidden", "not_found", None] or result.get("status") == "ok" assert (
result.get("error") in ["forbidden", "not_found", None]
or result.get("status") == "ok"
)
def test_path_with_null_byte(self, memory, real_folder): def test_path_with_null_byte(self, memory, real_folder):
"""Should block null byte injection.""" """Should block null byte injection."""