Formatting

This commit is contained in:
2025-12-07 03:33:51 +01:00
parent a923a760ef
commit 4eae1d6d58
24 changed files with 1003 additions and 833 deletions

View File

@@ -1,7 +1,8 @@
"""Main agent for media library management."""
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any
from infrastructure.persistence import get_memory
@@ -15,157 +16,156 @@ logger = logging.getLogger(__name__)
class Agent:
"""
AI agent for media library management.
Uses OpenAI-compatible tool calling API.
"""
def __init__(self, llm, max_tool_iterations: int = 5):
"""
Initialize the agent.
Args:
llm: LLM client with complete() method
max_tool_iterations: Maximum number of tool execution iterations
"""
self.llm = llm
self.tools: Dict[str, Tool] = make_tools()
self.tools: dict[str, Tool] = make_tools()
self.prompt_builder = PromptBuilder(self.tools)
self.max_tool_iterations = max_tool_iterations
def step(self, user_input: str) -> str:
"""
Execute one agent step with the user input.
This method:
1. Adds user message to memory
2. Builds prompt with history and context
3. Calls LLM, executing tools as needed
4. Returns final response
Args:
user_input: User's message
Returns:
Agent's final response
"""
memory = get_memory()
# Add user message to history
memory.stm.add_message("user", user_input)
memory.save()
# Build initial messages
system_prompt = self.prompt_builder.build_system_prompt()
messages: List[Dict[str, Any]] = [
{"role": "system", "content": system_prompt}
]
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
# Add conversation history
history = memory.stm.get_recent_history(settings.max_history_messages)
messages.extend(history)
# Add unread events if any
unread_events = memory.episodic.get_unread_events()
if unread_events:
events_text = "\n".join([
f"- {e['type']}: {e['data']}"
for e in unread_events
])
messages.append({
"role": "system",
"content": f"Background events:\n{events_text}"
})
events_text = "\n".join(
[f"- {e['type']}: {e['data']}" for e in unread_events]
)
messages.append(
{"role": "system", "content": f"Background events:\n{events_text}"}
)
# Get tools specification for OpenAI format
tools_spec = self.prompt_builder.build_tools_spec()
# Tool execution loop
for iteration in range(self.max_tool_iterations):
# Call LLM with tools
llm_result = self.llm.complete(messages, tools=tools_spec)
# Handle both tuple (response, usage) and dict response
if isinstance(llm_result, tuple):
response_message, usage = llm_result
else:
response_message = llm_result
# Check if there are tool calls
tool_calls = response_message.get("tool_calls")
if not tool_calls:
# No tool calls, this is the final response
final_content = response_message.get("content", "")
memory.stm.add_message("assistant", final_content)
memory.save()
return final_content
# Add assistant message with tool calls to conversation
messages.append(response_message)
# Execute each tool call
for tool_call in tool_calls:
tool_result = self._execute_tool_call(tool_call)
# Add tool result to messages
messages.append({
"tool_call_id": tool_call.get("id"),
"role": "tool",
"name": tool_call.get("function", {}).get("name"),
"content": json.dumps(tool_result, ensure_ascii=False),
})
messages.append(
{
"tool_call_id": tool_call.get("id"),
"role": "tool",
"name": tool_call.get("function", {}).get("name"),
"content": json.dumps(tool_result, ensure_ascii=False),
}
)
# Max iterations reached, force final response
messages.append({
"role": "system",
"content": "Please provide a final response to the user without using any more tools."
})
messages.append(
{
"role": "system",
"content": "Please provide a final response to the user without using any more tools.",
}
)
llm_result = self.llm.complete(messages)
if isinstance(llm_result, tuple):
final_message, usage = llm_result
else:
final_message = llm_result
final_response = final_message.get("content", "I've completed the requested actions.")
final_response = final_message.get(
"content", "I've completed the requested actions."
)
memory.stm.add_message("assistant", final_response)
memory.save()
return final_response
def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
"""
Execute a single tool call.
Args:
tool_call: OpenAI-format tool call dict
Returns:
Result dictionary
"""
function = tool_call.get("function", {})
tool_name = function.get("name", "")
try:
args_str = function.get("arguments", "{}")
args = json.loads(args_str)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {e}")
return {
"error": "bad_args",
"message": f"Invalid JSON arguments: {e}"
}
return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"}
# Validate tool exists
if tool_name not in self.tools:
available = list(self.tools.keys())
return {
"error": "unknown_tool",
"message": f"Tool '{tool_name}' not found",
"available_tools": available
"available_tools": available,
}
tool = self.tools[tool_name]
# Execute tool
try:
result = tool.func(**args)
@@ -177,17 +177,9 @@ class Agent:
# Bad arguments
memory = get_memory()
memory.episodic.add_error(tool_name, f"bad_args: {e}")
return {
"error": "bad_args",
"message": str(e),
"tool": tool_name
}
return {"error": "bad_args", "message": str(e), "tool": tool_name}
except Exception as e:
# Other errors
memory = get_memory()
memory.episodic.add_error(tool_name, str(e))
return {
"error": "execution_failed",
"message": str(e),
"tool": tool_name
}
return {"error": "execution_failed", "message": str(e), "tool": tool_name}

View File

@@ -51,7 +51,9 @@ class DeepSeekClient:
logger.info(f"DeepSeek client initialized with model: {self.model}")
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]:
def complete(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""
Generate a completion from the LLM.
@@ -80,7 +82,9 @@ class DeepSeekClient:
raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg:
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}")
raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/v1/chat/completions"
headers = {
@@ -92,13 +96,15 @@ class DeepSeekClient:
"messages": messages,
"temperature": settings.temperature,
}
# Add tools if provided
if tools:
payload["tools"] = tools
try:
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools")
logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(
url, headers=headers, json=payload, timeout=self.timeout
)

View File

@@ -66,7 +66,9 @@ class OllamaClient:
logger.info(f"Ollama client initialized with model: {self.model}")
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]:
def complete(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""
Generate a completion from the LLM.
@@ -95,7 +97,9 @@ class OllamaClient:
raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg:
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}")
raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/api/chat"
payload = {
@@ -106,13 +110,15 @@ class OllamaClient:
"temperature": self.temperature,
},
}
# Add tools if provided
if tools:
payload["tools"] = tools
try:
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools")
logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(url, json=payload, timeout=self.timeout)
response.raise_for_status()
data = response.json()

View File

