infra: reorganized repo

This commit is contained in:
2025-12-24 07:50:09 +01:00
parent e097a13221
commit 1f88e99e8b
113 changed files with 0 additions and 0 deletions

295
tests/conftest.py Normal file
View File

@@ -0,0 +1,295 @@
"""Pytest configuration and shared fixtures."""
import sys
from pathlib import Path
# Ajouter le dossier parent (brain) au PYTHONPATH
sys.path.insert(0, str(Path(__file__).parent.parent))
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, Mock
import pytest
from infrastructure.persistence import Memory, set_memory
@pytest.fixture
def temp_dir():
"""Create a temporary directory for tests."""
dirpath = tempfile.mkdtemp()
yield Path(dirpath)
shutil.rmtree(dirpath)
@pytest.fixture
def memory(temp_dir):
"""Create a fresh Memory instance for testing."""
mem = Memory(storage_dir=str(temp_dir))
set_memory(mem)
yield mem
@pytest.fixture
def memory_with_config(memory):
"""Memory with pre-configured folders."""
memory.ltm.set_config("download_folder", "/tmp/downloads")
memory.ltm.set_config("movie_folder", "/tmp/movies")
memory.ltm.set_config("tvshow_folder", "/tmp/tvshows")
memory.ltm.set_config("torrent_folder", "/tmp/torrents")
return memory
@pytest.fixture
def memory_with_search_results(memory):
"""Memory with pre-populated search results."""
memory.episodic.store_search_results(
query="Inception 1080p",
results=[
{
"name": "Inception.2010.1080p.BluRay.x264",
"size": "2.5 GB",
"seeders": 150,
"leechers": 10,
"magnet": "magnet:?xt=urn:btih:abc123",
"tracker": "ThePirateBay",
},
{
"name": "Inception.2010.1080p.WEB-DL.x265",
"size": "1.8 GB",
"seeders": 80,
"leechers": 5,
"magnet": "magnet:?xt=urn:btih:def456",
"tracker": "1337x",
},
{
"name": "Inception.2010.720p.BluRay",
"size": "1.2 GB",
"seeders": 45,
"leechers": 2,
"magnet": "magnet:?xt=urn:btih:ghi789",
"tracker": "RARBG",
},
],
search_type="torrent",
)
return memory
@pytest.fixture
def memory_with_history(memory):
"""Memory with conversation history."""
memory.stm.add_message("user", "Hello")
memory.stm.add_message("assistant", "Hi! How can I help you?")
memory.stm.add_message("user", "Find me Inception")
memory.stm.add_message("assistant", "I found Inception (2010)...")
return memory
@pytest.fixture
def memory_with_library(memory):
"""Memory with movies in library."""
memory.ltm.library["movies"] = [
{
"imdb_id": "tt1375666",
"title": "Inception",
"release_year": 2010,
"quality": "1080p",
"file_path": "/movies/Inception.2010.1080p.mkv",
"added_at": "2024-01-15T10:30:00",
},
{
"imdb_id": "tt0816692",
"title": "Interstellar",
"release_year": 2014,
"quality": "4K",
"file_path": "/movies/Interstellar.2014.4K.mkv",
"added_at": "2024-01-16T14:20:00",
},
]
memory.ltm.library["tv_shows"] = [
{
"imdb_id": "tt0944947",
"title": "Game of Thrones",
"seasons_count": 8,
"status": "ended",
"added_at": "2024-01-10T09:00:00",
},
]
return memory
@pytest.fixture
def mock_llm():
"""Create a mock LLM client that returns OpenAI-compatible format."""
llm = Mock()
# Return OpenAI-style message dict without tool calls
def complete_func(messages, tools=None):
return {"role": "assistant", "content": "I found what you're looking for!"}
llm.complete = Mock(side_effect=complete_func)
return llm
@pytest.fixture
def mock_llm_with_tool_call():
"""Create a mock LLM that returns a tool call then a response."""
llm = Mock()
# First call returns a tool call, second returns final response
def complete_side_effect(messages, tools=None):
if not hasattr(complete_side_effect, "call_count"):
complete_side_effect.call_count = 0
complete_side_effect.call_count += 1
if complete_side_effect.call_count == 1:
# First call: return tool call
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "find_torrent",
"arguments": '{"media_title": "Inception"}',
},
}
],
}
else:
# Second call: return final response
return {"role": "assistant", "content": "I found 3 torrents for Inception!"}
llm.complete = Mock(side_effect=complete_side_effect)
return llm
@pytest.fixture
def mock_tmdb_client():
"""Create a mock TMDB client."""
client = Mock()
client.search_movie = Mock(
return_value=Mock(
results=[
Mock(
id=27205,
title="Inception",
release_date="2010-07-16",
overview="A thief who steals corporate secrets...",
)
]
)
)
client.get_external_ids = Mock(return_value={"imdb_id": "tt1375666"})
return client
@pytest.fixture
def mock_knaben_client():
"""Create a mock Knaben client."""
client = Mock()
client.search = Mock(
return_value=[
Mock(
title="Inception.2010.1080p.BluRay",
size="2.5 GB",
seeders=150,
leechers=10,
magnet="magnet:?xt=urn:btih:abc123",
info_hash="abc123",
tracker="TPB",
upload_date="2024-01-01",
category="Movies",
),
]
)
return client
@pytest.fixture
def mock_qbittorrent_client():
"""Create a mock qBittorrent client."""
client = Mock()
client.add_torrent = Mock(return_value=True)
client.get_torrents = Mock(return_value=[])
return client
@pytest.fixture
def real_folder(temp_dir):
"""Create a real folder structure for filesystem tests."""
downloads = temp_dir / "downloads"
movies = temp_dir / "movies"
tvshows = temp_dir / "tvshows"
downloads.mkdir()
movies.mkdir()
tvshows.mkdir()
# Create some test files
(downloads / "test_movie.mkv").touch()
(downloads / "test_series").mkdir()
(downloads / "test_series" / "episode1.mkv").touch()
return {
"root": temp_dir,
"downloads": downloads,
"movies": movies,
"tvshows": tvshows,
}
@pytest.fixture(scope="function")
def mock_deepseek():
"""
Mock DeepSeekClient for individual tests that need it.
This prevents real API calls in tests that use this fixture.
Usage:
def test_something(mock_deepseek):
# Your test code here
"""
import sys
from unittest.mock import Mock
# Save the original module if it exists
original_module = sys.modules.get("agent.llm.deepseek")
# Create a mock module for deepseek
mock_deepseek_module = MagicMock()
class MockDeepSeekClient:
def __init__(self, *args, **kwargs):
self.complete = Mock(return_value="Mocked LLM response")
mock_deepseek_module.DeepSeekClient = MockDeepSeekClient
# Inject the mock
sys.modules["agent.llm.deepseek"] = mock_deepseek_module
yield mock_deepseek_module
# Restore the original module
if original_module is not None:
sys.modules["agent.llm.deepseek"] = original_module
elif "agent.llm.deepseek" in sys.modules:
del sys.modules["agent.llm.deepseek"]
@pytest.fixture
def mock_agent_step():
"""
Fixture to easily mock the agent's step method in API tests.
Returns a context manager that patches app.agent.step.
"""
from unittest.mock import patch
def _mock_step(return_value="Mocked agent response"):
return patch("app.agent.step", return_value=return_value)
return _mock_step

283
tests/test_agent.py Normal file
View File

@@ -0,0 +1,283 @@
"""Tests for the Agent."""
from unittest.mock import Mock
from agent.agent import Agent
from infrastructure.persistence import get_memory
class TestAgentInit:
"""Tests for Agent initialization."""
def test_init(self, memory, mock_llm):
"""Should initialize agent with LLM."""
agent = Agent(llm=mock_llm)
assert agent.llm is mock_llm
assert agent.tools is not None
assert agent.prompt_builder is not None
assert agent.max_tool_iterations == 5
def test_init_custom_iterations(self, memory, mock_llm):
"""Should accept custom max iterations."""
agent = Agent(llm=mock_llm, max_tool_iterations=10)
assert agent.max_tool_iterations == 10
def test_tools_registered(self, memory, mock_llm):
"""Should register all tools."""
agent = Agent(llm=mock_llm)
expected_tools = [
"set_path_for_folder",
"list_folder",
"find_media_imdb_id",
"find_torrent",
"add_torrent_by_index",
"add_torrent_to_qbittorrent",
"get_torrent_by_index",
"set_language",
]
for tool_name in expected_tools:
assert tool_name in agent.tools
class TestExecuteToolCall:
"""Tests for _execute_tool_call method."""
def test_execute_known_tool(self, memory, mock_llm, real_folder):
"""Should execute known tool."""
agent = Agent(llm=mock_llm)
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
tool_call = {
"id": "call_123",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
result = agent._execute_tool_call(tool_call)
assert result["status"] == "ok"
def test_execute_unknown_tool(self, memory, mock_llm):
"""Should return error for unknown tool."""
agent = Agent(llm=mock_llm)
tool_call = {
"id": "call_123",
"function": {"name": "unknown_tool", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
assert result["error"] == "unknown_tool"
assert "available_tools" in result
def test_execute_with_bad_args(self, memory, mock_llm):
"""Should return error for bad arguments."""
agent = Agent(llm=mock_llm)
tool_call = {
"id": "call_123",
"function": {"name": "set_path_for_folder", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
assert result["error"] == "bad_args"
def test_execute_tracks_errors(self, memory, mock_llm):
"""Should track errors in episodic memory."""
agent = Agent(llm=mock_llm)
# Use invalid arguments to trigger a TypeError
tool_call = {
"id": "call_123",
"function": {
"name": "set_path_for_folder",
"arguments": '{"folder_name": 123}', # Wrong type
},
}
agent._execute_tool_call(tool_call)
mem = get_memory()
assert len(mem.episodic.recent_errors) > 0
def test_execute_with_invalid_json(self, memory, mock_llm):
"""Should handle invalid JSON arguments."""
agent = Agent(llm=mock_llm)
tool_call = {
"id": "call_123",
"function": {"name": "list_folder", "arguments": "{invalid json}"},
}
result = agent._execute_tool_call(tool_call)
assert result["error"] == "bad_args"
class TestStep:
"""Tests for step method."""
def test_step_text_response(self, memory, mock_llm):
"""Should return text response when no tool call."""
agent = Agent(llm=mock_llm)
response = agent.step("Hello")
assert response == "I found what you're looking for!"
def test_step_saves_to_history(self, memory, mock_llm):
"""Should save conversation to STM history."""
agent = Agent(llm=mock_llm)
agent.step("Hi there")
mem = get_memory()
history = mem.stm.get_recent_history(10)
assert len(history) == 2
assert history[0]["role"] == "user"
assert history[0]["content"] == "Hi there"
assert history[1]["role"] == "assistant"
def test_step_with_tool_call(self, memory, mock_llm_with_tool_call, real_folder):
"""Should execute tool and continue."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
agent = Agent(llm=mock_llm_with_tool_call)
response = agent.step("List my downloads")
assert "found" in response.lower() or "torrent" in response.lower()
assert mock_llm_with_tool_call.complete.call_count == 2
# CRITICAL: Verify tools were passed to LLM
first_call_args = mock_llm_with_tool_call.complete.call_args_list[0]
assert first_call_args[1]["tools"] is not None, "Tools not passed to LLM!"
assert len(first_call_args[1]["tools"]) > 0, "Tools list is empty!"
def test_step_max_iterations(self, memory, mock_llm):
"""Should stop after max iterations."""
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
# CRITICAL: Verify tools are passed (except on forced final call)
if call_count[0] <= 3:
assert tools is not None, f"Tools not passed on call {call_count[0]}!"
if call_count[0] <= 3:
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
],
}
else:
return {"role": "assistant", "content": "I couldn't complete the task."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)
agent.step("Do something")
assert call_count[0] == 4
def test_step_includes_history(self, memory_with_history, mock_llm):
"""Should include conversation history in prompt."""
agent = Agent(llm=mock_llm)
agent.step("New message")
call_args = mock_llm.complete.call_args[0][0]
messages_content = [m.get("content", "") for m in call_args]
assert any("Hello" in str(c) for c in messages_content)
def test_step_includes_events(self, memory, mock_llm):
"""Should include unread events in prompt."""
memory.episodic.add_background_event("download_complete", {"name": "Movie.mkv"})
agent = Agent(llm=mock_llm)
agent.step("What's new?")
call_args = mock_llm.complete.call_args[0][0]
messages_content = [m.get("content", "") for m in call_args]
assert any("download" in str(c).lower() for c in messages_content)
def test_step_saves_ltm(self, memory, mock_llm, temp_dir):
"""Should save LTM after step."""
agent = Agent(llm=mock_llm)
agent.step("Hello")
ltm_file = temp_dir / "ltm.json"
assert ltm_file.exists()
class TestAgentIntegration:
"""Integration tests for Agent."""
def test_multiple_tool_calls(self, memory, mock_llm, real_folder):
"""Should handle multiple tool calls in sequence."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
memory.ltm.set_config("movie_folder", str(real_folder["movies"]))
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
# CRITICAL: Verify tools are passed on every call
assert tools is not None, f"Tools not passed on call {call_count[0]}!"
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
],
}
elif call_count[0] == 2:
# CRITICAL: Verify tool result was sent back
tool_messages = [m for m in messages if m.get("role") == "tool"]
assert len(tool_messages) > 0, "Tool result not sent back to LLM!"
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_2",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "movie"}',
},
}
],
}
else:
return {
"role": "assistant",
"content": "I listed both folders for you.",
}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
agent.step("List my downloads and movies")
assert call_count[0] == 3

View File

@@ -0,0 +1,6 @@
# Tests removed - too fragile with requests.post mocking
# The critical functionality is tested in test_agent.py with simpler mocks
# Key tests that were here:
# - Tools passed to LLM on every call (now in test_agent.py)
# - Tool results sent back to LLM (covered in test_agent.py)
# - Max iterations handling (covered in test_agent.py)

View File

