Formatting

This commit is contained in:
2025-12-07 03:33:51 +01:00
parent a923a760ef
commit 4eae1d6d58
24 changed files with 1003 additions and 833 deletions

View File

@@ -1,7 +1,8 @@
"""Main agent for media library management."""
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any
from infrastructure.persistence import get_memory
@@ -28,7 +29,7 @@ class Agent:
max_tool_iterations: Maximum number of tool execution iterations
"""
self.llm = llm
self.tools: Dict[str, Tool] = make_tools()
self.tools: dict[str, Tool] = make_tools()
self.prompt_builder = PromptBuilder(self.tools)
self.max_tool_iterations = max_tool_iterations
@@ -56,9 +57,7 @@ class Agent:
# Build initial messages
system_prompt = self.prompt_builder.build_system_prompt()
messages: List[Dict[str, Any]] = [
{"role": "system", "content": system_prompt}
]
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
# Add conversation history
history = memory.stm.get_recent_history(settings.max_history_messages)
@@ -67,14 +66,12 @@ class Agent:
# Add unread events if any
unread_events = memory.episodic.get_unread_events()
if unread_events:
events_text = "\n".join([
f"- {e['type']}: {e['data']}"
for e in unread_events
])
messages.append({
"role": "system",
"content": f"Background events:\n{events_text}"
})
events_text = "\n".join(
[f"- {e['type']}: {e['data']}" for e in unread_events]
)
messages.append(
{"role": "system", "content": f"Background events:\n{events_text}"}
)
# Get tools specification for OpenAI format
tools_spec = self.prompt_builder.build_tools_spec()
@@ -108,18 +105,22 @@ class Agent:
tool_result = self._execute_tool_call(tool_call)
# Add tool result to messages
messages.append({
messages.append(
{
"tool_call_id": tool_call.get("id"),
"role": "tool",
"name": tool_call.get("function", {}).get("name"),
"content": json.dumps(tool_result, ensure_ascii=False),
})
}
)
# Max iterations reached, force final response
messages.append({
messages.append(
{
"role": "system",
"content": "Please provide a final response to the user without using any more tools."
})
"content": "Please provide a final response to the user without using any more tools.",
}
)
llm_result = self.llm.complete(messages)
if isinstance(llm_result, tuple):
@@ -127,12 +128,14 @@ class Agent:
else:
final_message = llm_result
final_response = final_message.get("content", "I've completed the requested actions.")
final_response = final_message.get(
"content", "I've completed the requested actions."
)
memory.stm.add_message("assistant", final_response)
memory.save()
return final_response
def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
"""
Execute a single tool call.
@@ -150,10 +153,7 @@ class Agent:
args = json.loads(args_str)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {e}")
return {
"error": "bad_args",
"message": f"Invalid JSON arguments: {e}"
}
return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"}
# Validate tool exists
if tool_name not in self.tools:
@@ -161,7 +161,7 @@ class Agent:
return {
"error": "unknown_tool",
"message": f"Tool '{tool_name}' not found",
"available_tools": available
"available_tools": available,
}
tool = self.tools[tool_name]
@@ -177,17 +177,9 @@ class Agent:
# Bad arguments
memory = get_memory()
memory.episodic.add_error(tool_name, f"bad_args: {e}")
return {
"error": "bad_args",
"message": str(e),
"tool": tool_name
}
return {"error": "bad_args", "message": str(e), "tool": tool_name}
except Exception as e:
# Other errors
memory = get_memory()
memory.episodic.add_error(tool_name, str(e))
return {
"error": "execution_failed",
"message": str(e),
"tool": tool_name
}
return {"error": "execution_failed", "message": str(e), "tool": tool_name}

View File

@@ -51,7 +51,9 @@ class DeepSeekClient:
logger.info(f"DeepSeek client initialized with model: {self.model}")
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]:
def complete(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""
Generate a completion from the LLM.
@@ -80,7 +82,9 @@ class DeepSeekClient:
raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg:
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}")
raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/v1/chat/completions"
headers = {
@@ -98,7 +102,9 @@ class DeepSeekClient:
payload["tools"] = tools
try:
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools")
logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(
url, headers=headers, json=payload, timeout=self.timeout
)

View File

@@ -66,7 +66,9 @@ class OllamaClient:
logger.info(f"Ollama client initialized with model: {self.model}")
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]:
def complete(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""
Generate a completion from the LLM.
@@ -95,7 +97,9 @@ class OllamaClient:
raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg:
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}")
raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/api/chat"
payload = {
@@ -112,7 +116,9 @@ class OllamaClient:
payload["tools"] = tools
try:
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools")
logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(url, json=payload, timeout=self.timeout)
response.raise_for_status()
data = response.json()

View File

@@ -1,18 +1,20 @@
"""Prompt builder for the agent system."""
from typing import Dict, List, Any
import json
from typing import Any
from infrastructure.persistence import get_memory
from .registry import Tool
from infrastructure.persistence import get_memory
class PromptBuilder:
"""Builds system prompts for the agent with memory context."""
def __init__(self, tools: Dict[str, Tool]):
def __init__(self, tools: dict[str, Tool]):
self.tools = tools
def build_tools_spec(self) -> List[Dict[str, Any]]:
def build_tools_spec(self) -> list[dict[str, Any]]:
"""Build the tool specification for the LLM API."""
tool_specs = []
for tool in self.tools.values():
@@ -44,11 +46,13 @@ class PromptBuilder:
if memory.episodic.last_search_results:
results = memory.episodic.last_search_results
result_list = results.get('results', [])
lines.append(f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)")
result_list = results.get("results", [])
lines.append(
f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)"
)
# Show first 5 results
for i, result in enumerate(result_list[:5]):
name = result.get('name', 'Unknown')
name = result.get("name", "Unknown")
lines.append(f" {i+1}. {name}")
if len(result_list) > 5:
lines.append(f" ... and {len(result_list) - 5} more")
@@ -57,7 +61,7 @@ class PromptBuilder:
question = memory.episodic.pending_question
lines.append(f"\nPENDING QUESTION: {question.get('question')}")
lines.append(f" Type: {question.get('type')}")
if question.get('options'):
if question.get("options"):
lines.append(f" Options: {len(question.get('options'))}")
if memory.episodic.active_downloads:
@@ -68,10 +72,12 @@ class PromptBuilder:
if memory.episodic.recent_errors:
lines.append("\nRECENT ERRORS (up to 3):")
for error in memory.episodic.recent_errors[-3:]:
lines.append(f" - Action '{error.get('action')}' failed: {error.get('error')}")
lines.append(
f" - Action '{error.get('action')}' failed: {error.get('error')}"
)
# Unread events
unread = [e for e in memory.episodic.background_events if not e.get('read')]
unread = [e for e in memory.episodic.background_events if not e.get("read")]
if unread:
lines.append(f"\nUNREAD EVENTS: {len(unread)}")
for event in unread[:3]:
@@ -86,8 +92,10 @@ class PromptBuilder:
if memory.stm.current_workflow:
workflow = memory.stm.current_workflow
lines.append(f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})")
if workflow.get('target'):
lines.append(
f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})"
)
if workflow.get("target"):
lines.append(f" Target: {workflow.get('target')}")
if memory.stm.current_topic:

View File

@@ -1,8 +1,10 @@
"""Tool registry - defines and registers all available tools for the agent."""
from dataclasses import dataclass
from typing import Callable, Any, Dict
import logging
import inspect
import logging
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__)
@@ -10,10 +12,11 @@ logger = logging.getLogger(__name__)
@dataclass
class Tool:
"""Represents a tool that can be used by the agent."""
name: str
description: str
func: Callable[..., Dict[str, Any]]
parameters: Dict[str, Any]
func: Callable[..., dict[str, Any]]
parameters: dict[str, Any]
def _create_tool_from_function(func: Callable) -> Tool:
@@ -30,7 +33,7 @@ def _create_tool_from_function(func: Callable) -> Tool:
doc = inspect.getdoc(func)
# Extract description from docstring (first line)
description = doc.strip().split('\n')[0] if doc else func.__name__
description = doc.strip().split("\n")[0] if doc else func.__name__
# Build JSON schema from function signature
properties = {}
@@ -54,7 +57,7 @@ def _create_tool_from_function(func: Callable) -> Tool:
properties[param_name] = {
"type": param_type,
"description": f"Parameter {param_name}"
"description": f"Parameter {param_name}",
}
# Add to required if no default value
@@ -75,7 +78,7 @@ def _create_tool_from_function(func: Callable) -> Tool:
)
def make_tools() -> Dict[str, Tool]:
def make_tools() -> dict[str, Tool]:
"""
Create and register all available tools.
@@ -83,8 +86,8 @@ def make_tools() -> Dict[str, Tool]:
Dictionary mapping tool names to Tool objects
"""
# Import tools here to avoid circular dependencies
from .tools import filesystem as fs_tools
from .tools import api as api_tools
from .tools import filesystem as fs_tools
from .tools import language as lang_tools
# List of all tool functions

View File

@@ -1,13 +1,14 @@
"""Language management tools for the agent."""
import logging
from typing import Dict, Any
from typing import Any
from infrastructure.persistence import get_memory
logger = logging.getLogger(__name__)
def set_language(language: str) -> Dict[str, Any]:
def set_language(language: str) -> dict[str, Any]:
"""
Set the conversation language.
@@ -27,11 +28,8 @@ def set_language(language: str) -> Dict[str, Any]:
return {
"status": "ok",
"message": f"Language set to {language}",
"language": language
"language": language,
}
except Exception as e:
logger.error(f"Failed to set language: {e}")
return {
"status": "error",
"error": str(e)
}
return {"status": "error", "error": str(e)}

View File

@@ -359,9 +359,7 @@ class EpisodicMemory:
"""Get active downloads."""
return self.active_downloads
def add_error(
self, action: str, error: str, context: dict | None = None
) -> None:
def add_error(self, action: str, error: str, context: dict | None = None) -> None:
"""Record a recent error."""
self.recent_errors.append(
{
@@ -408,9 +406,7 @@ class EpisodicMemory:
"""Get the pending question."""
return self.pending_question
def resolve_pending_question(
self, answer_index: int | None = None
) -> dict | None:
def resolve_pending_question(self, answer_index: int | None = None) -> dict | None:
"""
Resolve the pending question and return the chosen option.

View File

@@ -110,4 +110,4 @@ select = [
"PL",
"UP",
]
ignore = ["W503", "PLR0913", "PLR2004"]
ignore = ["PLR0913", "PLR2004"]

View File