@@ -1,18 +1,20 @@
"""Prompt builder for the agent system."""
from typing import Dict, List, Any
import json
from typing import Any
from infrastructure.persistence import get_memory
from .registry import Tool
from infrastructure.persistence import get_memory
class PromptBuilder:
"""Builds system prompts for the agent with memory context."""
def __init__(self, tools: Dict[str, Tool]):
def __init__(self, tools: dict[str, Tool]):
self.tools = tools
def build_tools_spec(self) -> List[Dict[str, Any]]:
def build_tools_spec(self) -> list[dict[str, Any]]:
"""Build the tool specification for the LLM API."""
tool_specs = []
for tool in self.tools.values():
@@ -44,11 +46,13 @@ class PromptBuilder:
if memory.episodic.last_search_results:
results = memory.episodic.last_search_results
result_list = results.get('results', [])
lines.append(f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)")
result_list = results.get("results", [])
lines.append(
f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)"
)
# Show first 5 results
for i, result in enumerate(result_list[:5]):
name = result.get('name', 'Unknown')
name = result.get("name", "Unknown")
lines.append(f" {i+1}. {name}")
if len(result_list) > 5:
lines.append(f" ... and {len(result_list) - 5} more")
@@ -57,7 +61,7 @@ class PromptBuilder:
question = memory.episodic.pending_question
lines.append(f"\nPENDING QUESTION: {question.get('question')}")
lines.append(f" Type: {question.get('type')}")
if question.get('options'):
if question.get("options"):
lines.append(f" Options: {len(question.get('options'))}")
if memory.episodic.active_downloads:
@@ -68,10 +72,12 @@ class PromptBuilder:
if memory.episodic.recent_errors:
lines.append("\nRECENT ERRORS (up to 3):")
for error in memory.episodic.recent_errors[-3:]:
lines.append(f" - Action '{error.get('action')}' failed: {error.get('error')}")
lines.append(
f" - Action '{error.get('action')}' failed: {error.get('error')}"
)
# Unread events
unread = [e for e in memory.episodic.background_events if not e.get('read')]
unread = [e for e in memory.episodic.background_events if not e.get("read")]
if unread:
lines.append(f"\nUNREAD EVENTS: {len(unread)}")
for event in unread[:3]:
@@ -86,8 +92,10 @@ class PromptBuilder:
if memory.stm.current_workflow:
workflow = memory.stm.current_workflow
lines.append(f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})")
if workflow.get('target'):
lines.append(
f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})"
)
if workflow.get("target"):
lines.append(f" Target: {workflow.get('target')}")
if memory.stm.current_topic:
@@ -97,7 +105,7 @@ class PromptBuilder:
lines.append("EXTRACTED ENTITIES:")
for key, value in memory.stm.extracted_entities.items():
lines.append(f" - {key}: {value}")
if memory.stm.language:
lines.append(f"CONVERSATION LANGUAGE: {memory.stm.language}")
@@ -106,7 +114,7 @@ class PromptBuilder:
def _format_config_context(self) -> str:
"""Format configuration context."""
memory = get_memory()
lines = ["CURRENT CONFIGURATION:"]
if memory.ltm.config:
for key, value in memory.ltm.config.items():
@@ -118,10 +126,10 @@ class PromptBuilder:
def build_system_prompt(self) -> str:
"""Build the complete system prompt."""
memory = get_memory()
# Base instruction
base = "You are a helpful AI assistant for managing a media library."
# Language instruction
language_instruction = (
"Your first task is to determine the user's language from their message "

View File

@@ -1,8 +1,10 @@
"""Tool registry - defines and registers all available tools for the agent."""
from dataclasses import dataclass
from typing import Callable, Any, Dict
import logging
import inspect
import logging
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__)
@@ -10,36 +12,37 @@ logger = logging.getLogger(__name__)
@dataclass
class Tool:
"""Represents a tool that can be used by the agent."""
name: str
description: str
func: Callable[..., Dict[str, Any]]
parameters: Dict[str, Any]
func: Callable[..., dict[str, Any]]
parameters: dict[str, Any]
def _create_tool_from_function(func: Callable) -> Tool:
"""
Create a Tool object from a function.
Args:
func: Function to convert to a tool
Returns:
Tool object with metadata extracted from function
"""
sig = inspect.signature(func)
doc = inspect.getdoc(func)
# Extract description from docstring (first line)
description = doc.strip().split('\n')[0] if doc else func.__name__
description = doc.strip().split("\n")[0] if doc else func.__name__
# Build JSON schema from function signature
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
# Map Python types to JSON schema types
param_type = "string" # default
if param.annotation != inspect.Parameter.empty:
@@ -51,22 +54,22 @@ def _create_tool_from_function(func: Callable) -> Tool:
param_type = "number"
elif param.annotation == bool:
param_type = "boolean"
properties[param_name] = {
"type": param_type,
"description": f"Parameter {param_name}"
"description": f"Parameter {param_name}",
}
# Add to required if no default value
if param.default == inspect.Parameter.empty:
required.append(param_name)
parameters = {
"type": "object",
"properties": properties,
"required": required,
}
return Tool(
name=func.__name__,
description=description,
@@ -75,18 +78,18 @@ def _create_tool_from_function(func: Callable) -> Tool:
)
def make_tools() -> Dict[str, Tool]:
def make_tools() -> dict[str, Tool]:
"""
Create and register all available tools.
Returns:
Dictionary mapping tool names to Tool objects
"""
# Import tools here to avoid circular dependencies
from .tools import filesystem as fs_tools
from .tools import api as api_tools
from .tools import filesystem as fs_tools
from .tools import language as lang_tools
# List of all tool functions
tool_functions = [
fs_tools.set_path_for_folder,
@@ -98,12 +101,12 @@ def make_tools() -> Dict[str, Tool]:
api_tools.get_torrent_by_index,
lang_tools.set_language,
]
# Create Tool objects from functions
tools = {}
for func in tool_functions:
tool = _create_tool_from_function(func)
tools[tool.name] = tool
logger.info(f"Registered {len(tools)} tools: {list(tools.keys())}")
return tools

View File

@@ -1,19 +1,20 @@
"""Language management tools for the agent."""
import logging
from typing import Dict, Any
from typing import Any
from infrastructure.persistence import get_memory
logger = logging.getLogger(__name__)
def set_language(language: str) -> Dict[str, Any]:
def set_language(language: str) -> dict[str, Any]:
"""
Set the conversation language.
Args:
language: Language code (e.g., 'en', 'fr', 'es', 'de')
Returns:
Status dictionary
"""
@@ -21,17 +22,14 @@ def set_language(language: str) -> Dict[str, Any]:
memory = get_memory()
memory.stm.set_language(language)
memory.save()
logger.info(f"Language set to: {language}")
return {
"status": "ok",
"message": f"Language set to {language}",
"language": language
"language": language,
}
except Exception as e:
logger.error(f"Failed to set language: {e}")
return {
"status": "error",
"error": str(e)
}
return {"status": "error", "error": str(e)}

View File

@@ -359,9 +359,7 @@ class EpisodicMemory:
"""Get active downloads."""
return self.active_downloads
def add_error(
self, action: str, error: str, context: dict | None = None
) -> None:
def add_error(self, action: str, error: str, context: dict | None = None) -> None:
"""Record a recent error."""
self.recent_errors.append(
{
@@ -408,9 +406,7 @@ class EpisodicMemory:
"""Get the pending question."""
return self.pending_question
def resolve_pending_question(
self, answer_index: int | None = None
) -> dict | None:
def resolve_pending_question(self, answer_index: int | None = None) -> dict | None:
"""
Resolve the pending question and return the chosen option.

View File

@@ -110,4 +110,4 @@ select = [
"PL",
"UP",
]
ignore = ["W503", "PLR0913", "PLR2004"]
ignore = ["PLR0913", "PLR2004"]

View File

@@ -1,16 +1,13 @@
"""Pytest configuration and shared fixtures."""
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, MagicMock
from infrastructure.persistence import Memory, init_memory, set_memory, get_memory
from infrastructure.persistence.memory import (
LongTermMemory,
ShortTermMemory,
EpisodicMemory,
)
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, Mock
import pytest
from infrastructure.persistence import Memory, set_memory
@pytest.fixture
@@ -122,12 +119,11 @@ def memory_with_library(memory):
def mock_llm():
"""Create a mock LLM client that returns OpenAI-compatible format."""
llm = Mock()
# Return OpenAI-style message dict without tool calls
def complete_func(messages, tools=None):
return {
"role": "assistant",
"content": "I found what you're looking for!"
}
return {"role": "assistant", "content": "I found what you're looking for!"}
llm.complete = Mock(side_effect=complete_func)
return llm
@@ -136,34 +132,33 @@ def mock_llm():
def mock_llm_with_tool_call():
"""Create a mock LLM that returns a tool call then a response."""
llm = Mock()
# First call returns a tool call, second returns final response
def complete_side_effect(messages, tools=None):
if not hasattr(complete_side_effect, 'call_count'):
if not hasattr(complete_side_effect, "call_count"):
complete_side_effect.call_count = 0
complete_side_effect.call_count += 1
if complete_side_effect.call_count == 1:
# First call: return tool call
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_123",
"type": "function",
"function": {
"name": "find_torrent",
"arguments": '{"media_title": "Inception"}'
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "find_torrent",
"arguments": '{"media_title": "Inception"}',
},
}
}]
],
}
else:
# Second call: return final response
return {
"role": "assistant",
"content": "I found 3 torrents for Inception!"
}
return {"role": "assistant", "content": "I found 3 torrents for Inception!"}
llm.complete = Mock(side_effect=complete_side_effect)
return llm
@@ -248,36 +243,36 @@ def mock_deepseek():
"""
Mock DeepSeekClient for individual tests that need it.
This prevents real API calls in tests that use this fixture.
Usage:
def test_something(mock_deepseek):
# Your test code here
"""
import sys
from unittest.mock import Mock, MagicMock
from unittest.mock import Mock
# Save the original module if it exists
original_module = sys.modules.get('agent.llm.deepseek')
original_module = sys.modules.get("agent.llm.deepseek")
# Create a mock module for deepseek
mock_deepseek_module = MagicMock()
class MockDeepSeekClient:
def __init__(self, *args, **kwargs):
self.complete = Mock(return_value="Mocked LLM response")
mock_deepseek_module.DeepSeekClient = MockDeepSeekClient
# Inject the mock
sys.modules['agent.llm.deepseek'] = mock_deepseek_module
sys.modules["agent.llm.deepseek"] = mock_deepseek_module
yield mock_deepseek_module
# Restore the original module
if original_module is not None:
sys.modules['agent.llm.deepseek'] = original_module
elif 'agent.llm.deepseek' in sys.modules:
del sys.modules['agent.llm.deepseek']
sys.modules["agent.llm.deepseek"] = original_module
elif "agent.llm.deepseek" in sys.modules:
del sys.modules["agent.llm.deepseek"]
@pytest.fixture
@@ -287,8 +282,8 @@ def mock_agent_step():
Returns a context manager that patches app.agent.step.
"""
from unittest.mock import patch
def _mock_step(return_value="Mocked agent response"):
return patch("app.agent.step", return_value=return_value)
return _mock_step

View File

@@ -1,6 +1,6 @@
"""Tests for the Agent."""
from unittest.mock import Mock, patch
from unittest.mock import Mock
from agent.agent import Agent
from infrastructure.persistence import get_memory
@@ -55,8 +55,8 @@ class TestExecuteToolCall:
"id": "call_123",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
}
"arguments": '{"folder_type": "download"}',
},
}
result = agent._execute_tool_call(tool_call)
@@ -68,10 +68,7 @@ class TestExecuteToolCall:
tool_call = {
"id": "call_123",
"function": {
"name": "unknown_tool",
"arguments": '{}'
}
"function": {"name": "unknown_tool", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
@@ -84,10 +81,7 @@ class TestExecuteToolCall:
tool_call = {
"id": "call_123",
"function": {
"name": "set_path_for_folder",
"arguments": '{}'
}
"function": {"name": "set_path_for_folder", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
@@ -102,8 +96,8 @@ class TestExecuteToolCall:
"id": "call_123",
"function": {
"name": "set_path_for_folder",
"arguments": '{"folder_name": 123}' # Wrong type
}
"arguments": '{"folder_name": 123}', # Wrong type
},
}
result = agent._execute_tool_call(tool_call)
@@ -116,10 +110,7 @@ class TestExecuteToolCall:
tool_call = {
"id": "call_123",
"function": {
"name": "list_folder",
"arguments": '{invalid json}'
}
"function": {"name": "list_folder", "arguments": "{invalid json}"},
}
result = agent._execute_tool_call(tool_call)
@@ -160,40 +151,39 @@ class TestStep:
assert "found" in response.lower() or "torrent" in response.lower()
assert mock_llm_with_tool_call.complete.call_count == 2
# CRITICAL: Verify tools were passed to LLM
first_call_args = mock_llm_with_tool_call.complete.call_args_list[0]
assert first_call_args[1]['tools'] is not None, "Tools not passed to LLM!"
assert len(first_call_args[1]['tools']) > 0, "Tools list is empty!"
assert first_call_args[1]["tools"] is not None, "Tools not passed to LLM!"
assert len(first_call_args[1]["tools"]) > 0, "Tools list is empty!"
def test_step_max_iterations(self, memory, mock_llm):
"""Should stop after max iterations."""
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
# CRITICAL: Verify tools are passed (except on forced final call)
if call_count[0] <= 3:
assert tools is not None, f"Tools not passed on call {call_count[0]}!"
if call_count[0] <= 3:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": f"call_{call_count[0]}",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
}]
],
}
else:
return {
"role": "assistant",
"content": "I couldn't complete the task."
}
return {"role": "assistant", "content": "I couldn't complete the task."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)
@@ -241,49 +231,53 @@ class TestAgentIntegration:
memory.ltm.set_config("movie_folder", str(real_folder["movies"]))
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
# CRITICAL: Verify tools are passed on every call
assert tools is not None, f"Tools not passed on call {call_count[0]}!"
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
}]
],
}
elif call_count[0] == 2:
# CRITICAL: Verify tool result was sent back
tool_messages = [m for m in messages if m.get('role') == 'tool']
tool_messages = [m for m in messages if m.get("role") == "tool"]
assert len(tool_messages) > 0, "Tool result not sent back to LLM!"
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_2",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "movie"}'
"tool_calls": [
{
"id": "call_2",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "movie"}',
},
}
}]
],
}
else:
return {
"role": "assistant",
"content": "I listed both folders for you."
"content": "I listed both folders for you.",
}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
response = agent.step("List my downloads and movies")
assert call_count[0] == 3