@@ -0,0 +1,367 @@
"""Edge case tests for the Agent."""
from unittest.mock import Mock
import pytest
from agent.agent import Agent
from infrastructure.persistence import get_memory
class TestExecuteToolCallEdgeCases:
"""Edge case tests for _execute_tool_call."""
def test_tool_returns_none(self, memory, mock_llm):
"""Should handle tool returning None."""
agent = Agent(llm=mock_llm)
# Mock a tool that returns None
from agent.registry import Tool
agent.tools["test_tool"] = Tool(
name="test_tool", description="Test", func=lambda: None, parameters={}
)
tool_call = {
"id": "call_123",
"function": {"name": "test_tool", "arguments": "{}"},
}
result = agent._execute_tool_call(tool_call)
assert result is None or isinstance(result, dict)
def test_tool_raises_keyboard_interrupt(self, memory, mock_llm):
"""Should propagate KeyboardInterrupt."""
agent = Agent(llm=mock_llm)
from agent.registry import Tool
def raise_interrupt():
raise KeyboardInterrupt()
agent.tools["test_tool"] = Tool(
name="test_tool", description="Test", func=raise_interrupt, parameters={}
)
tool_call = {
"id": "call_123",
"function": {"name": "test_tool", "arguments": "{}"},
}
with pytest.raises(KeyboardInterrupt):
agent._execute_tool_call(tool_call)
def test_tool_with_extra_args(self, memory, mock_llm, real_folder):
"""Should handle extra arguments gracefully."""
agent = Agent(llm=mock_llm)
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
tool_call = {
"id": "call_123",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}',
},
}
result = agent._execute_tool_call(tool_call)
assert result.get("error") == "bad_args"
def test_tool_with_wrong_type_args(self, memory, mock_llm):
"""Should handle wrong argument types."""
agent = Agent(llm=mock_llm)
tool_call = {
"id": "call_123",
"function": {
"name": "get_torrent_by_index",
"arguments": '{"index": "not an int"}',
},
}
result = agent._execute_tool_call(tool_call)
assert "error" in result or "status" in result
class TestStepEdgeCases:
"""Edge case tests for step method."""
def test_step_with_empty_input(self, memory, mock_llm):
"""Should handle empty user input."""
agent = Agent(llm=mock_llm)
response = agent.step("")
assert response is not None
def test_step_with_very_long_input(self, memory, mock_llm):
"""Should handle very long user input."""
agent = Agent(llm=mock_llm)
long_input = "x" * 100000
response = agent.step(long_input)
assert response is not None
def test_step_with_unicode_input(self, memory, mock_llm):
"""Should handle unicode input."""
def mock_complete(messages, tools=None):
return {"role": "assistant", "content": "日本語の応答"}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
response = agent.step("日本語の質問")
assert response == "日本語の応答"
def test_step_llm_returns_empty(self, memory, mock_llm):
"""Should handle LLM returning empty string."""
def mock_complete(messages, tools=None):
return {"role": "assistant", "content": ""}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
response = agent.step("Hello")
assert response == ""
def test_step_llm_raises_exception(self, memory, mock_llm):
"""Should propagate LLM exceptions."""
mock_llm.complete.side_effect = Exception("LLM Error")
agent = Agent(llm=mock_llm)
with pytest.raises(Exception, match="LLM Error"):
agent.step("Hello")
def test_step_tool_loop_with_same_tool(self, memory, mock_llm):
"""Should handle tool calling same tool repeatedly."""
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] <= 3:
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
],
}
return {"role": "assistant", "content": "Done looping"}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)
agent.step("Loop test")
assert call_count[0] == 4
def test_step_preserves_history_order(self, memory, mock_llm):
"""Should preserve message order in history."""
agent = Agent(llm=mock_llm)
agent.step("First")
agent.step("Second")
agent.step("Third")
mem = get_memory()
history = mem.stm.get_recent_history(10)
user_messages = [h["content"] for h in history if h["role"] == "user"]
assert user_messages == ["First", "Second", "Third"]
def test_step_with_pending_question(self, memory, mock_llm):
"""Should include pending question in context."""
memory.episodic.set_pending_question(
"Which one?",
[{"index": 1, "label": "Option 1"}],
{},
)
agent = Agent(llm=mock_llm)
agent.step("Hello")
call_args = mock_llm.complete.call_args[0][0]
system_prompt = call_args[0]["content"]
assert "PENDING QUESTION" in system_prompt
def test_step_with_active_downloads(self, memory, mock_llm):
"""Should include active downloads in context."""
memory.episodic.add_active_download(
{
"task_id": "123",
"name": "Movie.mkv",
"progress": 50,
}
)
agent = Agent(llm=mock_llm)
agent.step("Hello")
call_args = mock_llm.complete.call_args[0][0]
system_prompt = call_args[0]["content"]
assert "ACTIVE DOWNLOADS" in system_prompt
def test_step_clears_events_after_notification(self, memory, mock_llm):
"""Should mark events as read after notification."""
memory.episodic.add_background_event("test_event", {"data": "test"})
agent = Agent(llm=mock_llm)
agent.step("Hello")
unread = memory.episodic.get_unread_events()
assert len(unread) == 0
class TestAgentConcurrencyEdgeCases:
"""Edge case tests for concurrent access."""
def test_multiple_agents_same_memory(self, memory, mock_llm):
"""Should handle multiple agents with same memory."""
agent1 = Agent(llm=mock_llm)
agent2 = Agent(llm=mock_llm)
agent1.step("From agent 1")
agent2.step("From agent 2")
mem = get_memory()
history = mem.stm.get_recent_history(10)
assert len(history) == 4
def test_tool_modifies_memory_during_step(self, memory, mock_llm, real_folder):
"""Should handle memory modifications during step."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "set_path_for_folder",
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}',
},
}
],
}
return {"role": "assistant", "content": "Path set successfully."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
agent.step("Set movie folder")
mem = get_memory()
assert mem.ltm.get_config("movie_folder") == str(real_folder["movies"])
class TestAgentErrorRecovery:
"""Tests for agent error recovery."""
def test_recovers_from_tool_error(self, memory, mock_llm):
"""Should recover from tool error and continue."""
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
],
}
return {"role": "assistant", "content": "The folder is not configured."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
response = agent.step("List downloads")
assert "not configured" in response.lower() or len(response) > 0
def test_error_tracked_in_memory(self, memory, mock_llm):
"""Should track errors in episodic memory."""
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "set_path_for_folder",
"arguments": "{}", # Missing required args
},
}
],
}
return {"role": "assistant", "content": "Error occurred."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm)
agent.step("Set folder")
mem = get_memory()
assert len(mem.episodic.recent_errors) > 0
def test_multiple_errors_in_sequence(self, memory, mock_llm):
"""Should track multiple errors."""
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] <= 3:
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": f"call_{call_count[0]}",
"function": {
"name": "set_path_for_folder",
"arguments": "{}", # Missing required args - will error
},
}
],
}
return {"role": "assistant", "content": "All attempts failed."}
mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3)
agent.step("Try multiple times")
mem = get_memory()
assert len(mem.episodic.recent_errors) >= 1

View File

@@ -0,0 +1,2 @@
# DEPRECATED - Tests removed due to mock issues
# Use test_agent_critical.py instead which has correct mock setup

242
tests/test_api.py Normal file
View File

@@ -0,0 +1,242 @@
"""Tests for FastAPI endpoints."""
from unittest.mock import patch
from fastapi.testclient import TestClient
class TestHealthEndpoint:
"""Tests for /health endpoint."""
def test_health_check(self, memory):
"""Should return healthy status."""
from app import app
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
class TestModelsEndpoint:
"""Tests for /v1/models endpoint."""
def test_list_models(self, memory):
"""Should return model list."""
from app import app
client = TestClient(app)
response = client.get("/v1/models")
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert len(data["data"]) > 0
assert data["data"][0]["id"] == "agent-media"
class TestMemoryEndpoints:
"""Tests for memory debug endpoints."""
def test_get_memory_state(self, memory):
"""Should return full memory state."""
from app import app
client = TestClient(app)
response = client.get("/memory/state")
assert response.status_code == 200
data = response.json()
assert "ltm" in data
assert "stm" in data
assert "episodic" in data
def test_get_search_results_empty(self, memory):
"""Should return empty when no search results."""
from app import app
client = TestClient(app)
response = client.get("/memory/episodic/search-results")
assert response.status_code == 200
data = response.json()
assert data["status"] == "empty"
def test_get_search_results_with_data(self, memory_with_search_results):
"""Should return search results when available."""
from app import app
client = TestClient(app)
response = client.get("/memory/episodic/search-results")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["query"] == "Inception 1080p"
assert data["result_count"] == 3
def test_clear_session(self, memory_with_search_results):
"""Should clear session memories."""
from app import app
client = TestClient(app)
response = client.post("/memory/clear-session")
assert response.status_code == 200
assert response.json()["status"] == "ok"
# Verify cleared
state = client.get("/memory/state").json()
assert state["episodic"]["last_search_results"] is None
class TestChatCompletionsEndpoint:
"""Tests for /v1/chat/completions endpoint."""
def test_chat_completion_success(self, memory):
"""Should return chat completion."""
from app import app
# Patch the agent's step method directly
with patch("app.agent.step", return_value="Hello! How can I help?"):
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
},
)
assert response.status_code == 200
data = response.json()
assert data["object"] == "chat.completion"
assert "Hello" in data["choices"][0]["message"]["content"]
def test_chat_completion_no_user_message(self, memory):
"""Should return error if no user message."""
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "system", "content": "You are helpful"}],
},
)
assert response.status_code == 422
detail = response.json()["detail"]
# Pydantic returns a list of errors or a string
if isinstance(detail, list):
detail_str = str(detail).lower()
else:
detail_str = detail.lower()
assert "user message" in detail_str
def test_chat_completion_empty_messages(self, memory):
"""Should return error for empty messages."""
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [],
},
)
assert response.status_code == 422
def test_chat_completion_invalid_json(self, memory):
"""Should return error for invalid JSON."""
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
content="not json",
headers={"Content-Type": "application/json"},
)
assert response.status_code == 422
def test_chat_completion_streaming(self, memory):
"""Should support streaming mode."""
from app import app
with patch("app.agent.step", return_value="Streaming response"):
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
},
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"]
def test_chat_completion_extracts_last_user_message(self, memory):
"""Should use last user message."""
from app import app
with patch("app.agent.step", return_value="Response") as mock_step:
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "user", "content": "First message"},
{"role": "assistant", "content": "Response"},
{"role": "user", "content": "Second message"},
],
},
)
assert response.status_code == 200
# Verify the agent received the last user message
mock_step.assert_called_once_with("Second message")
def test_chat_completion_response_format(self, memory):
"""Should return OpenAI-compatible format."""
from app import app
with patch("app.agent.step", return_value="Test response"):
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Test"}],
},
)
data = response.json()
assert "id" in data
assert data["id"].startswith("chatcmpl-")
assert "created" in data
assert "model" in data
assert "choices" in data
assert "usage" in data
assert data["choices"][0]["finish_reason"] == "stop"
assert data["choices"][0]["message"]["role"] == "assistant"

View File

@@ -0,0 +1,2 @@
# DEPRECATED - Tests removed due to API signature mismatches
# Use test_tools_api.py instead which has been refactored with correct signatures

View File

@@ -0,0 +1,549 @@
"""Edge case tests for FastAPI endpoints."""
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
class TestChatCompletionsEdgeCases:
"""Edge case tests for /v1/chat/completions endpoint."""
def test_very_long_message(self, memory):
"""Should handle very long user message."""
from app import agent, app
# Patch the agent's LLM directly
mock_llm = Mock()
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
long_message = "x" * 100000
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": long_message}],
},
)
assert response.status_code == 200
def test_unicode_message(self, memory):
"""Should handle unicode in message."""
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": "日本語の応答",
}
agent.llm = mock_llm
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}],
},
)
assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"]
assert "日本語" in content or len(content) > 0
def test_special_characters_in_message(self, memory):
"""Should handle special characters."""
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
special_message = 'Test with "quotes" and \\backslash and \n newline'
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": special_message}],
},
)
assert response.status_code == 200
def test_empty_content_in_message(self, memory):
"""Should handle empty content in message."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm.complete.return_value = "Response"
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": ""}],
},
)
# Empty content should be rejected
assert response.status_code == 422
def test_null_content_in_message(self, memory):
"""Should handle null content in message."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": None}],
},
)
assert response.status_code == 422
def test_missing_content_field(self, memory):
"""Should handle missing content field."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user"}], # No content
},
)
# May accept or reject depending on validation
assert response.status_code in [200, 400, 422]
def test_missing_role_field(self, memory):
"""Should handle missing role field."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"content": "Hello"}], # No role
},
)
# Should reject or accept depending on validation
assert response.status_code in [200, 400, 422]
def test_invalid_role(self, memory):
"""Should handle invalid role."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm.complete.return_value = "Response"
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "invalid_role", "content": "Hello"}],
},
)
# Should reject or ignore invalid role
assert response.status_code in [200, 400, 422]
def test_many_messages(self, memory):
"""Should handle many messages in conversation."""
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
messages = []
for i in range(100):
messages.append({"role": "user", "content": f"Message {i}"})
messages.append({"role": "assistant", "content": f"Response {i}"})
messages.append({"role": "user", "content": "Final message"})
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": messages,
},
)
assert response.status_code == 200
def test_only_system_messages(self, memory):
"""Should reject if only system messages."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "system", "content": "Be concise"},
],
},
)
assert response.status_code == 422
def test_only_assistant_messages(self, memory):
"""Should reject if only assistant messages."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [
{"role": "assistant", "content": "Hello"},
],
},
)
assert response.status_code == 422
def test_messages_not_array(self, memory):
"""Should reject if messages is not array."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": "not an array",
},
)
assert response.status_code == 422
# Pydantic validation error
def test_message_not_object(self, memory):
"""Should handle message that is not object."""
with patch("app.DeepSeekClient") as mock_llm_class:
mock_llm = Mock()
mock_llm_class.return_value = mock_llm
from app import app
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": ["not an object", 123, None],
},
)
assert response.status_code == 422
# Pydantic validation error
def test_extra_fields_in_request(self, memory):
"""Should ignore extra fields in request."""
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
"extra_field": "should be ignored",
"temperature": 0.7,
"max_tokens": 100,
},
)
assert response.status_code == 200
def test_streaming_with_tool_call(self, memory, real_folder):
"""Should handle streaming with tool execution."""
from app import agent, app
from infrastructure.persistence import get_memory
mem = get_memory()
mem.ltm.set_config("download_folder", str(real_folder["downloads"]))
call_count = [0]
def mock_complete(messages, tools=None):
call_count[0] += 1
if call_count[0] == 1:
return {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
}
],
}
return {"role": "assistant", "content": "Listed the folder."}
mock_llm = Mock()
mock_llm.complete = Mock(side_effect=mock_complete)
agent.llm = mock_llm
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "List downloads"}],
"stream": True,
},
)
assert response.status_code == 200
def test_concurrent_requests_simulation(self, memory):
"""Should handle rapid sequential requests."""
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
agent.llm = mock_llm
client = TestClient(app)
for i in range(10):
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": f"Request {i}"}],
},
)
assert response.status_code == 200
def test_llm_returns_json_in_response(self, memory):
"""Should handle LLM returning JSON in text response."""
from app import agent, app
mock_llm = Mock()
mock_llm.complete.return_value = {
"role": "assistant",
"content": '{"result": "some data", "count": 5}',
}
agent.llm = mock_llm
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "agent-media",
"messages": [{"role": "user", "content": "Give me JSON"}],
},
)
assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"]
assert "result" in content or len(content) > 0
class TestMemoryEndpointsEdgeCases:
"""Edge case tests for memory endpoints."""
def test_memory_state_with_large_data(self, memory):
"""Should handle large memory state."""
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
# Add lots of data to memory
for i in range(100):
memory.stm.add_message("user", f"Message {i}" * 100)
memory.episodic.add_error("action", f"Error {i}")
client = TestClient(app)
response = client.get("/memory/state")
assert response.status_code == 200
data = response.json()
assert "stm" in data
def test_memory_state_with_unicode(self, memory):
"""Should handle unicode in memory state."""
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
memory.ltm.set_config("japanese", "日本語テスト")
memory.stm.add_message("user", "🎬 Movie request")
client = TestClient(app)
response = client.get("/memory/state")
assert response.status_code == 200
data = response.json()
assert "日本語" in str(data)
def test_search_results_with_special_chars(self, memory):
"""Should handle special characters in search results."""
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
memory.episodic.store_search_results(
"Test <script>alert('xss')</script>",
[{"name": "Result with \"quotes\" and 'apostrophes'"}],
)
client = TestClient(app)
response = client.get("/memory/episodic/search-results")
assert response.status_code == 200
# Should be properly escaped in JSON
data = response.json()
assert "script" in data["query"]
def test_clear_session_idempotent(self, memory):
"""Should be idempotent - multiple clears should work."""
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
# Clear multiple times
for _ in range(5):
response = client.post("/memory/clear-session")
assert response.status_code == 200
def test_clear_session_preserves_ltm(self, memory):
"""Should preserve LTM after clear."""
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
memory.ltm.set_config("important", "data")
memory.stm.add_message("user", "Hello")
client = TestClient(app)
client.post("/memory/clear-session")
response = client.get("/memory/state")
data = response.json()
assert data["ltm"]["config"]["important"] == "data"
assert data["stm"]["conversation_history"] == []
class TestHealthEndpointEdgeCases:
"""Edge case tests for health endpoint."""
def test_health_returns_version(self, memory):
"""Should return version in health check."""
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
assert "version" in response.json()
def test_health_with_query_params(self, memory):
"""Should ignore query parameters."""
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/health?extra=param&another=value")
assert response.status_code == 200
class TestModelsEndpointEdgeCases:
"""Edge case tests for models endpoint."""
def test_models_response_format(self, memory):
"""Should return OpenAI-compatible format."""
with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock()
from app import app
client = TestClient(app)
response = client.get("/v1/models")
data = response.json()
assert data["object"] == "list"
assert isinstance(data["data"], list)
assert len(data["data"]) > 0
assert "id" in data["data"][0]
assert "object" in data["data"][0]
assert "created" in data["data"][0]
assert "owned_by" in data["data"][0]

View File

@@ -0,0 +1,191 @@
"""Critical tests for configuration validation."""
import pytest
from agent.config import ConfigurationError, Settings
class TestConfigValidation:
"""Critical tests for config validation."""
def test_invalid_temperature_raises_error(self):
"""Verify invalid temperature is rejected."""
with pytest.raises(ConfigurationError, match="Temperature"):
Settings(temperature=3.0) # > 2.0
with pytest.raises(ConfigurationError, match="Temperature"):
Settings(temperature=-0.1) # < 0.0
def test_valid_temperature_accepted(self):
"""Verify valid temperature is accepted."""
# Should not raise
Settings(temperature=0.0)
Settings(temperature=1.0)
Settings(temperature=2.0)
def test_invalid_max_iterations_raises_error(self):
"""Verify invalid max_iterations is rejected."""
with pytest.raises(ConfigurationError, match="max_tool_iterations"):
Settings(max_tool_iterations=0) # < 1
with pytest.raises(ConfigurationError, match="max_tool_iterations"):
Settings(max_tool_iterations=100) # > 20
def test_valid_max_iterations_accepted(self):
"""Verify valid max_iterations is accepted."""
# Should not raise
Settings(max_tool_iterations=1)
Settings(max_tool_iterations=10)
Settings(max_tool_iterations=20)
def test_invalid_timeout_raises_error(self):
"""Verify invalid timeout is rejected."""
with pytest.raises(ConfigurationError, match="request_timeout"):
Settings(request_timeout=0) # < 1
with pytest.raises(ConfigurationError, match="request_timeout"):
Settings(request_timeout=500) # > 300
def test_valid_timeout_accepted(self):
"""Verify valid timeout is accepted."""
# Should not raise
Settings(request_timeout=1)
Settings(request_timeout=30)
Settings(request_timeout=300)
def test_invalid_deepseek_url_raises_error(self):
"""Verify invalid DeepSeek URL is rejected."""
with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"):
Settings(deepseek_base_url="not-a-url")
with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"):
Settings(deepseek_base_url="ftp://invalid.com")
def test_valid_deepseek_url_accepted(self):
"""Verify valid DeepSeek URL is accepted."""
# Should not raise
Settings(deepseek_base_url="https://api.deepseek.com")
Settings(deepseek_base_url="http://localhost:8000")
def test_invalid_tmdb_url_raises_error(self):
"""Verify invalid TMDB URL is rejected."""
with pytest.raises(ConfigurationError, match="Invalid tmdb_base_url"):
Settings(tmdb_base_url="not-a-url")
def test_valid_tmdb_url_accepted(self):
"""Verify valid TMDB URL is accepted."""
# Should not raise
Settings(tmdb_base_url="https://api.themoviedb.org/3")
Settings(tmdb_base_url="http://localhost:3000")
class TestConfigChecks:
"""Tests for configuration check methods."""
def test_is_deepseek_configured_with_key(self):
"""Verify is_deepseek_configured returns True with API key."""
settings = Settings(
deepseek_api_key="test-key", deepseek_base_url="https://api.test.com"
)
assert settings.is_deepseek_configured() is True
def test_is_deepseek_configured_without_key(self):
"""Verify is_deepseek_configured returns False without API key."""
settings = Settings(
deepseek_api_key="", deepseek_base_url="https://api.test.com"
)
assert settings.is_deepseek_configured() is False
def test_is_deepseek_configured_without_url(self):
"""Verify is_deepseek_configured returns False without URL."""
# This will fail validation, so we can't test it directly
# The validation happens in __post_init__
pass
def test_is_tmdb_configured_with_key(self):
"""Verify is_tmdb_configured returns True with API key."""
settings = Settings(
tmdb_api_key="test-key", tmdb_base_url="https://api.test.com"
)
assert settings.is_tmdb_configured() is True
def test_is_tmdb_configured_without_key(self):
"""Verify is_tmdb_configured returns False without API key."""
settings = Settings(tmdb_api_key="", tmdb_base_url="https://api.test.com")
assert settings.is_tmdb_configured() is False
class TestConfigDefaults:
"""Tests for configuration defaults."""
def test_default_temperature(self):
"""Verify default temperature is reasonable."""
settings = Settings()
assert 0.0 <= settings.temperature <= 2.0
def test_default_max_iterations(self):
"""Verify default max_iterations is reasonable."""
settings = Settings()
assert 1 <= settings.max_tool_iterations <= 20
def test_default_timeout(self):
"""Verify default timeout is reasonable."""
settings = Settings()
assert 1 <= settings.request_timeout <= 300
def test_default_urls_are_valid(self):
"""Verify default URLs are valid."""
settings = Settings()
assert settings.deepseek_base_url.startswith(("http://", "https://"))
assert settings.tmdb_base_url.startswith(("http://", "https://"))
class TestConfigEnvironmentVariables:
"""Tests for environment variable loading."""
def test_loads_temperature_from_env(self, monkeypatch):
"""Verify temperature is loaded from environment."""
monkeypatch.setenv("TEMPERATURE", "0.5")
settings = Settings()
assert settings.temperature == 0.5
def test_loads_max_iterations_from_env(self, monkeypatch):
"""Verify max_iterations is loaded from environment."""
monkeypatch.setenv("MAX_TOOL_ITERATIONS", "10")
settings = Settings()
assert settings.max_tool_iterations == 10
def test_loads_timeout_from_env(self, monkeypatch):
"""Verify timeout is loaded from environment."""
monkeypatch.setenv("REQUEST_TIMEOUT", "60")
settings = Settings()
assert settings.request_timeout == 60
def test_loads_deepseek_url_from_env(self, monkeypatch):
"""Verify DeepSeek URL is loaded from environment."""
monkeypatch.setenv("DEEPSEEK_BASE_URL", "https://custom.api.com")
settings = Settings()
assert settings.deepseek_base_url == "https://custom.api.com"
def test_invalid_env_value_raises_error(self, monkeypatch):
"""Verify invalid environment value raises error."""
monkeypatch.setenv("TEMPERATURE", "invalid")
with pytest.raises(ValueError):
Settings()

View File

@@ -0,0 +1,319 @@
"""Edge case tests for configuration and parameters."""
import os
from unittest.mock import patch
import pytest
from agent.config import ConfigurationError, Settings
from agent.parameters import (
REQUIRED_PARAMETERS,
ParameterSchema,
format_parameters_for_prompt,
get_missing_required_parameters,
)
class TestSettingsEdgeCases:
"""Edge case tests for Settings."""
def test_default_values(self):
"""Should have sensible defaults."""
with patch.dict(os.environ, {}, clear=True):
settings = Settings()
assert settings.temperature == 0.2
assert settings.max_tool_iterations == 5
assert settings.request_timeout == 30
def test_temperature_boundary_low(self):
"""Should accept temperature at lower boundary."""
with patch.dict(os.environ, {"TEMPERATURE": "0.0"}, clear=True):
settings = Settings()
assert settings.temperature == 0.0
def test_temperature_boundary_high(self):
"""Should accept temperature at upper boundary."""
with patch.dict(os.environ, {"TEMPERATURE": "2.0"}, clear=True):
settings = Settings()
assert settings.temperature == 2.0
def test_temperature_below_boundary(self):
"""Should reject temperature below 0."""
with patch.dict(os.environ, {"TEMPERATURE": "-0.1"}, clear=True):
with pytest.raises(ConfigurationError):
Settings()
def test_temperature_above_boundary(self):
"""Should reject temperature above 2."""
with patch.dict(os.environ, {"TEMPERATURE": "2.1"}, clear=True):
with pytest.raises(ConfigurationError):
Settings()
def test_max_tool_iterations_boundary_low(self):
"""Should accept max_tool_iterations at lower boundary."""
with patch.dict(os.environ, {"MAX_TOOL_ITERATIONS": "1"}, clear=True):
settings = Settings()
assert settings.max_tool_iterations == 1
def test_max_tool_iterations_boundary_high(self):
"""Should accept max_tool_iterations at upper boundary."""
with patch.dict(os.environ, {"MAX_TOOL_ITERATIONS": "20"}, clear=True):
settings = Settings()
assert settings.max_tool_iterations == 20
def test_max_tool_iterations_below_boundary(self):
"""Should reject max_tool_iterations below 1."""
with patch.dict(os.environ, {"MAX_TOOL_ITERATIONS": "0"}, clear=True):
with pytest.raises(ConfigurationError):
Settings()
def test_max_tool_iterations_above_boundary(self):
"""Should reject max_tool_iterations above 20."""
with patch.dict(os.environ, {"MAX_TOOL_ITERATIONS": "21"}, clear=True):
with pytest.raises(ConfigurationError):
Settings()
def test_request_timeout_boundary_low(self):
"""Should accept request_timeout at lower boundary."""
with patch.dict(os.environ, {"REQUEST_TIMEOUT": "1"}, clear=True):
settings = Settings()
assert settings.request_timeout == 1
def test_request_timeout_boundary_high(self):
"""Should accept request_timeout at upper boundary."""
with patch.dict(os.environ, {"REQUEST_TIMEOUT": "300"}, clear=True):
settings = Settings()
assert settings.request_timeout == 300
def test_request_timeout_below_boundary(self):
"""Should reject request_timeout below 1."""
with patch.dict(os.environ, {"REQUEST_TIMEOUT": "0"}, clear=True):
with pytest.raises(ConfigurationError):
Settings()
def test_request_timeout_above_boundary(self):
"""Should reject request_timeout above 300."""
with patch.dict(os.environ, {"REQUEST_TIMEOUT": "301"}, clear=True):
with pytest.raises(ConfigurationError):
Settings()
def test_invalid_deepseek_url(self):
"""Should reject invalid DeepSeek URL."""
with patch.dict(os.environ, {"DEEPSEEK_BASE_URL": "not-a-url"}, clear=True):
with pytest.raises(ConfigurationError):
Settings()
def test_invalid_tmdb_url(self):
"""Should reject invalid TMDB URL."""
with patch.dict(os.environ, {"TMDB_BASE_URL": "ftp://invalid"}, clear=True):
with pytest.raises(ConfigurationError):
Settings()
def test_http_url_accepted(self):
"""Should accept http:// URLs."""
with patch.dict(
os.environ,
{
"DEEPSEEK_BASE_URL": "http://localhost:8080",
"TMDB_BASE_URL": "http://localhost:3000",
},
clear=True,
):
settings = Settings()
assert settings.deepseek_base_url == "http://localhost:8080"
def test_https_url_accepted(self):
"""Should accept https:// URLs."""
with patch.dict(
os.environ,
{
"DEEPSEEK_BASE_URL": "https://api.example.com",
"TMDB_BASE_URL": "https://api.example.com",
},
clear=True,
):
settings = Settings()
assert settings.deepseek_base_url == "https://api.example.com"
def test_is_deepseek_configured_with_key(self):
"""Should return True when API key is set."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}, clear=True):
settings = Settings()
assert settings.is_deepseek_configured() is True
def test_is_deepseek_configured_without_key(self):
"""Should return False when API key is not set."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": ""}, clear=True):
settings = Settings()
assert settings.is_deepseek_configured() is False
def test_is_tmdb_configured_with_key(self):
"""Should return True when API key is set."""
with patch.dict(os.environ, {"TMDB_API_KEY": "test-key"}, clear=True):
settings = Settings()
assert settings.is_tmdb_configured() is True
def test_is_tmdb_configured_without_key(self):
"""Should return False when API key is not set."""
with patch.dict(os.environ, {"TMDB_API_KEY": ""}, clear=True):
settings = Settings()
assert settings.is_tmdb_configured() is False
def test_non_numeric_temperature(self):
"""Should handle non-numeric temperature."""
with patch.dict(os.environ, {"TEMPERATURE": "not-a-number"}, clear=True):
with pytest.raises((ConfigurationError, ValueError)):
Settings()
def test_non_numeric_max_iterations(self):
"""Should handle non-numeric max_tool_iterations."""
with patch.dict(os.environ, {"MAX_TOOL_ITERATIONS": "five"}, clear=True):
with pytest.raises((ConfigurationError, ValueError)):
Settings()
class TestParametersEdgeCases:
"""Edge case tests for parameters module."""
def test_parameter_creation(self):
"""Should create parameter with all fields."""
param = ParameterSchema(
key="test_key",
description="Test description",
why_needed="Test reason",
type="string",
)
assert param.key == "test_key"
assert param.description == "Test description"
assert param.why_needed == "Test reason"
assert param.type == "string"
def test_required_parameters_not_empty(self):
"""Should have at least one required parameter."""
assert len(REQUIRED_PARAMETERS) > 0
def test_format_parameters_for_prompt(self):
"""Should format parameters for prompt."""
result = format_parameters_for_prompt()
assert isinstance(result, str)
# Should contain parameter information
for param in REQUIRED_PARAMETERS:
assert param.key in result or param.description in result
def test_get_missing_required_parameters_all_missing(self):
"""Should return all parameters when none configured."""
memory_data = {"config": {}}
missing = get_missing_required_parameters(memory_data)
# Config may have defaults, so check it's a list
assert isinstance(missing, list)
assert len(missing) >= 0
def test_get_missing_required_parameters_none_missing(self):
"""Should return empty when all configured."""
memory_data = {"config": {}}
for param in REQUIRED_PARAMETERS:
memory_data["config"][param.key] = "/some/path"
missing = get_missing_required_parameters(memory_data)
assert len(missing) == 0
def test_get_missing_required_parameters_some_missing(self):
"""Should return only missing parameters."""
memory_data = {"config": {}}
if REQUIRED_PARAMETERS:
# Configure first parameter only
memory_data["config"][REQUIRED_PARAMETERS[0].key] = "/path"
missing = get_missing_required_parameters(memory_data)
# Config may have defaults
assert isinstance(missing, list)
assert len(missing) >= 0
def test_get_missing_required_parameters_with_none_value(self):
"""Should treat None as missing."""
memory_data = {"config": {}}
for param in REQUIRED_PARAMETERS:
memory_data["config"][param.key] = None
missing = get_missing_required_parameters(memory_data)
# Config may have defaults
assert isinstance(missing, list)
assert len(missing) >= 0
def test_get_missing_required_parameters_with_empty_string(self):
"""Should treat empty string as missing."""
memory_data = {"config": {}}
for param in REQUIRED_PARAMETERS:
memory_data["config"][param.key] = ""
missing = get_missing_required_parameters(memory_data)
# Behavior depends on implementation
# Empty string might be considered as "set" or "missing"
assert isinstance(missing, list)
def test_get_missing_required_parameters_no_config_key(self):
"""Should handle missing config key in memory."""
memory_data = {} # No config key at all
missing = get_missing_required_parameters(memory_data)
# Config may have defaults
assert isinstance(missing, list)
assert len(missing) >= 0
def test_get_missing_required_parameters_config_not_dict(self):
"""Should handle config that is not a dict."""
memory_data = {"config": "not a dict"}
# Should either handle gracefully or raise
try:
missing = get_missing_required_parameters(memory_data)
assert isinstance(missing, list)
except (TypeError, AttributeError):
pass # Also acceptable
class TestParameterValidation:
"""Tests for parameter validation."""
def test_parameter_with_unicode(self):
"""Should handle unicode in parameter fields."""
param = ParameterSchema(
key="日本語_key",
description="日本語の説明",
why_needed="日本語の理由",
type="string",
)
assert "日本語" in param.description
def test_parameter_with_special_chars(self):
"""Should handle special characters."""
param = ParameterSchema(
key="key_with_special",
description='Description with "quotes" and \\backslash',
why_needed="Reason with <html> tags",
type="string",
)
assert '"quotes"' in param.description
def test_parameter_with_empty_fields(self):
"""Should handle empty fields."""
param = ParameterSchema(
key="",
description="",
why_needed="",
type="",
)
assert param.key == ""