@@ -1,16 +1,13 @@
"""Pytest configuration and shared fixtures."""
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, MagicMock
from infrastructure.persistence import Memory, init_memory, set_memory, get_memory
from infrastructure.persistence.memory import (
LongTermMemory,
ShortTermMemory,
EpisodicMemory,
)
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, Mock
import pytest
from infrastructure.persistence import Memory, set_memory
@pytest.fixture
@@ -122,12 +119,11 @@ def memory_with_library(memory):
def mock_llm():
"""Create a mock LLM client that returns OpenAI-compatible format."""
llm = Mock()
# Return OpenAI-style message dict without tool calls
def complete_func(messages, tools=None):
return {
"role": "assistant",
"content": "I found what you're looking for!"
}
return {"role": "assistant", "content": "I found what you're looking for!"}
llm.complete = Mock(side_effect=complete_func)
return llm
@@ -139,7 +135,7 @@ def mock_llm_with_tool_call():
# First call returns a tool call, second returns final response
def complete_side_effect(messages, tools=None):
if not hasattr(complete_side_effect, 'call_count'):
if not hasattr(complete_side_effect, "call_count"):
complete_side_effect.call_count = 0
complete_side_effect.call_count += 1
@@ -148,21 +144,20 @@ def mock_llm_with_tool_call():
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "find_torrent",
"arguments": '{"media_title": "Inception"}'
"arguments": '{"media_title": "Inception"}',
},
}
}]
],
}
else:
# Second call: return final response
return {
"role": "assistant",
"content": "I found 3 torrents for Inception!"
}
return {"role": "assistant", "content": "I found 3 torrents for Inception!"}
llm.complete = Mock(side_effect=complete_side_effect)
return llm
@@ -254,10 +249,10 @@ def mock_deepseek():
# Your test code here
"""
import sys
from unittest.mock import Mock, MagicMock
from unittest.mock import Mock
# Save the original module if it exists
original_module = sys.modules.get('agent.llm.deepseek')
original_module = sys.modules.get("agent.llm.deepseek")
# Create a mock module for deepseek
mock_deepseek_module = MagicMock()
@@ -269,15 +264,15 @@ def mock_deepseek():
mock_deepseek_module.DeepSeekClient = MockDeepSeekClient
# Inject the mock
sys.modules['agent.llm.deepseek'] = mock_deepseek_module
sys.modules["agent.llm.deepseek"] = mock_deepseek_module
yield mock_deepseek_module
# Restore the original module
if original_module is not None:
sys.modules['agent.llm.deepseek'] = original_module
elif 'agent.llm.deepseek' in sys.modules:
del sys.modules['agent.llm.deepseek']
sys.modules["agent.llm.deepseek"] = original_module
elif "agent.llm.deepseek" in sys.modules:
del sys.modules["agent.llm.deepseek"]
@pytest.fixture

View File

@@ -1,6 +1,6 @@
"""Tests for the Agent."""
from unittest.mock import Mock, patch
from unittest.mock import Mock
from agent.agent import Agent
from infrastructure.persistence import get_memory
@@ -55,8 +55,8 @@ class TestExecuteToolCall:
"id": "call_123",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
}
"arguments": '{"folder_type": "download"}',
},
}
result = agent._execute_tool_call(tool_call)
@@ -68,10 +68,7 @@ class TestExecuteToolCall:
tool_call = {
"id": "call_123",
"function": {
"name": "unknown_tool",
"arguments": '{}'
}
"function": {"name": "unknown_tool", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
@@ -84,10 +81,7 @@ class TestExecuteToolCall:
tool_call = {
"id": "call_123",
"function": {
"name": "set_path_for_folder",
"arguments": '{}'
}
"function": {"name": "set_path_for_folder", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
@@ -102,8 +96,8 @@ class TestExecuteToolCall:
"id": "call_123",
"function": {
"name": "set_path_for_folder",
"arguments": '{"folder_name": 123}' # Wrong type
}
"arguments": '{"folder_name": 123}', # Wrong type
},
}
result = agent._execute_tool_call(tool_call)
@@ -116,10 +110,7 @@ class TestExecuteToolCall:
tool_call = {
"id": "call_123",
"function": {
"name": "list_folder",
"arguments": '{invalid json}'
}
"function": {"name": "list_folder", "arguments": "{invalid json}"},
}
result = agent._execute_tool_call(tool_call)
@@ -163,8 +154,8 @@ class TestStep:
# CRITICAL: Verify tools were passed to LLM
first_call_args = mock_llm_with_tool_call.complete.call_args_list[0]
assert first_call_args[1]['tools'] is not None, "Tools not passed to LLM!"
assert len(first_call_args[1]['tools']) > 0, "Tools list is empty!"
assert first_call_args[1]["tools"] is not None, "Tools not passed to LLM!"
assert len(first_call_args[1]["tools"]) > 0, "Tools list is empty!"
def test_step_max_iterations(self, memory, mock_llm):
"""Should stop after max iterations."""
@@ -180,19 +171,18 @@ class TestStep:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"arguments": '{"folder_type": "download"}',
},
}
}]
],
}
else:
return {
"role": "assistant",
"content": "I couldn't complete the task."
}
return {"role": "assistant", "content": "I couldn't complete the task."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)
@@ -251,34 +241,38 @@ class TestAgentIntegration:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"arguments": '{"folder_type": "download"}',
},
}
}]
],
}
elif call_count[0] == 2:
# CRITICAL: Verify tool result was sent back
tool_messages = [m for m in messages if m.get('role') == 'tool']
tool_messages = [m for m in messages if m.get("role") == "tool"]
assert len(tool_messages) > 0, "Tool result not sent back to LLM!"
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": "call_2",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "movie"}'
"arguments": '{"folder_type": "movie"}',
},
}
}]
],
}
else:
return {
"role": "assistant",
"content": "I listed both folders for you."
"content": "I listed both folders for you.",
}
mock_llm.complete = Mock(side_effect=mock_complete)

View File

@@ -1,7 +1,9 @@
"""Edge case tests for the Agent."""
import pytest
from unittest.mock import Mock
import pytest
from agent.agent import Agent
from infrastructure.persistence import get_memory
@@ -15,19 +17,14 @@ class TestExecuteToolCallEdgeCases:
# Mock a tool that returns None
from agent.registry import Tool
agent.tools["test_tool"] = Tool(
name="test_tool",
description="Test",
func=lambda: None,
parameters={}
name="test_tool", description="Test", func=lambda: None, parameters={}
)
tool_call = {
"id": "call_123",
"function": {
"name": "test_tool",
"arguments": '{}'
}
"function": {"name": "test_tool", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
@@ -38,22 +35,17 @@ class TestExecuteToolCallEdgeCases:
agent = Agent(llm=mock_llm)
from agent.registry import Tool
def raise_interrupt():
raise KeyboardInterrupt()
agent.tools["test_tool"] = Tool(
name="test_tool",
description="Test",
func=raise_interrupt,
parameters={}
name="test_tool", description="Test", func=raise_interrupt, parameters={}
)
tool_call = {
"id": "call_123",
"function": {
"name": "test_tool",
"arguments": '{}'
}
"function": {"name": "test_tool", "arguments": "{}"},
}
with pytest.raises(KeyboardInterrupt):
@@ -68,8 +60,8 @@ class TestExecuteToolCallEdgeCases:
"id": "call_123",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}'
}
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}',
},
}
result = agent._execute_tool_call(tool_call)
@@ -84,8 +76,8 @@ class TestExecuteToolCallEdgeCases:
"id": "call_123",
"function": {
"name": "get_torrent_by_index",
"arguments": '{"index": "not an int"}'
}
"arguments": '{"index": "not an int"}',
},
}
result = agent._execute_tool_call(tool_call)
@@ -115,11 +107,9 @@ class TestStepEdgeCases:
def test_step_with_unicode_input(self, memory, mock_llm):
"""Should handle unicode input."""
def mock_complete(messages, tools=None):
return {
"role": "assistant",
"content": "日本語の応答"
}
return {"role": "assistant", "content": "日本語の応答"}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
@@ -130,11 +120,9 @@ class TestStepEdgeCases:
def test_step_llm_returns_empty(self, memory, mock_llm):
"""Should handle LLM returning empty string."""
def mock_complete(messages, tools=None):
return {
"role": "assistant",
"content": ""
}
return {"role": "assistant", "content": ""}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
@@ -161,18 +149,17 @@ class TestStepEdgeCases:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"arguments": '{"folder_type": "download"}',
},
}
}]
}
return {
"role": "assistant",
"content": "Done looping"
],
}
return {"role": "assistant", "content": "Done looping"}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)
@@ -212,11 +199,13 @@ class TestStepEdgeCases:
def test_step_with_active_downloads(self, memory, mock_llm):
"""Should include active downloads in context."""
memory.episodic.add_active_download({
memory.episodic.add_active_download(
{
"task_id": "123",
"name": "Movie.mkv",
"progress": 50,
})
}
)
agent = Agent(llm=mock_llm)
response = agent.step("Hello")
@@ -264,18 +253,17 @@ class TestAgentConcurrencyEdgeCases:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "set_path_for_folder",
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}'
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}',
},
}
}]
}
return {
"role": "assistant",
"content": "Path set successfully."
],
}
return {"role": "assistant", "content": "Path set successfully."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
@@ -299,18 +287,17 @@ class TestAgentErrorRecovery:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"arguments": '{"folder_type": "download"}',
},
}
}]
}
return {
"role": "assistant",
"content": "The folder is not configured."
],
}
return {"role": "assistant", "content": "The folder is not configured."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
@@ -329,18 +316,17 @@ class TestAgentErrorRecovery:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "set_path_for_folder",
"arguments": '{}' # Missing required args
"arguments": "{}", # Missing required args
},
}
}]
}
return {
"role": "assistant",
"content": "Error occurred."
],
}
return {"role": "assistant", "content": "Error occurred."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
@@ -360,18 +346,17 @@ class TestAgentErrorRecovery:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "set_path_for_folder",
"arguments": '{}' # Missing required args - will error
"arguments": "{}", # Missing required args - will error
},
}
}]
}
return {
"role": "assistant",
"content": "All attempts failed."
],
}
return {"role": "assistant", "content": "All attempts failed."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)

View File

@@ -1,6 +1,7 @@
"""Tests for FastAPI endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from unittest.mock import patch
from fastapi.testclient import TestClient
@@ -10,6 +11,7 @@ class TestHealthEndpoint:
def test_health_check(self, memory):
"""Should return healthy status."""
from app import app
client = TestClient(app)
response = client.get("/health")
@@ -24,6 +26,7 @@ class TestModelsEndpoint:
def test_list_models(self, memory):
"""Should return model list."""
from app import app
client = TestClient(app)
response = client.get("/v1/models")
@@ -41,6 +44,7 @@ class TestMemoryEndpoints:
def test_get_memory_state(self, memory):
"""Should return full memory state."""
from app import app
client = TestClient(app)
response = client.get("/memory/state")
@@ -54,6 +58,7 @@ class TestMemoryEndpoints:
def test_get_search_results_empty(self, memory):
"""Should return empty when no search results."""
from app import app
client = TestClient(app)
response = client.get("/memory/episodic/search-results")
@@ -65,6 +70,7 @@ class TestMemoryEndpoints:
def test_get_search_results_with_data(self, memory_with_search_results):
"""Should return search results when available."""
from app import app
client = TestClient(app)
response = client.get("/memory/episodic/search-results")
@@ -78,6 +84,7 @@ class TestMemoryEndpoints:
def test_clear_session(self, memory_with_search_results):
"""Should clear session memories."""
from app import app
client = TestClient(app)
response = client.post("/memory/clear-session")
@@ -96,14 +103,18 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_success(self, memory):
"""Should return chat completion."""
from app import app
# Patch the agent's step method directly
with patch("app.agent.step", return_value="Hello! How can I help?"):
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
})
},
)
assert response.status_code == 200
data = response.json()
@@ -113,12 +124,16 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_no_user_message(self, memory):
"""Should return error if no user message."""
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "system", "content": "You are helpful"}],
})
},
)
assert response.status_code == 422
detail = response.json()["detail"]
@@ -132,18 +147,23 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_empty_messages(self, memory):
"""Should return error for empty messages."""
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [],
})
},
)
assert response.status_code == 422
def test_chat_completion_invalid_json(self, memory):
"""Should return error for invalid JSON."""
from app import app
client = TestClient(app)
response = client.post(
@@ -157,14 +177,18 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_streaming(self, memory):
"""Should support streaming mode."""
from app import app
with patch("app.agent.step", return_value="Streaming response"):
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
})
},
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"]
@@ -172,17 +196,21 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_extracts_last_user_message(self, memory):
"""Should use last user message."""
from app import app
with patch("app.agent.step", return_value="Response") as mock_step:
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "user", "content": "First message"},
{"role": "assistant", "content": "Response"},
{"role": "user", "content": "Second message"},
],
})
},
)
assert response.status_code == 200
# Verify the agent received the last user message
@@ -191,13 +219,17 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_response_format(self, memory):
"""Should return OpenAI-compatible format."""
from app import app
with patch("app.agent.step", return_value="Test response"):
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Test"}],
})
},
)
data = response.json()
assert "id" in data