View File

@@ -1,7 +1,9 @@
"""Edge case tests for the Agent."""
import pytest
from unittest.mock import Mock
import pytest
from agent.agent import Agent
from infrastructure.persistence import get_memory
@@ -15,19 +17,14 @@ class TestExecuteToolCallEdgeCases:
# Mock a tool that returns None
from agent.registry import Tool
agent.tools["test_tool"] = Tool(
name="test_tool",
description="Test",
func=lambda: None,
parameters={}
name="test_tool", description="Test", func=lambda: None, parameters={}
)
tool_call = {
"id": "call_123",
"function": {
"name": "test_tool",
"arguments": '{}'
}
"function": {"name": "test_tool", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
@@ -38,22 +35,17 @@ class TestExecuteToolCallEdgeCases:
agent = Agent(llm=mock_llm)
from agent.registry import Tool
def raise_interrupt():
raise KeyboardInterrupt()
agent.tools["test_tool"] = Tool(
name="test_tool",
description="Test",
func=raise_interrupt,
parameters={}
name="test_tool", description="Test", func=raise_interrupt, parameters={}
)
tool_call = {
"id": "call_123",
"function": {
"name": "test_tool",
"arguments": '{}'
}
"function": {"name": "test_tool", "arguments": "{}"},
}
with pytest.raises(KeyboardInterrupt):
@@ -68,8 +60,8 @@ class TestExecuteToolCallEdgeCases:
"id": "call_123",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}'
}
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}',
},
}
result = agent._execute_tool_call(tool_call)
@@ -84,8 +76,8 @@ class TestExecuteToolCallEdgeCases:
"id": "call_123",
"function": {
"name": "get_torrent_by_index",
"arguments": '{"index": "not an int"}'
}
"arguments": '{"index": "not an int"}',
},
}
result = agent._execute_tool_call(tool_call)
@@ -115,12 +107,10 @@ class TestStepEdgeCases:
def test_step_with_unicode_input(self, memory, mock_llm):
"""Should handle unicode input."""
def mock_complete(messages, tools=None):
return {
"role": "assistant",
"content": "日本語の応答"
}
return {"role": "assistant", "content": "日本語の応答"}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
@@ -130,12 +120,10 @@ class TestStepEdgeCases:
def test_step_llm_returns_empty(self, memory, mock_llm):
"""Should handle LLM returning empty string."""
def mock_complete(messages, tools=None):
return {
"role": "assistant",
"content": ""
}
return {"role": "assistant", "content": ""}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
@@ -161,18 +149,17 @@ class TestStepEdgeCases:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": f"call_{call_count[0]}",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
}]
],
}
return {
"role": "assistant",
"content": "Done looping"
}
return {"role": "assistant", "content": "Done looping"}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)
@@ -212,11 +199,13 @@ class TestStepEdgeCases:
def test_step_with_active_downloads(self, memory, mock_llm):
"""Should include active downloads in context."""
memory.episodic.add_active_download({
"task_id": "123",
"name": "Movie.mkv",
"progress": 50,
})
memory.episodic.add_active_download(
{
"task_id": "123",
"name": "Movie.mkv",
"progress": 50,
}
)
agent = Agent(llm=mock_llm)
response = agent.step("Hello")
@@ -257,29 +246,28 @@ class TestAgentConcurrencyEdgeCases:
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"function": {
"name": "set_path_for_folder",
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}'
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "set_path_for_folder",
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}',
},
}
}]
],
}
return {
"role": "assistant",
"content": "Path set successfully."
}
return {"role": "assistant", "content": "Path set successfully."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
response = agent.step("Set movie folder")
mem = get_memory()
@@ -292,29 +280,28 @@ class TestAgentErrorRecovery:
def test_recovers_from_tool_error(self, memory, mock_llm):
"""Should recover from tool error and continue."""
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
}]
],
}
return {
"role": "assistant",
"content": "The folder is not configured."
}
return {"role": "assistant", "content": "The folder is not configured."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
response = agent.step("List downloads")
assert "not configured" in response.lower() or len(response) > 0
@@ -322,29 +309,28 @@ class TestAgentErrorRecovery:
def test_error_tracked_in_memory(self, memory, mock_llm):
"""Should track errors in episodic memory."""
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"function": {
"name": "set_path_for_folder",
"arguments": '{}' # Missing required args
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "set_path_for_folder",
"arguments": "{}", # Missing required args
},
}
}]
],
}
return {
"role": "assistant",
"content": "Error occurred."
}
return {"role": "assistant", "content": "Error occurred."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
agent.step("Set folder")
mem = get_memory()
@@ -360,18 +346,17 @@ class TestAgentErrorRecovery:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": f"call_{call_count[0]}",
"function": {
"name": "set_path_for_folder",
"arguments": '{}' # Missing required args - will error
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "set_path_for_folder",
"arguments": "{}", # Missing required args - will error
},
}
}]
],
}
return {
"role": "assistant",
"content": "All attempts failed."
}
return {"role": "assistant", "content": "All attempts failed."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)

View File