View File

@@ -0,0 +1,525 @@
"""Edge case tests for domain entities and value objects."""
from datetime import datetime
import pytest
from domain.movies.entities import Movie
from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear
from domain.shared.exceptions import ValidationError
from domain.shared.value_objects import FilePath, FileSize, ImdbId
from domain.subtitles.entities import Subtitle
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
class TestImdbIdEdgeCases:
"""Edge case tests for ImdbId."""
def test_valid_imdb_id(self):
"""Should accept valid IMDb ID."""
imdb_id = ImdbId("tt1375666")
assert str(imdb_id) == "tt1375666"
def test_imdb_id_with_leading_zeros(self):
"""Should accept IMDb ID with leading zeros."""
imdb_id = ImdbId("tt0000001")
assert str(imdb_id) == "tt0000001"
def test_imdb_id_long_number(self):
"""Should accept IMDb ID with 8 digits."""
imdb_id = ImdbId("tt12345678")
assert str(imdb_id) == "tt12345678"
def test_imdb_id_lowercase(self):
"""Should accept lowercase tt prefix."""
imdb_id = ImdbId("tt1234567")
assert str(imdb_id) == "tt1234567"
def test_imdb_id_uppercase(self):
"""Should handle uppercase TT prefix."""
# Behavior depends on implementation
try:
imdb_id = ImdbId("TT1234567")
# If accepted, should work
assert imdb_id is not None
except (ValidationError, ValueError):
# If rejected, that's also valid
pass
def test_imdb_id_without_prefix(self):
"""Should reject ID without tt prefix."""
with pytest.raises((ValidationError, ValueError)):
ImdbId("1234567")
def test_imdb_id_empty(self):
"""Should reject empty string."""
with pytest.raises((ValidationError, ValueError)):
ImdbId("")
def test_imdb_id_none(self):
"""Should reject None."""
with pytest.raises((ValidationError, ValueError, TypeError)):
ImdbId(None)
def test_imdb_id_with_spaces(self):
"""Should reject ID with spaces."""
with pytest.raises((ValidationError, ValueError)):
ImdbId("tt 1234567")
def test_imdb_id_with_special_chars(self):
"""Should reject ID with special characters."""
with pytest.raises((ValidationError, ValueError)):
ImdbId("tt1234567!")
def test_imdb_id_equality(self):
"""Should compare equal IDs."""
id1 = ImdbId("tt1234567")
id2 = ImdbId("tt1234567")
assert id1 == id2 or str(id1) == str(id2)
def test_imdb_id_hash(self):
"""Should be hashable for use in sets/dicts."""
id1 = ImdbId("tt1234567")
id2 = ImdbId("tt1234567")
# Should be usable in set
_s = {id1, id2} # Test hashability
# Depending on implementation, might be 1 or 2 items
class TestFilePathEdgeCases:
"""Edge case tests for FilePath."""
def test_absolute_path(self):
"""Should accept absolute path."""
path = FilePath("/home/user/movies/movie.mkv")
assert "/home/user/movies/movie.mkv" in str(path)
def test_relative_path(self):
"""Should accept relative path."""
path = FilePath("movies/movie.mkv")
assert "movies/movie.mkv" in str(path)
def test_path_with_spaces(self):
"""Should accept path with spaces."""
path = FilePath("/home/user/My Movies/movie file.mkv")
assert "My Movies" in str(path)
def test_path_with_unicode(self):
"""Should accept path with unicode."""
path = FilePath("/home/user/映画/日本語.mkv")
assert "映画" in str(path)
def test_windows_path(self):
"""Should handle Windows-style path."""
path = FilePath("C:\\Users\\user\\Movies\\movie.mkv")
assert "movie.mkv" in str(path)
def test_empty_path(self):
"""Should handle empty path."""
try:
path = FilePath("")
# If accepted, may return "." for current directory
assert str(path) in ["", "."]
except (ValidationError, ValueError):
# If rejected, that's also valid
pass
def test_path_with_dots(self):
"""Should handle path with . and .."""
path = FilePath("/home/user/../other/./movie.mkv")
assert "movie.mkv" in str(path)
class TestFileSizeEdgeCases:
"""Edge case tests for FileSize."""
def test_zero_size(self):
"""Should accept zero size."""
size = FileSize(0)
assert size.bytes == 0
def test_very_large_size(self):
"""Should accept very large size (petabytes)."""
size = FileSize(1024**5) # 1 PB
assert size.bytes == 1024**5
def test_negative_size(self):
"""Should reject negative size."""
with pytest.raises((ValidationError, ValueError)):
FileSize(-1)
def test_human_readable_bytes(self):
"""Should format bytes correctly."""
size = FileSize(500)
readable = size.to_human_readable()
assert "500" in readable or "B" in readable
def test_human_readable_kb(self):
"""Should format KB correctly."""
size = FileSize(1024)
readable = size.to_human_readable()
assert "KB" in readable or "1" in readable
def test_human_readable_mb(self):
"""Should format MB correctly."""
size = FileSize(1024 * 1024)
readable = size.to_human_readable()
assert "MB" in readable or "1" in readable
def test_human_readable_gb(self):
"""Should format GB correctly."""
size = FileSize(1024 * 1024 * 1024)
readable = size.to_human_readable()
assert "GB" in readable or "1" in readable
class TestMovieTitleEdgeCases:
"""Edge case tests for MovieTitle."""
def test_normal_title(self):
"""Should accept normal title."""
title = MovieTitle("Inception")
assert title.value == "Inception"
def test_title_with_year(self):
"""Should accept title with year."""
title = MovieTitle("Blade Runner 2049")
assert "2049" in title.value
def test_title_with_special_chars(self):
"""Should accept title with special characters."""
title = MovieTitle("Se7en")
assert title.value == "Se7en"
def test_title_with_colon(self):
"""Should accept title with colon."""
title = MovieTitle("Star Wars: A New Hope")
assert ":" in title.value
def test_title_with_unicode(self):
"""Should accept unicode title."""
title = MovieTitle("千と千尋の神隠し")
assert title.value == "千と千尋の神隠し"
def test_empty_title(self):
"""Should reject empty title."""
with pytest.raises((ValidationError, ValueError)):
MovieTitle("")
def test_whitespace_title(self):
"""Should handle whitespace title (may strip or reject)."""
try:
title = MovieTitle(" ")
# If accepted after stripping, that's valid
assert title.value is not None
except (ValidationError, ValueError):
# If rejected, that's also valid
pass
def test_very_long_title(self):
"""Should handle very long title."""
long_title = "A" * 1000
try:
title = MovieTitle(long_title)
assert len(title.value) == 1000
except (ValidationError, ValueError):
# If there's a length limit, that's valid
pass
class TestReleaseYearEdgeCases:
"""Edge case tests for ReleaseYear."""
def test_valid_year(self):
"""Should accept valid year."""
year = ReleaseYear(2024)
assert year.value == 2024
def test_old_movie_year(self):
"""Should accept old movie year."""
year = ReleaseYear(1895) # First movie ever
assert year.value == 1895
def test_future_year(self):
"""Should accept near future year."""
year = ReleaseYear(2030)
assert year.value == 2030
def test_very_old_year(self):
"""Should reject very old year."""
with pytest.raises((ValidationError, ValueError)):
ReleaseYear(1800)
def test_very_future_year(self):
"""Should reject very future year."""
with pytest.raises((ValidationError, ValueError)):
ReleaseYear(3000)
def test_negative_year(self):
"""Should reject negative year."""
with pytest.raises((ValidationError, ValueError)):
ReleaseYear(-2024)
def test_zero_year(self):
"""Should reject zero year."""
with pytest.raises((ValidationError, ValueError)):
ReleaseYear(0)
class TestQualityEdgeCases:
"""Edge case tests for Quality."""
def test_standard_qualities(self):
"""Should accept standard qualities."""
qualities = [
(Quality.SD, "480p"),
(Quality.HD, "720p"),
(Quality.FULL_HD, "1080p"),
(Quality.UHD_4K, "2160p"),
]
for quality_enum, expected_value in qualities:
assert quality_enum.value == expected_value
def test_unknown_quality(self):
"""Should accept unknown quality."""
quality = Quality.UNKNOWN
assert quality.value == "unknown"
def test_from_string_quality(self):
"""Should parse quality from string."""
assert Quality.from_string("1080p") == Quality.FULL_HD
assert Quality.from_string("720p") == Quality.HD
assert Quality.from_string("2160p") == Quality.UHD_4K
assert Quality.from_string("HDTV") == Quality.UNKNOWN
def test_empty_quality(self):
"""Should handle empty quality string."""
quality = Quality.from_string("")
assert quality == Quality.UNKNOWN
class TestShowStatusEdgeCases:
"""Edge case tests for ShowStatus."""
def test_all_statuses(self):
"""Should have all expected statuses."""
assert ShowStatus.ONGOING is not None
assert ShowStatus.ENDED is not None
assert ShowStatus.UNKNOWN is not None
def test_from_string_valid(self):
"""Should parse valid status strings."""
assert ShowStatus.from_string("ongoing") == ShowStatus.ONGOING
assert ShowStatus.from_string("ended") == ShowStatus.ENDED
def test_from_string_case_insensitive(self):
"""Should be case insensitive."""
assert ShowStatus.from_string("ONGOING") == ShowStatus.ONGOING
assert ShowStatus.from_string("Ended") == ShowStatus.ENDED
def test_from_string_unknown(self):
"""Should return UNKNOWN for invalid strings."""
assert ShowStatus.from_string("invalid") == ShowStatus.UNKNOWN
assert ShowStatus.from_string("") == ShowStatus.UNKNOWN
class TestLanguageEdgeCases:
"""Edge case tests for Language."""
def test_common_languages(self):
"""Should have common languages."""
assert Language.ENGLISH is not None
assert Language.FRENCH is not None
def test_from_code_valid(self):
"""Should parse valid language codes."""
assert Language.from_code("en") == Language.ENGLISH
assert Language.from_code("fr") == Language.FRENCH
def test_from_code_case_insensitive(self):
"""Should be case insensitive."""
assert Language.from_code("EN") == Language.ENGLISH
assert Language.from_code("Fr") == Language.FRENCH
def test_from_code_unknown(self):
"""Should handle unknown codes."""
# Behavior depends on implementation
try:
lang = Language.from_code("xx")
# If it returns something, that's valid
assert lang is not None
except (ValidationError, ValueError, KeyError):
# If it raises, that's also valid
pass
class TestSubtitleFormatEdgeCases:
"""Edge case tests for SubtitleFormat."""
def test_common_formats(self):
"""Should have common formats."""
assert SubtitleFormat.SRT is not None
assert SubtitleFormat.ASS is not None
def test_from_extension_with_dot(self):
"""Should handle extension with dot."""
fmt = SubtitleFormat.from_extension(".srt")
assert fmt == SubtitleFormat.SRT
def test_from_extension_without_dot(self):
"""Should handle extension without dot."""
fmt = SubtitleFormat.from_extension("srt")
assert fmt == SubtitleFormat.SRT
def test_from_extension_case_insensitive(self):
"""Should be case insensitive."""
assert SubtitleFormat.from_extension("SRT") == SubtitleFormat.SRT
assert SubtitleFormat.from_extension(".ASS") == SubtitleFormat.ASS
class TestTimingOffsetEdgeCases:
"""Edge case tests for TimingOffset."""
def test_zero_offset(self):
"""Should accept zero offset."""
offset = TimingOffset(0)
assert offset.milliseconds == 0
def test_positive_offset(self):
"""Should accept positive offset."""
offset = TimingOffset(5000)
assert offset.milliseconds == 5000
def test_negative_offset(self):
"""Should accept negative offset."""
offset = TimingOffset(-5000)
assert offset.milliseconds == -5000
def test_very_large_offset(self):
"""Should accept very large offset."""
offset = TimingOffset(3600000) # 1 hour
assert offset.milliseconds == 3600000
class TestMovieEntityEdgeCases:
"""Edge case tests for Movie entity."""
def test_minimal_movie(self):
"""Should create movie with minimal fields."""
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test"),
quality=Quality.UNKNOWN,
)
assert movie.imdb_id is not None
def test_full_movie(self):
"""Should create movie with all fields."""
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test Movie"),
release_year=ReleaseYear(2024),
quality=Quality.FULL_HD,
file_path=FilePath("/movies/test.mkv"),
file_size=FileSize(1000000000),
tmdb_id=12345,
added_at=datetime.now(),
)
assert movie.tmdb_id == 12345
def test_movie_without_optional_fields(self):
"""Should handle None optional fields."""
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test"),
release_year=None,
quality=Quality.UNKNOWN,
file_path=None,
file_size=None,
tmdb_id=None,
)
assert movie.release_year is None
assert movie.file_path is None
class TestTVShowEntityEdgeCases:
"""Edge case tests for TVShow entity."""
def test_minimal_show(self):
"""Should create show with minimal fields."""
show = TVShow(
imdb_id=ImdbId("tt1234567"),
title="Test Show",
seasons_count=1,
status=ShowStatus.UNKNOWN,
)
assert show.title == "Test Show"
def test_show_with_zero_seasons(self):
"""Should handle show with zero seasons."""
show = TVShow(
imdb_id=ImdbId("tt1234567"),
title="Upcoming Show",
seasons_count=0,
status=ShowStatus.ONGOING,
)
assert show.seasons_count == 0
def test_show_with_many_seasons(self):
"""Should handle show with many seasons."""
show = TVShow(
imdb_id=ImdbId("tt1234567"),
title="Long Running Show",
seasons_count=50,
status=ShowStatus.ONGOING,
)
assert show.seasons_count == 50
class TestSubtitleEntityEdgeCases:
"""Edge case tests for Subtitle entity."""
def test_minimal_subtitle(self):
"""Should create subtitle with minimal fields."""
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/test.srt"),
)
assert subtitle.language == Language.ENGLISH
def test_subtitle_for_episode(self):
"""Should create subtitle for specific episode."""
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e01.srt"),
season_number=1,
episode_number=1,
)
assert subtitle.season_number == 1
assert subtitle.episode_number == 1
def test_subtitle_with_all_metadata(self):
"""Should create subtitle with all metadata."""
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/test.srt"),
timing_offset=TimingOffset(500),
hearing_impaired=True,
forced=True,
source="OpenSubtitles",
uploader="user123",
download_count=10000,
rating=9.5,
)
assert subtitle.hearing_impaired is True
assert subtitle.forced is True
assert subtitle.rating == 9.5