View File

@@ -1,7 +1,7 @@
"""Edge case tests for FastAPI endpoints."""
import pytest
import json
from unittest.mock import Mock, patch, MagicMock
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
@@ -10,43 +10,46 @@ class TestChatCompletionsEdgeCases:
def test_very_long_message(self, memory):
"""Should handle very long user message."""
from app import app, agent
from app import agent, app
# Patch the agent's LLM directly
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
long_message = "x" * 100000
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": long_message}],
})
},
)
assert response.status_code == 200
def test_unicode_message(self, memory):
"""Should handle unicode in message."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "日本語の応答"
"content": "日本語の応答",
}
agent.llm = mock_llm
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}],
})
},
)
assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"]
@@ -54,22 +57,22 @@ class TestChatCompletionsEdgeCases:
def test_special_characters_in_message(self, memory):
"""Should handle special characters."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
special_message = 'Test with "quotes" and \\backslash and \n newline'
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": special_message}],
})
},
)
assert response.status_code == 200
@@ -81,12 +84,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": ""}],
})
},
)
# Empty content should be rejected
assert response.status_code == 422
@@ -98,12 +105,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": None}],
})
},
)
assert response.status_code == 422
@@ -114,12 +125,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user"}], # No content
})
},
)
# May accept or reject depending on validation
assert response.status_code in [200, 400, 422]
@@ -131,12 +146,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"content": "Hello"}], # No role
})
},
)
# Should reject or accept depending on validation
assert response.status_code in [200, 400, 422]
@@ -149,25 +168,26 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "invalid_role", "content": "Hello"}],
})
},
)
# Should reject or ignore invalid role
assert response.status_code in [200, 400, 422]
def test_many_messages(self, memory):
"""Should handle many messages in conversation."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
@@ -178,10 +198,13 @@ class TestChatCompletionsEdgeCases:
messages.append({"role": "assistant", "content": f"Response {i}"})
messages.append({"role": "user", "content": "Final message"})
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": messages,
})
},
)
assert response.status_code == 200
@@ -192,15 +215,19 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "system", "content": "Be concise"},
],
})
},
)
assert response.status_code == 422
@@ -211,14 +238,18 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "assistant", "content": "Hello"},
],
})
},
)
assert response.status_code == 422
@@ -229,12 +260,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": "not an array",
})
},
)
assert response.status_code == 422
# Pydantic validation error
@@ -246,66 +281,70 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": ["not an object", 123, None],
})
},
)
assert response.status_code == 422
# Pydantic validation error
def test_extra_fields_in_request(self, memory):
"""Should ignore extra fields in request."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
"extra_field": "should be ignored",
"temperature": 0.7,
"max_tokens": 100,
})
},
)
assert response.status_code == 200
def test_streaming_with_tool_call(self, memory, real_folder):
"""Should handle streaming with tool execution."""
from app import app, agent
from app import agent, app
from infrastructure.persistence import get_memory
mem = get_memory()
mem.ltm.set_config("download_folder", str(real_folder["downloads"]))
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}'
"arguments": '{"folder_type": "download"}',
},
}
}]
}
return {
"role": "assistant",
"content": "Listed the folder."
],
}
return {"role": "assistant", "content": "Listed the folder."}
mock_llm = Mock()
mock_llm.complete = Mock(side_effect=mock_complete)
@@ -313,51 +352,57 @@ class TestChatCompletionsEdgeCases:
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "List downloads"}],
"stream": True,
})
},
)
assert response.status_code == 200
def test_concurrent_requests_simulation(self, memory):
"""Should handle rapid sequential requests."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "Response"
}
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
for i in range(10):
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": f"Request {i}"}],
})
},
)
assert response.status_code == 200
def test_llm_returns_json_in_response(self, memory):
"""Should handle LLM returning JSON in text response."""
from app import app, agent
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": '{"result": "some data", "count": 5}'
"content": '{"result": "some data", "count": 5}',
}
agent.llm = mock_llm
client = TestClient(app)
response = client.post("/v1/chat/completions", json={
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Give me JSON"}],
})
},
)
assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"]
@@ -425,6 +470,7 @@ class TestMemoryEndpointsEdgeCases:
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
# Clear multiple times
@@ -459,6 +505,7 @@ class TestHealthEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/health")
@@ -471,6 +518,7 @@ class TestHealthEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/health?extra=param&another=value")
@@ -486,6 +534,7 @@ class TestModelsEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/v1/models")

View File

@@ -1,9 +1,9 @@
"""Critical tests for configuration validation."""
import pytest
import os
from agent.config import Settings, ConfigurationError
import pytest
from agent.config import ConfigurationError, Settings
class TestConfigValidation:
@@ -86,8 +86,7 @@ class TestConfigChecks:
def test_is_deepseek_configured_with_key(self):
"""Verify is_deepseek_configured returns True with API key."""
settings = Settings(
deepseek_api_key="test-key",
deepseek_base_url="https://api.test.com"
deepseek_api_key="test-key", deepseek_base_url="https://api.test.com"
)
assert settings.is_deepseek_configured() is True
@@ -95,8 +94,7 @@ class TestConfigChecks:
def test_is_deepseek_configured_without_key(self):
"""Verify is_deepseek_configured returns False without API key."""
settings = Settings(
deepseek_api_key="",
deepseek_base_url="https://api.test.com"
deepseek_api_key="", deepseek_base_url="https://api.test.com"
)
assert settings.is_deepseek_configured() is False
@@ -110,18 +108,14 @@ class TestConfigChecks:
def test_is_tmdb_configured_with_key(self):
"""Verify is_tmdb_configured returns True with API key."""
settings = Settings(
tmdb_api_key="test-key",
tmdb_base_url="https://api.test.com"
tmdb_api_key="test-key", tmdb_base_url="https://api.test.com"
)
assert settings.is_tmdb_configured() is True
def test_is_tmdb_configured_without_key(self):
"""Verify is_tmdb_configured returns False without API key."""
settings = Settings(
tmdb_api_key="",
tmdb_base_url="https://api.test.com"
)
settings = Settings(tmdb_api_key="", tmdb_base_url="https://api.test.com")
assert settings.is_tmdb_configured() is False

View File

@@ -1,12 +1,14 @@
"""Edge case tests for configuration and parameters."""
import pytest
import os
from unittest.mock import patch
from agent.config import Settings, ConfigurationError
import pytest
from agent.config import ConfigurationError, Settings
from agent.parameters import (
ParameterSchema,
REQUIRED_PARAMETERS,
ParameterSchema,
format_parameters_for_prompt,
get_missing_required_parameters,
)
@@ -110,19 +112,27 @@ class TestSettingsEdgeCases:
def test_http_url_accepted(self):
"""Should accept http:// URLs."""
with patch.dict(os.environ, {
with patch.dict(
os.environ,
{
"DEEPSEEK_BASE_URL": "http://localhost:8080",
"TMDB_BASE_URL": "http://localhost:3000",
}, clear=True):
},
clear=True,
):
settings = Settings()
assert settings.deepseek_base_url == "http://localhost:8080"
def test_https_url_accepted(self):
"""Should accept https:// URLs."""
with patch.dict(os.environ, {
with patch.dict(
os.environ,
{
"DEEPSEEK_BASE_URL": "https://api.example.com",
"TMDB_BASE_URL": "https://api.example.com",
}, clear=True):
},
clear=True,
):
settings = Settings()
assert settings.deepseek_base_url == "https://api.example.com"