@@ -1,6 +1,7 @@
"""Tests for FastAPI endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from unittest.mock import patch
from fastapi.testclient import TestClient
@@ -10,6 +11,7 @@ class TestHealthEndpoint:
def test_health_check(self, memory):
"""Should return healthy status."""
from app import app
client = TestClient(app)
response = client.get("/health")
@@ -24,6 +26,7 @@ class TestModelsEndpoint:
def test_list_models(self, memory):
"""Should return model list."""
from app import app
client = TestClient(app)
response = client.get("/v1/models")
@@ -41,6 +44,7 @@ class TestMemoryEndpoints:
def test_get_memory_state(self, memory):
"""Should return full memory state."""
from app import app
client = TestClient(app)
response = client.get("/memory/state")
@@ -54,6 +58,7 @@ class TestMemoryEndpoints:
def test_get_search_results_empty(self, memory):
"""Should return empty when no search results."""
from app import app
client = TestClient(app)
response = client.get("/memory/episodic/search-results")
@@ -65,6 +70,7 @@ class TestMemoryEndpoints:
def test_get_search_results_with_data(self, memory_with_search_results):
"""Should return search results when available."""
from app import app
client = TestClient(app)
response = client.get("/memory/episodic/search-results")
@@ -78,6 +84,7 @@ class TestMemoryEndpoints:
def test_clear_session(self, memory_with_search_results):
"""Should clear session memories."""
from app import app
client = TestClient(app)
response = client.post("/memory/clear-session")
@@ -96,14 +103,18 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_success(self, memory):
"""Should return chat completion."""
from app import app
# Patch the agent's step method directly
with patch("app.agent.step", return_value="Hello! How can I help?"):
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
},
)
assert response.status_code == 200
data = response.json()
@@ -113,12 +124,16 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_no_user_message(self, memory):
"""Should return error if no user message."""
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "system", "content": "You are helpful"}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "system", "content": "You are helpful"}],
},
)
assert response.status_code == 422
detail = response.json()["detail"]
@@ -132,18 +147,23 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_empty_messages(self, memory):
"""Should return error for empty messages."""
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [],
},
)
assert response.status_code == 422
def test_chat_completion_invalid_json(self, memory):
"""Should return error for invalid JSON."""
from app import app
client = TestClient(app)
response = client.post(
@@ -157,14 +177,18 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_streaming(self, memory):
"""Should support streaming mode."""
from app import app
with patch("app.agent.step", return_value="Streaming response"):
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
},
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"]
@@ -172,17 +196,21 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_extracts_last_user_message(self, memory):
"""Should use last user message."""
from app import app
with patch("app.agent.step", return_value="Response") as mock_step:
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [
{"role": "user", "content": "First message"},
{"role": "assistant", "content": "Response"},
{"role": "user", "content": "Second message"},
],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "user", "content": "First message"},
{"role": "assistant", "content": "Response"},
{"role": "user", "content": "Second message"},
],
},
)
assert response.status_code == 200
# Verify the agent received the last user message
@@ -191,13 +219,17 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_response_format(self, memory):
"""Should return OpenAI-compatible format."""
from app import app
with patch("app.agent.step", return_value="Test response"):
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Test"}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Test"}],
},
)
data = response.json()
assert "id" in data

View File

@@ -1,7 +1,7 @@
"""Edge case tests for FastAPI endpoints."""
import pytest
import json
from unittest.mock import Mock, patch, MagicMock
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
@@ -10,43 +10,46 @@ class TestChatCompletionsEdgeCases:
def test_very_long_message(self, memory):
"""Should handle very long user message."""
from app import app, agent
from app import agent, app
# Patch the agent's LLM directly
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
long_message = "x" * 100000
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": long_message}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": long_message}],
},
)
assert response.status_code == 200
def test_unicode_message(self, memory):
"""Should handle unicode in message."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "日本語の応答"
"content": "日本語の応答",
}
agent.llm = mock_llm
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}],
},
)
assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"]
@@ -54,22 +57,22 @@ class TestChatCompletionsEdgeCases:
def test_special_characters_in_message(self, memory):
"""Should handle special characters."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
special_message = 'Test with "quotes" and \\backslash and \n newline'
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": special_message}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": special_message}],
},
)
assert response.status_code == 200
@@ -81,12 +84,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": ""}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": ""}],
},
)
# Empty content should be rejected
assert response.status_code == 422
@@ -98,12 +105,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": None}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": None}],
},
)
assert response.status_code == 422
@@ -114,12 +125,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user"}], # No content
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user"}], # No content
},
)
# May accept or reject depending on validation
assert response.status_code in [200, 400, 422]
@@ -131,12 +146,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"content": "Hello"}], # No role
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"content": "Hello"}], # No role
},
)
# Should reject or accept depending on validation
assert response.status_code in [200, 400, 422]
@@ -149,27 +168,28 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "invalid_role", "content": "Hello"}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "invalid_role", "content": "Hello"}],
},
)
# Should reject or ignore invalid role
assert response.status_code in [200, 400, 422]
def test_many_messages(self, memory):
"""Should handle many messages in conversation."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
messages = []
@@ -178,10 +198,13 @@ class TestChatCompletionsEdgeCases:
messages.append({"role": "assistant", "content": f"Response {i}"})
messages.append({"role": "user", "content": "Final message"})
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": messages,
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": messages,
},
)
assert response.status_code == 200
@@ -192,15 +215,19 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "system", "content": "Be concise"},
],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "system", "content": "Be concise"},
],
},
)
assert response.status_code == 422
@@ -211,14 +238,18 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [
{"role": "assistant", "content": "Hello"},
],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "assistant", "content": "Hello"},
],
},
)
assert response.status_code == 422
@@ -229,12 +260,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": "not an array",
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": "not an array",
},
)
assert response.status_code == 422
# Pydantic validation error
@@ -246,118 +281,128 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": ["not an object", 123, None],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": ["not an object", 123, None],
},
)
assert response.status_code == 422
# Pydantic validation error
def test_extra_fields_in_request(self, memory):
"""Should ignore extra fields in request."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
"extra_field": "should be ignored",
"temperature": 0.7,
"max_tokens": 100,
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
"extra_field": "should be ignored",
"temperature": 0.7,
"max_tokens": 100,
},
)
assert response.status_code == 200
def test_streaming_with_tool_call(self, memory, real_folder):
"""Should handle streaming with tool execution."""
from app import app, agent
from app import agent, app
from infrastructure.persistence import get_memory
mem = get_memory()
mem.ltm.set_config("download_folder", str(real_folder["downloads"]))
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
}]
],
}
return {
"role": "assistant",
"content": "Listed the folder."
}
return {"role": "assistant", "content": "Listed the folder."}
mock_llm = Mock()
mock_llm.complete = Mock(side_effect=mock_complete)
agent.llm = mock_llm
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": "List downloads"}],
"stream": True,
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "List downloads"}],
"stream": True,
},
)
assert response.status_code == 200
def test_concurrent_requests_simulation(self, memory):
"""Should handle rapid sequential requests."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
for i in range(10):
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": f"Request {i}"}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": f"Request {i}"}],
},
)
assert response.status_code == 200
def test_llm_returns_json_in_response(self, memory):
"""Should handle LLM returning JSON in text response."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": '{"result": "some data", "count": 5}'
"content": '{"result": "some data", "count": 5}',
}
agent.llm = mock_llm
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Give me JSON"}],
})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Give me JSON"}],
},
)
assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"]
@@ -425,6 +470,7 @@ class TestMemoryEndpointsEdgeCases:
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
# Clear multiple times
@@ -459,6 +505,7 @@ class TestHealthEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/health")
@@ -471,6 +518,7 @@ class TestHealthEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/health?extra=param&another=value")
@@ -486,6 +534,7 @@ class TestModelsEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/v1/models")

View File

@@ -1,9 +1,9 @@
"""Critical tests for configuration validation."""
import pytest
import os
from agent.config import Settings, ConfigurationError
import pytest
from agent.config import ConfigurationError, Settings
class TestConfigValidation:
@@ -13,7 +13,7 @@ class TestConfigValidation:
"""Verify invalid temperature is rejected."""
with pytest.raises(ConfigurationError, match="Temperature"):
Settings(temperature=3.0) # > 2.0
with pytest.raises(ConfigurationError, match="Temperature"):
Settings(temperature=-0.1) # < 0.0
@@ -28,7 +28,7 @@ class TestConfigValidation:
"""Verify invalid max_iterations is rejected."""
with pytest.raises(ConfigurationError, match="max_tool_iterations"):
Settings(max_tool_iterations=0) # < 1
with pytest.raises(ConfigurationError, match="max_tool_iterations"):
Settings(max_tool_iterations=100) # > 20
@@ -43,7 +43,7 @@ class TestConfigValidation:
"""Verify invalid timeout is rejected."""
with pytest.raises(ConfigurationError, match="request_timeout"):
Settings(request_timeout=0) # < 1
with pytest.raises(ConfigurationError, match="request_timeout"):
Settings(request_timeout=500) # > 300
@@ -58,7 +58,7 @@ class TestConfigValidation:
"""Verify invalid DeepSeek URL is rejected."""
with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"):
Settings(deepseek_base_url="not-a-url")
with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"):
Settings(deepseek_base_url="ftp://invalid.com")
@@ -86,19 +86,17 @@ class TestConfigChecks:
def test_is_deepseek_configured_with_key(self):
"""Verify is_deepseek_configured returns True with API key."""
settings = Settings(
deepseek_api_key="test-key",
deepseek_base_url="https://api.test.com"
deepseek_api_key="test-key", deepseek_base_url="https://api.test.com"
)
assert settings.is_deepseek_configured() is True
def test_is_deepseek_configured_without_key(self):
"""Verify is_deepseek_configured returns False without API key."""
settings = Settings(
deepseek_api_key="",
deepseek_base_url="https://api.test.com"
deepseek_api_key="", deepseek_base_url="https://api.test.com"
)
assert settings.is_deepseek_configured() is False
def test_is_deepseek_configured_without_url(self):
@@ -110,19 +108,15 @@ class TestConfigChecks:
def test_is_tmdb_configured_with_key(self):
"""Verify is_tmdb_configured returns True with API key."""
settings = Settings(
tmdb_api_key="test-key",
tmdb_base_url="https://api.test.com"
tmdb_api_key="test-key", tmdb_base_url="https://api.test.com"
)
assert settings.is_tmdb_configured() is True
def test_is_tmdb_configured_without_key(self):
"""Verify is_tmdb_configured returns False without API key."""
settings = Settings(
tmdb_api_key="",
tmdb_base_url="https://api.test.com"
)
settings = Settings(tmdb_api_key="", tmdb_base_url="https://api.test.com")
assert settings.is_tmdb_configured() is False
@@ -132,25 +126,25 @@ class TestConfigDefaults:
def test_default_temperature(self):
"""Verify default temperature is reasonable."""
settings = Settings()
assert 0.0 <= settings.temperature <= 2.0
def test_default_max_iterations(self):
"""Verify default max_iterations is reasonable."""
settings = Settings()
assert 1 <= settings.max_tool_iterations <= 20
def test_default_timeout(self):
"""Verify default timeout is reasonable."""
settings = Settings()
assert 1 <= settings.request_timeout <= 300
def test_default_urls_are_valid(self):
"""Verify default URLs are valid."""
settings = Settings()
assert settings.deepseek_base_url.startswith(("http://", "https://"))
assert settings.tmdb_base_url.startswith(("http://", "https://"))
@@ -161,38 +155,38 @@ class TestConfigEnvironmentVariables:
def test_loads_temperature_from_env(self, monkeypatch):
"""Verify temperature is loaded from environment."""
monkeypatch.setenv("TEMPERATURE", "0.5")
settings = Settings()
assert settings.temperature == 0.5
def test_loads_max_iterations_from_env(self, monkeypatch):
"""Verify max_iterations is loaded from environment."""
monkeypatch.setenv("MAX_TOOL_ITERATIONS", "10")
settings = Settings()
assert settings.max_tool_iterations == 10
def test_loads_timeout_from_env(self, monkeypatch):
"""Verify timeout is loaded from environment."""
monkeypatch.setenv("REQUEST_TIMEOUT", "60")
settings = Settings()
assert settings.request_timeout == 60
def test_loads_deepseek_url_from_env(self, monkeypatch):
"""Verify DeepSeek URL is loaded from environment."""
monkeypatch.setenv("DEEPSEEK_BASE_URL", "https://custom.api.com")
settings = Settings()
assert settings.deepseek_base_url == "https://custom.api.com"
def test_invalid_env_value_raises_error(self, monkeypatch):
"""Verify invalid environment value raises error."""
monkeypatch.setenv("TEMPERATURE", "invalid")
with pytest.raises(ValueError):
Settings()

View File

@@ -1,12 +1,14 @@
"""Edge case tests for configuration and parameters."""
import pytest
import os
from unittest.mock import patch
from agent.config import Settings, ConfigurationError
import pytest
from agent.config import ConfigurationError, Settings
from agent.parameters import (
ParameterSchema,
REQUIRED_PARAMETERS,
ParameterSchema,
format_parameters_for_prompt,
get_missing_required_parameters,
)
@@ -110,19 +112,27 @@ class TestSettingsEdgeCases:
def test_http_url_accepted(self):
"""Should accept http:// URLs."""
with patch.dict(os.environ, {
"DEEPSEEK_BASE_URL": "http://localhost:8080",
"TMDB_BASE_URL": "http://localhost:3000",
}, clear=True):
with patch.dict(
os.environ,
{
"DEEPSEEK_BASE_URL": "http://localhost:8080",
"TMDB_BASE_URL": "http://localhost:3000",
},
clear=True,
):
settings = Settings()
assert settings.deepseek_base_url == "http://localhost:8080"
def test_https_url_accepted(self):
"""Should accept https:// URLs."""
with patch.dict(os.environ, {
"DEEPSEEK_BASE_URL": "https://api.example.com",
"TMDB_BASE_URL": "https://api.example.com",
}, clear=True):
with patch.dict(
os.environ,
{
"DEEPSEEK_BASE_URL": "https://api.example.com",
"TMDB_BASE_URL": "https://api.example.com",
},
clear=True,
):
settings = Settings()
assert settings.deepseek_base_url == "https://api.example.com"