View File

@@ -0,0 +1,2 @@
# DEPRECATED - Tests removed due to incorrect assumptions about LLM client initialization
# The LLM clients don't raise errors on missing config, they use defaults

241
tests/test_memory.py Normal file
View File

@@ -0,0 +1,241 @@
"""Tests for the Memory system."""
from datetime import datetime
import pytest
from infrastructure.persistence import (
EpisodicMemory,
LongTermMemory,
Memory,
ShortTermMemory,
get_memory,
has_memory,
init_memory,
)
from infrastructure.persistence.context import _memory_ctx
def is_iso_format(s: str) -> bool:
"""Helper to check if a string is a valid ISO 8601 timestamp."""
if not isinstance(s, str):
return False
try:
# Attempt to parse the string as an ISO 8601 timestamp
datetime.fromisoformat(s.replace("Z", "+00:00"))
return True
except (ValueError, TypeError):
return False
class TestLongTermMemory:
"""Tests for LongTermMemory."""
def test_default_values(self):
ltm = LongTermMemory()
assert ltm.config == {}
assert ltm.preferences["preferred_quality"] == "1080p"
assert "en" in ltm.preferences["preferred_languages"]
assert ltm.library == {"movies": [], "tv_shows": []}
assert ltm.following == []
def test_set_and_get_config(self):
ltm = LongTermMemory()
ltm.set_config("download_folder", "/path/to/downloads")
assert ltm.get_config("download_folder") == "/path/to/downloads"
def test_get_config_default(self):
ltm = LongTermMemory()
assert ltm.get_config("nonexistent") is None
assert ltm.get_config("nonexistent", "default") == "default"
def test_has_config(self):
ltm = LongTermMemory()
assert not ltm.has_config("download_folder")
ltm.set_config("download_folder", "/path")
assert ltm.has_config("download_folder")
def test_has_config_none_value(self):
ltm = LongTermMemory()
ltm.config["key"] = None
assert not ltm.has_config("key")
def test_add_to_library(self):
ltm = LongTermMemory()
movie = {"imdb_id": "tt1375666", "title": "Inception"}
ltm.add_to_library("movies", movie)
assert len(ltm.library["movies"]) == 1
assert ltm.library["movies"][0]["title"] == "Inception"
assert is_iso_format(ltm.library["movies"][0].get("added_at"))
def test_add_to_library_no_duplicates(self):
ltm = LongTermMemory()
movie = {"imdb_id": "tt1375666", "title": "Inception"}
ltm.add_to_library("movies", movie)
ltm.add_to_library("movies", movie)
assert len(ltm.library["movies"]) == 1
def test_add_to_library_new_type(self):
ltm = LongTermMemory()
subtitle = {"imdb_id": "tt1375666", "language": "en"}
ltm.add_to_library("subtitles", subtitle)
assert "subtitles" in ltm.library
assert len(ltm.library["subtitles"]) == 1
def test_get_library(self):
ltm = LongTermMemory()
ltm.add_to_library("movies", {"imdb_id": "tt1", "title": "Movie 1"})
ltm.add_to_library("movies", {"imdb_id": "tt2", "title": "Movie 2"})
movies = ltm.get_library("movies")
assert len(movies) == 2
def test_get_library_empty(self):
ltm = LongTermMemory()
assert ltm.get_library("unknown") == []
def test_follow_show(self):
ltm = LongTermMemory()
show = {"imdb_id": "tt0944947", "title": "Game of Thrones"}
ltm.follow_show(show)
assert len(ltm.following) == 1
assert ltm.following[0]["title"] == "Game of Thrones"
assert is_iso_format(ltm.following[0].get("followed_at"))
def test_follow_show_no_duplicates(self):
ltm = LongTermMemory()
show = {"imdb_id": "tt0944947", "title": "Game of Thrones"}
ltm.follow_show(show)
ltm.follow_show(show)
assert len(ltm.following) == 1
def test_to_dict(self):
ltm = LongTermMemory()
ltm.set_config("key", "value")
data = ltm.to_dict()
assert "config" in data
assert data["config"]["key"] == "value"
def test_from_dict(self):
data = {
"config": {"download_folder": "/downloads"},
"preferences": {"preferred_quality": "4K"},
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
"following": [],
}
ltm = LongTermMemory.from_dict(data)
assert ltm.get_config("download_folder") == "/downloads"
assert ltm.preferences["preferred_quality"] == "4K"
assert len(ltm.library["movies"]) == 1
class TestShortTermMemory:
"""Tests for ShortTermMemory."""
def test_default_values(self):
stm = ShortTermMemory()
assert stm.conversation_history == []
assert stm.current_workflow is None
assert stm.extracted_entities == {}
assert stm.current_topic is None
assert stm.language == "en"
def test_add_message(self):
stm = ShortTermMemory()
stm.add_message("user", "Hello")
assert len(stm.conversation_history) == 1
assert is_iso_format(stm.conversation_history[0].get("timestamp"))
def test_add_message_max_history(self):
stm = ShortTermMemory(max_history=5)
for i in range(10):
stm.add_message("user", f"Message {i}")
assert len(stm.conversation_history) == 5
assert stm.conversation_history[0]["content"] == "Message 5"
def test_language_management(self):
stm = ShortTermMemory()
assert stm.language == "en"
stm.set_language("fr")
assert stm.language == "fr"
stm.clear()
assert stm.language == "en"
def test_clear(self):
stm = ShortTermMemory()
stm.add_message("user", "Hello")
stm.set_language("fr")
stm.clear()
assert stm.conversation_history == []
assert stm.language == "en"
class TestEpisodicMemory:
"""Tests for EpisodicMemory."""
def test_add_error(self):
episodic = EpisodicMemory()
episodic.add_error("find_torrent", "API timeout")
assert len(episodic.recent_errors) == 1
assert is_iso_format(episodic.recent_errors[0].get("timestamp"))
def test_add_error_max_limit(self):
episodic = EpisodicMemory(max_errors=3)
for i in range(5):
episodic.add_error("action", f"Error {i}")
assert len(episodic.recent_errors) == 3
error_messages = [e["error"] for e in episodic.recent_errors]
assert error_messages == ["Error 2", "Error 3", "Error 4"]
def test_store_search_results(self):
episodic = EpisodicMemory()
episodic.store_search_results("test query", [])
assert is_iso_format(episodic.last_search_results.get("timestamp"))
def test_get_result_by_index(self):
episodic = EpisodicMemory()
results = [{"name": "Result 1"}, {"name": "Result 2"}]
episodic.store_search_results("query", results)
result = episodic.get_result_by_index(2)
assert result is not None
assert result["name"] == "Result 2"
class TestMemory:
"""Tests for the Memory manager."""
def test_init_creates_directories(self, temp_dir):
storage = temp_dir / "memory_data"
Memory(storage_dir=str(storage))
assert storage.exists()
def test_save_and_load_ltm(self, temp_dir):
storage = str(temp_dir)
memory = Memory(storage_dir=storage)
memory.ltm.set_config("test_key", "test_value")
memory.save()
new_memory = Memory(storage_dir=storage)
assert new_memory.ltm.get_config("test_key") == "test_value"
def test_clear_session(self, memory):
memory.ltm.set_config("key", "value")
memory.stm.add_message("user", "Hello")
memory.episodic.add_error("action", "error")
memory.clear_session()
assert memory.ltm.get_config("key") == "value"
assert memory.stm.conversation_history == []
assert memory.episodic.recent_errors == []
class TestMemoryContext:
"""Tests for memory context functions."""
def test_get_memory_not_initialized(self):
_memory_ctx.set(None)
with pytest.raises(RuntimeError, match="Memory not initialized"):
get_memory()
def test_init_memory(self, temp_dir):
_memory_ctx.set(None)
memory = init_memory(str(temp_dir))
assert has_memory()
assert get_memory() is memory

View File

@@ -0,0 +1,543 @@
"""Edge case tests for the Memory system."""
import json
import os
import pytest
from infrastructure.persistence import (
EpisodicMemory,
LongTermMemory,
Memory,
ShortTermMemory,
get_memory,
init_memory,
set_memory,
)
from infrastructure.persistence.context import _memory_ctx
class TestLongTermMemoryEdgeCases:
"""Edge case tests for LongTermMemory."""
def test_config_with_none_value(self):
"""Should handle None values in config."""
ltm = LongTermMemory()
ltm.set_config("key", None)
assert ltm.get_config("key") is None
assert not ltm.has_config("key")
def test_config_with_empty_string(self):
"""Should handle empty string values."""
ltm = LongTermMemory()
ltm.set_config("key", "")
assert ltm.get_config("key") == ""
assert ltm.has_config("key") # Empty string is still a value
def test_config_with_complex_types(self):
"""Should handle complex types in config."""
ltm = LongTermMemory()
ltm.set_config("list", [1, 2, 3])
ltm.set_config("dict", {"nested": {"deep": "value"}})
ltm.set_config("bool", False)
ltm.set_config("int", 0)
assert ltm.get_config("list") == [1, 2, 3]
assert ltm.get_config("dict")["nested"]["deep"] == "value"
assert ltm.get_config("bool") is False
assert ltm.get_config("int") == 0
def test_library_with_missing_imdb_id(self):
"""Should handle media without imdb_id."""
ltm = LongTermMemory()
media = {"title": "No ID Movie"}
ltm.add_to_library("movies", media)
# Should still add (imdb_id will be None)
assert len(ltm.library["movies"]) == 1
def test_library_duplicate_check_with_none_id(self):
"""Should handle duplicate check when imdb_id is None."""
ltm = LongTermMemory()
media1 = {"title": "Movie 1"}
media2 = {"title": "Movie 2"}
ltm.add_to_library("movies", media1)
ltm.add_to_library("movies", media2)
# May dedupe or not depending on implementation
assert len(ltm.library["movies"]) >= 1
def test_from_dict_with_extra_keys(self):
"""Should ignore extra keys in dict."""
data = {
"config": {},
"preferences": {},
"library": {"movies": []},
"following": [],
"extra_key": "should be ignored",
"another_extra": [1, 2, 3],
}
ltm = LongTermMemory.from_dict(data)
assert not hasattr(ltm, "extra_key")
def test_from_dict_with_wrong_types(self):
"""Should handle wrong types gracefully."""
data = {
"config": "not a dict", # Should be dict
"preferences": [], # Should be dict
"library": "wrong", # Should be dict
"following": {}, # Should be list
}
# Should not crash, but behavior may vary
try:
ltm = LongTermMemory.from_dict(data)
# If it doesn't crash, check it has some defaults
assert ltm is not None
except (TypeError, AttributeError):
# This is also acceptable behavior
pass
def test_to_dict_preserves_unicode(self):
"""Should preserve unicode in serialization."""
ltm = LongTermMemory()
ltm.set_config("japanese", "日本語")
ltm.set_config("emoji", "🎬🎥")
ltm.add_to_library("movies", {"title": "Amélie", "imdb_id": "tt1"})
data = ltm.to_dict()
assert data["config"]["japanese"] == "日本語"
assert data["config"]["emoji"] == "🎬🎥"
assert data["library"]["movies"][0]["title"] == "Amélie"
class TestShortTermMemoryEdgeCases:
"""Edge case tests for ShortTermMemory."""
def test_add_message_with_empty_content(self):
"""Should handle empty message content."""
stm = ShortTermMemory()
stm.add_message("user", "")
assert len(stm.conversation_history) == 1
assert stm.conversation_history[0]["content"] == ""
def test_add_message_with_very_long_content(self):
"""Should handle very long messages."""
stm = ShortTermMemory()
long_content = "x" * 100000
stm.add_message("user", long_content)
assert len(stm.conversation_history[0]["content"]) == 100000
def test_add_message_with_special_characters(self):
"""Should handle special characters."""
stm = ShortTermMemory()
special = "Line1\nLine2\tTab\r\nWindows\x00Null"
stm.add_message("user", special)
assert stm.conversation_history[0]["content"] == special
def test_max_history_zero(self):
"""Should handle max_history of 0."""
stm = ShortTermMemory()
stm.max_history = 0
stm.add_message("user", "Hello")
# Behavior: either empty or keeps last message
assert len(stm.conversation_history) <= 1
def test_max_history_one(self):
"""Should handle max_history of 1."""
stm = ShortTermMemory()
stm.max_history = 1
stm.add_message("user", "First")
stm.add_message("user", "Second")
assert len(stm.conversation_history) == 1
assert stm.conversation_history[0]["content"] == "Second"
def test_get_recent_history_zero(self):
"""Should handle n=0."""
stm = ShortTermMemory()
stm.add_message("user", "Hello")
recent = stm.get_recent_history(0)
# May return empty or all messages depending on implementation
assert isinstance(recent, list)
def test_get_recent_history_negative(self):
"""Should handle negative n."""
stm = ShortTermMemory()
stm.add_message("user", "Hello")
recent = stm.get_recent_history(-1)
# Python slicing with negative returns empty or last element
assert isinstance(recent, list)
def test_workflow_with_empty_target(self):
"""Should handle empty workflow target."""
stm = ShortTermMemory()
stm.start_workflow("download", {})
assert stm.current_workflow["target"] == {}
def test_workflow_with_none_target(self):
"""Should handle None workflow target."""
stm = ShortTermMemory()
stm.start_workflow("download", None)
assert stm.current_workflow["target"] is None
def test_entity_with_none_value(self):
"""Should store None as entity value."""
stm = ShortTermMemory()
stm.set_entity("key", None)
assert stm.get_entity("key") is None
assert "key" in stm.extracted_entities
def test_entity_overwrite(self):
"""Should overwrite existing entity."""
stm = ShortTermMemory()
stm.set_entity("key", "value1")
stm.set_entity("key", "value2")
assert stm.get_entity("key") == "value2"
def test_topic_with_empty_string(self):
"""Should handle empty topic."""
stm = ShortTermMemory()
stm.set_topic("")
assert stm.current_topic == ""
class TestEpisodicMemoryEdgeCases:
"""Edge case tests for EpisodicMemory."""
def test_store_empty_results(self):
"""Should handle empty results list."""
episodic = EpisodicMemory()
episodic.store_search_results("query", [])
assert episodic.last_search_results is not None
assert episodic.last_search_results["results"] == []
def test_store_results_with_none_values(self):
"""Should handle results with None values."""
episodic = EpisodicMemory()
results = [
{"name": None, "seeders": None},
{"name": "Valid", "seeders": 100},
]
episodic.store_search_results("query", results)
assert len(episodic.last_search_results["results"]) == 2
def test_get_result_by_index_after_clear(self):
"""Should return None after clearing results."""
episodic = EpisodicMemory()
episodic.store_search_results("query", [{"name": "Test"}])
episodic.clear_search_results()
result = episodic.get_result_by_index(1)
assert result is None
def test_get_result_by_very_large_index(self):
"""Should handle very large index."""
episodic = EpisodicMemory()
episodic.store_search_results("query", [{"name": "Test"}])
result = episodic.get_result_by_index(999999999)
assert result is None
def test_download_with_missing_fields(self):
"""Should handle download with missing fields."""
episodic = EpisodicMemory()
episodic.add_active_download({}) # Empty dict
assert len(episodic.active_downloads) == 1
assert "started_at" in episodic.active_downloads[0]
def test_update_nonexistent_download(self):
"""Should not crash when updating nonexistent download."""
episodic = EpisodicMemory()
# Should not raise
episodic.update_download_progress("nonexistent", 50)
assert episodic.active_downloads == []
def test_complete_nonexistent_download(self):
"""Should return None for nonexistent download."""
episodic = EpisodicMemory()
result = episodic.complete_download("nonexistent", "/path")
assert result is None
def test_error_with_empty_context(self):
"""Should handle error with None context."""
episodic = EpisodicMemory()
episodic.add_error("action", "error", None)
assert episodic.recent_errors[0]["context"] == {}
def test_error_with_very_long_message(self):
"""Should handle very long error messages."""
episodic = EpisodicMemory()
long_error = "x" * 10000
episodic.add_error("action", long_error)
assert len(episodic.recent_errors[0]["error"]) == 10000
def test_pending_question_with_empty_options(self):
"""Should handle question with no options."""
episodic = EpisodicMemory()
episodic.set_pending_question("Question?", [], {})
assert episodic.pending_question["options"] == []
def test_resolve_question_invalid_index(self):
"""Should return None for invalid answer index."""
episodic = EpisodicMemory()
episodic.set_pending_question(
"Question?",
[{"index": 1, "label": "Option"}],
{},
)
result = episodic.resolve_pending_question(999)
assert result is None
assert episodic.pending_question is None # Still cleared
def test_resolve_question_when_none(self):
"""Should handle resolving when no question pending."""
episodic = EpisodicMemory()
result = episodic.resolve_pending_question(1)
assert result is None
def test_background_event_with_empty_data(self):
"""Should handle event with empty data."""
episodic = EpisodicMemory()
episodic.add_background_event("event", {})
assert episodic.background_events[0]["data"] == {}
def test_get_unread_events_multiple_calls(self):
"""Should return empty on second call."""
episodic = EpisodicMemory()
episodic.add_background_event("event", {})
first = episodic.get_unread_events()
second = episodic.get_unread_events()
assert len(first) == 1
assert len(second) == 0
def test_max_errors_boundary(self):
"""Should keep exactly max_errors."""
episodic = EpisodicMemory()
episodic.max_errors = 3
for i in range(3):
episodic.add_error("action", f"Error {i}")
assert len(episodic.recent_errors) == 3
episodic.add_error("action", "Error 3")
assert len(episodic.recent_errors) == 3
assert episodic.recent_errors[0]["error"] == "Error 1"
def test_max_events_boundary(self):
"""Should keep exactly max_events."""
episodic = EpisodicMemory()
episodic.max_events = 3
for i in range(5):
episodic.add_background_event("event", {"i": i})
assert len(episodic.background_events) == 3
assert episodic.background_events[0]["data"]["i"] == 2
class TestMemoryEdgeCases:
"""Edge case tests for Memory manager."""
def test_init_with_nonexistent_directory(self, temp_dir):
"""Should create directory if not exists."""
new_dir = temp_dir / "new" / "nested" / "dir"
# Don't create the directory - let Memory do it
Memory(storage_dir=str(new_dir))
assert new_dir.exists()
def test_init_with_readonly_directory(self, temp_dir):
"""Should handle readonly directory gracefully."""
readonly_dir = temp_dir / "readonly"
readonly_dir.mkdir()
# Make readonly (may not work on all systems)
try:
os.chmod(readonly_dir, 0o444)
# This might raise or might work depending on OS
Memory(storage_dir=str(readonly_dir))
except (PermissionError, OSError):
pass # Expected on some systems
finally:
os.chmod(readonly_dir, 0o755)
def test_load_ltm_with_empty_file(self, temp_dir):
"""Should handle empty LTM file."""
ltm_file = temp_dir / "ltm.json"
ltm_file.write_text("")
memory = Memory(storage_dir=str(temp_dir))
# Should use defaults
assert memory.ltm.config == {}
def test_load_ltm_with_partial_data(self, temp_dir):
"""Should handle partial LTM data."""
ltm_file = temp_dir / "ltm.json"
ltm_file.write_text('{"config": {"key": "value"}}')
memory = Memory(storage_dir=str(temp_dir))
assert memory.ltm.get_config("key") == "value"
# Other fields should have defaults
assert memory.ltm.library == {"movies": [], "tv_shows": []}
def test_save_with_unicode(self, temp_dir):
"""Should save unicode correctly."""
memory = Memory(storage_dir=str(temp_dir))
memory.ltm.set_config("japanese", "日本語テスト")
memory.save()
# Read back and verify
ltm_file = temp_dir / "ltm.json"
data = json.loads(ltm_file.read_text(encoding="utf-8"))
assert data["config"]["japanese"] == "日本語テスト"
def test_save_preserves_formatting(self, temp_dir):
"""Should save with readable formatting."""
memory = Memory(storage_dir=str(temp_dir))
memory.ltm.set_config("key", "value")
memory.save()
ltm_file = temp_dir / "ltm.json"
content = ltm_file.read_text()
# Should be indented (pretty printed)
assert "\n" in content
def test_concurrent_access_simulation(self, temp_dir):
"""Should handle rapid save/load cycles."""
memory = Memory(storage_dir=str(temp_dir))
for i in range(100):
memory.ltm.set_config(f"key_{i}", f"value_{i}")
memory.save()
# Reload and verify
memory2 = Memory(storage_dir=str(temp_dir))
assert memory2.ltm.get_config("key_99") == "value_99"
def test_clear_session_preserves_ltm(self, temp_dir):
"""Should preserve LTM after clear_session."""
memory = Memory(storage_dir=str(temp_dir))
memory.ltm.set_config("important", "data")
memory.stm.add_message("user", "Hello")
memory.episodic.store_search_results("query", [{}])
memory.clear_session()
assert memory.ltm.get_config("important") == "data"
assert memory.stm.conversation_history == []
assert memory.episodic.last_search_results is None
def test_get_context_for_prompt_empty(self, temp_dir):
"""Should handle empty memory state."""
memory = Memory(storage_dir=str(temp_dir))
context = memory.get_context_for_prompt()
assert context["config"] == {}
assert context["last_search"]["query"] is None
assert context["last_search"]["result_count"] == 0
def test_get_full_state_serializable(self, temp_dir):
"""Should return JSON-serializable state."""
memory = Memory(storage_dir=str(temp_dir))
memory.ltm.set_config("key", "value")
memory.stm.add_message("user", "Hello")
memory.episodic.store_search_results("query", [{"name": "Test"}])
state = memory.get_full_state()
# Should be JSON serializable
json_str = json.dumps(state)
assert json_str is not None
class TestMemoryContextEdgeCases:
"""Edge case tests for memory context."""
def test_multiple_init_calls(self, temp_dir):
"""Should handle multiple init calls."""
_memory_ctx.set(None)
init_memory(str(temp_dir))
mem2 = init_memory(str(temp_dir))
# Second call should replace first
assert get_memory() is mem2
def test_set_memory_with_none(self):
"""Should handle setting None."""
_memory_ctx.set(None)
set_memory(None)
with pytest.raises(RuntimeError):
get_memory()
def test_context_isolation(self, temp_dir):
"""Context should be isolated per context."""
from contextvars import copy_context
_memory_ctx.set(None)
mem1 = init_memory(str(temp_dir))
# Create a copy of context
ctx = copy_context()
# In the copy, memory should still be set
def check_memory():
return get_memory()
result = ctx.run(check_memory)
assert result is mem1