View File

@@ -1,18 +1,17 @@
"""Tests for the Memory system."""
import pytest
import json
from datetime import datetime
from pathlib import Path
import pytest
from infrastructure.persistence import (
Memory,
LongTermMemory,
ShortTermMemory,
EpisodicMemory,
init_memory,
LongTermMemory,
Memory,
ShortTermMemory,
get_memory,
set_memory,
has_memory,
init_memory,
)
from infrastructure.persistence.context import _memory_ctx
@@ -23,11 +22,12 @@ def is_iso_format(s: str) -> bool:
return False
try:
# Attempt to parse the string as an ISO 8601 timestamp
datetime.fromisoformat(s.replace('Z', '+00:00'))
datetime.fromisoformat(s.replace("Z", "+00:00"))
return True
except (ValueError, TypeError):
return False
class TestLongTermMemory:
"""Tests for LongTermMemory."""
@@ -116,12 +116,18 @@ class TestLongTermMemory:
assert data["config"]["key"] == "value"
def test_from_dict(self):
data = {"config": {"download_folder": "/downloads"}, "preferences": {"preferred_quality": "4K"}, "library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]}, "following": []}
data = {
"config": {"download_folder": "/downloads"},
"preferences": {"preferred_quality": "4K"},
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
"following": [],
}
ltm = LongTermMemory.from_dict(data)
assert ltm.get_config("download_folder") == "/downloads"
assert ltm.preferences["preferred_quality"] == "4K"
assert len(ltm.library["movies"]) == 1
class TestShortTermMemory:
"""Tests for ShortTermMemory."""
@@ -162,6 +168,7 @@ class TestShortTermMemory:
assert stm.conversation_history == []
assert stm.language == "en"
class TestEpisodicMemory:
"""Tests for EpisodicMemory."""
@@ -192,6 +199,7 @@ class TestEpisodicMemory:
assert result is not None
assert result["name"] == "Result 2"
class TestMemory:
"""Tests for the Memory manager."""
@@ -217,11 +225,10 @@ class TestMemory:
assert memory.stm.conversation_history == []
assert memory.episodic.recent_errors == []
class TestMemoryContext:
"""Tests for memory context functions."""
def test_get_memory_not_initialized(self):
_memory_ctx.set(None)
with pytest.raises(RuntimeError, match="Memory not initialized"):

View File

@@ -1,18 +1,17 @@
"""Edge case tests for the Memory system."""
import pytest
import json
import os
from pathlib import Path
from datetime import datetime
from unittest.mock import patch, mock_open
import pytest
from infrastructure.persistence import (
Memory,
LongTermMemory,
ShortTermMemory,
EpisodicMemory,
init_memory,
LongTermMemory,
Memory,
ShortTermMemory,
get_memory,
init_memory,
set_memory,
)
from infrastructure.persistence.context import _memory_ctx
@@ -529,7 +528,6 @@ class TestMemoryContextEdgeCases:
def test_context_isolation(self, temp_dir):
"""Context should be isolated per context."""
import asyncio
from contextvars import copy_context
_memory_ctx.set(None)

View File

@@ -1,10 +1,8 @@
"""Critical tests for prompt builder - Tests that would have caught bugs."""
import pytest
from agent.registry import make_tools
from agent.prompts import PromptBuilder
from infrastructure.persistence import get_memory
from agent.registry import make_tools
class TestPromptBuilderToolsInjection:
@@ -18,7 +16,9 @@ class TestPromptBuilderToolsInjection:
# Verify each tool is mentioned
for tool_name in tools.keys():
assert tool_name in prompt, f"Tool {tool_name} not mentioned in system prompt"
assert (
tool_name in prompt
), f"Tool {tool_name} not mentioned in system prompt"
def test_tools_spec_contains_all_registered_tools(self):
"""CRITICAL: Verify build_tools_spec() returns all tools."""
@@ -26,7 +26,7 @@ class TestPromptBuilderToolsInjection:
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
spec_names = {spec['function']['name'] for spec in specs}
spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys())
assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}"
@@ -46,12 +46,12 @@ class TestPromptBuilderToolsInjection:
specs = builder.build_tools_spec()
for spec in specs:
assert 'type' in spec
assert spec['type'] == 'function'
assert 'function' in spec
assert 'name' in spec['function']
assert 'description' in spec['function']
assert 'parameters' in spec['function']
assert "type" in spec
assert spec["type"] == "function"
assert "function" in spec
assert "name" in spec["function"]
assert "description" in spec["function"]
assert "parameters" in spec["function"]
class TestPromptBuilderMemoryContext:
@@ -92,11 +92,9 @@ class TestPromptBuilderMemoryContext:
tools = make_tools()
builder = PromptBuilder(tools)
memory.episodic.add_active_download({
"task_id": "123",
"name": "Test Movie",
"progress": 50
})
memory.episodic.add_active_download(
{"task_id": "123", "name": "Test Movie", "progress": 50}
)
prompt = builder.build_system_prompt()