View File

@@ -1,18 +1,17 @@
"""Tests for the Memory system."""
import pytest
import json
from datetime import datetime
from pathlib import Path
import pytest
from infrastructure.persistence import (
Memory,
LongTermMemory,
ShortTermMemory,
EpisodicMemory,
init_memory,
LongTermMemory,
Memory,
ShortTermMemory,
get_memory,
set_memory,
has_memory,
init_memory,
)
from infrastructure.persistence.context import _memory_ctx
@@ -23,11 +22,12 @@ def is_iso_format(s: str) -> bool:
return False
try:
# Attempt to parse the string as an ISO 8601 timestamp
datetime.fromisoformat(s.replace('Z', '+00:00'))
datetime.fromisoformat(s.replace("Z", "+00:00"))
return True
except (ValueError, TypeError):
return False
class TestLongTermMemory:
"""Tests for LongTermMemory."""
@@ -116,12 +116,18 @@ class TestLongTermMemory:
assert data["config"]["key"] == "value"
def test_from_dict(self):
data = {"config": {"download_folder": "/downloads"}, "preferences": {"preferred_quality": "4K"}, "library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]}, "following": []}
data = {
"config": {"download_folder": "/downloads"},
"preferences": {"preferred_quality": "4K"},
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
"following": [],
}
ltm = LongTermMemory.from_dict(data)
assert ltm.get_config("download_folder") == "/downloads"
assert ltm.preferences["preferred_quality"] == "4K"
assert len(ltm.library["movies"]) == 1
class TestShortTermMemory:
"""Tests for ShortTermMemory."""
@@ -162,6 +168,7 @@ class TestShortTermMemory:
assert stm.conversation_history == []
assert stm.language == "en"
class TestEpisodicMemory:
"""Tests for EpisodicMemory."""
@@ -192,6 +199,7 @@ class TestEpisodicMemory:
assert result is not None
assert result["name"] == "Result 2"
class TestMemory:
"""Tests for the Memory manager."""
@@ -217,11 +225,10 @@ class TestMemory:
assert memory.stm.conversation_history == []
assert memory.episodic.recent_errors == []
class TestMemoryContext:
"""Tests for memory context functions."""
def test_get_memory_not_initialized(self):
_memory_ctx.set(None)
with pytest.raises(RuntimeError, match="Memory not initialized"):

View File

@@ -1,18 +1,17 @@
"""Edge case tests for the Memory system."""
import pytest
import json
import os
from pathlib import Path
from datetime import datetime
from unittest.mock import patch, mock_open
import pytest
from infrastructure.persistence import (
Memory,
LongTermMemory,
ShortTermMemory,
EpisodicMemory,
init_memory,
LongTermMemory,
Memory,
ShortTermMemory,
get_memory,
init_memory,
set_memory,
)
from infrastructure.persistence.context import _memory_ctx
@@ -390,7 +389,7 @@ class TestMemoryEdgeCases:
def test_init_with_nonexistent_directory(self, temp_dir):
"""Should create directory if not exists."""
new_dir = temp_dir / "new" / "nested" / "dir"
# Create parent directories first
new_dir.mkdir(parents=True, exist_ok=True)
memory = Memory(storage_dir=str(new_dir))
@@ -529,7 +528,6 @@ class TestMemoryContextEdgeCases:
def test_context_isolation(self, temp_dir):
"""Context should be isolated per context."""
import asyncio
from contextvars import copy_context
_memory_ctx.set(None)

View File

@@ -1,10 +1,8 @@
"""Critical tests for prompt builder - Tests that would have caught bugs."""
import pytest
from agent.registry import make_tools
from agent.prompts import PromptBuilder
from infrastructure.persistence import get_memory
from agent.registry import make_tools
class TestPromptBuilderToolsInjection:
@@ -15,20 +13,22 @@ class TestPromptBuilderToolsInjection:
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Verify each tool is mentioned
for tool_name in tools.keys():
assert tool_name in prompt, f"Tool {tool_name} not mentioned in system prompt"
assert (
tool_name in prompt
), f"Tool {tool_name} not mentioned in system prompt"
def test_tools_spec_contains_all_registered_tools(self):
"""CRITICAL: Verify build_tools_spec() returns all tools."""
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
spec_names = {spec['function']['name'] for spec in specs}
spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys())
assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}"
def test_tools_spec_is_not_empty(self):
@@ -36,7 +36,7 @@ class TestPromptBuilderToolsInjection:
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
assert len(specs) > 0, "Tools spec is empty!"
def test_tools_spec_format_matches_openai(self):
@@ -44,14 +44,14 @@ class TestPromptBuilderToolsInjection:
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
for spec in specs:
assert 'type' in spec
assert spec['type'] == 'function'
assert 'function' in spec
assert 'name' in spec['function']
assert 'description' in spec['function']
assert 'parameters' in spec['function']
assert "type" in spec
assert spec["type"] == "function"
assert "function" in spec
assert "name" in spec["function"]
assert "description" in spec["function"]
assert "parameters" in spec["function"]
class TestPromptBuilderMemoryContext:
@@ -61,29 +61,29 @@ class TestPromptBuilderMemoryContext:
"""Verify current topic is included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_topic("test_topic")
prompt = builder.build_system_prompt()
assert "test_topic" in prompt
def test_prompt_includes_extracted_entities(self, memory):
"""Verify extracted entities are included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_entity("test_key", "test_value")
prompt = builder.build_system_prompt()
assert "test_key" in prompt
def test_prompt_includes_search_results(self, memory_with_search_results):
"""Verify search results are included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "Inception" in prompt
assert "LAST SEARCH" in prompt
@@ -91,15 +91,13 @@ class TestPromptBuilderMemoryContext:
"""Verify active downloads are included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.episodic.add_active_download({
"task_id": "123",
"name": "Test Movie",
"progress": 50
})
memory.episodic.add_active_download(
{"task_id": "123", "name": "Test Movie", "progress": 50}
)
prompt = builder.build_system_prompt()
assert "ACTIVE DOWNLOADS" in prompt
assert "Test Movie" in prompt
@@ -107,33 +105,33 @@ class TestPromptBuilderMemoryContext:
"""Verify recent errors are included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.episodic.add_error("test_action", "test error message")
prompt = builder.build_system_prompt()
assert "RECENT ERRORS" in prompt or "error" in prompt.lower()
def test_prompt_includes_configuration(self, memory):
"""Verify configuration is included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.ltm.set_config("download_folder", "/test/downloads")
prompt = builder.build_system_prompt()
assert "CONFIGURATION" in prompt or "download_folder" in prompt
def test_prompt_includes_language(self, memory):
"""Verify language is included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_language("fr")
prompt = builder.build_system_prompt()
assert "fr" in prompt or "LANGUAGE" in prompt
@@ -145,7 +143,7 @@ class TestPromptBuilderStructure:
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert len(prompt) > 0
assert prompt.strip() != ""
@@ -154,7 +152,7 @@ class TestPromptBuilderStructure:
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "assistant" in prompt.lower() or "help" in prompt.lower()
def test_system_prompt_includes_rules(self):
@@ -162,7 +160,7 @@ class TestPromptBuilderStructure:
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "RULES" in prompt or "IMPORTANT" in prompt
def test_system_prompt_includes_examples(self):
@@ -170,16 +168,16 @@ class TestPromptBuilderStructure:
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "EXAMPLES" in prompt or "example" in prompt.lower()
def test_tools_description_format(self):
"""Verify tools are properly formatted in description."""
tools = make_tools()
builder = PromptBuilder(tools)
description = builder._format_tools_description()
# Should have tool names and descriptions
for tool_name, tool in tools.items():
assert tool_name in description
@@ -190,9 +188,9 @@ class TestPromptBuilderStructure:
"""Verify episodic context is properly formatted."""
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_episodic_context()
assert "LAST SEARCH" in context
assert "Inception" in context
@@ -200,12 +198,12 @@ class TestPromptBuilderStructure:
"""Verify STM context is properly formatted."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_topic("test_topic")
memory.stm.set_entity("key", "value")
context = builder._format_stm_context()
assert "TOPIC" in context or "test_topic" in context
assert "ENTITIES" in context or "key" in context
@@ -213,11 +211,11 @@ class TestPromptBuilderStructure:
"""Verify config context is properly formatted."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.ltm.set_config("test_key", "test_value")
context = builder._format_config_context()
assert "CONFIGURATION" in context
assert "test_key" in context
@@ -229,10 +227,10 @@ class TestPromptBuilderEdgeCases:
"""Verify prompt works with empty memory."""
tools = make_tools()
builder = PromptBuilder(tools)
# Memory is empty
prompt = builder.build_system_prompt()
# Should still have base content
assert len(prompt) > 0
assert "assistant" in prompt.lower()
@@ -240,18 +238,18 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_empty_tools(self):
"""Verify prompt handles empty tools dict."""
builder = PromptBuilder({})
prompt = builder.build_system_prompt()
# Should still generate a prompt
assert len(prompt) > 0
def test_tools_spec_with_empty_tools(self):
"""Verify tools spec handles empty tools dict."""
builder = PromptBuilder({})
specs = builder.build_tools_spec()
assert isinstance(specs, list)
assert len(specs) == 0
@@ -259,11 +257,11 @@ class TestPromptBuilderEdgeCases:
"""Verify prompt handles unicode in memory."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_entity("movie", "Amélie 🎬")
prompt = builder.build_system_prompt()
assert "Amélie" in prompt
assert "🎬" in prompt
@@ -271,13 +269,13 @@ class TestPromptBuilderEdgeCases:
"""Verify prompt handles many search results."""
tools = make_tools()
builder = PromptBuilder(tools)
# Add many results
results = [{"name": f"Movie {i}", "seeders": i} for i in range(20)]
memory.episodic.store_search_results("test", results, "torrent")
prompt = builder.build_system_prompt()
# Should include some results but not all (to avoid huge prompts)
assert "Movie 0" in prompt or "Movie 1" in prompt
# Should indicate there are more