297
tests/test_prompts.py Normal file
View File

@@ -0,0 +1,297 @@
"""Tests for PromptBuilder."""
from agent.prompts import PromptBuilder
from agent.registry import make_tools
class TestPromptBuilder:
"""Tests for PromptBuilder."""
def test_init(self, memory):
"""Should initialize with tools."""
tools = make_tools()
builder = PromptBuilder(tools)
assert builder.tools is tools
def test_build_system_prompt(self, memory):
"""Should build a complete system prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "AI assistant" in prompt
assert "media library" in prompt
assert "AVAILABLE TOOLS" in prompt
def test_includes_tools(self, memory):
"""Should include all tool descriptions."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
for tool_name in tools.keys():
assert tool_name in prompt
def test_includes_config(self, memory):
"""Should include current configuration."""
memory.ltm.set_config("download_folder", "/path/to/downloads")
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "/path/to/downloads" in prompt
def test_includes_search_results(self, memory_with_search_results):
"""Should include search results summary."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "LAST SEARCH" in prompt
assert "Inception 1080p" in prompt
assert "3 results" in prompt or "results available" in prompt
def test_includes_search_result_names(self, memory_with_search_results):
"""Should include search result names."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "Inception.2010.1080p.BluRay.x264" in prompt
def test_includes_active_downloads(self, memory):
"""Should include active downloads."""
memory.episodic.add_active_download(
{
"task_id": "123",
"name": "Test.Movie.mkv",
"progress": 50,
}
)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "ACTIVE DOWNLOADS" in prompt
assert "Test.Movie.mkv" in prompt
def test_includes_pending_question(self, memory):
"""Should include pending question."""
memory.episodic.set_pending_question(
"Which torrent?",
[{"index": 1, "label": "Option 1"}, {"index": 2, "label": "Option 2"}],
{},
)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "PENDING QUESTION" in prompt
assert "Which torrent?" in prompt
def test_includes_last_error(self, memory):
"""Should include last error."""
memory.episodic.add_error("find_torrent", "API timeout")
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "RECENT ERRORS" in prompt
assert "API timeout" in prompt
def test_includes_workflow(self, memory):
"""Should include current workflow."""
memory.stm.start_workflow("download", {"title": "Inception"})
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "CURRENT WORKFLOW" in prompt
assert "download" in prompt
def test_includes_topic(self, memory):
"""Should include current topic."""
memory.stm.set_topic("selecting_torrent")
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "CURRENT TOPIC" in prompt
assert "selecting_torrent" in prompt
def test_includes_entities(self, memory):
"""Should include extracted entities."""
memory.stm.set_entity("movie_title", "Inception")
memory.stm.set_entity("year", 2010)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "EXTRACTED ENTITIES" in prompt
assert "Inception" in prompt
def test_includes_rules(self, memory):
"""Should include important rules."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "IMPORTANT RULES" in prompt
assert "add_torrent_by_index" in prompt
def test_includes_examples(self, memory):
"""Should include usage examples."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "EXAMPLES" in prompt
assert "download the 3rd one" in prompt or "torrent number" in prompt
def test_empty_context(self, memory):
"""Should handle empty context gracefully."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Should not crash and should have basic structure
assert "AVAILABLE TOOLS" in prompt
assert "CURRENT CONFIGURATION" in prompt
def test_limits_search_results_display(self, memory):
"""Should limit displayed search results."""
# Add many results
results = [{"name": f"Torrent {i}", "seeders": i} for i in range(20)]
memory.episodic.store_search_results("test", results)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Should show first 5 and indicate more
assert "Torrent 0" in prompt or "1." in prompt
assert "... and" in prompt or "more" in prompt
# REMOVED: test_json_format_in_prompt
# We removed the "action" format from prompts as it was confusing the LLM
# The LLM now uses native OpenAI tool calling format
class TestFormatToolsDescription:
"""Tests for _format_tools_description method."""
def test_format_all_tools(self, memory):
"""Should format all tools."""
tools = make_tools()
builder = PromptBuilder(tools)
desc = builder._format_tools_description()
for tool in tools.values():
assert tool.name in desc
assert tool.description in desc
def test_includes_parameters(self, memory):
"""Should include parameter schemas."""
tools = make_tools()
builder = PromptBuilder(tools)
desc = builder._format_tools_description()
assert "Parameters:" in desc
assert '"type"' in desc
class TestFormatEpisodicContext:
"""Tests for _format_episodic_context method."""
def test_empty_episodic(self, memory):
"""Should return empty string for empty episodic."""
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_episodic_context(memory)
assert context == ""
def test_with_search_results(self, memory_with_search_results):
"""Should format search results."""
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_episodic_context(memory_with_search_results)
assert "LAST SEARCH" in context
assert "Inception 1080p" in context
def test_with_multiple_sections(self, memory):
"""Should format multiple sections."""
memory.episodic.store_search_results("test", [{"name": "Result"}])
memory.episodic.add_active_download({"task_id": "1", "name": "Download"})
memory.episodic.add_error("action", "error")
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_episodic_context(memory)
assert "LAST SEARCH" in context
assert "ACTIVE DOWNLOADS" in context
assert "RECENT ERRORS" in context
class TestFormatStmContext:
"""Tests for _format_stm_context method."""
def test_empty_stm(self, memory):
"""Should return language info even for empty STM."""
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_stm_context(memory)
# Should at least show language
assert "CONVERSATION LANGUAGE" in context or context == ""
def test_with_workflow(self, memory):
"""Should format workflow."""
memory.stm.start_workflow("download", {"title": "Test"})
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_stm_context(memory)
assert "CURRENT WORKFLOW" in context
assert "download" in context
def test_with_all_sections(self, memory):
"""Should format all STM sections."""
memory.stm.start_workflow("download", {"title": "Test"})
memory.stm.set_topic("searching")
memory.stm.set_entity("key", "value")
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_stm_context(memory)
assert "CURRENT WORKFLOW" in context
assert "CURRENT TOPIC" in context
assert "EXTRACTED ENTITIES" in context

View File

@@ -0,0 +1,281 @@
"""Critical tests for prompt builder - Tests that would have caught bugs."""
from agent.prompts import PromptBuilder
from agent.registry import make_tools
class TestPromptBuilderToolsInjection:
"""Critical tests for tools injection in prompts."""
def test_system_prompt_includes_all_tools(self, memory):
"""CRITICAL: Verify all tools are mentioned in system prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Verify each tool is mentioned
for tool_name in tools.keys():
assert (
tool_name in prompt
), f"Tool {tool_name} not mentioned in system prompt"
def test_tools_spec_contains_all_registered_tools(self, memory):
"""CRITICAL: Verify build_tools_spec() returns all tools."""
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys())
assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}"
def test_tools_spec_is_not_empty(self, memory):
"""CRITICAL: Verify tools spec is never empty."""
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
assert len(specs) > 0, "Tools spec is empty!"
def test_tools_spec_format_matches_openai(self, memory):
"""CRITICAL: Verify tools spec format is OpenAI-compatible."""
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
for spec in specs:
assert "type" in spec
assert spec["type"] == "function"
assert "function" in spec
assert "name" in spec["function"]
assert "description" in spec["function"]
assert "parameters" in spec["function"]
class TestPromptBuilderMemoryContext:
"""Tests for memory context injection in prompts."""
def test_prompt_includes_current_topic(self, memory):
"""Verify current topic is included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_topic("test_topic")
prompt = builder.build_system_prompt()
assert "test_topic" in prompt
def test_prompt_includes_extracted_entities(self, memory):
"""Verify extracted entities are included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_entity("test_key", "test_value")
prompt = builder.build_system_prompt()
assert "test_key" in prompt
def test_prompt_includes_search_results(self, memory_with_search_results):
"""Verify search results are included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "Inception" in prompt
assert "LAST SEARCH" in prompt
def test_prompt_includes_active_downloads(self, memory):
"""Verify active downloads are included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.episodic.add_active_download(
{"task_id": "123", "name": "Test Movie", "progress": 50}
)
prompt = builder.build_system_prompt()
assert "ACTIVE DOWNLOADS" in prompt
assert "Test Movie" in prompt
def test_prompt_includes_recent_errors(self, memory):
"""Verify recent errors are included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.episodic.add_error("test_action", "test error message")
prompt = builder.build_system_prompt()
assert "RECENT ERRORS" in prompt or "error" in prompt.lower()
def test_prompt_includes_configuration(self, memory):
"""Verify configuration is included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.ltm.set_config("download_folder", "/test/downloads")
prompt = builder.build_system_prompt()
assert "CONFIGURATION" in prompt or "download_folder" in prompt
def test_prompt_includes_language(self, memory):
"""Verify language is included in prompt."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_language("fr")
prompt = builder.build_system_prompt()
assert "fr" in prompt or "LANGUAGE" in prompt
class TestPromptBuilderStructure:
"""Tests for prompt structure and completeness."""
def test_system_prompt_is_not_empty(self, memory):
"""Verify system prompt is never empty."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert len(prompt) > 0
assert prompt.strip() != ""
def test_system_prompt_includes_base_instruction(self, memory):
"""Verify system prompt includes base instruction."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "assistant" in prompt.lower() or "help" in prompt.lower()
def test_system_prompt_includes_rules(self, memory):
"""Verify system prompt includes important rules."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "RULES" in prompt or "IMPORTANT" in prompt
def test_system_prompt_includes_examples(self, memory):
"""Verify system prompt includes examples."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "EXAMPLES" in prompt or "example" in prompt.lower()
def test_tools_description_format(self, memory):
"""Verify tools are properly formatted in description."""
tools = make_tools()
builder = PromptBuilder(tools)
description = builder._format_tools_description()
# Should have tool names and descriptions
for tool_name, _tool in tools.items():
assert tool_name in description
# Should have parameters info
assert "Parameters" in description or "parameters" in description
def test_episodic_context_format(self, memory_with_search_results):
"""Verify episodic context is properly formatted."""
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_episodic_context(memory_with_search_results)
assert "LAST SEARCH" in context
assert "Inception" in context
def test_stm_context_format(self, memory):
"""Verify STM context is properly formatted."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_topic("test_topic")
memory.stm.set_entity("key", "value")
context = builder._format_stm_context(memory)
assert "TOPIC" in context or "test_topic" in context
assert "ENTITIES" in context or "key" in context
def test_config_context_format(self, memory):
"""Verify config context is properly formatted."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.ltm.set_config("test_key", "test_value")
context = builder._format_config_context(memory)
assert "CONFIGURATION" in context
assert "test_key" in context
class TestPromptBuilderEdgeCases:
"""Tests for edge cases in prompt building."""
def test_prompt_with_no_memory_context(self, memory):
"""Verify prompt works with empty memory."""
tools = make_tools()
builder = PromptBuilder(tools)
# Memory is empty
prompt = builder.build_system_prompt()
# Should still have base content
assert len(prompt) > 0
assert "assistant" in prompt.lower()
def test_prompt_with_empty_tools(self):
"""Verify prompt handles empty tools dict."""
builder = PromptBuilder({})
prompt = builder.build_system_prompt()
# Should still generate a prompt
assert len(prompt) > 0
def test_tools_spec_with_empty_tools(self):
"""Verify tools spec handles empty tools dict."""
builder = PromptBuilder({})
specs = builder.build_tools_spec()
assert isinstance(specs, list)
assert len(specs) == 0
def test_prompt_with_unicode_in_memory(self, memory):
"""Verify prompt handles unicode in memory."""
tools = make_tools()
builder = PromptBuilder(tools)
memory.stm.set_entity("movie", "Amélie 🎬")
prompt = builder.build_system_prompt()
assert "Amélie" in prompt
assert "🎬" in prompt
def test_prompt_with_long_search_results(self, memory):
"""Verify prompt handles many search results."""
tools = make_tools()
builder = PromptBuilder(tools)
# Add many results
results = [{"name": f"Movie {i}", "seeders": i} for i in range(20)]
memory.episodic.store_search_results("test", results, "torrent")
prompt = builder.build_system_prompt()
# Should include some results but not all (to avoid huge prompts)
assert "Movie 0" in prompt or "Movie 1" in prompt
# Should indicate there are more
assert "more" in prompt.lower() or "..." in prompt

View File

@@ -0,0 +1,400 @@
"""Edge case tests for PromptBuilder."""
from agent.prompts import PromptBuilder
from agent.registry import make_tools
class TestPromptBuilderEdgeCases:
"""Edge case tests for PromptBuilder."""
def test_prompt_with_empty_memory(self, memory):
"""Should build prompt with completely empty memory."""
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "AVAILABLE TOOLS" in prompt
assert "CURRENT CONFIGURATION" in prompt
def test_prompt_with_unicode_config(self, memory):
"""Should handle unicode in config."""
memory.ltm.set_config("folder_日本語", "/path/to/日本語")
memory.ltm.set_config("emoji_folder", "/path/🎬")
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "日本語" in prompt
assert "🎬" in prompt
def test_prompt_with_very_long_config_value(self, memory):
"""Should handle very long config values."""
long_path = "/very/long/path/" + "x" * 1000
memory.ltm.set_config("download_folder", long_path)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Should include the path (possibly truncated)
assert "very/long/path" in prompt
def test_prompt_with_special_chars_in_config(self, memory):
"""Should escape special characters in config."""
memory.ltm.set_config("path", '/path/with "quotes" and \\backslash')
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Should be valid (not crash)
assert "CURRENT CONFIGURATION" in prompt
def test_prompt_with_many_search_results(self, memory):
"""Should limit displayed search results."""
results = [{"name": f"Torrent {i}", "seeders": i} for i in range(50)]
memory.episodic.store_search_results("test query", results)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Should show limited results
assert "LAST SEARCH" in prompt
# Should indicate there are more
assert "more" in prompt.lower() or "..." in prompt
def test_prompt_with_search_results_missing_fields(self, memory):
"""Should handle search results with missing fields."""
results = [
{"name": "Complete"},
{}, # Empty
{"seeders": 100}, # Missing name
]
memory.episodic.store_search_results("test", results)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Should not crash
assert "LAST SEARCH" in prompt
def test_prompt_with_many_active_downloads(self, memory):
"""Should limit displayed active downloads."""
for i in range(20):
memory.episodic.add_active_download(
{
"task_id": str(i),
"name": f"Download {i}",
"progress": i * 5,
}
)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "ACTIVE DOWNLOADS" in prompt
# Should show limited number
assert "Download 0" in prompt
def test_prompt_with_many_errors(self, memory):
"""Should show recent errors."""
for i in range(10):
memory.episodic.add_error(f"action_{i}", f"Error {i}")
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "RECENT ERRORS" in prompt
# Should show the most recent errors (up to 3)
def test_prompt_with_pending_question_many_options(self, memory):
"""Should handle pending question with many options."""
options = [{"index": i, "label": f"Option {i}"} for i in range(20)]
memory.episodic.set_pending_question("Choose one:", options, {})
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "PENDING QUESTION" in prompt
assert "Choose one:" in prompt
def test_prompt_with_complex_workflow(self, memory):
"""Should handle complex workflow state."""
memory.stm.start_workflow(
"download",
{
"title": "Test Movie",
"year": 2024,
"quality": "1080p",
"nested": {"deep": {"value": "test"}},
},
)
memory.stm.update_workflow_stage("searching_torrents")
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "CURRENT WORKFLOW" in prompt
assert "download" in prompt
assert "searching_torrents" in prompt
def test_prompt_with_many_entities(self, memory):
"""Should handle many extracted entities."""
for i in range(50):
memory.stm.set_entity(f"entity_{i}", f"value_{i}")
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "EXTRACTED ENTITIES" in prompt
def test_prompt_with_null_values_in_entities(self, memory):
"""Should handle null values in entities."""
memory.stm.set_entity("null_value", None)
memory.stm.set_entity("empty_string", "")
memory.stm.set_entity("zero", 0)
memory.stm.set_entity("false", False)
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# Should not crash
assert "EXTRACTED ENTITIES" in prompt
def test_prompt_with_unread_events(self, memory):
"""Should include unread events."""
memory.episodic.add_background_event("download_complete", {"name": "Movie.mkv"})
memory.episodic.add_background_event("new_files", {"count": 5})
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
assert "UNREAD EVENTS" in prompt
def test_prompt_with_all_sections(self, memory):
"""Should include all sections when all data present."""
# Config
memory.ltm.set_config("download_folder", "/downloads")
# Search results
memory.episodic.store_search_results("test", [{"name": "Result"}])
# Active downloads
memory.episodic.add_active_download({"task_id": "1", "name": "Download"})
# Errors
memory.episodic.add_error("action", "error")
# Pending question
memory.episodic.set_pending_question("Question?", [], {})
# Workflow
memory.stm.start_workflow("download", {"title": "Test"})
# Topic
memory.stm.set_topic("searching")
# Entities
memory.stm.set_entity("key", "value")
# Events
memory.episodic.add_background_event("event", {})
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# All sections should be present
assert "CURRENT CONFIGURATION" in prompt
assert "LAST SEARCH" in prompt
assert "ACTIVE DOWNLOADS" in prompt
assert "RECENT ERRORS" in prompt
assert "PENDING QUESTION" in prompt
assert "CURRENT WORKFLOW" in prompt
assert "CURRENT TOPIC" in prompt
assert "EXTRACTED ENTITIES" in prompt
assert "UNREAD EVENTS" in prompt
def test_prompt_json_serializable(self, memory):
"""Should produce JSON-serializable content."""
memory.ltm.set_config("key", {"nested": [1, 2, 3]})
memory.stm.set_entity("complex", {"a": {"b": {"c": "d"}}})
tools = make_tools()
builder = PromptBuilder(tools)
prompt = builder.build_system_prompt()
# The prompt itself is a string, but embedded JSON should be valid
assert isinstance(prompt, str)
class TestFormatToolsDescriptionEdgeCases:
"""Edge case tests for _format_tools_description."""
def test_format_with_no_tools(self, memory):
"""Should handle empty tools dict."""
builder = PromptBuilder({})
desc = builder._format_tools_description()
assert desc == ""
def test_format_with_complex_parameters(self, memory):
"""Should format complex parameter schemas."""
from agent.registry import Tool
tools = {
"complex_tool": Tool(
name="complex_tool",
description="A complex tool",
func=lambda: {},
parameters={
"type": "object",
"properties": {
"nested": {
"type": "object",
"properties": {
"deep": {"type": "string"},
},
},
"array": {
"type": "array",
"items": {"type": "integer"},
},
},
"required": ["nested"],
},
),
}
builder = PromptBuilder(tools)
desc = builder._format_tools_description()
assert "complex_tool" in desc
assert "nested" in desc
class TestFormatEpisodicContextEdgeCases:
"""Edge case tests for _format_episodic_context."""
def test_format_with_empty_search_query(self, memory):
"""Should handle empty search query."""
memory.episodic.store_search_results("", [{"name": "Result"}])
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_episodic_context(memory)
assert "LAST SEARCH" in context
def test_format_with_search_results_none_names(self, memory):
"""Should handle results with None names."""
memory.episodic.store_search_results(
"test",
[
{"name": None},
{"title": None},
{},
],
)
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_episodic_context(memory)
# Should not crash
assert "LAST SEARCH" in context
def test_format_with_download_missing_progress(self, memory):
"""Should handle download without progress."""
memory.episodic.add_active_download({"task_id": "1", "name": "Test"})
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_episodic_context(memory)
assert "ACTIVE DOWNLOADS" in context
assert "0%" in context # Default progress
class TestFormatStmContextEdgeCases:
"""Edge case tests for _format_stm_context."""
def test_format_with_workflow_missing_target(self, memory):
"""Should handle workflow with missing target."""
memory.stm.current_workflow = {
"type": "download",
"stage": "started",
}
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_stm_context(memory)
assert "CURRENT WORKFLOW" in context
def test_format_with_workflow_none_target(self, memory):
"""Should handle workflow with None target."""
memory.stm.start_workflow("download", None)
tools = make_tools()
builder = PromptBuilder(tools)
try:
context = builder._format_stm_context(memory)
assert "CURRENT WORKFLOW" in context or True
except (AttributeError, TypeError):
# Expected if None target causes issues
pass
def test_format_with_empty_topic(self, memory):
"""Should handle empty topic."""
memory.stm.set_topic("")
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_stm_context(memory)
# Empty topic might not be shown
assert isinstance(context, str)
def test_format_with_entities_containing_json(self, memory):
"""Should handle entities containing JSON strings."""
memory.stm.set_entity("json_string", '{"key": "value"}')
tools = make_tools()
builder = PromptBuilder(tools)
context = builder._format_stm_context(memory)
assert "EXTRACTED ENTITIES" in context

View File

@@ -0,0 +1,232 @@
"""Critical tests for tool registry - Tests that would have caught bugs."""
import inspect
import pytest
from agent.prompts import PromptBuilder
from agent.registry import Tool, _create_tool_from_function, make_tools
class TestToolSpecFormat:
"""Critical tests for tool specification format."""
def test_tool_spec_format_is_openai_compatible(self):
"""CRITICAL: Verify tool specs are OpenAI-compatible."""
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
# Verify structure
assert isinstance(specs, list), "Tool specs must be a list"
assert len(specs) > 0, "Tool specs list is empty"
for spec in specs:
# OpenAI format requires these fields
assert (
spec["type"] == "function"
), f"Tool type must be 'function', got {spec.get('type')}"
assert "function" in spec, "Tool spec missing 'function' key"
func = spec["function"]
assert "name" in func, "Function missing 'name'"
assert "description" in func, "Function missing 'description'"
assert "parameters" in func, "Function missing 'parameters'"
params = func["parameters"]
assert params["type"] == "object", "Parameters type must be 'object'"
assert "properties" in params, "Parameters missing 'properties'"
assert "required" in params, "Parameters missing 'required'"
assert isinstance(params["required"], list), "Required must be a list"
def test_tool_parameters_match_function_signature(self):
"""CRITICAL: Verify generated parameters match function signature."""
def test_func(name: str, age: int, active: bool = True):
"""Test function with typed parameters."""
return {"status": "ok"}
tool = _create_tool_from_function(test_func)
# Verify types are correctly mapped
assert tool.parameters["properties"]["name"]["type"] == "string"
assert tool.parameters["properties"]["age"]["type"] == "integer"
assert tool.parameters["properties"]["active"]["type"] == "boolean"
# Verify required vs optional
assert "name" in tool.parameters["required"], "name should be required"
assert "age" in tool.parameters["required"], "age should be required"
assert (
"active" not in tool.parameters["required"]
), "active has default, should not be required"
def test_all_registered_tools_are_callable(self):
"""CRITICAL: Verify all registered tools are actually callable."""
tools = make_tools()
assert len(tools) > 0, "No tools registered"
for name, tool in tools.items():
assert callable(tool.func), f"Tool {name} is not callable"
# Verify function has valid signature
try:
inspect.signature(tool.func)
# If we get here, signature is valid
except Exception as e:
pytest.fail(f"Tool {name} has invalid signature: {e}")
def test_tools_spec_contains_all_registered_tools(self):
"""CRITICAL: Verify build_tools_spec() returns all registered tools."""
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys())
missing = tool_names - spec_names
extra = spec_names - tool_names
assert not missing, f"Tools missing from specs: {missing}"
assert not extra, f"Extra tools in specs: {extra}"
assert spec_names == tool_names, "Tool specs don't match registered tools"
def test_tool_description_extracted_from_docstring(self):
"""Verify tool description is extracted from function docstring."""
def test_func(param: str):
"""This is the description.
More details here.
"""
return {}
tool = _create_tool_from_function(test_func)
assert tool.description == "This is the description."
assert "More details" not in tool.description
def test_tool_without_docstring_uses_function_name(self):
"""Verify tool without docstring uses function name as description."""
def test_func_no_doc(param: str):
return {}
tool = _create_tool_from_function(test_func_no_doc)
assert tool.description == "test_func_no_doc"
def test_tool_parameters_have_descriptions(self):
"""Verify all tool parameters have descriptions."""
tools = make_tools()
builder = PromptBuilder(tools)
specs = builder.build_tools_spec()
for spec in specs:
params = spec["function"]["parameters"]
properties = params.get("properties", {})
for param_name, param_spec in properties.items():
assert (
"description" in param_spec
), f"Parameter {param_name} in {spec['function']['name']} missing description"
def test_required_parameters_are_marked_correctly(self):
"""Verify required parameters are correctly identified."""
def func_with_optional(required: str, optional: int = 5):
return {}
tool = _create_tool_from_function(func_with_optional)
assert "required" in tool.parameters["required"]
assert "optional" not in tool.parameters["required"]
assert len(tool.parameters["required"]) == 1
class TestToolRegistry:
"""Tests for tool registry functionality."""
def test_make_tools_returns_dict(self):
"""Verify make_tools returns a dictionary."""
tools = make_tools()
assert isinstance(tools, dict)
assert len(tools) > 0
def test_all_tools_have_unique_names(self):
"""Verify all tool names are unique."""
tools = make_tools()
names = [tool.name for tool in tools.values()]
assert len(names) == len(set(names)), "Duplicate tool names found"
def test_tool_names_match_dict_keys(self):
"""Verify tool names match their dictionary keys."""
tools = make_tools()
for key, tool in tools.items():
assert key == tool.name, f"Key {key} doesn't match tool name {tool.name}"
def test_expected_tools_are_registered(self):
"""Verify all expected tools are registered."""
tools = make_tools()
expected_tools = [
"set_path_for_folder",
"list_folder",
"find_media_imdb_id",
"find_torrent",
"add_torrent_by_index",
"add_torrent_to_qbittorrent",
"get_torrent_by_index",
"set_language",
]
for expected in expected_tools:
assert expected in tools, f"Expected tool {expected} not registered"
def test_tool_functions_are_valid(self):
"""Verify all tool functions are properly structured."""
tools = make_tools()
# Verify structure without calling functions
# (calling would require full setup with memory, clients, etc.)
for name, tool in tools.items():
assert callable(tool.func), f"Tool {name} function is not callable"
class TestToolDataclass:
"""Tests for Tool dataclass."""
def test_tool_creation(self):
"""Verify Tool can be created with all fields."""
def dummy_func():
return {}
tool = Tool(
name="test_tool",
description="Test description",
func=dummy_func,
parameters={"type": "object", "properties": {}, "required": []},
)
assert tool.name == "test_tool"
assert tool.description == "Test description"
assert tool.func == dummy_func
assert isinstance(tool.parameters, dict)
def test_tool_parameters_structure(self):
"""Verify Tool parameters have correct structure."""
def dummy_func(arg: str):
return {}
tool = _create_tool_from_function(dummy_func)
assert "type" in tool.parameters
assert "properties" in tool.parameters
assert "required" in tool.parameters
assert tool.parameters["type"] == "object"

View File

@@ -0,0 +1,304 @@
"""Edge case tests for tool registry."""
import pytest
from agent.registry import Tool, make_tools
class TestToolEdgeCases:
"""Edge case tests for Tool dataclass."""
def test_tool_creation(self):
"""Should create tool with all fields."""
tool = Tool(
name="test_tool",
description="Test description",
func=lambda: {"status": "ok"},
parameters={"type": "object", "properties": {}},
)
assert tool.name == "test_tool"
assert tool.description == "Test description"
assert callable(tool.func)
def test_tool_with_unicode_name(self):
"""Should handle unicode in tool name."""
tool = Tool(
name="tool_日本語",
description="Japanese tool",
func=lambda: {},
parameters={},
)
assert "日本語" in tool.name
def test_tool_with_unicode_description(self):
"""Should handle unicode in description."""
tool = Tool(
name="test",
description="日本語の説明 🔧",
func=lambda: {},
parameters={},
)
assert "日本語" in tool.description
def test_tool_with_complex_parameters(self):
"""Should handle complex parameter schemas."""
tool = Tool(
name="complex",
description="Complex tool",
func=lambda: {},
parameters={
"type": "object",
"properties": {
"nested": {
"type": "object",
"properties": {
"deep": {
"type": "array",
"items": {"type": "string"},
},
},
},
"enum_field": {
"type": "string",
"enum": ["a", "b", "c"],
},
},
"required": ["nested"],
},
)
assert "nested" in tool.parameters["properties"]
def test_tool_with_empty_parameters(self):
"""Should handle empty parameters."""
tool = Tool(
name="no_params",
description="Tool with no parameters",
func=lambda: {},
parameters={},
)
assert tool.parameters == {}
def test_tool_with_none_func(self):
"""Should handle None func (though invalid)."""
tool = Tool(
name="invalid",
description="Invalid tool",
func=None,
parameters={},
)
assert tool.func is None
def test_tool_func_execution(self):
"""Should execute tool function."""
result_value = {"status": "ok", "data": "test"}
tool = Tool(
name="test",
description="Test",
func=lambda: result_value,
parameters={},
)
result = tool.func()
assert result == result_value
def test_tool_func_with_args(self):
"""Should execute tool function with arguments."""
tool = Tool(
name="test",
description="Test",
func=lambda x, y: {"sum": x + y},
parameters={},
)
result = tool.func(1, 2)
assert result["sum"] == 3
def test_tool_func_with_kwargs(self):
"""Should execute tool function with keyword arguments."""
tool = Tool(
name="test",
description="Test",
func=lambda **kwargs: {"received": kwargs},
parameters={},
)
result = tool.func(a=1, b=2)
assert result["received"]["a"] == 1
class TestMakeToolsEdgeCases:
"""Edge case tests for make_tools function."""
def test_make_tools_returns_dict(self, memory):
"""Should return dictionary of tools."""
tools = make_tools()
assert isinstance(tools, dict)
def test_make_tools_all_tools_have_required_fields(self, memory):
"""Should have all required fields for each tool."""
tools = make_tools()
for name, tool in tools.items():
assert tool.name == name
assert isinstance(tool.description, str)
assert len(tool.description) > 0
assert callable(tool.func)
assert isinstance(tool.parameters, dict)
def test_make_tools_unique_names(self, memory):
"""Should have unique tool names."""
tools = make_tools()
names = list(tools.keys())
assert len(names) == len(set(names))
def test_make_tools_valid_parameter_schemas(self, memory):
"""Should have valid JSON Schema for parameters."""
tools = make_tools()
for tool in tools.values():
params = tool.parameters
if params:
assert "type" in params
assert params["type"] == "object"
if "properties" in params:
assert isinstance(params["properties"], dict)
def test_make_tools_required_params_in_properties(self, memory):
"""Should have required params defined in properties."""
tools = make_tools()
for tool in tools.values():
params = tool.parameters
if "required" in params and "properties" in params:
for req in params["required"]:
assert (
req in params["properties"]
), f"Required param {req} not in properties for {tool.name}"
def test_make_tools_descriptions_not_empty(self, memory):
"""Should have non-empty descriptions."""
tools = make_tools()
for tool in tools.values():
assert tool.description.strip() != ""
def test_make_tools_funcs_callable(self, memory):
"""Should have callable functions."""
tools = make_tools()
for tool in tools.values():
assert callable(tool.func)
def test_make_tools_expected_tools_present(self, memory):
"""Should have expected tools."""
tools = make_tools()
expected = [
"set_path_for_folder",
"list_folder",
"find_media_imdb_id",
"find_torrent", # Changed from find_torrents
"add_torrent_by_index",
"add_torrent_to_qbittorrent",
"get_torrent_by_index",
"set_language",
]
for name in expected:
assert name in tools, f"Expected tool {name} not found"
def test_make_tools_idempotent(self, memory):
"""Should return same tools on multiple calls."""
tools1 = make_tools()
tools2 = make_tools()
assert set(tools1.keys()) == set(tools2.keys())
def test_make_tools_parameter_types(self, memory):
"""Should have valid parameter types."""
tools = make_tools()
valid_types = ["string", "integer", "number", "boolean", "array", "object"]
for tool in tools.values():
if "properties" in tool.parameters:
for prop_name, prop_schema in tool.parameters["properties"].items():
if "type" in prop_schema:
assert (
prop_schema["type"] in valid_types
), f"Invalid type for {tool.name}.{prop_name}"
def test_make_tools_enum_values(self, memory):
"""Should have valid enum values."""
tools = make_tools()
for tool in tools.values():
if "properties" in tool.parameters:
for _prop_name, prop_schema in tool.parameters["properties"].items():
if "enum" in prop_schema:
assert isinstance(prop_schema["enum"], list)
assert len(prop_schema["enum"]) > 0
class TestToolExecution:
"""Tests for tool execution edge cases."""
def test_tool_returns_dict(self, memory, real_folder):
"""Should return dict from tool execution."""
tools = make_tools()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = tools["list_folder"].func(folder_type="download")
assert isinstance(result, dict)
def test_tool_returns_status(self, memory, real_folder):
"""Should return status in result."""
tools = make_tools()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = tools["list_folder"].func(folder_type="download")
assert "status" in result or "error" in result
def test_tool_handles_missing_args(self, memory):
"""Should handle missing required arguments."""
tools = make_tools()
with pytest.raises(TypeError):
tools["set_path_for_folder"].func() # Missing required args
def test_tool_handles_wrong_type_args(self, memory):
"""Should handle wrong type arguments."""
tools = make_tools()
# Pass wrong type - should either work or raise
try:
result = tools["get_torrent_by_index"].func(index="not an int")
# If it doesn't raise, should return error
assert "error" in result or "status" in result
except (TypeError, ValueError):
pass # Also acceptable
def test_tool_handles_extra_args(self, memory, real_folder):
"""Should handle extra arguments."""
tools = make_tools()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
# Extra args should raise TypeError
with pytest.raises(TypeError):
tools["list_folder"].func(
folder_type="download",
extra_arg="should fail",
)

422
tests/test_repositories.py Normal file
View File

@@ -0,0 +1,422 @@
"""Tests for JSON repositories."""
from domain.movies.entities import Movie
from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear
from domain.shared.value_objects import FilePath, FileSize, ImdbId
from domain.subtitles.entities import Subtitle
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonSubtitleRepository,
JsonTVShowRepository,
)
class TestJsonMovieRepository:
"""Tests for JsonMovieRepository."""
def test_save_movie(self, memory):
"""Should save a movie."""
repo = JsonMovieRepository()
movie = Movie(
imdb_id=ImdbId("tt1375666"),
title=MovieTitle("Inception"),
release_year=ReleaseYear(2010),
quality=Quality.FULL_HD,
)
repo.save(movie)
assert len(memory.ltm.library["movies"]) == 1
assert memory.ltm.library["movies"][0]["imdb_id"] == "tt1375666"
def test_save_updates_existing(self, memory):
"""Should update existing movie."""
repo = JsonMovieRepository()
movie1 = Movie(
imdb_id=ImdbId("tt1375666"),
title=MovieTitle("Inception"),
quality=Quality.HD,
)
movie2 = Movie(
imdb_id=ImdbId("tt1375666"),
title=MovieTitle("Inception"),
quality=Quality.FULL_HD,
)
repo.save(movie1)
repo.save(movie2)
assert len(memory.ltm.library["movies"]) == 1
assert memory.ltm.library["movies"][0]["quality"] == "1080p"
def test_find_by_imdb_id(self, memory_with_library):
"""Should find movie by IMDb ID."""
repo = JsonMovieRepository()
movie = repo.find_by_imdb_id(ImdbId("tt1375666"))
assert movie is not None
assert movie.title.value == "Inception"
def test_find_by_imdb_id_not_found(self, memory):
"""Should return None if not found."""
repo = JsonMovieRepository()
movie = repo.find_by_imdb_id(ImdbId("tt9999999"))
assert movie is None
def test_find_all(self, memory_with_library):
"""Should return all movies."""
repo = JsonMovieRepository()
movies = repo.find_all()
assert len(movies) >= 2
titles = [m.title.value for m in movies]
assert "Inception" in titles
assert "Interstellar" in titles
def test_find_all_empty(self, memory):
"""Should return empty list if no movies."""
repo = JsonMovieRepository()
movies = repo.find_all()
assert movies == []
def test_delete(self, memory_with_library):
"""Should delete movie."""
repo = JsonMovieRepository()
result = repo.delete(ImdbId("tt1375666"))
assert result is True
assert len(memory_with_library.ltm.library["movies"]) == 1
def test_delete_not_found(self, memory):
"""Should return False if not found."""
repo = JsonMovieRepository()
result = repo.delete(ImdbId("tt9999999"))
assert result is False
def test_exists(self, memory_with_library):
"""Should check if movie exists."""
repo = JsonMovieRepository()
assert repo.exists(ImdbId("tt1375666")) is True
assert repo.exists(ImdbId("tt9999999")) is False
def test_preserves_all_fields(self, memory):
"""Should preserve all movie fields."""
repo = JsonMovieRepository()
movie = Movie(
imdb_id=ImdbId("tt1375666"),
title=MovieTitle("Inception"),
release_year=ReleaseYear(2010),
quality=Quality.FULL_HD,
file_path=FilePath("/movies/inception.mkv"),
file_size=FileSize(2500000000),
tmdb_id=27205,
)
repo.save(movie)
loaded = repo.find_by_imdb_id(ImdbId("tt1375666"))
assert loaded.title.value == "Inception"
assert loaded.release_year.value == 2010
assert loaded.quality.value == "1080p"
assert str(loaded.file_path) == "/movies/inception.mkv"
assert loaded.file_size.bytes == 2500000000
assert loaded.tmdb_id == 27205
class TestJsonTVShowRepository:
"""Tests for JsonTVShowRepository."""
def test_save_show(self, memory):
"""Should save a TV show."""
repo = JsonTVShowRepository()
show = TVShow(
imdb_id=ImdbId("tt0944947"),
title="Game of Thrones",
seasons_count=8,
status=ShowStatus.ENDED,
)
repo.save(show)
assert len(memory.ltm.library["tv_shows"]) == 1
assert memory.ltm.library["tv_shows"][0]["title"] == "Game of Thrones"
def test_save_updates_existing(self, memory):
"""Should update existing show."""
repo = JsonTVShowRepository()
show1 = TVShow(
imdb_id=ImdbId("tt0944947"),
title="Game of Thrones",
seasons_count=7,
status=ShowStatus.ONGOING,
)
show2 = TVShow(
imdb_id=ImdbId("tt0944947"),
title="Game of Thrones",
seasons_count=8,
status=ShowStatus.ENDED,
)
repo.save(show1)
repo.save(show2)
assert len(memory.ltm.library["tv_shows"]) == 1
assert memory.ltm.library["tv_shows"][0]["seasons_count"] == 8
def test_find_by_imdb_id(self, memory_with_library):
"""Should find show by IMDb ID."""
repo = JsonTVShowRepository()
show = repo.find_by_imdb_id(ImdbId("tt0944947"))
assert show is not None
assert show.title == "Game of Thrones"
def test_find_by_imdb_id_not_found(self, memory):
"""Should return None if not found."""
repo = JsonTVShowRepository()
show = repo.find_by_imdb_id(ImdbId("tt9999999"))
assert show is None
def test_find_all(self, memory_with_library):
"""Should return all shows."""
repo = JsonTVShowRepository()
shows = repo.find_all()
assert len(shows) == 1
assert shows[0].title == "Game of Thrones"
def test_delete(self, memory_with_library):
"""Should delete show."""
repo = JsonTVShowRepository()
result = repo.delete(ImdbId("tt0944947"))
assert result is True
assert len(memory_with_library.ltm.library["tv_shows"]) == 0
def test_exists(self, memory_with_library):
"""Should check if show exists."""
repo = JsonTVShowRepository()
assert repo.exists(ImdbId("tt0944947")) is True
assert repo.exists(ImdbId("tt9999999")) is False
def test_preserves_status(self, memory):
"""Should preserve show status."""
repo = JsonTVShowRepository()
for i, status in enumerate(
[ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]
):
show = TVShow(
imdb_id=ImdbId(f"tt{i+1000000:07d}"),
title=f"Show {status.value}",
seasons_count=1,
status=status,
)
repo.save(show)
loaded = repo.find_by_imdb_id(ImdbId(f"tt{i+1000000:07d}"))
assert loaded.status == status
class TestJsonSubtitleRepository:
"""Tests for JsonSubtitleRepository."""
def test_save_subtitle(self, memory):
"""Should save a subtitle."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/inception.en.srt"),
)
repo.save(subtitle)
assert "subtitles" in memory.ltm.library
assert len(memory.ltm.library["subtitles"]) == 1
def test_save_multiple_for_same_media(self, memory):
"""Should allow multiple subtitles for same media."""
repo = JsonSubtitleRepository()
sub_en = Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/inception.en.srt"),
)
sub_fr = Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.FRENCH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/inception.fr.srt"),
)
repo.save(sub_en)
repo.save(sub_fr)
assert len(memory.ltm.library["subtitles"]) == 2
def test_find_by_media(self, memory):
"""Should find subtitles by media ID."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/inception.en.srt"),
)
repo.save(subtitle)
results = repo.find_by_media(ImdbId("tt1375666"))
assert len(results) == 1
assert results[0].language == Language.ENGLISH
def test_find_by_media_with_language_filter(self, memory):
"""Should filter by language."""
repo = JsonSubtitleRepository()
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/en.srt"),
)
)
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.FRENCH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/fr.srt"),
)
)
results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH)
assert len(results) == 1
assert results[0].language == Language.FRENCH
def test_find_by_media_with_episode_filter(self, memory):
"""Should filter by season/episode."""
repo = JsonSubtitleRepository()
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt0944947"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e01.srt"),
season_number=1,
episode_number=1,
)
)
repo.save(
Subtitle(
media_imdb_id=ImdbId("tt0944947"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/s01e02.srt"),
season_number=1,
episode_number=2,
)
)
results = repo.find_by_media(
ImdbId("tt0944947"),
season=1,
episode=1,
)
assert len(results) == 1
assert results[0].episode_number == 1
def test_find_by_media_not_found(self, memory):
"""Should return empty list if not found."""
repo = JsonSubtitleRepository()
results = repo.find_by_media(ImdbId("tt9999999"))
assert results == []
def test_delete(self, memory):
"""Should delete subtitle by file path."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/inception.en.srt"),
)
repo.save(subtitle)
result = repo.delete(subtitle)
assert result is True
assert len(memory.ltm.library["subtitles"]) == 0
def test_delete_not_found(self, memory):
"""Should return False if not found."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/nonexistent.srt"),
)
result = repo.delete(subtitle)
assert result is False
def test_preserves_all_fields(self, memory):
"""Should preserve all subtitle fields."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1375666"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/inception.en.srt"),
season_number=1,
episode_number=5,
timing_offset=TimingOffset(500),
hearing_impaired=True,
forced=False,
source="OpenSubtitles",
uploader="user123",
download_count=1000,
rating=8.5,
)
repo.save(subtitle)
results = repo.find_by_media(ImdbId("tt1375666"))
assert len(results) == 1
loaded = results[0]
assert loaded.season_number == 1
assert loaded.episode_number == 5
assert loaded.timing_offset.milliseconds == 500
assert loaded.hearing_impaired is True
assert loaded.forced is False
assert loaded.source == "OpenSubtitles"
assert loaded.uploader == "user123"
assert loaded.download_count == 1000
assert loaded.rating == 8.5