View File

@@ -1,10 +1,8 @@
"""Edge case tests for PromptBuilder."""
import pytest
import json
from agent.prompts import PromptBuilder
from agent.registry import make_tools
from infrastructure.persistence import get_memory
class TestPromptBuilderEdgeCases:
@@ -93,11 +91,13 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_many_active_downloads(self, memory):
"""Should limit displayed active downloads."""
for i in range(20):
memory.episodic.add_active_download({
memory.episodic.add_active_download(
{
"task_id": str(i),
"name": f"Download {i}",
"progress": i * 5,
})
}
)
tools = make_tools()
builder = PromptBuilder(tools)
@@ -136,12 +136,15 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_complex_workflow(self, memory):
"""Should handle complex workflow state."""
memory.stm.start_workflow("download", {
memory.stm.start_workflow(
"download",
{
"title": "Test Movie",
"year": 2024,
"quality": "1080p",
"nested": {"deep": {"value": "test"}},
})
},
)
memory.stm.update_workflow_stage("searching_torrents")
tools = make_tools()
@@ -313,11 +316,14 @@ class TestFormatEpisodicContextEdgeCases:
def test_format_with_search_results_none_names(self, memory):
"""Should handle results with None names."""
memory.episodic.store_search_results("test", [
memory.episodic.store_search_results(
"test",
[
{"name": None},
{"title": None},
{},
])
],
)
tools = make_tools()
builder = PromptBuilder(tools)

View File