View File

@@ -1,10 +1,8 @@
"""Edge case tests for PromptBuilder."""
import pytest
import json
from agent.prompts import PromptBuilder
from agent.registry import make_tools
from infrastructure.persistence import get_memory
class TestPromptBuilderEdgeCases:
@@ -93,11 +91,13 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_many_active_downloads(self, memory):
"""Should limit displayed active downloads."""
for i in range(20):
memory.episodic.add_active_download({
"task_id": str(i),
"name": f"Download {i}",
"progress": i * 5,
})
memory.episodic.add_active_download(
{
"task_id": str(i),
"name": f"Download {i}",
"progress": i * 5,
}
)
tools = make_tools()
builder = PromptBuilder(tools)
@@ -136,12 +136,15 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_complex_workflow(self, memory):
"""Should handle complex workflow state."""
memory.stm.start_workflow("download", {
"title": "Test Movie",
"year": 2024,
"quality": "1080p",
"nested": {"deep": {"value": "test"}},
})
memory.stm.start_workflow(
"download",
{
"title": "Test Movie",
"year": 2024,
"quality": "1080p",
"nested": {"deep": {"value": "test"}},
},
)
memory.stm.update_workflow_stage("searching_torrents")
tools = make_tools()
@@ -313,11 +316,14 @@ class TestFormatEpisodicContextEdgeCases:
def test_format_with_search_results_none_names(self, memory):
"""Should handle results with None names."""
memory.episodic.store_search_results("test", [
{"name": None},
{"title": None},
{},
])
memory.episodic.store_search_results(
"test",
[
{"name": None},
{"title": None},
{},
],
)
tools = make_tools()
builder = PromptBuilder(tools)

View File