View File

@@ -0,0 +1,513 @@
"""Edge case tests for JSON repositories."""
from datetime import datetime
from domain.movies.entities import Movie
from domain.movies.value_objects import MovieTitle, Quality
from domain.shared.value_objects import FilePath, FileSize, ImdbId
from domain.subtitles.entities import Subtitle
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonSubtitleRepository,
JsonTVShowRepository,
)
class TestJsonMovieRepositoryEdgeCases:
"""Edge case tests for JsonMovieRepository."""
def test_save_movie_with_unicode_title(self, memory):
"""Should save movie with unicode title."""
repo = JsonMovieRepository()
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("千と千尋の神隠し"),
quality=Quality.FULL_HD,
)
repo.save(movie)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert loaded.title.value == "千と千尋の神隠し"
def test_save_movie_with_special_chars_in_path(self, memory):
"""Should save movie with special characters in path."""
repo = JsonMovieRepository()
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test"),
quality=Quality.FULL_HD,
file_path=FilePath("/movies/Test (2024) [1080p] {x265}.mkv"),
)
repo.save(movie)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert "[1080p]" in str(loaded.file_path)
def test_save_movie_with_very_long_title(self, memory):
"""Should save movie with very long title."""
repo = JsonMovieRepository()
long_title = "A" * 500
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle(long_title),
quality=Quality.FULL_HD,
)
repo.save(movie)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert len(loaded.title.value) == 500
def test_save_movie_with_zero_file_size(self, memory):
"""Should save movie with zero file size."""
repo = JsonMovieRepository()
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test"),
quality=Quality.FULL_HD,
file_size=FileSize(0),
)
repo.save(movie)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
# May be None or 0 depending on implementation
assert loaded.file_size is None or loaded.file_size.bytes == 0
def test_save_movie_with_very_large_file_size(self, memory):
"""Should save movie with very large file size."""
repo = JsonMovieRepository()
large_size = 100 * 1024 * 1024 * 1024 # 100 GB
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test"),
quality=Quality.UHD_4K, # Use valid quality enum
file_size=FileSize(large_size),
)
repo.save(movie)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert loaded.file_size.bytes == large_size
def test_find_all_with_corrupted_entry(self, memory):
"""Should handle corrupted entries gracefully."""
# Manually add corrupted data with valid IMDb IDs
memory.ltm.library["movies"] = [
{
"imdb_id": "tt1234567",
"title": "Valid",
"quality": "1080p",
"added_at": datetime.now().isoformat(),
},
{"imdb_id": "tt2345678"}, # Missing required fields
{
"imdb_id": "tt3456789",
"title": "Also Valid",
"quality": "720p",
"added_at": datetime.now().isoformat(),
},
]
repo = JsonMovieRepository()
# Should either skip corrupted or raise
try:
movies = repo.find_all()
# If it works, should have at least the valid ones
assert len(movies) >= 1
except (KeyError, TypeError, Exception):
# If it raises, that's also acceptable
pass
def test_delete_nonexistent_movie(self, memory):
"""Should return False for nonexistent movie."""
repo = JsonMovieRepository()
result = repo.delete(ImdbId("tt9999999"))
assert result is False
def test_delete_from_empty_library(self, memory):
"""Should handle delete from empty library."""
repo = JsonMovieRepository()
memory.ltm.library["movies"] = []
result = repo.delete(ImdbId("tt1234567"))
assert result is False
def test_exists_with_similar_ids(self, memory):
"""Should distinguish similar IMDb IDs."""
repo = JsonMovieRepository()
movie = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test"),
quality=Quality.FULL_HD,
)
repo.save(movie)
assert repo.exists(ImdbId("tt1234567")) is True
assert repo.exists(ImdbId("tt12345678")) is False
assert repo.exists(ImdbId("tt7654321")) is False
def test_save_preserves_added_at(self, memory):
"""Should preserve original added_at on update."""
repo = JsonMovieRepository()
# Save first version
movie1 = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test"),
quality=Quality.HD,
added_at=datetime(2020, 1, 1, 12, 0, 0),
)
repo.save(movie1)
# Update with new quality
movie2 = Movie(
imdb_id=ImdbId("tt1234567"),
title=MovieTitle("Test"),
quality=Quality.FULL_HD,
added_at=datetime(2024, 1, 1, 12, 0, 0),
)
repo.save(movie2)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
# The new added_at should be used (since it's a full replacement)
assert loaded.quality.value == "1080p"
def test_concurrent_saves(self, memory):
"""Should handle rapid saves."""
repo = JsonMovieRepository()
for i in range(100):
movie = Movie(
imdb_id=ImdbId(f"tt{i:07d}"),
title=MovieTitle(f"Movie {i}"),
quality=Quality.FULL_HD,
)
repo.save(movie)
movies = repo.find_all()
assert len(movies) == 100
class TestJsonTVShowRepositoryEdgeCases:
"""Edge case tests for JsonTVShowRepository."""
def test_save_show_with_zero_seasons(self, memory):
"""Should save show with zero seasons."""
repo = JsonTVShowRepository()
show = TVShow(
imdb_id=ImdbId("tt1234567"),
title="Upcoming Show",
seasons_count=0,
status=ShowStatus.ONGOING,
)
repo.save(show)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert loaded.seasons_count == 0
def test_save_show_with_many_seasons(self, memory):
"""Should save show with many seasons."""
repo = JsonTVShowRepository()
show = TVShow(
imdb_id=ImdbId("tt1234567"),
title="Long Running Show",
seasons_count=100,
status=ShowStatus.ONGOING,
)
repo.save(show)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert loaded.seasons_count == 100
def test_save_show_with_all_statuses(self, memory):
"""Should save shows with all status types."""
repo = JsonTVShowRepository()
for i, status in enumerate(
[ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]
):
show = TVShow(
imdb_id=ImdbId(f"tt{i:07d}"),
title=f"Show {i}",
seasons_count=1,
status=status,
)
repo.save(show)
loaded = repo.find_by_imdb_id(ImdbId(f"tt{i:07d}"))
assert loaded.status == status
def test_save_show_with_unicode_title(self, memory):
"""Should save show with unicode title."""
repo = JsonTVShowRepository()
show = TVShow(
imdb_id=ImdbId("tt1234567"),
title="日本のドラマ",
seasons_count=1,
status=ShowStatus.ONGOING,
)
repo.save(show)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert loaded.title == "日本のドラマ"
def test_save_show_with_first_air_date(self, memory):
"""Should save show with first air date."""
repo = JsonTVShowRepository()
show = TVShow(
imdb_id=ImdbId("tt1234567"),
title="Test Show",
seasons_count=1,
status=ShowStatus.ONGOING,
first_air_date="2024-01-15",
)
repo.save(show)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert loaded.first_air_date == "2024-01-15"
def test_find_all_empty(self, memory):
"""Should return empty list for empty library."""
repo = JsonTVShowRepository()
memory.ltm.library["tv_shows"] = []
shows = repo.find_all()
assert shows == []
def test_update_show_seasons(self, memory):
"""Should update show seasons count."""
repo = JsonTVShowRepository()
# Save initial
show1 = TVShow(
imdb_id=ImdbId("tt1234567"),
title="Test Show",
seasons_count=5,
status=ShowStatus.ONGOING,
)
repo.save(show1)
# Update seasons
show2 = TVShow(
imdb_id=ImdbId("tt1234567"),
title="Test Show",
seasons_count=6,
status=ShowStatus.ONGOING,
)
repo.save(show2)
loaded = repo.find_by_imdb_id(ImdbId("tt1234567"))
assert loaded.seasons_count == 6
class TestJsonSubtitleRepositoryEdgeCases:
"""Edge case tests for JsonSubtitleRepository."""
def test_save_subtitle_with_large_timing_offset(self, memory):
"""Should save subtitle with large timing offset."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/test.srt"),
timing_offset=TimingOffset(3600000), # 1 hour
)
repo.save(subtitle)
results = repo.find_by_media(ImdbId("tt1234567"))
assert results[0].timing_offset.milliseconds == 3600000
def test_save_subtitle_with_negative_timing_offset(self, memory):
"""Should save subtitle with negative timing offset."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/test.srt"),
timing_offset=TimingOffset(-5000),
)
repo.save(subtitle)
results = repo.find_by_media(ImdbId("tt1234567"))
assert results[0].timing_offset.milliseconds == -5000
def test_find_by_media_multiple_languages(self, memory):
"""Should find subtitles for multiple languages."""
repo = JsonSubtitleRepository()
# Only use existing languages
for lang in [Language.ENGLISH, Language.FRENCH]:
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=lang,
format=SubtitleFormat.SRT,
file_path=FilePath(f"/subs/test.{lang.value}.srt"),
)
repo.save(subtitle)
all_subs = repo.find_by_media(ImdbId("tt1234567"))
en_subs = repo.find_by_media(ImdbId("tt1234567"), language=Language.ENGLISH)
assert len(all_subs) == 2
assert len(en_subs) == 1
def test_find_by_media_specific_episode(self, memory):
"""Should find subtitle for specific episode."""
repo = JsonSubtitleRepository()
# Add subtitles for multiple episodes
for ep in range(1, 4):
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath(f"/subs/s01e{ep:02d}.srt"),
season_number=1,
episode_number=ep,
)
repo.save(subtitle)
results = repo.find_by_media(
ImdbId("tt1234567"),
season=1,
episode=2,
)
assert len(results) == 1
assert results[0].episode_number == 2
def test_find_by_media_season_only(self, memory):
"""Should find all subtitles for a season."""
repo = JsonSubtitleRepository()
# Add subtitles for multiple seasons
for season in [1, 2]:
for ep in range(1, 3):
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath(f"/subs/s{season:02d}e{ep:02d}.srt"),
season_number=season,
episode_number=ep,
)
repo.save(subtitle)
results = repo.find_by_media(ImdbId("tt1234567"), season=1)
assert len(results) == 2
def test_delete_subtitle_by_path(self, memory):
"""Should delete subtitle by file path."""
repo = JsonSubtitleRepository()
sub1 = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/test1.srt"),
)
sub2 = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.FRENCH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/test2.srt"),
)
repo.save(sub1)
repo.save(sub2)
result = repo.delete(sub1)
assert result is True
remaining = repo.find_by_media(ImdbId("tt1234567"))
assert len(remaining) == 1
assert remaining[0].language == Language.FRENCH
def test_save_subtitle_with_all_metadata(self, memory):
"""Should save subtitle with all metadata fields."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/test.srt"),
season_number=1,
episode_number=5,
timing_offset=TimingOffset(500),
hearing_impaired=True,
forced=True,
source="OpenSubtitles",
uploader="user123",
download_count=10000,
rating=9.5,
)
repo.save(subtitle)
results = repo.find_by_media(ImdbId("tt1234567"))
loaded = results[0]
assert loaded.hearing_impaired is True
assert loaded.forced is True
assert loaded.source == "OpenSubtitles"
assert loaded.uploader == "user123"
assert loaded.download_count == 10000
assert loaded.rating == 9.5
def test_save_subtitle_with_unicode_path(self, memory):
"""Should save subtitle with unicode in path."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.FRENCH, # Use existing language
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/日本語字幕.srt"),
)
repo.save(subtitle)
results = repo.find_by_media(ImdbId("tt1234567"))
assert "日本語" in str(results[0].file_path)
def test_find_by_media_no_results(self, memory):
"""Should return empty list when no subtitles found."""
repo = JsonSubtitleRepository()
results = repo.find_by_media(ImdbId("tt9999999"))
assert results == []
def test_find_by_media_wrong_language(self, memory):
"""Should return empty when language doesn't match."""
repo = JsonSubtitleRepository()
subtitle = Subtitle(
media_imdb_id=ImdbId("tt1234567"),
language=Language.ENGLISH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/test.srt"),
)
repo.save(subtitle)
results = repo.find_by_media(ImdbId("tt1234567"), language=Language.FRENCH)
assert results == []