@@ -1,10 +1,11 @@
"""Critical tests for tool registry - Tests that would have caught bugs."""
import pytest
import inspect
from agent.registry import make_tools, _create_tool_from_function, Tool
import pytest
from agent.prompts import PromptBuilder
from agent.registry import Tool, _create_tool_from_function, make_tools
class TestToolSpecFormat:
@@ -22,22 +23,25 @@ class TestToolSpecFormat:
for spec in specs:
# OpenAI format requires these fields
assert spec['type'] == 'function', f"Tool type must be 'function', got {spec.get('type')}"
assert 'function' in spec, "Tool spec missing 'function' key"
assert (
spec["type"] == "function"
), f"Tool type must be 'function', got {spec.get('type')}"
assert "function" in spec, "Tool spec missing 'function' key"
func = spec['function']
assert 'name' in func, "Function missing 'name'"
assert 'description' in func, "Function missing 'description'"
assert 'parameters' in func, "Function missing 'parameters'"
func = spec["function"]
assert "name" in func, "Function missing 'name'"
assert "description" in func, "Function missing 'description'"
assert "parameters" in func, "Function missing 'parameters'"
params = func['parameters']
assert params['type'] == 'object', "Parameters type must be 'object'"
assert 'properties' in params, "Parameters missing 'properties'"
assert 'required' in params, "Parameters missing 'required'"
assert isinstance(params['required'], list), "Required must be a list"
params = func["parameters"]
assert params["type"] == "object", "Parameters type must be 'object'"
assert "properties" in params, "Parameters missing 'properties'"
assert "required" in params, "Parameters missing 'required'"
assert isinstance(params["required"], list), "Required must be a list"
def test_tool_parameters_match_function_signature(self):
"""CRITICAL: Verify generated parameters match function signature."""
def test_func(name: str, age: int, active: bool = True):
"""Test function with typed parameters."""
return {"status": "ok"}
@@ -45,14 +49,16 @@ class TestToolSpecFormat:
tool = _create_tool_from_function(test_func)
# Verify types are correctly mapped
assert tool.parameters['properties']['name']['type'] == 'string'
assert tool.parameters['properties']['age']['type'] == 'integer'
assert tool.parameters['properties']['active']['type'] == 'boolean'
assert tool.parameters["properties"]["name"]["type"] == "string"
assert tool.parameters["properties"]["age"]["type"] == "integer"
assert tool.parameters["properties"]["active"]["type"] == "boolean"
# Verify required vs optional
assert 'name' in tool.parameters['required'], "name should be required"
assert 'age' in tool.parameters['required'], "age should be required"
assert 'active' not in tool.parameters['required'], "active has default, should not be required"
assert "name" in tool.parameters["required"], "name should be required"
assert "age" in tool.parameters["required"], "age should be required"
assert (
"active" not in tool.parameters["required"]
), "active has default, should not be required"
def test_all_registered_tools_are_callable(self):
"""CRITICAL: Verify all registered tools are actually callable."""
@@ -76,7 +82,7 @@ class TestToolSpecFormat:
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
spec_names = {spec['function']['name'] for spec in specs}
spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys())
missing = tool_names - spec_names
@@ -88,6 +94,7 @@ class TestToolSpecFormat:
def test_tool_description_extracted_from_docstring(self):
"""Verify tool description is extracted from function docstring."""
def test_func(param: str):
"""This is the description.
@@ -102,6 +109,7 @@ class TestToolSpecFormat:
def test_tool_without_docstring_uses_function_name(self):
"""Verify tool without docstring uses function name as description."""
def test_func_no_doc(param: str):
return {}
@@ -116,23 +124,25 @@ class TestToolSpecFormat:
specs = builder.build_tools_spec()
for spec in specs:
params = spec['function']['parameters']
properties = params.get('properties', {})
params = spec["function"]["parameters"]
properties = params.get("properties", {})
for param_name, param_spec in properties.items():
assert 'description' in param_spec, \
f"Parameter {param_name} in {spec['function']['name']} missing description"
assert (
"description" in param_spec
), f"Parameter {param_name} in {spec['function']['name']} missing description"
def test_required_parameters_are_marked_correctly(self):
"""Verify required parameters are correctly identified."""
def func_with_optional(required: str, optional: int = 5):
return {}
tool = _create_tool_from_function(func_with_optional)
assert 'required' in tool.parameters['required']
assert 'optional' not in tool.parameters['required']
assert len(tool.parameters['required']) == 1
assert "required" in tool.parameters["required"]
assert "optional" not in tool.parameters["required"]
assert len(tool.parameters["required"]) == 1
class TestToolRegistry:
@@ -195,6 +205,7 @@ class TestToolDataclass:
def test_tool_creation(self):
"""Verify Tool can be created with all fields."""
def dummy_func():
return {}
@@ -202,7 +213,7 @@ class TestToolDataclass:
name="test_tool",
description="Test description",
func=dummy_func,
parameters={"type": "object", "properties": {}, "required": []}
parameters={"type": "object", "properties": {}, "required": []},
)
assert tool.name == "test_tool"
@@ -212,12 +223,13 @@ class TestToolDataclass:
def test_tool_parameters_structure(self):
"""Verify Tool parameters have correct structure."""
def dummy_func(arg: str):
return {}
tool = _create_tool_from_function(dummy_func)
assert 'type' in tool.parameters
assert 'properties' in tool.parameters
assert 'required' in tool.parameters
assert tool.parameters['type'] == 'object'
assert "type" in tool.parameters
assert "properties" in tool.parameters
assert "required" in tool.parameters
assert tool.parameters["type"] == "object"

View File

@@ -1,6 +1,7 @@
"""Edge case tests for tool registry."""
import pytest
from unittest.mock import Mock
from agent.registry import Tool, make_tools
@@ -182,7 +183,9 @@ class TestMakeToolsEdgeCases:
params = tool.parameters
if "required" in params and "properties" in params:
for req in params["required"]:
assert req in params["properties"], f"Required param {req} not in properties for {tool.name}"
assert (
req in params["properties"]
), f"Required param {req} not in properties for {tool.name}"
def test_make_tools_descriptions_not_empty(self, memory):
"""Should have non-empty descriptions."""
@@ -233,7 +236,9 @@ class TestMakeToolsEdgeCases:
if "properties" in tool.parameters:
for prop_name, prop_schema in tool.parameters["properties"].items():
if "type" in prop_schema:
assert prop_schema["type"] in valid_types, f"Invalid type for {tool.name}.{prop_name}"
assert (
prop_schema["type"] in valid_types
), f"Invalid type for {tool.name}.{prop_name}"
def test_make_tools_enum_values(self, memory):
"""Should have valid enum values."""

View File

@@ -1,19 +1,18 @@
"""Tests for JSON repositories."""
import pytest
from datetime import datetime
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonTVShowRepository,
JsonSubtitleRepository,
)
from domain.movies.entities import Movie
from domain.movies.value_objects import MovieTitle, ReleaseYear, Quality
from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear
from domain.shared.value_objects import FilePath, FileSize, ImdbId
from domain.subtitles.entities import Subtitle
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
from domain.shared.value_objects import ImdbId, FilePath, FileSize
from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonSubtitleRepository,
JsonTVShowRepository,
)
class TestJsonMovieRepository:
@@ -224,7 +223,9 @@ class TestJsonTVShowRepository:
"""Should preserve show status."""
repo = JsonTVShowRepository()
for i, status in enumerate([ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]):
for i, status in enumerate(
[ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]
):
show = TVShow(
imdb_id=ImdbId(f"tt{i+1000000:07d}"),
title=f"Show {status.value}",
@@ -294,18 +295,22 @@ class TestJsonSubtitleRepository:
def test_find_by_media_with_language_filter(self, memory):
"""Should filter by language."""
repo = JsonSubtitleRepository()
repo.save(Subtitle(
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/en.srt"),
))
repo.save(Subtitle(
)
)
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.FRENCH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/fr.srt"),
))
)
)
results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH)
@@ -315,22 +320,26 @@ class TestJsonSubtitleRepository:
def test_find_by_media_with_episode_filter(self, memory):
"""Should filter by season/episode."""
repo = JsonSubtitleRepository()
repo.save(Subtitle(
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt0944947"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e01.srt"),
season_number=1,
episode_number=1,
))
repo.save(Subtitle(
)
)
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt0944947"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e02.srt"),
season_number=1,
episode_number=2,
))
)
)
results = repo.find_by_media(
ImdbId("tt0944947"),

View File

@@ -21,21 +21,27 @@ def create_mock_response(status_code, json_data=None, text=None):
class TestFindMediaImdbId:
"""Tests for find_media_imdb_id tool."""
@patch('infrastructure.api.tmdb.client.requests.get')
@patch("infrastructure.api.tmdb.client.requests.get")
def test_success(self, mock_get, memory):
"""Should return movie info on success."""
# Mock HTTP responses
def mock_get_side_effect(url, **kwargs):
if "search" in url:
return create_mock_response(200, json_data={
"results": [{
return create_mock_response(
200,
json_data={
"results": [
{
"id": 27205,
"title": "Inception",
"release_date": "2010-07-16",
"overview": "A thief...",
"media_type": "movie"
}]
})
"media_type": "movie",
}
]
},
)
elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
@@ -50,19 +56,25 @@ class TestFindMediaImdbId:
# Verify HTTP calls
assert mock_get.call_count == 2
@patch('infrastructure.api.tmdb.client.requests.get')
@patch("infrastructure.api.tmdb.client.requests.get")
def test_stores_in_stm(self, mock_get, memory):
"""Should store result in STM on success."""
def mock_get_side_effect(url, **kwargs):
if "search" in url:
return create_mock_response(200, json_data={
"results": [{
return create_mock_response(
200,
json_data={
"results": [
{
"id": 27205,
"title": "Inception",
"release_date": "2010-07-16",
"media_type": "movie"
}]
})
"media_type": "movie",
}
]
},
)
elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
@@ -76,7 +88,7 @@ class TestFindMediaImdbId:
assert entity["title"] == "Inception"
assert mem.stm.current_topic == "searching_media"
@patch('infrastructure.api.tmdb.client.requests.get')
@patch("infrastructure.api.tmdb.client.requests.get")
def test_not_found(self, mock_get, memory):
"""Should return error when not found."""
mock_get.return_value = create_mock_response(200, json_data={"results": []})
@@ -86,7 +98,7 @@ class TestFindMediaImdbId:
assert result["status"] == "error"
assert result["error"] == "not_found"
@patch('infrastructure.api.tmdb.client.requests.get')
@patch("infrastructure.api.tmdb.client.requests.get")
def test_does_not_store_on_error(self, mock_get, memory):
"""Should not store in STM on error."""
mock_get.return_value = create_mock_response(200, json_data={"results": []})
@@ -100,27 +112,30 @@ class TestFindMediaImdbId:
class TestFindTorrent:
"""Tests for find_torrent tool."""
@patch('infrastructure.api.knaben.client.requests.post')
@patch("infrastructure.api.knaben.client.requests.post")
def test_success(self, mock_post, memory):
"""Should return torrents on success."""
mock_post.return_value = create_mock_response(200, json_data={
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=...",
"size": "2.5 GB"
"size": "2.5 GB",
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=...",
"size": "1.8 GB"
}
"size": "1.8 GB",
},
]
})
},
)
result = api_tools.find_torrent("Inception 1080p")
@@ -128,21 +143,26 @@ class TestFindTorrent:
assert len(result["torrents"]) == 2
# Verify HTTP payload
payload = mock_post.call_args[1]['json']
assert payload['query'] == "Inception 1080p"
payload = mock_post.call_args[1]["json"]
assert payload["query"] == "Inception 1080p"
@patch('infrastructure.api.knaben.client.requests.post')
@patch("infrastructure.api.knaben.client.requests.post")
def test_stores_in_episodic(self, mock_post, memory):
"""Should store results in episodic memory."""
mock_post.return_value = create_mock_response(200, json_data={
"hits": [{
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=...",
"size": "2.5 GB"
}]
})
"size": "2.5 GB",
}
]
},
)
api_tools.find_torrent("Inception")
@@ -151,16 +171,37 @@ class TestFindTorrent:
assert mem.episodic.last_search_results["query"] == "Inception"
assert mem.stm.current_topic == "selecting_torrent"
@patch('infrastructure.api.knaben.client.requests.post')
@patch("infrastructure.api.knaben.client.requests.post")
def test_results_have_indexes(self, mock_post, memory):
"""Should add indexes to results."""
mock_post.return_value = create_mock_response(200, json_data={
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{"title": "Torrent 1", "seeders": 100, "leechers": 10, "magnetUrl": "magnet:?xt=1", "size": "1GB"},
{"title": "Torrent 2", "seeders": 50, "leechers": 5, "magnetUrl": "magnet:?xt=2", "size": "2GB"},
{"title": "Torrent 3", "seeders": 25, "leechers": 2, "magnetUrl": "magnet:?xt=3", "size": "3GB"}
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=1",
"size": "1GB",
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=2",
"size": "2GB",
},
{
"title": "Torrent 3",
"seeders": 25,
"leechers": 2,
"magnetUrl": "magnet:?xt=3",
"size": "3GB",
},
]
})
},
)
api_tools.find_torrent("Test")
@@ -170,7 +211,7 @@ class TestFindTorrent:
assert results[1]["index"] == 2
assert results[2]["index"] == 3
@patch('infrastructure.api.knaben.client.requests.post')
@patch("infrastructure.api.knaben.client.requests.post")
def test_not_found(self, mock_post, memory):
"""Should return error when no torrents found."""
mock_post.return_value = create_mock_response(200, json_data={"hits": []})
@@ -245,7 +286,7 @@ class TestAddTorrentToQbittorrent:
This is acceptable mocking because we're testing the TOOL logic, not the client.
"""
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory):
"""Should add torrent successfully and update memory."""
mock_client.add_torrent.return_value = True
@@ -257,7 +298,7 @@ class TestAddTorrentToQbittorrent:
# Verify client was called correctly
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_adds_to_active_downloads(self, mock_client, memory_with_search_results):
"""Should add to active downloads on success."""
mock_client.add_torrent.return_value = True
@@ -267,9 +308,12 @@ class TestAddTorrentToQbittorrent:
# Test memory update logic
mem = get_memory()
assert len(mem.episodic.active_downloads) == 1
assert mem.episodic.active_downloads[0]["name"] == "Inception.2010.1080p.BluRay.x264"
assert (
mem.episodic.active_downloads[0]["name"]
== "Inception.2010.1080p.BluRay.x264"
)
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_sets_topic_and_ends_workflow(self, mock_client, memory):
"""Should set topic and end workflow."""
mock_client.add_torrent.return_value = True
@@ -282,10 +326,11 @@ class TestAddTorrentToQbittorrent:
assert mem.stm.current_topic == "downloading"
assert mem.stm.current_workflow is None
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_error_handling(self, mock_client, memory):
"""Should handle client errors correctly."""
from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError
mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed")
result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...")
@@ -304,7 +349,7 @@ class TestAddTorrentByIndex:
- Error handling for edge cases
"""
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory_with_search_results):
"""Should get torrent by index and add it."""
mock_client.add_torrent.return_value = True
@@ -317,7 +362,7 @@ class TestAddTorrentByIndex:
# Verify correct magnet was extracted and used
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch('agent.tools.api.qbittorrent_client')
@patch("agent.tools.api.qbittorrent_client")
def test_uses_correct_magnet(self, mock_client, memory_with_search_results):
"""Should extract correct magnet from index."""
mock_client.add_torrent.return_value = True

View File

@@ -1,7 +1,8 @@
"""Edge case tests for tools."""
from unittest.mock import Mock, patch
import pytest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
from agent.tools import api as api_tools
from agent.tools import filesystem as fs_tools
@@ -15,7 +16,10 @@ class TestFindTorrentEdgeCases:
def test_empty_query(self, mock_use_case_class, memory):
"""Should handle empty query."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "error", "error": "invalid_query"}
mock_response.to_dict.return_value = {
"status": "error",
"error": "invalid_query",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -28,7 +32,11 @@ class TestFindTorrentEdgeCases:
def test_very_long_query(self, mock_use_case_class, memory):
"""Should handle very long query."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -43,7 +51,11 @@ class TestFindTorrentEdgeCases:
def test_special_characters_in_query(self, mock_use_case_class, memory):
"""Should handle special characters in query."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -57,7 +69,11 @@ class TestFindTorrentEdgeCases:
def test_unicode_query(self, mock_use_case_class, memory):
"""Should handle unicode in query."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -161,7 +177,10 @@ class TestAddTorrentEdgeCases:
def test_empty_magnet_link(self, mock_use_case_class, memory):
"""Should handle empty magnet link."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "error", "error": "empty_magnet"}
mock_response.to_dict.return_value = {
"status": "error",
"error": "empty_magnet",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
@@ -326,7 +345,10 @@ class TestFilesystemEdgeCases:
for attempt in attempts:
result = fs_tools.list_folder("download", attempt)
# Should either be forbidden or not found
assert result.get("error") in ["forbidden", "not_found", None] or result.get("status") == "ok"
assert (
result.get("error") in ["forbidden", "not_found", None]
or result.get("status") == "ok"
)
def test_path_with_null_byte(self, memory, real_folder):
"""Should block null byte injection."""