@@ -1,10 +1,11 @@
"""Critical tests for tool registry - Tests that would have caught bugs."""
import pytest
import inspect
from agent.registry import make_tools, _create_tool_from_function, Tool
import pytest
from agent.prompts import PromptBuilder
from agent.registry import Tool, _create_tool_from_function, make_tools
class TestToolSpecFormat:
@@ -15,54 +16,59 @@ class TestToolSpecFormat:
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
# Verify structure
assert isinstance(specs, list), "Tool specs must be a list"
assert len(specs) > 0, "Tool specs list is empty"
for spec in specs:
# OpenAI format requires these fields
assert spec['type'] == 'function', f"Tool type must be 'function', got {spec.get('type')}"
assert 'function' in spec, "Tool spec missing 'function' key"
func = spec['function']
assert 'name' in func, "Function missing 'name'"
assert 'description' in func, "Function missing 'description'"
assert 'parameters' in func, "Function missing 'parameters'"
params = func['parameters']
assert params['type'] == 'object', "Parameters type must be 'object'"
assert 'properties' in params, "Parameters missing 'properties'"
assert 'required' in params, "Parameters missing 'required'"
assert isinstance(params['required'], list), "Required must be a list"
assert (
spec["type"] == "function"
), f"Tool type must be 'function', got {spec.get('type')}"
assert "function" in spec, "Tool spec missing 'function' key"
func = spec["function"]
assert "name" in func, "Function missing 'name'"
assert "description" in func, "Function missing 'description'"
assert "parameters" in func, "Function missing 'parameters'"
params = func["parameters"]
assert params["type"] == "object", "Parameters type must be 'object'"
assert "properties" in params, "Parameters missing 'properties'"
assert "required" in params, "Parameters missing 'required'"
assert isinstance(params["required"], list), "Required must be a list"
def test_tool_parameters_match_function_signature(self):
"""CRITICAL: Verify generated parameters match function signature."""
def test_func(name: str, age: int, active: bool = True):
"""Test function with typed parameters."""
return {"status": "ok"}
tool = _create_tool_from_function(test_func)
# Verify types are correctly mapped
assert tool.parameters['properties']['name']['type'] == 'string'
assert tool.parameters['properties']['age']['type'] == 'integer'
assert tool.parameters['properties']['active']['type'] == 'boolean'
assert tool.parameters["properties"]["name"]["type"] == "string"
assert tool.parameters["properties"]["age"]["type"] == "integer"
assert tool.parameters["properties"]["active"]["type"] == "boolean"
# Verify required vs optional
assert 'name' in tool.parameters['required'], "name should be required"
assert 'age' in tool.parameters['required'], "age should be required"
assert 'active' not in tool.parameters['required'], "active has default, should not be required"
assert "name" in tool.parameters["required"], "name should be required"
assert "age" in tool.parameters["required"], "age should be required"
assert (
"active" not in tool.parameters["required"]
), "active has default, should not be required"
def test_all_registered_tools_are_callable(self):
"""CRITICAL: Verify all registered tools are actually callable."""
tools = make_tools()
assert len(tools) > 0, "No tools registered"
for name, tool in tools.items():
assert callable(tool.func), f"Tool {name} is not callable"
# Verify function has valid signature
try:
sig = inspect.signature(tool.func)
@@ -75,38 +81,40 @@ class TestToolSpecFormat:
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
spec_names = {spec['function']['name'] for spec in specs}
spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys())
missing = tool_names - spec_names
extra = spec_names - tool_names
assert not missing, f"Tools missing from specs: {missing}"
assert not extra, f"Extra tools in specs: {extra}"
assert spec_names == tool_names, "Tool specs don't match registered tools"
def test_tool_description_extracted_from_docstring(self):
"""Verify tool description is extracted from function docstring."""
def test_func(param: str):
"""This is the description.
More details here.
"""
return {}
tool = _create_tool_from_function(test_func)
assert tool.description == "This is the description."
assert "More details" not in tool.description
def test_tool_without_docstring_uses_function_name(self):
"""Verify tool without docstring uses function name as description."""
def test_func_no_doc(param: str):
return {}
tool = _create_tool_from_function(test_func_no_doc)
assert tool.description == "test_func_no_doc"
def test_tool_parameters_have_descriptions(self):
@@ -114,25 +122,27 @@ class TestToolSpecFormat:
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
for spec in specs:
params = spec['function']['parameters']
properties = params.get('properties', {})
params = spec["function"]["parameters"]
properties = params.get("properties", {})
for param_name, param_spec in properties.items():
assert 'description' in param_spec, \
f"Parameter {param_name} in {spec['function']['name']} missing description"
assert (
"description" in param_spec
), f"Parameter {param_name} in {spec['function']['name']} missing description"
def test_required_parameters_are_marked_correctly(self):
"""Verify required parameters are correctly identified."""
def func_with_optional(required: str, optional: int = 5):
return {}
tool = _create_tool_from_function(func_with_optional)
assert 'required' in tool.parameters['required']
assert 'optional' not in tool.parameters['required']
assert len(tool.parameters['required']) == 1
assert "required" in tool.parameters["required"]
assert "optional" not in tool.parameters["required"]
assert len(tool.parameters["required"]) == 1
class TestToolRegistry:
@@ -141,28 +151,28 @@ class TestToolRegistry:
def test_make_tools_returns_dict(self):
"""Verify make_tools returns a dictionary."""
tools = make_tools()
assert isinstance(tools, dict)
assert len(tools) > 0
def test_all_tools_have_unique_names(self):
"""Verify all tool names are unique."""
tools = make_tools()
names = [tool.name for tool in tools.values()]
assert len(names) == len(set(names)), "Duplicate tool names found"
def test_tool_names_match_dict_keys(self):
"""Verify tool names match their dictionary keys."""
tools = make_tools()
for key, tool in tools.items():
assert key == tool.name, f"Key {key} doesn't match tool name {tool.name}"
def test_expected_tools_are_registered(self):
"""Verify all expected tools are registered."""
tools = make_tools()
expected_tools = [
"set_path_for_folder",
"list_folder",
@@ -173,14 +183,14 @@ class TestToolRegistry:
"get_torrent_by_index",
"set_language",
]
for expected in expected_tools:
assert expected in tools, f"Expected tool {expected} not registered"
def test_tool_functions_return_dict(self):
"""Verify all tool functions return dictionaries."""
tools = make_tools()
# Test with minimal valid arguments
# Note: This is a smoke test, not full integration
for name, tool in tools.items():
@@ -195,16 +205,17 @@ class TestToolDataclass:
def test_tool_creation(self):
"""Verify Tool can be created with all fields."""
def dummy_func():
return {}
tool = Tool(
name="test_tool",
description="Test description",
func=dummy_func,
parameters={"type": "object", "properties": {}, "required": []}
parameters={"type": "object", "properties": {}, "required": []},
)
assert tool.name == "test_tool"
assert tool.description == "Test description"
assert tool.func == dummy_func
@@ -212,12 +223,13 @@ class TestToolDataclass:
def test_tool_parameters_structure(self):
"""Verify Tool parameters have correct structure."""
def dummy_func(arg: str):
return {}
tool = _create_tool_from_function(dummy_func)
assert 'type' in tool.parameters
assert 'properties' in tool.parameters
assert 'required' in tool.parameters
assert tool.parameters['type'] == 'object'
assert "type" in tool.parameters
assert "properties" in tool.parameters
assert "required" in tool.parameters
assert tool.parameters["type"] == "object"

View File

@@ -1,6 +1,7 @@
"""Edge case tests for tool registry."""
import pytest
from unittest.mock import Mock
from agent.registry import Tool, make_tools
@@ -182,7 +183,9 @@ class TestMakeToolsEdgeCases:
params = tool.parameters
if "required" in params and "properties" in params:
for req in params["required"]:
assert req in params["properties"], f"Required param {req} not in properties for {tool.name}"
assert (
req in params["properties"]
), f"Required param {req} not in properties for {tool.name}"
def test_make_tools_descriptions_not_empty(self, memory):
"""Should have non-empty descriptions."""
@@ -233,7 +236,9 @@ class TestMakeToolsEdgeCases:
if "properties" in tool.parameters:
for prop_name, prop_schema in tool.parameters["properties"].items():
if "type" in prop_schema:
assert prop_schema["type"] in valid_types, f"Invalid type for {tool.name}.{prop_name}"
assert (
prop_schema["type"] in valid_types
), f"Invalid type for {tool.name}.{prop_name}"
def test_make_tools_enum_values(self, memory):
"""Should have valid enum values."""

View File