402
tests/test_tools_api.py Normal file
View File

@@ -0,0 +1,402 @@
"""Tests for API tools - Refactored to use real components with minimal mocking."""
from unittest.mock import Mock, patch
from agent.tools import api as api_tools
from infrastructure.persistence import get_memory
def create_mock_response(status_code, json_data=None, text=None):
"""Helper to create properly mocked HTTP response."""
response = Mock()
response.status_code = status_code
response.raise_for_status = Mock()
if json_data is not None:
response.json = Mock(return_value=json_data)
if text is not None:
response.text = text
return response
class TestFindMediaImdbId:
"""Tests for find_media_imdb_id tool."""
@patch("infrastructure.api.tmdb.client.requests.get")
def test_success(self, mock_get, memory):
"""Should return movie info on success."""
# Mock HTTP responses
def mock_get_side_effect(url, **kwargs):
if "search" in url:
return create_mock_response(
200,
json_data={
"results": [
{
"id": 27205,
"title": "Inception",
"release_date": "2010-07-16",
"overview": "A thief...",
"media_type": "movie",
}
]
},
)
elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
mock_get.side_effect = mock_get_side_effect
result = api_tools.find_media_imdb_id("Inception")
assert result["status"] == "ok"
assert result["imdb_id"] == "tt1375666"
assert result["title"] == "Inception"
# Verify HTTP calls
assert mock_get.call_count == 2
@patch("infrastructure.api.tmdb.client.requests.get")
def test_stores_in_stm(self, mock_get, memory):
"""Should store result in STM on success."""
def mock_get_side_effect(url, **kwargs):
if "search" in url:
return create_mock_response(
200,
json_data={
"results": [
{
"id": 27205,
"title": "Inception",
"release_date": "2010-07-16",
"media_type": "movie",
}
]
},
)
elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
mock_get.side_effect = mock_get_side_effect
api_tools.find_media_imdb_id("Inception")
mem = get_memory()
entity = mem.stm.get_entity("last_media_search")
assert entity is not None
assert entity["title"] == "Inception"
assert mem.stm.current_topic == "searching_media"
@patch("infrastructure.api.tmdb.client.requests.get")
def test_not_found(self, mock_get, memory):
"""Should return error when not found."""
mock_get.return_value = create_mock_response(200, json_data={"results": []})
result = api_tools.find_media_imdb_id("NonexistentMovie12345")
assert result["status"] == "error"
assert result["error"] == "not_found"
@patch("infrastructure.api.tmdb.client.requests.get")
def test_does_not_store_on_error(self, mock_get, memory):
"""Should not store in STM on error."""
mock_get.return_value = create_mock_response(200, json_data={"results": []})
api_tools.find_media_imdb_id("Test")
mem = get_memory()
assert mem.stm.get_entity("last_media_search") is None
class TestFindTorrent:
"""Tests for find_torrent tool."""
@patch("infrastructure.api.knaben.client.requests.post")
def test_success(self, mock_post, memory):
"""Should return torrents on success."""
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=...",
"size": "2.5 GB",
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=...",
"size": "1.8 GB",
},
]
},
)
result = api_tools.find_torrent("Inception 1080p")
assert result["status"] == "ok"
assert len(result["torrents"]) == 2
# Verify HTTP payload
payload = mock_post.call_args[1]["json"]
assert payload["query"] == "Inception 1080p"
@patch("infrastructure.api.knaben.client.requests.post")
def test_stores_in_episodic(self, mock_post, memory):
"""Should store results in episodic memory."""
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=...",
"size": "2.5 GB",
}
]
},
)
api_tools.find_torrent("Inception")
mem = get_memory()
assert mem.episodic.last_search_results is not None
assert mem.episodic.last_search_results["query"] == "Inception"
assert mem.stm.current_topic == "selecting_torrent"
@patch("infrastructure.api.knaben.client.requests.post")
def test_results_have_indexes(self, mock_post, memory):
"""Should add indexes to results."""
mock_post.return_value = create_mock_response(
200,
json_data={
"hits": [
{
"title": "Torrent 1",
"seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=1",
"size": "1GB",
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=2",
"size": "2GB",
},
{
"title": "Torrent 3",
"seeders": 25,
"leechers": 2,
"magnetUrl": "magnet:?xt=3",
"size": "3GB",
},
]
},
)
api_tools.find_torrent("Test")
mem = get_memory()
results = mem.episodic.last_search_results["results"]
assert results[0]["index"] == 1
assert results[1]["index"] == 2
assert results[2]["index"] == 3
@patch("infrastructure.api.knaben.client.requests.post")
def test_not_found(self, mock_post, memory):
"""Should return error when no torrents found."""
mock_post.return_value = create_mock_response(200, json_data={"hits": []})
result = api_tools.find_torrent("NonexistentMovie12345")
assert result["status"] == "error"
class TestGetTorrentByIndex:
"""Tests for get_torrent_by_index tool."""
def test_success(self, memory_with_search_results):
"""Should return torrent at index."""
result = api_tools.get_torrent_by_index(2)
assert result["status"] == "ok"
assert result["torrent"]["name"] == "Inception.2010.1080p.WEB-DL.x265"
def test_first_index(self, memory_with_search_results):
"""Should return first torrent."""
result = api_tools.get_torrent_by_index(1)
assert result["status"] == "ok"
assert result["torrent"]["name"] == "Inception.2010.1080p.BluRay.x264"
def test_last_index(self, memory_with_search_results):
"""Should return last torrent."""
result = api_tools.get_torrent_by_index(3)
assert result["status"] == "ok"
assert result["torrent"]["name"] == "Inception.2010.720p.BluRay"
def test_index_out_of_range(self, memory_with_search_results):
"""Should return error for invalid index."""
result = api_tools.get_torrent_by_index(10)
assert result["status"] == "error"
assert result["error"] == "not_found"
def test_index_zero(self, memory_with_search_results):
"""Should return error for index 0."""
result = api_tools.get_torrent_by_index(0)
assert result["status"] == "error"
assert result["error"] == "not_found"
def test_negative_index(self, memory_with_search_results):
"""Should return error for negative index."""
result = api_tools.get_torrent_by_index(-1)
assert result["status"] == "error"
assert result["error"] == "not_found"
def test_no_search_results(self, memory):
"""Should return error if no search results."""
result = api_tools.get_torrent_by_index(1)
assert result["status"] == "error"
assert result["error"] == "not_found"
assert "Search for torrents first" in result["message"]
class TestAddTorrentToQbittorrent:
"""Tests for add_torrent_to_qbittorrent tool.
Note: These tests mock the qBittorrent client because:
1. The client requires authentication/session management
2. We want to test the tool's logic (memory updates, workflow management)
3. The client itself is tested separately in infrastructure tests
This is acceptable mocking because we're testing the TOOL logic, not the client.
"""
@patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory):
"""Should add torrent successfully and update memory."""
mock_client.add_torrent.return_value = True
result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=urn:btih:abc123")
# Test tool logic
assert result["status"] == "ok"
# Verify client was called correctly
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch("agent.tools.api.qbittorrent_client")
def test_adds_to_active_downloads(self, mock_client, memory_with_search_results):
"""Should add to active downloads on success."""
mock_client.add_torrent.return_value = True
api_tools.add_torrent_to_qbittorrent("magnet:?xt=urn:btih:abc123")
# Test memory update logic
mem = get_memory()
assert len(mem.episodic.active_downloads) == 1
assert (
mem.episodic.active_downloads[0]["name"]
== "Inception.2010.1080p.BluRay.x264"
)
@patch("agent.tools.api.qbittorrent_client")
def test_sets_topic_and_ends_workflow(self, mock_client, memory):
"""Should set topic and end workflow."""
mock_client.add_torrent.return_value = True
memory.stm.start_workflow("download", {"title": "Test"})
api_tools.add_torrent_to_qbittorrent("magnet:?xt=...")
# Test workflow management logic
mem = get_memory()
assert mem.stm.current_topic == "downloading"
assert mem.stm.current_workflow is None
@patch("agent.tools.api.qbittorrent_client")
def test_error_handling(self, mock_client, memory):
"""Should handle client errors correctly."""
from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError
mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed")
result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...")
# Test error handling logic
assert result["status"] == "error"
class TestAddTorrentByIndex:
"""Tests for add_torrent_by_index tool.
These tests verify the tool's logic:
- Getting torrent from memory by index
- Extracting magnet link
- Calling add_torrent_to_qbittorrent
- Error handling for edge cases
"""
@patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory_with_search_results):
"""Should get torrent by index and add it."""
mock_client.add_torrent.return_value = True
result = api_tools.add_torrent_by_index(1)
# Test tool logic
assert result["status"] == "ok"
assert result["torrent_name"] == "Inception.2010.1080p.BluRay.x264"
# Verify correct magnet was extracted and used
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch("agent.tools.api.qbittorrent_client")
def test_uses_correct_magnet(self, mock_client, memory_with_search_results):
"""Should extract correct magnet from index."""
mock_client.add_torrent.return_value = True
api_tools.add_torrent_by_index(2)
# Test magnet extraction logic
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:def456")
def test_invalid_index(self, memory_with_search_results):
"""Should return error for invalid index."""
result = api_tools.add_torrent_by_index(99)
# Test error handling logic (no mock needed)
assert result["status"] == "error"
assert result["error"] == "not_found"
def test_no_search_results(self, memory):
"""Should return error if no search results."""
result = api_tools.add_torrent_by_index(1)
# Test error handling logic (no mock needed)
assert result["status"] == "error"
assert result["error"] == "not_found"
def test_no_magnet_link(self, memory):
"""Should return error if torrent has no magnet."""
memory.episodic.store_search_results(
"test",
[{"name": "Torrent without magnet", "seeders": 100}],
)
result = api_tools.add_torrent_by_index(1)
# Test error handling logic (no mock needed)
assert result["status"] == "error"
assert result["error"] == "no_magnet"

View File

@@ -0,0 +1,445 @@
"""Edge case tests for tools."""
from unittest.mock import Mock, patch
import pytest
from agent.tools import api as api_tools
from agent.tools import filesystem as fs_tools
from infrastructure.persistence import get_memory
class TestFindTorrentEdgeCases:
"""Edge case tests for find_torrent."""
@patch("agent.tools.api.SearchTorrentsUseCase")
def test_empty_query(self, mock_use_case_class, memory):
"""Should handle empty query."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "error",
"error": "invalid_query",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
result = api_tools.find_torrent("")
assert result["status"] == "error"
@patch("agent.tools.api.SearchTorrentsUseCase")
def test_very_long_query(self, mock_use_case_class, memory):
"""Should handle very long query."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
long_query = "x" * 10000
result = api_tools.find_torrent(long_query)
# Should not crash
assert "status" in result
@patch("agent.tools.api.SearchTorrentsUseCase")
def test_special_characters_in_query(self, mock_use_case_class, memory):
"""Should handle special characters in query."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
special_query = "Movie (2024) [1080p] {x265} <HDR>"
result = api_tools.find_torrent(special_query)
assert "status" in result
@patch("agent.tools.api.SearchTorrentsUseCase")
def test_unicode_query(self, mock_use_case_class, memory):
"""Should handle unicode in query."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
result = api_tools.find_torrent("日本語映画 2024")
assert "status" in result
@patch("agent.tools.api.SearchTorrentsUseCase")
def test_results_with_missing_fields(self, mock_use_case_class, memory):
"""Should handle results with missing fields."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [
{"name": "Torrent 1"}, # Missing seeders, magnet, etc.
{}, # Completely empty
],
"count": 2,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
result = api_tools.find_torrent("Test")
assert result["status"] == "ok"
mem = get_memory()
assert len(mem.episodic.last_search_results["results"]) == 2
@patch("agent.tools.api.SearchTorrentsUseCase")
def test_api_timeout(self, mock_use_case_class, memory):
"""Should handle API timeout."""
mock_use_case = Mock()
mock_use_case.execute.side_effect = TimeoutError("Connection timed out")
mock_use_case_class.return_value = mock_use_case
with pytest.raises(TimeoutError):
api_tools.find_torrent("Test")
class TestGetTorrentByIndexEdgeCases:
"""Edge case tests for get_torrent_by_index."""
def test_index_as_float(self, memory_with_search_results):
"""Should handle float index (converted to int)."""
# Python will convert 2.0 to 2 when passed as int
result = api_tools.get_torrent_by_index(int(2.9))
assert result["status"] == "ok"
assert result["torrent"]["index"] == 2
def test_results_modified_between_calls(self, memory):
"""Should handle results being modified."""
memory.episodic.store_search_results("query1", [{"name": "Result 1"}])
# Get first result
result1 = api_tools.get_torrent_by_index(1)
assert result1["status"] == "ok"
# Store new results
memory.episodic.store_search_results("query2", [{"name": "New Result"}])
# Get first result again - should be new result
result2 = api_tools.get_torrent_by_index(1)
assert result2["torrent"]["name"] == "New Result"
def test_result_with_index_already_set(self, memory):
"""Should handle results that already have index field."""
memory.episodic.store_search_results(
"query",
[{"name": "Result", "index": 999}], # Pre-existing index
)
result = api_tools.get_torrent_by_index(1)
# May overwrite or error depending on implementation
assert result["status"] in ["ok", "error"]
class TestAddTorrentEdgeCases:
"""Edge case tests for add_torrent functions."""
@patch("agent.tools.api.AddTorrentUseCase")
def test_invalid_magnet_link(self, mock_use_case_class, memory):
"""Should handle invalid magnet link."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "error",
"error": "invalid_magnet",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
result = api_tools.add_torrent_to_qbittorrent("not a magnet link")
assert result["status"] == "error"
@patch("agent.tools.api.AddTorrentUseCase")
def test_empty_magnet_link(self, mock_use_case_class, memory):
"""Should handle empty magnet link."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "error",
"error": "empty_magnet",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
result = api_tools.add_torrent_to_qbittorrent("")
assert result["status"] == "error"
@patch("agent.tools.api.AddTorrentUseCase")
def test_very_long_magnet_link(self, mock_use_case_class, memory):
"""Should handle very long magnet link."""
mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok"}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
long_magnet = "magnet:?xt=urn:btih:" + "a" * 10000
result = api_tools.add_torrent_to_qbittorrent(long_magnet)
assert "status" in result
@patch("agent.tools.api.AddTorrentUseCase")
def test_qbittorrent_connection_refused(self, mock_use_case_class, memory):
"""Should handle qBittorrent connection refused."""
mock_use_case = Mock()
mock_use_case.execute.side_effect = ConnectionRefusedError()
mock_use_case_class.return_value = mock_use_case
with pytest.raises(ConnectionRefusedError):
api_tools.add_torrent_to_qbittorrent("magnet:?xt=...")
def test_add_by_index_with_empty_magnet(self, memory):
"""Should handle torrent with empty magnet."""
memory.episodic.store_search_results(
"query",
[{"name": "Torrent", "magnet": ""}],
)
result = api_tools.add_torrent_by_index(1)
assert result["status"] == "error"
assert result["error"] == "no_magnet"
def test_add_by_index_with_whitespace_magnet(self, memory):
"""Should handle torrent with whitespace magnet."""
memory.episodic.store_search_results(
"query",
[{"name": "Torrent", "magnet": " "}],
)
result = api_tools.add_torrent_by_index(1)
# Whitespace-only magnet should be treated as no magnet
# Behavior depends on implementation
assert "status" in result
class TestFilesystemEdgeCases:
"""Edge case tests for filesystem tools."""
def test_set_path_with_trailing_slash(self, memory, real_folder):
"""Should handle path with trailing slash."""
path_with_slash = str(real_folder["downloads"]) + "/"
result = fs_tools.set_path_for_folder("download", path_with_slash)
assert result["status"] == "ok"
def test_set_path_with_double_slashes(self, memory, real_folder):
"""Should handle path with double slashes."""
path_double = str(real_folder["downloads"]).replace("/", "//")
result = fs_tools.set_path_for_folder("download", path_double)
# Should normalize and work
assert result["status"] == "ok"
def test_set_path_with_dot_segments(self, memory, real_folder):
"""Should handle path with . segments."""
path_with_dots = str(real_folder["downloads"]) + "/./."
result = fs_tools.set_path_for_folder("download", path_with_dots)
assert result["status"] == "ok"
def test_list_folder_with_hidden_files(self, memory, real_folder):
"""Should list hidden files."""
hidden_file = real_folder["downloads"] / ".hidden"
hidden_file.touch()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download")
assert ".hidden" in result["entries"]
def test_list_folder_with_broken_symlink(self, memory, real_folder):
"""Should handle broken symlinks."""
broken_link = real_folder["downloads"] / "broken_link"
try:
broken_link.symlink_to("/nonexistent/target")
except OSError:
pytest.skip("Cannot create symlinks")
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download")
# Should still list the symlink
assert "broken_link" in result["entries"]
def test_list_folder_with_permission_denied_file(self, memory, real_folder):
"""Should handle files with no read permission."""
import os
no_read = real_folder["downloads"] / "no_read.txt"
no_read.touch()
try:
os.chmod(no_read, 0o000)
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download")
# Should still list the file (listing doesn't require read permission)
assert "no_read.txt" in result["entries"]
finally:
os.chmod(no_read, 0o644)
def test_list_folder_case_sensitivity(self, memory, real_folder):
"""Should handle case sensitivity correctly."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
# Try with different cases
result_lower = fs_tools.list_folder("download")
# Note: folder_type is validated, so "DOWNLOAD" would fail validation
assert result_lower["status"] == "ok"
def test_list_folder_with_spaces_in_path(self, memory, real_folder):
"""Should handle spaces in path."""
space_dir = real_folder["downloads"] / "folder with spaces"
space_dir.mkdir()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "folder with spaces")
assert result["status"] == "ok"
def test_path_traversal_with_encoded_chars(self, memory, real_folder):
"""Should block URL-encoded traversal attempts."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
# Various encoding attempts
attempts = [
"..%2f",
"..%5c",
"%2e%2e/",
"..%252f",
]
for attempt in attempts:
result = fs_tools.list_folder("download", attempt)
# Should either be forbidden or not found
assert (
result.get("error") in ["forbidden", "not_found", None]
or result.get("status") == "ok"
)
def test_path_with_null_byte(self, memory, real_folder):
"""Should block null byte injection."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "file\x00.txt")
assert result["error"] == "forbidden"
def test_very_deep_path(self, memory, real_folder):
"""Should handle very deep paths."""
# Create deep directory structure
deep_path = real_folder["downloads"]
for i in range(20):
deep_path = deep_path / f"level{i}"
deep_path.mkdir(parents=True)
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
# Navigate to deep path
relative_path = "/".join([f"level{i}" for i in range(20)])
result = fs_tools.list_folder("download", relative_path)
assert result["status"] == "ok"
def test_folder_with_many_files(self, memory, real_folder):
"""Should handle folder with many files."""
# Create many files
for i in range(1000):
(real_folder["downloads"] / f"file_{i:04d}.txt").touch()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download")
assert result["status"] == "ok"
assert result["count"] >= 1000
class TestFindMediaImdbIdEdgeCases:
"""Edge case tests for find_media_imdb_id."""
@patch("agent.tools.api.SearchMovieUseCase")
def test_movie_with_same_name_different_years(self, mock_use_case_class, memory):
"""Should handle movies with same name."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "ok",
"imdb_id": "tt1234567",
"title": "The Thing",
"year": 1982,
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
result = api_tools.find_media_imdb_id("The Thing 1982")
assert result["status"] == "ok"
@patch("agent.tools.api.SearchMovieUseCase")
def test_movie_with_special_title(self, mock_use_case_class, memory):
"""Should handle movies with special characters in title."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "ok",
"imdb_id": "tt1234567",
"title": "Se7en",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
result = api_tools.find_media_imdb_id("Se7en")
assert result["status"] == "ok"
@patch("agent.tools.api.SearchMovieUseCase")
def test_tv_show_vs_movie(self, mock_use_case_class, memory):
"""Should distinguish TV shows from movies."""
mock_response = Mock()
mock_response.to_dict.return_value = {
"status": "ok",
"imdb_id": "tt0944947",
"title": "Game of Thrones",
"media_type": "tv",
}
mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case
result = api_tools.find_media_imdb_id("Game of Thrones")
assert result["media_type"] == "tv"

View File

@@ -0,0 +1,240 @@
"""Tests for filesystem tools."""
from pathlib import Path
import pytest
from agent.tools import filesystem as fs_tools
from infrastructure.persistence import get_memory
class TestSetPathForFolder:
"""Tests for set_path_for_folder tool."""
def test_success(self, memory, real_folder):
"""Should set folder path successfully."""
result = fs_tools.set_path_for_folder("download", str(real_folder["downloads"]))
assert result["status"] == "ok"
assert result["folder_name"] == "download"
assert result["path"] == str(real_folder["downloads"])
def test_saves_to_ltm(self, memory, real_folder):
"""Should save path to LTM config."""
fs_tools.set_path_for_folder("download", str(real_folder["downloads"]))
mem = get_memory()
assert mem.ltm.get_config("download_folder") == str(real_folder["downloads"])
def test_all_folder_types(self, memory, real_folder):
"""Should accept all valid folder types."""
for folder_type in ["download", "movie", "tvshow", "torrent"]:
result = fs_tools.set_path_for_folder(
folder_type, str(real_folder["downloads"])
)
assert result["status"] == "ok"
def test_invalid_folder_type(self, memory, real_folder):
"""Should reject invalid folder type."""
result = fs_tools.set_path_for_folder("invalid", str(real_folder["downloads"]))
assert result["error"] == "validation_failed"
def test_path_not_exists(self, memory):
"""Should reject non-existent path."""
result = fs_tools.set_path_for_folder("download", "/nonexistent/path/12345")
assert result["error"] == "invalid_path"
assert "does not exist" in result["message"]
def test_path_is_file(self, memory, real_folder):
"""Should reject file path."""
file_path = real_folder["downloads"] / "test_movie.mkv"
result = fs_tools.set_path_for_folder("download", str(file_path))
assert result["error"] == "invalid_path"
assert "not a directory" in result["message"]
def test_resolves_path(self, memory, real_folder):
"""Should resolve relative paths."""
# Create a symlink or use relative path
relative_path = real_folder["downloads"]
result = fs_tools.set_path_for_folder("download", str(relative_path))
assert result["status"] == "ok"
# Path should be absolute
assert Path(result["path"]).is_absolute()
class TestListFolder:
"""Tests for list_folder tool."""
def test_success(self, memory, real_folder):
"""Should list folder contents."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download")
assert result["status"] == "ok"
assert "test_movie.mkv" in result["entries"]
assert "test_series" in result["entries"]
assert result["count"] == 2
def test_subfolder(self, memory, real_folder):
"""Should list subfolder contents."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "test_series")
assert result["status"] == "ok"
assert "episode1.mkv" in result["entries"]
def test_folder_not_configured(self, memory):
"""Should return error if folder not configured."""
result = fs_tools.list_folder("download")
assert result["error"] == "folder_not_set"
def test_invalid_folder_type(self, memory):
"""Should reject invalid folder type."""
result = fs_tools.list_folder("invalid")
assert result["error"] == "validation_failed"
def test_path_traversal_dotdot(self, memory, real_folder):
"""Should block path traversal with .."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "../")
assert result["error"] == "forbidden"
def test_path_traversal_absolute(self, memory, real_folder):
"""Should block absolute paths."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "/etc/passwd")
assert result["error"] == "forbidden"
def test_path_traversal_encoded(self, memory, real_folder):
"""Should block encoded traversal attempts."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "..%2F..%2Fetc")
# Should either be forbidden or not found (depending on normalization)
assert result.get("error") in ["forbidden", "not_found"]
def test_path_not_exists(self, memory, real_folder):
"""Should return error for non-existent path."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "nonexistent_folder")
assert result["error"] == "not_found"
def test_path_is_file(self, memory, real_folder):
"""Should return error if path is a file."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "test_movie.mkv")
assert result["error"] == "not_a_directory"
def test_empty_folder(self, memory, real_folder):
"""Should handle empty folder."""
empty_dir = real_folder["downloads"] / "empty"
empty_dir.mkdir()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "empty")
assert result["status"] == "ok"
assert result["entries"] == []
assert result["count"] == 0
def test_sorted_entries(self, memory, real_folder):
"""Should return sorted entries."""
# Create files with different names
(real_folder["downloads"] / "zebra.txt").touch()
(real_folder["downloads"] / "alpha.txt").touch()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download")
assert result["status"] == "ok"
# Check that entries are sorted
entries = result["entries"]
assert entries == sorted(entries)
class TestFileManagerSecurity:
"""Security-focused tests for FileManager."""
def test_null_byte_injection(self, memory, real_folder):
"""Should block null byte injection."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "test\x00.txt")
assert result["error"] == "forbidden"
def test_path_outside_root(self, memory, real_folder):
"""Should block paths that escape root."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
# Try to access parent directory
result = fs_tools.list_folder("download", "test_series/../../")
assert result["error"] == "forbidden"
def test_symlink_escape(self, memory, real_folder):
"""Should handle symlinks that point outside root."""
# Create a symlink pointing outside
symlink = real_folder["downloads"] / "escape_link"
try:
symlink.symlink_to("/tmp")
except OSError:
pytest.skip("Cannot create symlinks")
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "escape_link")
# Should either be forbidden or work (depending on policy)
# The important thing is it doesn't crash
assert "error" in result or "status" in result
def test_special_characters_in_path(self, memory, real_folder):
"""Should handle special characters in path."""
special_dir = real_folder["downloads"] / "special !@#$%"
special_dir.mkdir()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "special !@#$%")
assert result["status"] == "ok"
def test_unicode_path(self, memory, real_folder):
"""Should handle unicode in path."""
unicode_dir = real_folder["downloads"] / "日本語フォルダ"
unicode_dir.mkdir()
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
result = fs_tools.list_folder("download", "日本語フォルダ")
assert result["status"] == "ok"
def test_very_long_path(self, memory, real_folder):
"""Should handle very long paths gracefully."""
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
long_path = "a" * 1000
result = fs_tools.list_folder("download", long_path)
# Should return an error, not crash
assert "error" in result