@@ -1,19 +1,18 @@
"""Tests for JSON repositories."""
import pytest
from datetime import datetime
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonTVShowRepository,
JsonSubtitleRepository,
)
from domain.movies.entities import Movie
from domain.movies.value_objects import MovieTitle, ReleaseYear, Quality
from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear
from domain.shared.value_objects import FilePath, FileSize, ImdbId
from domain.subtitles.entities import Subtitle
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
from domain.shared.value_objects import ImdbId, FilePath, FileSize
from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonSubtitleRepository,
JsonTVShowRepository,
)
class TestJsonMovieRepository:
@@ -224,7 +223,9 @@ class TestJsonTVShowRepository:
"""Should preserve show status."""
repo = JsonTVShowRepository()
for i, status in enumerate([ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]):
for i, status in enumerate(
[ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]
):
show = TVShow(
imdb_id=ImdbId(f"tt{i+1000000:07d}"),
title=f"Show {status.value}",
@@ -294,18 +295,22 @@ class TestJsonSubtitleRepository:
def test_find_by_media_with_language_filter(self, memory):
"""Should filter by language."""
repo = JsonSubtitleRepository()
repo.save(Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/en.srt"),
))
repo.save(Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.FRENCH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/fr.srt"),
))
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/en.srt"),
)
)
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.FRENCH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/fr.srt"),
)
)
results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH)
@@ -315,22 +320,26 @@ class TestJsonSubtitleRepository:
def test_find_by_media_with_episode_filter(self, memory):
"""Should filter by season/episode."""
repo = JsonSubtitleRepository()
repo.save(Subtitle(
media_imdb_id=ImdbId("tt0944947"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e01.srt"),
season_number=1,
episode_number=1,
))
repo.save(Subtitle(
media_imdb_id=ImdbId("tt0944947"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e02.srt"),
season_number=1,
episode_number=2,
))
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt0944947"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e01.srt"),
season_number=1,
episode_number=1,
)
)
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt0944947"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e02.srt"),
season_number=1,
episode_number=2,
)
)
results = repo.find_by_media(
ImdbId("tt0944947"),

View File

@@ -21,24 +21,30 @@ def create_mock_response(status_code, json_data=None, text=None):
class TestFindMediaImdbId:
"""Tests for find_media_imdb_id tool."""
@patch('infrastructure.api.tmdb.client.requests.get')
@patch("infrastructure.api.tmdb.client.requests.get")
def test_success(self, mock_get, memory):
"""Should return movie info on success."""
# Mock HTTP responses
def mock_get_side_effect(url, **kwargs):
if "search" in url:
return create_mock_response(200, json_data={
"results": [{
"id": 27205,
"title": "Inception",
"release_date": "2010-07-16",
"overview": "A thief...",
"media_type": "movie"
}]
})
return create_mock_response(
200,
json_data={
"results": [
{
"id": 27205,
"title": "Inception",
"release_date": "2010-07-16",
"overview": "A thief...",
"media_type": "movie",
}
]
},
)
elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
mock_get.side_effect = mock_get_side_effect
result = api_tools.find_media_imdb_id("Inception")
@@ -46,26 +52,32 @@ class TestFindMediaImdbId:
assert result["status"] == "ok"
assert result["imdb_id"] == "tt1375666"
assert result["title"] == "Inception"
# Verify HTTP calls
assert mock_get.call_count == 2
@patch('infrastructure.api.tmdb.client.requests.get')
@patch("infrastructure.api.tmdb.client.requests.get")
def test_stores_in_stm(self, mock_get, memory):
"""Should store result in STM on success."""
def mock_get_side_effect(url, **kwargs):
if "search" in url:
return create_mock_response(200, json_data={
"results": [{
"id": 27205,
"title": "Inception",
"release_date": "2010-07-16",
"media_type": "movie"
}]
})
return create_mock_response(
200,
json_data={
"results": [
{
"id": 27205,
"title": "Inception",
"release_date": "2010-07-16",
"media_type": "movie",
}
]
},
)
elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
mock_get.side_effect = mock_get_side_effect
api_tools.find_media_imdb_id("Inception")
@@ -76,7 +88,7 @@ class TestFindMediaImdbId:
assert entity["title"] == "Inception"
assert mem.stm.current_topic == "searching_media"
@patch('infrastructure.api.tmdb.client.requests.get')
@patch("infrastructure.api.tmdb.client.requests.get")
def test_not_found(self, mock_get, memory):
"""Should return error when not found."""
mock_get.return_value = create_mock_response(200, json_data={"results": []})
@@ -86,7 +98,7 @@ class TestFindMediaImdbId:
assert result["status"] == "error"
assert result["error"] == "not_found"
@patch('infrastructure.api.tmdb.client.requests.get')
@patch("infrastructure.api.tmdb.client.requests.get")
def test_does_not_store_on_error(self, mock_get, memory):
"""Should not store in STM on error."""
mock_get.return_value = create_mock_response(200, json_data={"results": []})
@@ -100,49 +112,57 @@ class TestFindMediaImdbId:
class TestFindTorrent:
"""Tests for find_torrent tool."""
@patch('infrastructure.api.knaben.client.requests.post')
@patch("infrastructure.api.knaben.client.requests.post")
def test_success(self, mock_post, memory):
"""Should return torrents on success."""
mock_post.return_value = create_mock_response(200, json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=...",
"size": "2.5 GB"
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=...",
"size": "1.8 GB"
}
]
})
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=...",
"size": "2.5 GB",
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=...",
"size": "1.8 GB",
},
]
},
)
result = api_tools.find_torrent("Inception 1080p")
assert result["status"] == "ok"
assert len(result["torrents"]) == 2
# Verify HTTP payload
payload = mock_post.call_args[1]['json']
assert payload['query'] == "Inception 1080p"
@patch('infrastructure.api.knaben.client.requests.post')
# Verify HTTP payload
payload = mock_post.call_args[1]["json"]
assert payload["query"] == "Inception 1080p"
@patch("infrastructure.api.knaben.client.requests.post")
def test_stores_in_episodic(self, mock_post, memory):
"""Should store results in episodic memory."""
mock_post.return_value = create_mock_response(200, json_data={
"hits": [{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=...",
"size": "2.5 GB"
}]
})
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=...",
"size": "2.5 GB",
}
]
},
)
api_tools.find_torrent("Inception")
@@ -151,16 +171,37 @@ class TestFindTorrent:
assert mem.episodic.last_search_results["query"] == "Inception"
assert mem.stm.current_topic == "selecting_torrent"
@patch('infrastructure.api.knaben.client.requests.post')
@patch("infrastructure.api.knaben.client.requests.post")
def test_results_have_indexes(self, mock_post, memory):
"""Should add indexes to results."""
mock_post.return_value = create_mock_response(200, json_data={
"hits": [
{"title": "Torrent 1", "seeders": 100, "leechers": 10, "magnetUrl": "magnet:?xt=1", "size": "1GB"},
{"title": "Torrent 2", "seeders": 50, "leechers": 5, "magnetUrl": "magnet:?xt=2", "size": "2GB"},
{"title": "Torrent 3", "seeders": 25, "leechers": 2, "magnetUrl": "magnet:?xt=3", "size": "3GB"}
]
})
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=1",
"size": "1GB",
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=2",
"size": "2GB",
},
{
"title": "Torrent 3",
"seeders": 25,
"leechers": 2,
"magnetUrl": "magnet:?xt=3",
"size": "3GB",
},
]
},
)
api_tools.find_torrent("Test")
@@ -170,7 +211,7 @@ class TestFindTorrent:
assert results[1]["index"] == 2
assert results[2]["index"] == 3
@patch('infrastructure.api.knaben.client.requests.post')
@patch("infrastructure.api.knaben.client.requests.post")
def test_not_found(self, mock_post, memory):
"""Should return error when no torrents found."""
mock_post.return_value = create_mock_response(200, json_data={"hits": []})
@@ -236,16 +277,16 @@ class TestGetTorrentByIndex:
class TestAddTorrentToQbittorrent:
"""Tests for add_torrent_to_qbittorrent tool.
Note: These tests mock the qBittorrent client because:
1. The client requires authentication/session management
2. We want to test the tool's logic (memory updates, workflow management)
3. The client itself is tested separately in infrastructure tests
This is acceptable mocking because we're testing the TOOL logic, not the client.
"""
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory):
"""Should add torrent successfully and update memory."""
mock_client.add_torrent.return_value = True
@@ -257,7 +298,7 @@ class TestAddTorrentToQbittorrent:
# Verify client was called correctly
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_adds_to_active_downloads(self, mock_client, memory_with_search_results):
"""Should add to active downloads on success."""
mock_client.add_torrent.return_value = True
@@ -267,9 +308,12 @@ class TestAddTorrentToQbittorrent:
# Test memory update logic
mem = get_memory()
assert len(mem.episodic.active_downloads) == 1
assert mem.episodic.active_downloads[0]["name"] == "Inception.2010.1080p.BluRay.x264"
assert (
mem.episodic.active_downloads[0]["name"]
== "Inception.2010.1080p.BluRay.x264"
)
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_sets_topic_and_ends_workflow(self, mock_client, memory):
"""Should set topic and end workflow."""
mock_client.add_torrent.return_value = True
@@ -282,10 +326,11 @@ class TestAddTorrentToQbittorrent:
assert mem.stm.current_topic == "downloading"
assert mem.stm.current_workflow is None
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_error_handling(self, mock_client, memory):
"""Should handle client errors correctly."""
from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError
mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed")
result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...")
@@ -296,7 +341,7 @@ class TestAddTorrentToQbittorrent:
class TestAddTorrentByIndex:
"""Tests for add_torrent_by_index tool.
These tests verify the tool's logic:
- Getting torrent from memory by index
- Extracting magnet link
@@ -304,7 +349,7 @@ class TestAddTorrentByIndex:
- Error handling for edge cases
"""
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory_with_search_results):
"""Should get torrent by index and add it."""
mock_client.add_torrent.return_value = True
@@ -317,7 +362,7 @@ class TestAddTorrentByIndex:
# Verify correct magnet was extracted and used
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_uses_correct_magnet(self, mock_client, memory_with_search_results):
"""Should extract correct magnet from index."""
mock_client.add_torrent.return_value = True

View File

@@ -1,7 +1,8 @@
"""Edge case tests for tools."""
from unittest.mock import Mock, patch
import pytest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
from agent.tools import api as api_tools
from agent.tools import filesystem as fs_tools
@@ -15,7 +16,10 @@ class TestFindTorrentEdgeCases:
def test_empty_query(self, mock_use_case_class, memory):
"""Should handle empty query."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "error", "error": "invalid_query"}
mock_response.to_dict.return_value = {
"status": "error",
"error": "invalid_query",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -28,7 +32,11 @@ class TestFindTorrentEdgeCases:
def test_very_long_query(self, mock_use_case_class, memory):
"""Should handle very long query."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -43,7 +51,11 @@ class TestFindTorrentEdgeCases:
def test_special_characters_in_query(self, mock_use_case_class, memory):
"""Should handle special characters in query."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -57,7 +69,11 @@ class TestFindTorrentEdgeCases:
def test_unicode_query(self, mock_use_case_class, memory):
"""Should handle unicode in query."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -161,7 +177,10 @@ class TestAddTorrentEdgeCases:
def test_empty_magnet_link(self, mock_use_case_class, memory):
"""Should handle empty magnet link."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "error", "error": "empty_magnet"}
mock_response.to_dict.return_value = {
"status": "error",
"error": "empty_magnet",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -326,7 +345,10 @@ class TestFilesystemEdgeCases:
for attempt in attempts:
result = fs_tools.list_folder("download", attempt)
# Should either be forbidden or not found
assert result.get("error") in ["forbidden", "not_found", None] or result.get("status") == "ok"
assert (
result.get("error") in ["forbidden", "not_found", None]
or result.get("status") == "ok"
)
def test_path_with_null_byte(self, memory, real_folder):
"""Should block null byte injection."""