Formatting
This commit is contained in:
124
agent/agent.py
124
agent/agent.py
@@ -1,7 +1,8 @@
|
||||
"""Main agent for media library management."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
@@ -15,157 +16,156 @@ logger = logging.getLogger(__name__)
|
||||
class Agent:
|
||||
"""
|
||||
AI agent for media library management.
|
||||
|
||||
|
||||
Uses OpenAI-compatible tool calling API.
|
||||
"""
|
||||
|
||||
def __init__(self, llm, max_tool_iterations: int = 5):
|
||||
"""
|
||||
Initialize the agent.
|
||||
|
||||
|
||||
Args:
|
||||
llm: LLM client with complete() method
|
||||
max_tool_iterations: Maximum number of tool execution iterations
|
||||
"""
|
||||
self.llm = llm
|
||||
self.tools: Dict[str, Tool] = make_tools()
|
||||
self.tools: dict[str, Tool] = make_tools()
|
||||
self.prompt_builder = PromptBuilder(self.tools)
|
||||
self.max_tool_iterations = max_tool_iterations
|
||||
|
||||
def step(self, user_input: str) -> str:
|
||||
"""
|
||||
Execute one agent step with the user input.
|
||||
|
||||
|
||||
This method:
|
||||
1. Adds user message to memory
|
||||
2. Builds prompt with history and context
|
||||
3. Calls LLM, executing tools as needed
|
||||
4. Returns final response
|
||||
|
||||
|
||||
Args:
|
||||
user_input: User's message
|
||||
|
||||
|
||||
Returns:
|
||||
Agent's final response
|
||||
"""
|
||||
memory = get_memory()
|
||||
|
||||
|
||||
# Add user message to history
|
||||
memory.stm.add_message("user", user_input)
|
||||
memory.save()
|
||||
|
||||
|
||||
# Build initial messages
|
||||
system_prompt = self.prompt_builder.build_system_prompt()
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
|
||||
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add conversation history
|
||||
history = memory.stm.get_recent_history(settings.max_history_messages)
|
||||
messages.extend(history)
|
||||
|
||||
|
||||
# Add unread events if any
|
||||
unread_events = memory.episodic.get_unread_events()
|
||||
if unread_events:
|
||||
events_text = "\n".join([
|
||||
f"- {e['type']}: {e['data']}"
|
||||
for e in unread_events
|
||||
])
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": f"Background events:\n{events_text}"
|
||||
})
|
||||
|
||||
events_text = "\n".join(
|
||||
[f"- {e['type']}: {e['data']}" for e in unread_events]
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": f"Background events:\n{events_text}"}
|
||||
)
|
||||
|
||||
# Get tools specification for OpenAI format
|
||||
tools_spec = self.prompt_builder.build_tools_spec()
|
||||
|
||||
|
||||
# Tool execution loop
|
||||
for iteration in range(self.max_tool_iterations):
|
||||
# Call LLM with tools
|
||||
llm_result = self.llm.complete(messages, tools=tools_spec)
|
||||
|
||||
|
||||
# Handle both tuple (response, usage) and dict response
|
||||
if isinstance(llm_result, tuple):
|
||||
response_message, usage = llm_result
|
||||
else:
|
||||
response_message = llm_result
|
||||
|
||||
|
||||
# Check if there are tool calls
|
||||
tool_calls = response_message.get("tool_calls")
|
||||
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls, this is the final response
|
||||
final_content = response_message.get("content", "")
|
||||
memory.stm.add_message("assistant", final_content)
|
||||
memory.save()
|
||||
return final_content
|
||||
|
||||
|
||||
# Add assistant message with tool calls to conversation
|
||||
messages.append(response_message)
|
||||
|
||||
|
||||
# Execute each tool call
|
||||
for tool_call in tool_calls:
|
||||
tool_result = self._execute_tool_call(tool_call)
|
||||
|
||||
|
||||
# Add tool result to messages
|
||||
messages.append({
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"role": "tool",
|
||||
"name": tool_call.get("function", {}).get("name"),
|
||||
"content": json.dumps(tool_result, ensure_ascii=False),
|
||||
})
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"role": "tool",
|
||||
"name": tool_call.get("function", {}).get("name"),
|
||||
"content": json.dumps(tool_result, ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
|
||||
# Max iterations reached, force final response
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": "Please provide a final response to the user without using any more tools."
|
||||
})
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please provide a final response to the user without using any more tools.",
|
||||
}
|
||||
)
|
||||
|
||||
llm_result = self.llm.complete(messages)
|
||||
if isinstance(llm_result, tuple):
|
||||
final_message, usage = llm_result
|
||||
else:
|
||||
final_message = llm_result
|
||||
|
||||
final_response = final_message.get("content", "I've completed the requested actions.")
|
||||
|
||||
final_response = final_message.get(
|
||||
"content", "I've completed the requested actions."
|
||||
)
|
||||
memory.stm.add_message("assistant", final_response)
|
||||
memory.save()
|
||||
return final_response
|
||||
|
||||
def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a single tool call.
|
||||
|
||||
|
||||
Args:
|
||||
tool_call: OpenAI-format tool call dict
|
||||
|
||||
|
||||
Returns:
|
||||
Result dictionary
|
||||
"""
|
||||
function = tool_call.get("function", {})
|
||||
tool_name = function.get("name", "")
|
||||
|
||||
|
||||
try:
|
||||
args_str = function.get("arguments", "{}")
|
||||
args = json.loads(args_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse tool arguments: {e}")
|
||||
return {
|
||||
"error": "bad_args",
|
||||
"message": f"Invalid JSON arguments: {e}"
|
||||
}
|
||||
|
||||
return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"}
|
||||
|
||||
# Validate tool exists
|
||||
if tool_name not in self.tools:
|
||||
available = list(self.tools.keys())
|
||||
return {
|
||||
"error": "unknown_tool",
|
||||
"message": f"Tool '{tool_name}' not found",
|
||||
"available_tools": available
|
||||
"available_tools": available,
|
||||
}
|
||||
|
||||
|
||||
tool = self.tools[tool_name]
|
||||
|
||||
|
||||
# Execute tool
|
||||
try:
|
||||
result = tool.func(**args)
|
||||
@@ -177,17 +177,9 @@ class Agent:
|
||||
# Bad arguments
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(tool_name, f"bad_args: {e}")
|
||||
return {
|
||||
"error": "bad_args",
|
||||
"message": str(e),
|
||||
"tool": tool_name
|
||||
}
|
||||
return {"error": "bad_args", "message": str(e), "tool": tool_name}
|
||||
except Exception as e:
|
||||
# Other errors
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(tool_name, str(e))
|
||||
return {
|
||||
"error": "execution_failed",
|
||||
"message": str(e),
|
||||
"tool": tool_name
|
||||
}
|
||||
return {"error": "execution_failed", "message": str(e), "tool": tool_name}
|
||||
|
||||
@@ -51,7 +51,9 @@ class DeepSeekClient:
|
||||
|
||||
logger.info(f"DeepSeek client initialized with model: {self.model}")
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]:
|
||||
def complete(
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
|
||||
@@ -80,7 +82,9 @@ class DeepSeekClient:
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
# Content is optional for tool messages (they may have tool_call_id instead)
|
||||
if msg["role"] != "tool" and "content" not in msg:
|
||||
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}")
|
||||
raise ValueError(
|
||||
f"Non-tool message must have 'content' key, got {msg.keys()}"
|
||||
)
|
||||
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
headers = {
|
||||
@@ -92,13 +96,15 @@ class DeepSeekClient:
|
||||
"messages": messages,
|
||||
"temperature": settings.temperature,
|
||||
}
|
||||
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools")
|
||||
logger.debug(
|
||||
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
|
||||
)
|
||||
response = requests.post(
|
||||
url, headers=headers, json=payload, timeout=self.timeout
|
||||
)
|
||||
|
||||
@@ -66,7 +66,9 @@ class OllamaClient:
|
||||
|
||||
logger.info(f"Ollama client initialized with model: {self.model}")
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]:
|
||||
def complete(
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
|
||||
@@ -95,7 +97,9 @@ class OllamaClient:
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
# Content is optional for tool messages (they may have tool_call_id instead)
|
||||
if msg["role"] != "tool" and "content" not in msg:
|
||||
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}")
|
||||
raise ValueError(
|
||||
f"Non-tool message must have 'content' key, got {msg.keys()}"
|
||||
)
|
||||
|
||||
url = f"{self.base_url}/api/chat"
|
||||
payload = {
|
||||
@@ -106,13 +110,15 @@ class OllamaClient:
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools")
|
||||
logger.debug(
|
||||
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
|
||||
)
|
||||
response = requests.post(url, json=payload, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
"""Prompt builder for the agent system."""
|
||||
from typing import Dict, List, Any
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
from .registry import Tool
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""Builds system prompts for the agent with memory context."""
|
||||
|
||||
def __init__(self, tools: Dict[str, Tool]):
|
||||
def __init__(self, tools: dict[str, Tool]):
|
||||
self.tools = tools
|
||||
|
||||
def build_tools_spec(self) -> List[Dict[str, Any]]:
|
||||
def build_tools_spec(self) -> list[dict[str, Any]]:
|
||||
"""Build the tool specification for the LLM API."""
|
||||
tool_specs = []
|
||||
for tool in self.tools.values():
|
||||
@@ -44,11 +46,13 @@ class PromptBuilder:
|
||||
|
||||
if memory.episodic.last_search_results:
|
||||
results = memory.episodic.last_search_results
|
||||
result_list = results.get('results', [])
|
||||
lines.append(f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)")
|
||||
result_list = results.get("results", [])
|
||||
lines.append(
|
||||
f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)"
|
||||
)
|
||||
# Show first 5 results
|
||||
for i, result in enumerate(result_list[:5]):
|
||||
name = result.get('name', 'Unknown')
|
||||
name = result.get("name", "Unknown")
|
||||
lines.append(f" {i+1}. {name}")
|
||||
if len(result_list) > 5:
|
||||
lines.append(f" ... and {len(result_list) - 5} more")
|
||||
@@ -57,7 +61,7 @@ class PromptBuilder:
|
||||
question = memory.episodic.pending_question
|
||||
lines.append(f"\nPENDING QUESTION: {question.get('question')}")
|
||||
lines.append(f" Type: {question.get('type')}")
|
||||
if question.get('options'):
|
||||
if question.get("options"):
|
||||
lines.append(f" Options: {len(question.get('options'))}")
|
||||
|
||||
if memory.episodic.active_downloads:
|
||||
@@ -68,10 +72,12 @@ class PromptBuilder:
|
||||
if memory.episodic.recent_errors:
|
||||
lines.append("\nRECENT ERRORS (up to 3):")
|
||||
for error in memory.episodic.recent_errors[-3:]:
|
||||
lines.append(f" - Action '{error.get('action')}' failed: {error.get('error')}")
|
||||
lines.append(
|
||||
f" - Action '{error.get('action')}' failed: {error.get('error')}"
|
||||
)
|
||||
|
||||
# Unread events
|
||||
unread = [e for e in memory.episodic.background_events if not e.get('read')]
|
||||
unread = [e for e in memory.episodic.background_events if not e.get("read")]
|
||||
if unread:
|
||||
lines.append(f"\nUNREAD EVENTS: {len(unread)}")
|
||||
for event in unread[:3]:
|
||||
@@ -86,8 +92,10 @@ class PromptBuilder:
|
||||
|
||||
if memory.stm.current_workflow:
|
||||
workflow = memory.stm.current_workflow
|
||||
lines.append(f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})")
|
||||
if workflow.get('target'):
|
||||
lines.append(
|
||||
f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})"
|
||||
)
|
||||
if workflow.get("target"):
|
||||
lines.append(f" Target: {workflow.get('target')}")
|
||||
|
||||
if memory.stm.current_topic:
|
||||
@@ -97,7 +105,7 @@ class PromptBuilder:
|
||||
lines.append("EXTRACTED ENTITIES:")
|
||||
for key, value in memory.stm.extracted_entities.items():
|
||||
lines.append(f" - {key}: {value}")
|
||||
|
||||
|
||||
if memory.stm.language:
|
||||
lines.append(f"CONVERSATION LANGUAGE: {memory.stm.language}")
|
||||
|
||||
@@ -106,7 +114,7 @@ class PromptBuilder:
|
||||
def _format_config_context(self) -> str:
|
||||
"""Format configuration context."""
|
||||
memory = get_memory()
|
||||
|
||||
|
||||
lines = ["CURRENT CONFIGURATION:"]
|
||||
if memory.ltm.config:
|
||||
for key, value in memory.ltm.config.items():
|
||||
@@ -118,10 +126,10 @@ class PromptBuilder:
|
||||
def build_system_prompt(self) -> str:
|
||||
"""Build the complete system prompt."""
|
||||
memory = get_memory()
|
||||
|
||||
|
||||
# Base instruction
|
||||
base = "You are a helpful AI assistant for managing a media library."
|
||||
|
||||
|
||||
# Language instruction
|
||||
language_instruction = (
|
||||
"Your first task is to determine the user's language from their message "
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Tool registry - defines and registers all available tools for the agent."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Any, Dict
|
||||
import logging
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,36 +12,37 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class Tool:
|
||||
"""Represents a tool that can be used by the agent."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
func: Callable[..., Dict[str, Any]]
|
||||
parameters: Dict[str, Any]
|
||||
func: Callable[..., dict[str, Any]]
|
||||
parameters: dict[str, Any]
|
||||
|
||||
|
||||
def _create_tool_from_function(func: Callable) -> Tool:
|
||||
"""
|
||||
Create a Tool object from a function.
|
||||
|
||||
|
||||
Args:
|
||||
func: Function to convert to a tool
|
||||
|
||||
|
||||
Returns:
|
||||
Tool object with metadata extracted from function
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func)
|
||||
|
||||
|
||||
# Extract description from docstring (first line)
|
||||
description = doc.strip().split('\n')[0] if doc else func.__name__
|
||||
|
||||
description = doc.strip().split("\n")[0] if doc else func.__name__
|
||||
|
||||
# Build JSON schema from function signature
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
|
||||
# Map Python types to JSON schema types
|
||||
param_type = "string" # default
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
@@ -51,22 +54,22 @@ def _create_tool_from_function(func: Callable) -> Tool:
|
||||
param_type = "number"
|
||||
elif param.annotation == bool:
|
||||
param_type = "boolean"
|
||||
|
||||
|
||||
properties[param_name] = {
|
||||
"type": param_type,
|
||||
"description": f"Parameter {param_name}"
|
||||
"description": f"Parameter {param_name}",
|
||||
}
|
||||
|
||||
|
||||
# Add to required if no default value
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
|
||||
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
|
||||
return Tool(
|
||||
name=func.__name__,
|
||||
description=description,
|
||||
@@ -75,18 +78,18 @@ def _create_tool_from_function(func: Callable) -> Tool:
|
||||
)
|
||||
|
||||
|
||||
def make_tools() -> Dict[str, Tool]:
|
||||
def make_tools() -> dict[str, Tool]:
|
||||
"""
|
||||
Create and register all available tools.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to Tool objects
|
||||
"""
|
||||
# Import tools here to avoid circular dependencies
|
||||
from .tools import filesystem as fs_tools
|
||||
from .tools import api as api_tools
|
||||
from .tools import filesystem as fs_tools
|
||||
from .tools import language as lang_tools
|
||||
|
||||
|
||||
# List of all tool functions
|
||||
tool_functions = [
|
||||
fs_tools.set_path_for_folder,
|
||||
@@ -98,12 +101,12 @@ def make_tools() -> Dict[str, Tool]:
|
||||
api_tools.get_torrent_by_index,
|
||||
lang_tools.set_language,
|
||||
]
|
||||
|
||||
|
||||
# Create Tool objects from functions
|
||||
tools = {}
|
||||
for func in tool_functions:
|
||||
tool = _create_tool_from_function(func)
|
||||
tools[tool.name] = tool
|
||||
|
||||
|
||||
logger.info(f"Registered {len(tools)} tools: {list(tools.keys())}")
|
||||
return tools
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
"""Language management tools for the agent."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_language(language: str) -> Dict[str, Any]:
|
||||
def set_language(language: str) -> dict[str, Any]:
|
||||
"""
|
||||
Set the conversation language.
|
||||
|
||||
|
||||
Args:
|
||||
language: Language code (e.g., 'en', 'fr', 'es', 'de')
|
||||
|
||||
|
||||
Returns:
|
||||
Status dictionary
|
||||
"""
|
||||
@@ -21,17 +22,14 @@ def set_language(language: str) -> Dict[str, Any]:
|
||||
memory = get_memory()
|
||||
memory.stm.set_language(language)
|
||||
memory.save()
|
||||
|
||||
|
||||
logger.info(f"Language set to: {language}")
|
||||
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"Language set to {language}",
|
||||
"language": language
|
||||
"language": language,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set language: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
@@ -359,9 +359,7 @@ class EpisodicMemory:
|
||||
"""Get active downloads."""
|
||||
return self.active_downloads
|
||||
|
||||
def add_error(
|
||||
self, action: str, error: str, context: dict | None = None
|
||||
) -> None:
|
||||
def add_error(self, action: str, error: str, context: dict | None = None) -> None:
|
||||
"""Record a recent error."""
|
||||
self.recent_errors.append(
|
||||
{
|
||||
@@ -408,9 +406,7 @@ class EpisodicMemory:
|
||||
"""Get the pending question."""
|
||||
return self.pending_question
|
||||
|
||||
def resolve_pending_question(
|
||||
self, answer_index: int | None = None
|
||||
) -> dict | None:
|
||||
def resolve_pending_question(self, answer_index: int | None = None) -> dict | None:
|
||||
"""
|
||||
Resolve the pending question and return the chosen option.
|
||||
|
||||
|
||||
@@ -110,4 +110,4 @@ select = [
|
||||
"PL",
|
||||
"UP",
|
||||
]
|
||||
ignore = ["W503", "PLR0913", "PLR2004"]
|
||||
ignore = ["PLR0913", "PLR2004"]
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
"""Pytest configuration and shared fixtures."""
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock
|
||||
|
||||
from infrastructure.persistence import Memory, init_memory, set_memory, get_memory
|
||||
from infrastructure.persistence.memory import (
|
||||
LongTermMemory,
|
||||
ShortTermMemory,
|
||||
EpisodicMemory,
|
||||
)
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.persistence import Memory, set_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -122,12 +119,11 @@ def memory_with_library(memory):
|
||||
def mock_llm():
|
||||
"""Create a mock LLM client that returns OpenAI-compatible format."""
|
||||
llm = Mock()
|
||||
|
||||
# Return OpenAI-style message dict without tool calls
|
||||
def complete_func(messages, tools=None):
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "I found what you're looking for!"
|
||||
}
|
||||
return {"role": "assistant", "content": "I found what you're looking for!"}
|
||||
|
||||
llm.complete = Mock(side_effect=complete_func)
|
||||
return llm
|
||||
|
||||
@@ -136,34 +132,33 @@ def mock_llm():
|
||||
def mock_llm_with_tool_call():
|
||||
"""Create a mock LLM that returns a tool call then a response."""
|
||||
llm = Mock()
|
||||
|
||||
|
||||
# First call returns a tool call, second returns final response
|
||||
def complete_side_effect(messages, tools=None):
|
||||
if not hasattr(complete_side_effect, 'call_count'):
|
||||
if not hasattr(complete_side_effect, "call_count"):
|
||||
complete_side_effect.call_count = 0
|
||||
complete_side_effect.call_count += 1
|
||||
|
||||
|
||||
if complete_side_effect.call_count == 1:
|
||||
# First call: return tool call
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_torrent",
|
||||
"arguments": '{"media_title": "Inception"}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_torrent",
|
||||
"arguments": '{"media_title": "Inception"}',
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
else:
|
||||
# Second call: return final response
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "I found 3 torrents for Inception!"
|
||||
}
|
||||
|
||||
return {"role": "assistant", "content": "I found 3 torrents for Inception!"}
|
||||
|
||||
llm.complete = Mock(side_effect=complete_side_effect)
|
||||
return llm
|
||||
|
||||
@@ -248,36 +243,36 @@ def mock_deepseek():
|
||||
"""
|
||||
Mock DeepSeekClient for individual tests that need it.
|
||||
This prevents real API calls in tests that use this fixture.
|
||||
|
||||
|
||||
Usage:
|
||||
def test_something(mock_deepseek):
|
||||
# Your test code here
|
||||
"""
|
||||
import sys
|
||||
from unittest.mock import Mock, MagicMock
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Save the original module if it exists
|
||||
original_module = sys.modules.get('agent.llm.deepseek')
|
||||
|
||||
original_module = sys.modules.get("agent.llm.deepseek")
|
||||
|
||||
# Create a mock module for deepseek
|
||||
mock_deepseek_module = MagicMock()
|
||||
|
||||
|
||||
class MockDeepSeekClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.complete = Mock(return_value="Mocked LLM response")
|
||||
|
||||
|
||||
mock_deepseek_module.DeepSeekClient = MockDeepSeekClient
|
||||
|
||||
|
||||
# Inject the mock
|
||||
sys.modules['agent.llm.deepseek'] = mock_deepseek_module
|
||||
|
||||
sys.modules["agent.llm.deepseek"] = mock_deepseek_module
|
||||
|
||||
yield mock_deepseek_module
|
||||
|
||||
|
||||
# Restore the original module
|
||||
if original_module is not None:
|
||||
sys.modules['agent.llm.deepseek'] = original_module
|
||||
elif 'agent.llm.deepseek' in sys.modules:
|
||||
del sys.modules['agent.llm.deepseek']
|
||||
sys.modules["agent.llm.deepseek"] = original_module
|
||||
elif "agent.llm.deepseek" in sys.modules:
|
||||
del sys.modules["agent.llm.deepseek"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -287,8 +282,8 @@ def mock_agent_step():
|
||||
Returns a context manager that patches app.agent.step.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _mock_step(return_value="Mocked agent response"):
|
||||
return patch("app.agent.step", return_value=return_value)
|
||||
|
||||
|
||||
return _mock_step
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for the Agent."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
from agent.agent import Agent
|
||||
from infrastructure.persistence import get_memory
|
||||
@@ -55,8 +55,8 @@ class TestExecuteToolCall:
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}'
|
||||
}
|
||||
"arguments": '{"folder_type": "download"}',
|
||||
},
|
||||
}
|
||||
result = agent._execute_tool_call(tool_call)
|
||||
|
||||
@@ -68,10 +68,7 @@ class TestExecuteToolCall:
|
||||
|
||||
tool_call = {
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "unknown_tool",
|
||||
"arguments": '{}'
|
||||
}
|
||||
"function": {"name": "unknown_tool", "arguments": "{}"},
|
||||
}
|
||||
result = agent._execute_tool_call(tool_call)
|
||||
|
||||
@@ -84,10 +81,7 @@ class TestExecuteToolCall:
|
||||
|
||||
tool_call = {
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "set_path_for_folder",
|
||||
"arguments": '{}'
|
||||
}
|
||||
"function": {"name": "set_path_for_folder", "arguments": "{}"},
|
||||
}
|
||||
result = agent._execute_tool_call(tool_call)
|
||||
|
||||
@@ -102,8 +96,8 @@ class TestExecuteToolCall:
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "set_path_for_folder",
|
||||
"arguments": '{"folder_name": 123}' # Wrong type
|
||||
}
|
||||
"arguments": '{"folder_name": 123}', # Wrong type
|
||||
},
|
||||
}
|
||||
result = agent._execute_tool_call(tool_call)
|
||||
|
||||
@@ -116,10 +110,7 @@ class TestExecuteToolCall:
|
||||
|
||||
tool_call = {
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{invalid json}'
|
||||
}
|
||||
"function": {"name": "list_folder", "arguments": "{invalid json}"},
|
||||
}
|
||||
result = agent._execute_tool_call(tool_call)
|
||||
|
||||
@@ -160,40 +151,39 @@ class TestStep:
|
||||
|
||||
assert "found" in response.lower() or "torrent" in response.lower()
|
||||
assert mock_llm_with_tool_call.complete.call_count == 2
|
||||
|
||||
|
||||
# CRITICAL: Verify tools were passed to LLM
|
||||
first_call_args = mock_llm_with_tool_call.complete.call_args_list[0]
|
||||
assert first_call_args[1]['tools'] is not None, "Tools not passed to LLM!"
|
||||
assert len(first_call_args[1]['tools']) > 0, "Tools list is empty!"
|
||||
assert first_call_args[1]["tools"] is not None, "Tools not passed to LLM!"
|
||||
assert len(first_call_args[1]["tools"]) > 0, "Tools list is empty!"
|
||||
|
||||
def test_step_max_iterations(self, memory, mock_llm):
|
||||
"""Should stop after max iterations."""
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_complete(messages, tools=None):
|
||||
call_count[0] += 1
|
||||
# CRITICAL: Verify tools are passed (except on forced final call)
|
||||
if call_count[0] <= 3:
|
||||
assert tools is not None, f"Tools not passed on call {call_count[0]}!"
|
||||
|
||||
|
||||
if call_count[0] <= 3:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": f"call_{call_count[0]}",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": f"call_{call_count[0]}",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}',
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "I couldn't complete the task."
|
||||
}
|
||||
|
||||
return {"role": "assistant", "content": "I couldn't complete the task."}
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm, max_tool_iterations=3)
|
||||
|
||||
@@ -241,49 +231,53 @@ class TestAgentIntegration:
|
||||
memory.ltm.set_config("movie_folder", str(real_folder["movies"]))
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_complete(messages, tools=None):
|
||||
call_count[0] += 1
|
||||
# CRITICAL: Verify tools are passed on every call
|
||||
assert tools is not None, f"Tools not passed on call {call_count[0]}!"
|
||||
|
||||
|
||||
if call_count[0] == 1:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}',
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
elif call_count[0] == 2:
|
||||
# CRITICAL: Verify tool result was sent back
|
||||
tool_messages = [m for m in messages if m.get('role') == 'tool']
|
||||
tool_messages = [m for m in messages if m.get("role") == "tool"]
|
||||
assert len(tool_messages) > 0, "Tool result not sent back to LLM!"
|
||||
|
||||
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_2",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "movie"}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_2",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "movie"}',
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "I listed both folders for you."
|
||||
"content": "I listed both folders for you.",
|
||||
}
|
||||
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm)
|
||||
|
||||
|
||||
response = agent.step("List my downloads and movies")
|
||||
|
||||
assert call_count[0] == 3
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Edge case tests for the Agent."""
|
||||
import pytest
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.agent import Agent
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
@@ -15,19 +17,14 @@ class TestExecuteToolCallEdgeCases:
|
||||
|
||||
# Mock a tool that returns None
|
||||
from agent.registry import Tool
|
||||
|
||||
agent.tools["test_tool"] = Tool(
|
||||
name="test_tool",
|
||||
description="Test",
|
||||
func=lambda: None,
|
||||
parameters={}
|
||||
name="test_tool", description="Test", func=lambda: None, parameters={}
|
||||
)
|
||||
|
||||
tool_call = {
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{}'
|
||||
}
|
||||
"function": {"name": "test_tool", "arguments": "{}"},
|
||||
}
|
||||
result = agent._execute_tool_call(tool_call)
|
||||
|
||||
@@ -38,22 +35,17 @@ class TestExecuteToolCallEdgeCases:
|
||||
agent = Agent(llm=mock_llm)
|
||||
|
||||
from agent.registry import Tool
|
||||
|
||||
def raise_interrupt():
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
|
||||
agent.tools["test_tool"] = Tool(
|
||||
name="test_tool",
|
||||
description="Test",
|
||||
func=raise_interrupt,
|
||||
parameters={}
|
||||
name="test_tool", description="Test", func=raise_interrupt, parameters={}
|
||||
)
|
||||
|
||||
tool_call = {
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{}'
|
||||
}
|
||||
"function": {"name": "test_tool", "arguments": "{}"},
|
||||
}
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
@@ -68,8 +60,8 @@ class TestExecuteToolCallEdgeCases:
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}'
|
||||
}
|
||||
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}',
|
||||
},
|
||||
}
|
||||
|
||||
result = agent._execute_tool_call(tool_call)
|
||||
@@ -84,8 +76,8 @@ class TestExecuteToolCallEdgeCases:
|
||||
"id": "call_123",
|
||||
"function": {
|
||||
"name": "get_torrent_by_index",
|
||||
"arguments": '{"index": "not an int"}'
|
||||
}
|
||||
"arguments": '{"index": "not an int"}',
|
||||
},
|
||||
}
|
||||
|
||||
result = agent._execute_tool_call(tool_call)
|
||||
@@ -115,12 +107,10 @@ class TestStepEdgeCases:
|
||||
|
||||
def test_step_with_unicode_input(self, memory, mock_llm):
|
||||
"""Should handle unicode input."""
|
||||
|
||||
def mock_complete(messages, tools=None):
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "日本語の応答"
|
||||
}
|
||||
|
||||
return {"role": "assistant", "content": "日本語の応答"}
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm)
|
||||
|
||||
@@ -130,12 +120,10 @@ class TestStepEdgeCases:
|
||||
|
||||
def test_step_llm_returns_empty(self, memory, mock_llm):
|
||||
"""Should handle LLM returning empty string."""
|
||||
|
||||
def mock_complete(messages, tools=None):
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": ""
|
||||
}
|
||||
|
||||
return {"role": "assistant", "content": ""}
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm)
|
||||
|
||||
@@ -161,18 +149,17 @@ class TestStepEdgeCases:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": f"call_{call_count[0]}",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": f"call_{call_count[0]}",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}',
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "Done looping"
|
||||
}
|
||||
return {"role": "assistant", "content": "Done looping"}
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm, max_tool_iterations=3)
|
||||
@@ -212,11 +199,13 @@ class TestStepEdgeCases:
|
||||
|
||||
def test_step_with_active_downloads(self, memory, mock_llm):
|
||||
"""Should include active downloads in context."""
|
||||
memory.episodic.add_active_download({
|
||||
"task_id": "123",
|
||||
"name": "Movie.mkv",
|
||||
"progress": 50,
|
||||
})
|
||||
memory.episodic.add_active_download(
|
||||
{
|
||||
"task_id": "123",
|
||||
"name": "Movie.mkv",
|
||||
"progress": 50,
|
||||
}
|
||||
)
|
||||
agent = Agent(llm=mock_llm)
|
||||
|
||||
response = agent.step("Hello")
|
||||
@@ -257,29 +246,28 @@ class TestAgentConcurrencyEdgeCases:
|
||||
memory.ltm.set_config("download_folder", str(real_folder["downloads"]))
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_complete(messages, tools=None):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "set_path_for_folder",
|
||||
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "set_path_for_folder",
|
||||
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}',
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "Path set successfully."
|
||||
}
|
||||
return {"role": "assistant", "content": "Path set successfully."}
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm)
|
||||
|
||||
|
||||
response = agent.step("Set movie folder")
|
||||
|
||||
mem = get_memory()
|
||||
@@ -292,29 +280,28 @@ class TestAgentErrorRecovery:
|
||||
def test_recovers_from_tool_error(self, memory, mock_llm):
|
||||
"""Should recover from tool error and continue."""
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_complete(messages, tools=None):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}',
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "The folder is not configured."
|
||||
}
|
||||
return {"role": "assistant", "content": "The folder is not configured."}
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm)
|
||||
|
||||
|
||||
response = agent.step("List downloads")
|
||||
|
||||
assert "not configured" in response.lower() or len(response) > 0
|
||||
@@ -322,29 +309,28 @@ class TestAgentErrorRecovery:
|
||||
def test_error_tracked_in_memory(self, memory, mock_llm):
|
||||
"""Should track errors in episodic memory."""
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_complete(messages, tools=None):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "set_path_for_folder",
|
||||
"arguments": '{}' # Missing required args
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "set_path_for_folder",
|
||||
"arguments": "{}", # Missing required args
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "Error occurred."
|
||||
}
|
||||
return {"role": "assistant", "content": "Error occurred."}
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm)
|
||||
|
||||
|
||||
agent.step("Set folder")
|
||||
|
||||
mem = get_memory()
|
||||
@@ -360,18 +346,17 @@ class TestAgentErrorRecovery:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": f"call_{call_count[0]}",
|
||||
"function": {
|
||||
"name": "set_path_for_folder",
|
||||
"arguments": '{}' # Missing required args - will error
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": f"call_{call_count[0]}",
|
||||
"function": {
|
||||
"name": "set_path_for_folder",
|
||||
"arguments": "{}", # Missing required args - will error
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "All attempts failed."
|
||||
}
|
||||
return {"role": "assistant", "content": "All attempts failed."}
|
||||
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent = Agent(llm=mock_llm, max_tool_iterations=3)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for FastAPI endpoints."""
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@@ -10,6 +11,7 @@ class TestHealthEndpoint:
|
||||
def test_health_check(self, memory):
|
||||
"""Should return healthy status."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
@@ -24,6 +26,7 @@ class TestModelsEndpoint:
|
||||
def test_list_models(self, memory):
|
||||
"""Should return model list."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/v1/models")
|
||||
@@ -41,6 +44,7 @@ class TestMemoryEndpoints:
|
||||
def test_get_memory_state(self, memory):
|
||||
"""Should return full memory state."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/memory/state")
|
||||
@@ -54,6 +58,7 @@ class TestMemoryEndpoints:
|
||||
def test_get_search_results_empty(self, memory):
|
||||
"""Should return empty when no search results."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/memory/episodic/search-results")
|
||||
@@ -65,6 +70,7 @@ class TestMemoryEndpoints:
|
||||
def test_get_search_results_with_data(self, memory_with_search_results):
|
||||
"""Should return search results when available."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/memory/episodic/search-results")
|
||||
@@ -78,6 +84,7 @@ class TestMemoryEndpoints:
|
||||
def test_clear_session(self, memory_with_search_results):
|
||||
"""Should clear session memories."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/memory/clear-session")
|
||||
@@ -96,14 +103,18 @@ class TestChatCompletionsEndpoint:
|
||||
def test_chat_completion_success(self, memory):
|
||||
"""Should return chat completion."""
|
||||
from app import app
|
||||
|
||||
# Patch the agent's step method directly
|
||||
with patch("app.agent.step", return_value="Hello! How can I help?"):
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -113,12 +124,16 @@ class TestChatCompletionsEndpoint:
|
||||
def test_chat_completion_no_user_message(self, memory):
|
||||
"""Should return error if no user message."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "system", "content": "You are helpful"}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "system", "content": "You are helpful"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
@@ -132,18 +147,23 @@ class TestChatCompletionsEndpoint:
|
||||
def test_chat_completion_empty_messages(self, memory):
|
||||
"""Should return error for empty messages."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_chat_completion_invalid_json(self, memory):
|
||||
"""Should return error for invalid JSON."""
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
@@ -157,14 +177,18 @@ class TestChatCompletionsEndpoint:
|
||||
def test_chat_completion_streaming(self, memory):
|
||||
"""Should support streaming mode."""
|
||||
from app import app
|
||||
|
||||
with patch("app.agent.step", return_value="Streaming response"):
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "text/event-stream" in response.headers["content-type"]
|
||||
@@ -172,17 +196,21 @@ class TestChatCompletionsEndpoint:
|
||||
def test_chat_completion_extracts_last_user_message(self, memory):
|
||||
"""Should use last user message."""
|
||||
from app import app
|
||||
|
||||
with patch("app.agent.step", return_value="Response") as mock_step:
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [
|
||||
{"role": "user", "content": "First message"},
|
||||
{"role": "assistant", "content": "Response"},
|
||||
{"role": "user", "content": "Second message"},
|
||||
],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [
|
||||
{"role": "user", "content": "First message"},
|
||||
{"role": "assistant", "content": "Response"},
|
||||
{"role": "user", "content": "Second message"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Verify the agent received the last user message
|
||||
@@ -191,13 +219,17 @@ class TestChatCompletionsEndpoint:
|
||||
def test_chat_completion_response_format(self, memory):
|
||||
"""Should return OpenAI-compatible format."""
|
||||
from app import app
|
||||
|
||||
with patch("app.agent.step", return_value="Test response"):
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Test"}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Test"}],
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Edge case tests for FastAPI endpoints."""
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@@ -10,43 +10,46 @@ class TestChatCompletionsEdgeCases:
|
||||
|
||||
def test_very_long_message(self, memory):
|
||||
"""Should handle very long user message."""
|
||||
from app import app, agent
|
||||
|
||||
from app import agent, app
|
||||
|
||||
# Patch the agent's LLM directly
|
||||
mock_llm = Mock()
|
||||
mock_llm.complete.return_value = {
|
||||
"role": "assistant",
|
||||
"content": "Response"
|
||||
}
|
||||
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
|
||||
agent.llm = mock_llm
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
long_message = "x" * 100000
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": long_message}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": long_message}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_unicode_message(self, memory):
|
||||
"""Should handle unicode in message."""
|
||||
from app import app, agent
|
||||
|
||||
from app import agent, app
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.complete.return_value = {
|
||||
"role": "assistant",
|
||||
"content": "日本語の応答"
|
||||
"content": "日本語の応答",
|
||||
}
|
||||
agent.llm = mock_llm
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
content = response.json()["choices"][0]["message"]["content"]
|
||||
@@ -54,22 +57,22 @@ class TestChatCompletionsEdgeCases:
|
||||
|
||||
def test_special_characters_in_message(self, memory):
|
||||
"""Should handle special characters."""
|
||||
from app import app, agent
|
||||
|
||||
from app import agent, app
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.complete.return_value = {
|
||||
"role": "assistant",
|
||||
"content": "Response"
|
||||
}
|
||||
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
|
||||
agent.llm = mock_llm
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
special_message = 'Test with "quotes" and \\backslash and \n newline'
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": special_message}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": special_message}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -81,12 +84,16 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": ""}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": ""}],
|
||||
},
|
||||
)
|
||||
|
||||
# Empty content should be rejected
|
||||
assert response.status_code == 422
|
||||
@@ -98,12 +105,16 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": None}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": None}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -114,12 +125,16 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user"}], # No content
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user"}], # No content
|
||||
},
|
||||
)
|
||||
|
||||
# May accept or reject depending on validation
|
||||
assert response.status_code in [200, 400, 422]
|
||||
@@ -131,12 +146,16 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"content": "Hello"}], # No role
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"content": "Hello"}], # No role
|
||||
},
|
||||
)
|
||||
|
||||
# Should reject or accept depending on validation
|
||||
assert response.status_code in [200, 400, 422]
|
||||
@@ -149,27 +168,28 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "invalid_role", "content": "Hello"}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "invalid_role", "content": "Hello"}],
|
||||
},
|
||||
)
|
||||
|
||||
# Should reject or ignore invalid role
|
||||
assert response.status_code in [200, 400, 422]
|
||||
|
||||
def test_many_messages(self, memory):
|
||||
"""Should handle many messages in conversation."""
|
||||
from app import app, agent
|
||||
|
||||
from app import agent, app
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.complete.return_value = {
|
||||
"role": "assistant",
|
||||
"content": "Response"
|
||||
}
|
||||
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
|
||||
agent.llm = mock_llm
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
messages = []
|
||||
@@ -178,10 +198,13 @@ class TestChatCompletionsEdgeCases:
|
||||
messages.append({"role": "assistant", "content": f"Response {i}"})
|
||||
messages.append({"role": "user", "content": "Final message"})
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": messages,
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": messages,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -192,15 +215,19 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "system", "content": "Be concise"},
|
||||
],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "system", "content": "Be concise"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -211,14 +238,18 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -229,12 +260,16 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": "not an array",
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": "not an array",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
# Pydantic validation error
|
||||
@@ -246,118 +281,128 @@ class TestChatCompletionsEdgeCases:
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": ["not an object", 123, None],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": ["not an object", 123, None],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
# Pydantic validation error
|
||||
|
||||
def test_extra_fields_in_request(self, memory):
|
||||
"""Should ignore extra fields in request."""
|
||||
from app import app, agent
|
||||
|
||||
from app import agent, app
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.complete.return_value = {
|
||||
"role": "assistant",
|
||||
"content": "Response"
|
||||
}
|
||||
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
|
||||
agent.llm = mock_llm
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"extra_field": "should be ignored",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"extra_field": "should be ignored",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_streaming_with_tool_call(self, memory, real_folder):
|
||||
"""Should handle streaming with tool execution."""
|
||||
from app import app, agent
|
||||
from app import agent, app
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
|
||||
mem = get_memory()
|
||||
mem.ltm.set_config("download_folder", str(real_folder["downloads"]))
|
||||
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def mock_complete(messages, tools=None):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "list_folder",
|
||||
"arguments": '{"folder_type": "download"}',
|
||||
},
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "Listed the folder."
|
||||
}
|
||||
|
||||
return {"role": "assistant", "content": "Listed the folder."}
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.complete = Mock(side_effect=mock_complete)
|
||||
agent.llm = mock_llm
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "List downloads"}],
|
||||
"stream": True,
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "List downloads"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_concurrent_requests_simulation(self, memory):
|
||||
"""Should handle rapid sequential requests."""
|
||||
from app import app, agent
|
||||
|
||||
from app import agent, app
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.complete.return_value = {
|
||||
"role": "assistant",
|
||||
"content": "Response"
|
||||
}
|
||||
mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
|
||||
agent.llm = mock_llm
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
for i in range(10):
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": f"Request {i}"}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": f"Request {i}"}],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_llm_returns_json_in_response(self, memory):
|
||||
"""Should handle LLM returning JSON in text response."""
|
||||
from app import app, agent
|
||||
|
||||
from app import agent, app
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.complete.return_value = {
|
||||
"role": "assistant",
|
||||
"content": '{"result": "some data", "count": 5}'
|
||||
"content": '{"result": "some data", "count": 5}',
|
||||
}
|
||||
agent.llm = mock_llm
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/v1/chat/completions", json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Give me JSON"}],
|
||||
})
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "agent-media",
|
||||
"messages": [{"role": "user", "content": "Give me JSON"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
content = response.json()["choices"][0]["message"]["content"]
|
||||
@@ -425,6 +470,7 @@ class TestMemoryEndpointsEdgeCases:
|
||||
with patch("app.DeepSeekClient") as mock_llm:
|
||||
mock_llm.return_value = Mock()
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Clear multiple times
|
||||
@@ -459,6 +505,7 @@ class TestHealthEndpointEdgeCases:
|
||||
with patch("app.DeepSeekClient") as mock_llm:
|
||||
mock_llm.return_value = Mock()
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
@@ -471,6 +518,7 @@ class TestHealthEndpointEdgeCases:
|
||||
with patch("app.DeepSeekClient") as mock_llm:
|
||||
mock_llm.return_value = Mock()
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health?extra=param&another=value")
|
||||
@@ -486,6 +534,7 @@ class TestModelsEndpointEdgeCases:
|
||||
with patch("app.DeepSeekClient") as mock_llm:
|
||||
mock_llm.return_value = Mock()
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/v1/models")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Critical tests for configuration validation."""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
from agent.config import Settings, ConfigurationError
|
||||
import pytest
|
||||
|
||||
from agent.config import ConfigurationError, Settings
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
@@ -13,7 +13,7 @@ class TestConfigValidation:
|
||||
"""Verify invalid temperature is rejected."""
|
||||
with pytest.raises(ConfigurationError, match="Temperature"):
|
||||
Settings(temperature=3.0) # > 2.0
|
||||
|
||||
|
||||
with pytest.raises(ConfigurationError, match="Temperature"):
|
||||
Settings(temperature=-0.1) # < 0.0
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestConfigValidation:
|
||||
"""Verify invalid max_iterations is rejected."""
|
||||
with pytest.raises(ConfigurationError, match="max_tool_iterations"):
|
||||
Settings(max_tool_iterations=0) # < 1
|
||||
|
||||
|
||||
with pytest.raises(ConfigurationError, match="max_tool_iterations"):
|
||||
Settings(max_tool_iterations=100) # > 20
|
||||
|
||||
@@ -43,7 +43,7 @@ class TestConfigValidation:
|
||||
"""Verify invalid timeout is rejected."""
|
||||
with pytest.raises(ConfigurationError, match="request_timeout"):
|
||||
Settings(request_timeout=0) # < 1
|
||||
|
||||
|
||||
with pytest.raises(ConfigurationError, match="request_timeout"):
|
||||
Settings(request_timeout=500) # > 300
|
||||
|
||||
@@ -58,7 +58,7 @@ class TestConfigValidation:
|
||||
"""Verify invalid DeepSeek URL is rejected."""
|
||||
with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"):
|
||||
Settings(deepseek_base_url="not-a-url")
|
||||
|
||||
|
||||
with pytest.raises(ConfigurationError, match="Invalid deepseek_base_url"):
|
||||
Settings(deepseek_base_url="ftp://invalid.com")
|
||||
|
||||
@@ -86,19 +86,17 @@ class TestConfigChecks:
|
||||
def test_is_deepseek_configured_with_key(self):
|
||||
"""Verify is_deepseek_configured returns True with API key."""
|
||||
settings = Settings(
|
||||
deepseek_api_key="test-key",
|
||||
deepseek_base_url="https://api.test.com"
|
||||
deepseek_api_key="test-key", deepseek_base_url="https://api.test.com"
|
||||
)
|
||||
|
||||
|
||||
assert settings.is_deepseek_configured() is True
|
||||
|
||||
def test_is_deepseek_configured_without_key(self):
|
||||
"""Verify is_deepseek_configured returns False without API key."""
|
||||
settings = Settings(
|
||||
deepseek_api_key="",
|
||||
deepseek_base_url="https://api.test.com"
|
||||
deepseek_api_key="", deepseek_base_url="https://api.test.com"
|
||||
)
|
||||
|
||||
|
||||
assert settings.is_deepseek_configured() is False
|
||||
|
||||
def test_is_deepseek_configured_without_url(self):
|
||||
@@ -110,19 +108,15 @@ class TestConfigChecks:
|
||||
def test_is_tmdb_configured_with_key(self):
|
||||
"""Verify is_tmdb_configured returns True with API key."""
|
||||
settings = Settings(
|
||||
tmdb_api_key="test-key",
|
||||
tmdb_base_url="https://api.test.com"
|
||||
tmdb_api_key="test-key", tmdb_base_url="https://api.test.com"
|
||||
)
|
||||
|
||||
|
||||
assert settings.is_tmdb_configured() is True
|
||||
|
||||
def test_is_tmdb_configured_without_key(self):
|
||||
"""Verify is_tmdb_configured returns False without API key."""
|
||||
settings = Settings(
|
||||
tmdb_api_key="",
|
||||
tmdb_base_url="https://api.test.com"
|
||||
)
|
||||
|
||||
settings = Settings(tmdb_api_key="", tmdb_base_url="https://api.test.com")
|
||||
|
||||
assert settings.is_tmdb_configured() is False
|
||||
|
||||
|
||||
@@ -132,25 +126,25 @@ class TestConfigDefaults:
|
||||
def test_default_temperature(self):
|
||||
"""Verify default temperature is reasonable."""
|
||||
settings = Settings()
|
||||
|
||||
|
||||
assert 0.0 <= settings.temperature <= 2.0
|
||||
|
||||
def test_default_max_iterations(self):
|
||||
"""Verify default max_iterations is reasonable."""
|
||||
settings = Settings()
|
||||
|
||||
|
||||
assert 1 <= settings.max_tool_iterations <= 20
|
||||
|
||||
def test_default_timeout(self):
|
||||
"""Verify default timeout is reasonable."""
|
||||
settings = Settings()
|
||||
|
||||
|
||||
assert 1 <= settings.request_timeout <= 300
|
||||
|
||||
def test_default_urls_are_valid(self):
|
||||
"""Verify default URLs are valid."""
|
||||
settings = Settings()
|
||||
|
||||
|
||||
assert settings.deepseek_base_url.startswith(("http://", "https://"))
|
||||
assert settings.tmdb_base_url.startswith(("http://", "https://"))
|
||||
|
||||
@@ -161,38 +155,38 @@ class TestConfigEnvironmentVariables:
|
||||
def test_loads_temperature_from_env(self, monkeypatch):
|
||||
"""Verify temperature is loaded from environment."""
|
||||
monkeypatch.setenv("TEMPERATURE", "0.5")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
assert settings.temperature == 0.5
|
||||
|
||||
def test_loads_max_iterations_from_env(self, monkeypatch):
|
||||
"""Verify max_iterations is loaded from environment."""
|
||||
monkeypatch.setenv("MAX_TOOL_ITERATIONS", "10")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
assert settings.max_tool_iterations == 10
|
||||
|
||||
def test_loads_timeout_from_env(self, monkeypatch):
|
||||
"""Verify timeout is loaded from environment."""
|
||||
monkeypatch.setenv("REQUEST_TIMEOUT", "60")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
assert settings.request_timeout == 60
|
||||
|
||||
def test_loads_deepseek_url_from_env(self, monkeypatch):
|
||||
"""Verify DeepSeek URL is loaded from environment."""
|
||||
monkeypatch.setenv("DEEPSEEK_BASE_URL", "https://custom.api.com")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
assert settings.deepseek_base_url == "https://custom.api.com"
|
||||
|
||||
def test_invalid_env_value_raises_error(self, monkeypatch):
|
||||
"""Verify invalid environment value raises error."""
|
||||
monkeypatch.setenv("TEMPERATURE", "invalid")
|
||||
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Settings()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Edge case tests for configuration and parameters."""
|
||||
import pytest
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.config import Settings, ConfigurationError
|
||||
import pytest
|
||||
|
||||
from agent.config import ConfigurationError, Settings
|
||||
from agent.parameters import (
|
||||
ParameterSchema,
|
||||
REQUIRED_PARAMETERS,
|
||||
ParameterSchema,
|
||||
format_parameters_for_prompt,
|
||||
get_missing_required_parameters,
|
||||
)
|
||||
@@ -110,19 +112,27 @@ class TestSettingsEdgeCases:
|
||||
|
||||
def test_http_url_accepted(self):
|
||||
"""Should accept http:// URLs."""
|
||||
with patch.dict(os.environ, {
|
||||
"DEEPSEEK_BASE_URL": "http://localhost:8080",
|
||||
"TMDB_BASE_URL": "http://localhost:3000",
|
||||
}, clear=True):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DEEPSEEK_BASE_URL": "http://localhost:8080",
|
||||
"TMDB_BASE_URL": "http://localhost:3000",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
settings = Settings()
|
||||
assert settings.deepseek_base_url == "http://localhost:8080"
|
||||
|
||||
def test_https_url_accepted(self):
|
||||
"""Should accept https:// URLs."""
|
||||
with patch.dict(os.environ, {
|
||||
"DEEPSEEK_BASE_URL": "https://api.example.com",
|
||||
"TMDB_BASE_URL": "https://api.example.com",
|
||||
}, clear=True):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DEEPSEEK_BASE_URL": "https://api.example.com",
|
||||
"TMDB_BASE_URL": "https://api.example.com",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
settings = Settings()
|
||||
assert settings.deepseek_base_url == "https://api.example.com"
|
||||
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
"""Tests for the Memory system."""
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.persistence import (
|
||||
Memory,
|
||||
LongTermMemory,
|
||||
ShortTermMemory,
|
||||
EpisodicMemory,
|
||||
init_memory,
|
||||
LongTermMemory,
|
||||
Memory,
|
||||
ShortTermMemory,
|
||||
get_memory,
|
||||
set_memory,
|
||||
has_memory,
|
||||
init_memory,
|
||||
)
|
||||
from infrastructure.persistence.context import _memory_ctx
|
||||
|
||||
@@ -23,11 +22,12 @@ def is_iso_format(s: str) -> bool:
|
||||
return False
|
||||
try:
|
||||
# Attempt to parse the string as an ISO 8601 timestamp
|
||||
datetime.fromisoformat(s.replace('Z', '+00:00'))
|
||||
datetime.fromisoformat(s.replace("Z", "+00:00"))
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
class TestLongTermMemory:
|
||||
"""Tests for LongTermMemory."""
|
||||
|
||||
@@ -116,12 +116,18 @@ class TestLongTermMemory:
|
||||
assert data["config"]["key"] == "value"
|
||||
|
||||
def test_from_dict(self):
|
||||
data = {"config": {"download_folder": "/downloads"}, "preferences": {"preferred_quality": "4K"}, "library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]}, "following": []}
|
||||
data = {
|
||||
"config": {"download_folder": "/downloads"},
|
||||
"preferences": {"preferred_quality": "4K"},
|
||||
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
|
||||
"following": [],
|
||||
}
|
||||
ltm = LongTermMemory.from_dict(data)
|
||||
assert ltm.get_config("download_folder") == "/downloads"
|
||||
assert ltm.preferences["preferred_quality"] == "4K"
|
||||
assert len(ltm.library["movies"]) == 1
|
||||
|
||||
|
||||
class TestShortTermMemory:
|
||||
"""Tests for ShortTermMemory."""
|
||||
|
||||
@@ -162,6 +168,7 @@ class TestShortTermMemory:
|
||||
assert stm.conversation_history == []
|
||||
assert stm.language == "en"
|
||||
|
||||
|
||||
class TestEpisodicMemory:
|
||||
"""Tests for EpisodicMemory."""
|
||||
|
||||
@@ -192,6 +199,7 @@ class TestEpisodicMemory:
|
||||
assert result is not None
|
||||
assert result["name"] == "Result 2"
|
||||
|
||||
|
||||
class TestMemory:
|
||||
"""Tests for the Memory manager."""
|
||||
|
||||
@@ -217,11 +225,10 @@ class TestMemory:
|
||||
assert memory.stm.conversation_history == []
|
||||
assert memory.episodic.recent_errors == []
|
||||
|
||||
|
||||
class TestMemoryContext:
|
||||
"""Tests for memory context functions."""
|
||||
|
||||
|
||||
|
||||
def test_get_memory_not_initialized(self):
|
||||
_memory_ctx.set(None)
|
||||
with pytest.raises(RuntimeError, match="Memory not initialized"):
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
"""Edge case tests for the Memory system."""
|
||||
import pytest
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.persistence import (
|
||||
Memory,
|
||||
LongTermMemory,
|
||||
ShortTermMemory,
|
||||
EpisodicMemory,
|
||||
init_memory,
|
||||
LongTermMemory,
|
||||
Memory,
|
||||
ShortTermMemory,
|
||||
get_memory,
|
||||
init_memory,
|
||||
set_memory,
|
||||
)
|
||||
from infrastructure.persistence.context import _memory_ctx
|
||||
@@ -390,7 +389,7 @@ class TestMemoryEdgeCases:
|
||||
def test_init_with_nonexistent_directory(self, temp_dir):
|
||||
"""Should create directory if not exists."""
|
||||
new_dir = temp_dir / "new" / "nested" / "dir"
|
||||
|
||||
|
||||
# Create parent directories first
|
||||
new_dir.mkdir(parents=True, exist_ok=True)
|
||||
memory = Memory(storage_dir=str(new_dir))
|
||||
@@ -529,7 +528,6 @@ class TestMemoryContextEdgeCases:
|
||||
|
||||
def test_context_isolation(self, temp_dir):
|
||||
"""Context should be isolated per context."""
|
||||
import asyncio
|
||||
from contextvars import copy_context
|
||||
|
||||
_memory_ctx.set(None)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Critical tests for prompt builder - Tests that would have caught bugs."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.registry import make_tools
|
||||
from agent.prompts import PromptBuilder
|
||||
from infrastructure.persistence import get_memory
|
||||
from agent.registry import make_tools
|
||||
|
||||
|
||||
class TestPromptBuilderToolsInjection:
|
||||
@@ -15,20 +13,22 @@ class TestPromptBuilderToolsInjection:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
# Verify each tool is mentioned
|
||||
for tool_name in tools.keys():
|
||||
assert tool_name in prompt, f"Tool {tool_name} not mentioned in system prompt"
|
||||
assert (
|
||||
tool_name in prompt
|
||||
), f"Tool {tool_name} not mentioned in system prompt"
|
||||
|
||||
def test_tools_spec_contains_all_registered_tools(self):
|
||||
"""CRITICAL: Verify build_tools_spec() returns all tools."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
specs = builder.build_tools_spec()
|
||||
|
||||
spec_names = {spec['function']['name'] for spec in specs}
|
||||
|
||||
spec_names = {spec["function"]["name"] for spec in specs}
|
||||
tool_names = set(tools.keys())
|
||||
|
||||
|
||||
assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}"
|
||||
|
||||
def test_tools_spec_is_not_empty(self):
|
||||
@@ -36,7 +36,7 @@ class TestPromptBuilderToolsInjection:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
specs = builder.build_tools_spec()
|
||||
|
||||
|
||||
assert len(specs) > 0, "Tools spec is empty!"
|
||||
|
||||
def test_tools_spec_format_matches_openai(self):
|
||||
@@ -44,14 +44,14 @@ class TestPromptBuilderToolsInjection:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
specs = builder.build_tools_spec()
|
||||
|
||||
|
||||
for spec in specs:
|
||||
assert 'type' in spec
|
||||
assert spec['type'] == 'function'
|
||||
assert 'function' in spec
|
||||
assert 'name' in spec['function']
|
||||
assert 'description' in spec['function']
|
||||
assert 'parameters' in spec['function']
|
||||
assert "type" in spec
|
||||
assert spec["type"] == "function"
|
||||
assert "function" in spec
|
||||
assert "name" in spec["function"]
|
||||
assert "description" in spec["function"]
|
||||
assert "parameters" in spec["function"]
|
||||
|
||||
|
||||
class TestPromptBuilderMemoryContext:
|
||||
@@ -61,29 +61,29 @@ class TestPromptBuilderMemoryContext:
|
||||
"""Verify current topic is included in prompt."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
memory.stm.set_topic("test_topic")
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "test_topic" in prompt
|
||||
|
||||
def test_prompt_includes_extracted_entities(self, memory):
|
||||
"""Verify extracted entities are included in prompt."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
memory.stm.set_entity("test_key", "test_value")
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "test_key" in prompt
|
||||
|
||||
def test_prompt_includes_search_results(self, memory_with_search_results):
|
||||
"""Verify search results are included in prompt."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "Inception" in prompt
|
||||
assert "LAST SEARCH" in prompt
|
||||
|
||||
@@ -91,15 +91,13 @@ class TestPromptBuilderMemoryContext:
|
||||
"""Verify active downloads are included in prompt."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
memory.episodic.add_active_download({
|
||||
"task_id": "123",
|
||||
"name": "Test Movie",
|
||||
"progress": 50
|
||||
})
|
||||
|
||||
|
||||
memory.episodic.add_active_download(
|
||||
{"task_id": "123", "name": "Test Movie", "progress": 50}
|
||||
)
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "ACTIVE DOWNLOADS" in prompt
|
||||
assert "Test Movie" in prompt
|
||||
|
||||
@@ -107,33 +105,33 @@ class TestPromptBuilderMemoryContext:
|
||||
"""Verify recent errors are included in prompt."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
memory.episodic.add_error("test_action", "test error message")
|
||||
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "RECENT ERRORS" in prompt or "error" in prompt.lower()
|
||||
|
||||
def test_prompt_includes_configuration(self, memory):
|
||||
"""Verify configuration is included in prompt."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
memory.ltm.set_config("download_folder", "/test/downloads")
|
||||
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "CONFIGURATION" in prompt or "download_folder" in prompt
|
||||
|
||||
def test_prompt_includes_language(self, memory):
|
||||
"""Verify language is included in prompt."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
memory.stm.set_language("fr")
|
||||
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "fr" in prompt or "LANGUAGE" in prompt
|
||||
|
||||
|
||||
@@ -145,7 +143,7 @@ class TestPromptBuilderStructure:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert len(prompt) > 0
|
||||
assert prompt.strip() != ""
|
||||
|
||||
@@ -154,7 +152,7 @@ class TestPromptBuilderStructure:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "assistant" in prompt.lower() or "help" in prompt.lower()
|
||||
|
||||
def test_system_prompt_includes_rules(self):
|
||||
@@ -162,7 +160,7 @@ class TestPromptBuilderStructure:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "RULES" in prompt or "IMPORTANT" in prompt
|
||||
|
||||
def test_system_prompt_includes_examples(self):
|
||||
@@ -170,16 +168,16 @@ class TestPromptBuilderStructure:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "EXAMPLES" in prompt or "example" in prompt.lower()
|
||||
|
||||
def test_tools_description_format(self):
|
||||
"""Verify tools are properly formatted in description."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
description = builder._format_tools_description()
|
||||
|
||||
|
||||
# Should have tool names and descriptions
|
||||
for tool_name, tool in tools.items():
|
||||
assert tool_name in description
|
||||
@@ -190,9 +188,9 @@ class TestPromptBuilderStructure:
|
||||
"""Verify episodic context is properly formatted."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
context = builder._format_episodic_context()
|
||||
|
||||
|
||||
assert "LAST SEARCH" in context
|
||||
assert "Inception" in context
|
||||
|
||||
@@ -200,12 +198,12 @@ class TestPromptBuilderStructure:
|
||||
"""Verify STM context is properly formatted."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
memory.stm.set_topic("test_topic")
|
||||
memory.stm.set_entity("key", "value")
|
||||
|
||||
|
||||
context = builder._format_stm_context()
|
||||
|
||||
|
||||
assert "TOPIC" in context or "test_topic" in context
|
||||
assert "ENTITIES" in context or "key" in context
|
||||
|
||||
@@ -213,11 +211,11 @@ class TestPromptBuilderStructure:
|
||||
"""Verify config context is properly formatted."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
memory.ltm.set_config("test_key", "test_value")
|
||||
|
||||
|
||||
context = builder._format_config_context()
|
||||
|
||||
|
||||
assert "CONFIGURATION" in context
|
||||
assert "test_key" in context
|
||||
|
||||
@@ -229,10 +227,10 @@ class TestPromptBuilderEdgeCases:
|
||||
"""Verify prompt works with empty memory."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
# Memory is empty
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
# Should still have base content
|
||||
assert len(prompt) > 0
|
||||
assert "assistant" in prompt.lower()
|
||||
@@ -240,18 +238,18 @@ class TestPromptBuilderEdgeCases:
|
||||
def test_prompt_with_empty_tools(self):
|
||||
"""Verify prompt handles empty tools dict."""
|
||||
builder = PromptBuilder({})
|
||||
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
# Should still generate a prompt
|
||||
assert len(prompt) > 0
|
||||
|
||||
def test_tools_spec_with_empty_tools(self):
|
||||
"""Verify tools spec handles empty tools dict."""
|
||||
builder = PromptBuilder({})
|
||||
|
||||
|
||||
specs = builder.build_tools_spec()
|
||||
|
||||
|
||||
assert isinstance(specs, list)
|
||||
assert len(specs) == 0
|
||||
|
||||
@@ -259,11 +257,11 @@ class TestPromptBuilderEdgeCases:
|
||||
"""Verify prompt handles unicode in memory."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
memory.stm.set_entity("movie", "Amélie 🎬")
|
||||
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
assert "Amélie" in prompt
|
||||
assert "🎬" in prompt
|
||||
|
||||
@@ -271,13 +269,13 @@ class TestPromptBuilderEdgeCases:
|
||||
"""Verify prompt handles many search results."""
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
|
||||
# Add many results
|
||||
results = [{"name": f"Movie {i}", "seeders": i} for i in range(20)]
|
||||
memory.episodic.store_search_results("test", results, "torrent")
|
||||
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
|
||||
# Should include some results but not all (to avoid huge prompts)
|
||||
assert "Movie 0" in prompt or "Movie 1" in prompt
|
||||
# Should indicate there are more
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Edge case tests for PromptBuilder."""
|
||||
import pytest
|
||||
import json
|
||||
|
||||
|
||||
from agent.prompts import PromptBuilder
|
||||
from agent.registry import make_tools
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
|
||||
class TestPromptBuilderEdgeCases:
|
||||
@@ -93,11 +91,13 @@ class TestPromptBuilderEdgeCases:
|
||||
def test_prompt_with_many_active_downloads(self, memory):
|
||||
"""Should limit displayed active downloads."""
|
||||
for i in range(20):
|
||||
memory.episodic.add_active_download({
|
||||
"task_id": str(i),
|
||||
"name": f"Download {i}",
|
||||
"progress": i * 5,
|
||||
})
|
||||
memory.episodic.add_active_download(
|
||||
{
|
||||
"task_id": str(i),
|
||||
"name": f"Download {i}",
|
||||
"progress": i * 5,
|
||||
}
|
||||
)
|
||||
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
@@ -136,12 +136,15 @@ class TestPromptBuilderEdgeCases:
|
||||
|
||||
def test_prompt_with_complex_workflow(self, memory):
|
||||
"""Should handle complex workflow state."""
|
||||
memory.stm.start_workflow("download", {
|
||||
"title": "Test Movie",
|
||||
"year": 2024,
|
||||
"quality": "1080p",
|
||||
"nested": {"deep": {"value": "test"}},
|
||||
})
|
||||
memory.stm.start_workflow(
|
||||
"download",
|
||||
{
|
||||
"title": "Test Movie",
|
||||
"year": 2024,
|
||||
"quality": "1080p",
|
||||
"nested": {"deep": {"value": "test"}},
|
||||
},
|
||||
)
|
||||
memory.stm.update_workflow_stage("searching_torrents")
|
||||
|
||||
tools = make_tools()
|
||||
@@ -313,11 +316,14 @@ class TestFormatEpisodicContextEdgeCases:
|
||||
|
||||
def test_format_with_search_results_none_names(self, memory):
|
||||
"""Should handle results with None names."""
|
||||
memory.episodic.store_search_results("test", [
|
||||
{"name": None},
|
||||
{"title": None},
|
||||
{},
|
||||
])
|
||||
memory.episodic.store_search_results(
|
||||
"test",
|
||||
[
|
||||
{"name": None},
|
||||
{"title": None},
|
||||
{},
|
||||
],
|
||||
)
|
||||
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Critical tests for tool registry - Tests that would have caught bugs."""
|
||||
|
||||
import pytest
|
||||
import inspect
|
||||
|
||||
from agent.registry import make_tools, _create_tool_from_function, Tool
|
||||
import pytest
|
||||
|
||||
from agent.prompts import PromptBuilder
|
||||
from agent.registry import Tool, _create_tool_from_function, make_tools
|
||||
|
||||
|
||||
class TestToolSpecFormat:
|
||||
@@ -15,54 +16,59 @@ class TestToolSpecFormat:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
specs = builder.build_tools_spec()
|
||||
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(specs, list), "Tool specs must be a list"
|
||||
assert len(specs) > 0, "Tool specs list is empty"
|
||||
|
||||
|
||||
for spec in specs:
|
||||
# OpenAI format requires these fields
|
||||
assert spec['type'] == 'function', f"Tool type must be 'function', got {spec.get('type')}"
|
||||
assert 'function' in spec, "Tool spec missing 'function' key"
|
||||
|
||||
func = spec['function']
|
||||
assert 'name' in func, "Function missing 'name'"
|
||||
assert 'description' in func, "Function missing 'description'"
|
||||
assert 'parameters' in func, "Function missing 'parameters'"
|
||||
|
||||
params = func['parameters']
|
||||
assert params['type'] == 'object', "Parameters type must be 'object'"
|
||||
assert 'properties' in params, "Parameters missing 'properties'"
|
||||
assert 'required' in params, "Parameters missing 'required'"
|
||||
assert isinstance(params['required'], list), "Required must be a list"
|
||||
assert (
|
||||
spec["type"] == "function"
|
||||
), f"Tool type must be 'function', got {spec.get('type')}"
|
||||
assert "function" in spec, "Tool spec missing 'function' key"
|
||||
|
||||
func = spec["function"]
|
||||
assert "name" in func, "Function missing 'name'"
|
||||
assert "description" in func, "Function missing 'description'"
|
||||
assert "parameters" in func, "Function missing 'parameters'"
|
||||
|
||||
params = func["parameters"]
|
||||
assert params["type"] == "object", "Parameters type must be 'object'"
|
||||
assert "properties" in params, "Parameters missing 'properties'"
|
||||
assert "required" in params, "Parameters missing 'required'"
|
||||
assert isinstance(params["required"], list), "Required must be a list"
|
||||
|
||||
def test_tool_parameters_match_function_signature(self):
|
||||
"""CRITICAL: Verify generated parameters match function signature."""
|
||||
|
||||
def test_func(name: str, age: int, active: bool = True):
|
||||
"""Test function with typed parameters."""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
tool = _create_tool_from_function(test_func)
|
||||
|
||||
|
||||
# Verify types are correctly mapped
|
||||
assert tool.parameters['properties']['name']['type'] == 'string'
|
||||
assert tool.parameters['properties']['age']['type'] == 'integer'
|
||||
assert tool.parameters['properties']['active']['type'] == 'boolean'
|
||||
|
||||
assert tool.parameters["properties"]["name"]["type"] == "string"
|
||||
assert tool.parameters["properties"]["age"]["type"] == "integer"
|
||||
assert tool.parameters["properties"]["active"]["type"] == "boolean"
|
||||
|
||||
# Verify required vs optional
|
||||
assert 'name' in tool.parameters['required'], "name should be required"
|
||||
assert 'age' in tool.parameters['required'], "age should be required"
|
||||
assert 'active' not in tool.parameters['required'], "active has default, should not be required"
|
||||
assert "name" in tool.parameters["required"], "name should be required"
|
||||
assert "age" in tool.parameters["required"], "age should be required"
|
||||
assert (
|
||||
"active" not in tool.parameters["required"]
|
||||
), "active has default, should not be required"
|
||||
|
||||
def test_all_registered_tools_are_callable(self):
|
||||
"""CRITICAL: Verify all registered tools are actually callable."""
|
||||
tools = make_tools()
|
||||
|
||||
|
||||
assert len(tools) > 0, "No tools registered"
|
||||
|
||||
|
||||
for name, tool in tools.items():
|
||||
assert callable(tool.func), f"Tool {name} is not callable"
|
||||
|
||||
|
||||
# Verify function has valid signature
|
||||
try:
|
||||
sig = inspect.signature(tool.func)
|
||||
@@ -75,38 +81,40 @@ class TestToolSpecFormat:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
specs = builder.build_tools_spec()
|
||||
|
||||
spec_names = {spec['function']['name'] for spec in specs}
|
||||
|
||||
spec_names = {spec["function"]["name"] for spec in specs}
|
||||
tool_names = set(tools.keys())
|
||||
|
||||
|
||||
missing = tool_names - spec_names
|
||||
extra = spec_names - tool_names
|
||||
|
||||
|
||||
assert not missing, f"Tools missing from specs: {missing}"
|
||||
assert not extra, f"Extra tools in specs: {extra}"
|
||||
assert spec_names == tool_names, "Tool specs don't match registered tools"
|
||||
|
||||
def test_tool_description_extracted_from_docstring(self):
|
||||
"""Verify tool description is extracted from function docstring."""
|
||||
|
||||
def test_func(param: str):
|
||||
"""This is the description.
|
||||
|
||||
|
||||
More details here.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
tool = _create_tool_from_function(test_func)
|
||||
|
||||
|
||||
assert tool.description == "This is the description."
|
||||
assert "More details" not in tool.description
|
||||
|
||||
def test_tool_without_docstring_uses_function_name(self):
|
||||
"""Verify tool without docstring uses function name as description."""
|
||||
|
||||
def test_func_no_doc(param: str):
|
||||
return {}
|
||||
|
||||
|
||||
tool = _create_tool_from_function(test_func_no_doc)
|
||||
|
||||
|
||||
assert tool.description == "test_func_no_doc"
|
||||
|
||||
def test_tool_parameters_have_descriptions(self):
|
||||
@@ -114,25 +122,27 @@ class TestToolSpecFormat:
|
||||
tools = make_tools()
|
||||
builder = PromptBuilder(tools)
|
||||
specs = builder.build_tools_spec()
|
||||
|
||||
|
||||
for spec in specs:
|
||||
params = spec['function']['parameters']
|
||||
properties = params.get('properties', {})
|
||||
|
||||
params = spec["function"]["parameters"]
|
||||
properties = params.get("properties", {})
|
||||
|
||||
for param_name, param_spec in properties.items():
|
||||
assert 'description' in param_spec, \
|
||||
f"Parameter {param_name} in {spec['function']['name']} missing description"
|
||||
assert (
|
||||
"description" in param_spec
|
||||
), f"Parameter {param_name} in {spec['function']['name']} missing description"
|
||||
|
||||
def test_required_parameters_are_marked_correctly(self):
|
||||
"""Verify required parameters are correctly identified."""
|
||||
|
||||
def func_with_optional(required: str, optional: int = 5):
|
||||
return {}
|
||||
|
||||
|
||||
tool = _create_tool_from_function(func_with_optional)
|
||||
|
||||
assert 'required' in tool.parameters['required']
|
||||
assert 'optional' not in tool.parameters['required']
|
||||
assert len(tool.parameters['required']) == 1
|
||||
|
||||
assert "required" in tool.parameters["required"]
|
||||
assert "optional" not in tool.parameters["required"]
|
||||
assert len(tool.parameters["required"]) == 1
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
@@ -141,28 +151,28 @@ class TestToolRegistry:
|
||||
def test_make_tools_returns_dict(self):
|
||||
"""Verify make_tools returns a dictionary."""
|
||||
tools = make_tools()
|
||||
|
||||
|
||||
assert isinstance(tools, dict)
|
||||
assert len(tools) > 0
|
||||
|
||||
def test_all_tools_have_unique_names(self):
|
||||
"""Verify all tool names are unique."""
|
||||
tools = make_tools()
|
||||
|
||||
|
||||
names = [tool.name for tool in tools.values()]
|
||||
assert len(names) == len(set(names)), "Duplicate tool names found"
|
||||
|
||||
def test_tool_names_match_dict_keys(self):
|
||||
"""Verify tool names match their dictionary keys."""
|
||||
tools = make_tools()
|
||||
|
||||
|
||||
for key, tool in tools.items():
|
||||
assert key == tool.name, f"Key {key} doesn't match tool name {tool.name}"
|
||||
|
||||
def test_expected_tools_are_registered(self):
|
||||
"""Verify all expected tools are registered."""
|
||||
tools = make_tools()
|
||||
|
||||
|
||||
expected_tools = [
|
||||
"set_path_for_folder",
|
||||
"list_folder",
|
||||
@@ -173,14 +183,14 @@ class TestToolRegistry:
|
||||
"get_torrent_by_index",
|
||||
"set_language",
|
||||
]
|
||||
|
||||
|
||||
for expected in expected_tools:
|
||||
assert expected in tools, f"Expected tool {expected} not registered"
|
||||
|
||||
def test_tool_functions_return_dict(self):
|
||||
"""Verify all tool functions return dictionaries."""
|
||||
tools = make_tools()
|
||||
|
||||
|
||||
# Test with minimal valid arguments
|
||||
# Note: This is a smoke test, not full integration
|
||||
for name, tool in tools.items():
|
||||
@@ -195,16 +205,17 @@ class TestToolDataclass:
|
||||
|
||||
def test_tool_creation(self):
|
||||
"""Verify Tool can be created with all fields."""
|
||||
|
||||
def dummy_func():
|
||||
return {}
|
||||
|
||||
|
||||
tool = Tool(
|
||||
name="test_tool",
|
||||
description="Test description",
|
||||
func=dummy_func,
|
||||
parameters={"type": "object", "properties": {}, "required": []}
|
||||
parameters={"type": "object", "properties": {}, "required": []},
|
||||
)
|
||||
|
||||
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "Test description"
|
||||
assert tool.func == dummy_func
|
||||
@@ -212,12 +223,13 @@ class TestToolDataclass:
|
||||
|
||||
def test_tool_parameters_structure(self):
|
||||
"""Verify Tool parameters have correct structure."""
|
||||
|
||||
def dummy_func(arg: str):
|
||||
return {}
|
||||
|
||||
|
||||
tool = _create_tool_from_function(dummy_func)
|
||||
|
||||
assert 'type' in tool.parameters
|
||||
assert 'properties' in tool.parameters
|
||||
assert 'required' in tool.parameters
|
||||
assert tool.parameters['type'] == 'object'
|
||||
|
||||
assert "type" in tool.parameters
|
||||
assert "properties" in tool.parameters
|
||||
assert "required" in tool.parameters
|
||||
assert tool.parameters["type"] == "object"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Edge case tests for tool registry."""
|
||||
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from agent.registry import Tool, make_tools
|
||||
|
||||
@@ -182,7 +183,9 @@ class TestMakeToolsEdgeCases:
|
||||
params = tool.parameters
|
||||
if "required" in params and "properties" in params:
|
||||
for req in params["required"]:
|
||||
assert req in params["properties"], f"Required param {req} not in properties for {tool.name}"
|
||||
assert (
|
||||
req in params["properties"]
|
||||
), f"Required param {req} not in properties for {tool.name}"
|
||||
|
||||
def test_make_tools_descriptions_not_empty(self, memory):
|
||||
"""Should have non-empty descriptions."""
|
||||
@@ -233,7 +236,9 @@ class TestMakeToolsEdgeCases:
|
||||
if "properties" in tool.parameters:
|
||||
for prop_name, prop_schema in tool.parameters["properties"].items():
|
||||
if "type" in prop_schema:
|
||||
assert prop_schema["type"] in valid_types, f"Invalid type for {tool.name}.{prop_name}"
|
||||
assert (
|
||||
prop_schema["type"] in valid_types
|
||||
), f"Invalid type for {tool.name}.{prop_name}"
|
||||
|
||||
def test_make_tools_enum_values(self, memory):
|
||||
"""Should have valid enum values."""
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
"""Tests for JSON repositories."""
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from infrastructure.persistence.json import (
|
||||
JsonMovieRepository,
|
||||
JsonTVShowRepository,
|
||||
JsonSubtitleRepository,
|
||||
)
|
||||
|
||||
from domain.movies.entities import Movie
|
||||
from domain.movies.value_objects import MovieTitle, ReleaseYear, Quality
|
||||
from domain.tv_shows.entities import TVShow
|
||||
from domain.tv_shows.value_objects import ShowStatus
|
||||
from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear
|
||||
from domain.shared.value_objects import FilePath, FileSize, ImdbId
|
||||
from domain.subtitles.entities import Subtitle
|
||||
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
|
||||
from domain.shared.value_objects import ImdbId, FilePath, FileSize
|
||||
from domain.tv_shows.entities import TVShow
|
||||
from domain.tv_shows.value_objects import ShowStatus
|
||||
from infrastructure.persistence.json import (
|
||||
JsonMovieRepository,
|
||||
JsonSubtitleRepository,
|
||||
JsonTVShowRepository,
|
||||
)
|
||||
|
||||
|
||||
class TestJsonMovieRepository:
|
||||
@@ -224,7 +223,9 @@ class TestJsonTVShowRepository:
|
||||
"""Should preserve show status."""
|
||||
repo = JsonTVShowRepository()
|
||||
|
||||
for i, status in enumerate([ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]):
|
||||
for i, status in enumerate(
|
||||
[ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]
|
||||
):
|
||||
show = TVShow(
|
||||
imdb_id=ImdbId(f"tt{i+1000000:07d}"),
|
||||
title=f"Show {status.value}",
|
||||
@@ -294,18 +295,22 @@ class TestJsonSubtitleRepository:
|
||||
def test_find_by_media_with_language_filter(self, memory):
|
||||
"""Should filter by language."""
|
||||
repo = JsonSubtitleRepository()
|
||||
repo.save(Subtitle(
|
||||
media_imdb_id=ImdbId("tt1375666"),
|
||||
language=Language.ENGLISH,
|
||||
format=SubtitleFormat.SRT,
|
||||
file_path=FilePath("/subs/en.srt"),
|
||||
))
|
||||
repo.save(Subtitle(
|
||||
media_imdb_id=ImdbId("tt1375666"),
|
||||
language=Language.FRENCH,
|
||||
format=SubtitleFormat.SRT,
|
||||
file_path=FilePath("/subs/fr.srt"),
|
||||
))
|
||||
repo.save(
|
||||
Subtitle(
|
||||
media_imdb_id=ImdbId("tt1375666"),
|
||||
language=Language.ENGLISH,
|
||||
format=SubtitleFormat.SRT,
|
||||
file_path=FilePath("/subs/en.srt"),
|
||||
)
|
||||
)
|
||||
repo.save(
|
||||
Subtitle(
|
||||
media_imdb_id=ImdbId("tt1375666"),
|
||||
language=Language.FRENCH,
|
||||
format=SubtitleFormat.SRT,
|
||||
file_path=FilePath("/subs/fr.srt"),
|
||||
)
|
||||
)
|
||||
|
||||
results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH)
|
||||
|
||||
@@ -315,22 +320,26 @@ class TestJsonSubtitleRepository:
|
||||
def test_find_by_media_with_episode_filter(self, memory):
|
||||
"""Should filter by season/episode."""
|
||||
repo = JsonSubtitleRepository()
|
||||
repo.save(Subtitle(
|
||||
media_imdb_id=ImdbId("tt0944947"),
|
||||
language=Language.ENGLISH,
|
||||
format=SubtitleFormat.SRT,
|
||||
file_path=FilePath("/subs/s01e01.srt"),
|
||||
season_number=1,
|
||||
episode_number=1,
|
||||
))
|
||||
repo.save(Subtitle(
|
||||
media_imdb_id=ImdbId("tt0944947"),
|
||||
language=Language.ENGLISH,
|
||||
format=SubtitleFormat.SRT,
|
||||
file_path=FilePath("/subs/s01e02.srt"),
|
||||
season_number=1,
|
||||
episode_number=2,
|
||||
))
|
||||
repo.save(
|
||||
Subtitle(
|
||||
media_imdb_id=ImdbId("tt0944947"),
|
||||
language=Language.ENGLISH,
|
||||
format=SubtitleFormat.SRT,
|
||||
file_path=FilePath("/subs/s01e01.srt"),
|
||||
season_number=1,
|
||||
episode_number=1,
|
||||
)
|
||||
)
|
||||
repo.save(
|
||||
Subtitle(
|
||||
media_imdb_id=ImdbId("tt0944947"),
|
||||
language=Language.ENGLISH,
|
||||
format=SubtitleFormat.SRT,
|
||||
file_path=FilePath("/subs/s01e02.srt"),
|
||||
season_number=1,
|
||||
episode_number=2,
|
||||
)
|
||||
)
|
||||
|
||||
results = repo.find_by_media(
|
||||
ImdbId("tt0944947"),
|
||||
|
||||
@@ -21,24 +21,30 @@ def create_mock_response(status_code, json_data=None, text=None):
|
||||
class TestFindMediaImdbId:
|
||||
"""Tests for find_media_imdb_id tool."""
|
||||
|
||||
@patch('infrastructure.api.tmdb.client.requests.get')
|
||||
@patch("infrastructure.api.tmdb.client.requests.get")
|
||||
def test_success(self, mock_get, memory):
|
||||
"""Should return movie info on success."""
|
||||
|
||||
# Mock HTTP responses
|
||||
def mock_get_side_effect(url, **kwargs):
|
||||
if "search" in url:
|
||||
return create_mock_response(200, json_data={
|
||||
"results": [{
|
||||
"id": 27205,
|
||||
"title": "Inception",
|
||||
"release_date": "2010-07-16",
|
||||
"overview": "A thief...",
|
||||
"media_type": "movie"
|
||||
}]
|
||||
})
|
||||
return create_mock_response(
|
||||
200,
|
||||
json_data={
|
||||
"results": [
|
||||
{
|
||||
"id": 27205,
|
||||
"title": "Inception",
|
||||
"release_date": "2010-07-16",
|
||||
"overview": "A thief...",
|
||||
"media_type": "movie",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
elif "external_ids" in url:
|
||||
return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
|
||||
|
||||
|
||||
mock_get.side_effect = mock_get_side_effect
|
||||
|
||||
result = api_tools.find_media_imdb_id("Inception")
|
||||
@@ -46,26 +52,32 @@ class TestFindMediaImdbId:
|
||||
assert result["status"] == "ok"
|
||||
assert result["imdb_id"] == "tt1375666"
|
||||
assert result["title"] == "Inception"
|
||||
|
||||
|
||||
# Verify HTTP calls
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@patch('infrastructure.api.tmdb.client.requests.get')
|
||||
@patch("infrastructure.api.tmdb.client.requests.get")
|
||||
def test_stores_in_stm(self, mock_get, memory):
|
||||
"""Should store result in STM on success."""
|
||||
|
||||
def mock_get_side_effect(url, **kwargs):
|
||||
if "search" in url:
|
||||
return create_mock_response(200, json_data={
|
||||
"results": [{
|
||||
"id": 27205,
|
||||
"title": "Inception",
|
||||
"release_date": "2010-07-16",
|
||||
"media_type": "movie"
|
||||
}]
|
||||
})
|
||||
return create_mock_response(
|
||||
200,
|
||||
json_data={
|
||||
"results": [
|
||||
{
|
||||
"id": 27205,
|
||||
"title": "Inception",
|
||||
"release_date": "2010-07-16",
|
||||
"media_type": "movie",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
elif "external_ids" in url:
|
||||
return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
|
||||
|
||||
|
||||
mock_get.side_effect = mock_get_side_effect
|
||||
|
||||
api_tools.find_media_imdb_id("Inception")
|
||||
@@ -76,7 +88,7 @@ class TestFindMediaImdbId:
|
||||
assert entity["title"] == "Inception"
|
||||
assert mem.stm.current_topic == "searching_media"
|
||||
|
||||
@patch('infrastructure.api.tmdb.client.requests.get')
|
||||
@patch("infrastructure.api.tmdb.client.requests.get")
|
||||
def test_not_found(self, mock_get, memory):
|
||||
"""Should return error when not found."""
|
||||
mock_get.return_value = create_mock_response(200, json_data={"results": []})
|
||||
@@ -86,7 +98,7 @@ class TestFindMediaImdbId:
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "not_found"
|
||||
|
||||
@patch('infrastructure.api.tmdb.client.requests.get')
|
||||
@patch("infrastructure.api.tmdb.client.requests.get")
|
||||
def test_does_not_store_on_error(self, mock_get, memory):
|
||||
"""Should not store in STM on error."""
|
||||
mock_get.return_value = create_mock_response(200, json_data={"results": []})
|
||||
@@ -100,49 +112,57 @@ class TestFindMediaImdbId:
|
||||
class TestFindTorrent:
|
||||
"""Tests for find_torrent tool."""
|
||||
|
||||
@patch('infrastructure.api.knaben.client.requests.post')
|
||||
@patch("infrastructure.api.knaben.client.requests.post")
|
||||
def test_success(self, mock_post, memory):
|
||||
"""Should return torrents on success."""
|
||||
mock_post.return_value = create_mock_response(200, json_data={
|
||||
"hits": [
|
||||
{
|
||||
"title": "Torrent 1",
|
||||
"seeders": 100,
|
||||
"leechers": 10,
|
||||
"magnetUrl": "magnet:?xt=...",
|
||||
"size": "2.5 GB"
|
||||
},
|
||||
{
|
||||
"title": "Torrent 2",
|
||||
"seeders": 50,
|
||||
"leechers": 5,
|
||||
"magnetUrl": "magnet:?xt=...",
|
||||
"size": "1.8 GB"
|
||||
}
|
||||
]
|
||||
})
|
||||
mock_post.return_value = create_mock_response(
|
||||
200,
|
||||
json_data={
|
||||
"hits": [
|
||||
{
|
||||
"title": "Torrent 1",
|
||||
"seeders": 100,
|
||||
"leechers": 10,
|
||||
"magnetUrl": "magnet:?xt=...",
|
||||
"size": "2.5 GB",
|
||||
},
|
||||
{
|
||||
"title": "Torrent 2",
|
||||
"seeders": 50,
|
||||
"leechers": 5,
|
||||
"magnetUrl": "magnet:?xt=...",
|
||||
"size": "1.8 GB",
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
result = api_tools.find_torrent("Inception 1080p")
|
||||
|
||||
assert result["status"] == "ok"
|
||||
assert len(result["torrents"]) == 2
|
||||
|
||||
# Verify HTTP payload
|
||||
payload = mock_post.call_args[1]['json']
|
||||
assert payload['query'] == "Inception 1080p"
|
||||
|
||||
@patch('infrastructure.api.knaben.client.requests.post')
|
||||
# Verify HTTP payload
|
||||
payload = mock_post.call_args[1]["json"]
|
||||
assert payload["query"] == "Inception 1080p"
|
||||
|
||||
@patch("infrastructure.api.knaben.client.requests.post")
|
||||
def test_stores_in_episodic(self, mock_post, memory):
|
||||
"""Should store results in episodic memory."""
|
||||
mock_post.return_value = create_mock_response(200, json_data={
|
||||
"hits": [{
|
||||
"title": "Torrent 1",
|
||||
"seeders": 100,
|
||||
"leechers": 10,
|
||||
"magnetUrl": "magnet:?xt=...",
|
||||
"size": "2.5 GB"
|
||||
}]
|
||||
})
|
||||
mock_post.return_value = create_mock_response(
|
||||
200,
|
||||
json_data={
|
||||
"hits": [
|
||||
{
|
||||
"title": "Torrent 1",
|
||||
"seeders": 100,
|
||||
"leechers": 10,
|
||||
"magnetUrl": "magnet:?xt=...",
|
||||
"size": "2.5 GB",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
api_tools.find_torrent("Inception")
|
||||
|
||||
@@ -151,16 +171,37 @@ class TestFindTorrent:
|
||||
assert mem.episodic.last_search_results["query"] == "Inception"
|
||||
assert mem.stm.current_topic == "selecting_torrent"
|
||||
|
||||
@patch('infrastructure.api.knaben.client.requests.post')
|
||||
@patch("infrastructure.api.knaben.client.requests.post")
|
||||
def test_results_have_indexes(self, mock_post, memory):
|
||||
"""Should add indexes to results."""
|
||||
mock_post.return_value = create_mock_response(200, json_data={
|
||||
"hits": [
|
||||
{"title": "Torrent 1", "seeders": 100, "leechers": 10, "magnetUrl": "magnet:?xt=1", "size": "1GB"},
|
||||
{"title": "Torrent 2", "seeders": 50, "leechers": 5, "magnetUrl": "magnet:?xt=2", "size": "2GB"},
|
||||
{"title": "Torrent 3", "seeders": 25, "leechers": 2, "magnetUrl": "magnet:?xt=3", "size": "3GB"}
|
||||
]
|
||||
})
|
||||
mock_post.return_value = create_mock_response(
|
||||
200,
|
||||
json_data={
|
||||
"hits": [
|
||||
{
|
||||
"title": "Torrent 1",
|
||||
"seeders": 100,
|
||||
"leechers": 10,
|
||||
"magnetUrl": "magnet:?xt=1",
|
||||
"size": "1GB",
|
||||
},
|
||||
{
|
||||
"title": "Torrent 2",
|
||||
"seeders": 50,
|
||||
"leechers": 5,
|
||||
"magnetUrl": "magnet:?xt=2",
|
||||
"size": "2GB",
|
||||
},
|
||||
{
|
||||
"title": "Torrent 3",
|
||||
"seeders": 25,
|
||||
"leechers": 2,
|
||||
"magnetUrl": "magnet:?xt=3",
|
||||
"size": "3GB",
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
api_tools.find_torrent("Test")
|
||||
|
||||
@@ -170,7 +211,7 @@ class TestFindTorrent:
|
||||
assert results[1]["index"] == 2
|
||||
assert results[2]["index"] == 3
|
||||
|
||||
@patch('infrastructure.api.knaben.client.requests.post')
|
||||
@patch("infrastructure.api.knaben.client.requests.post")
|
||||
def test_not_found(self, mock_post, memory):
|
||||
"""Should return error when no torrents found."""
|
||||
mock_post.return_value = create_mock_response(200, json_data={"hits": []})
|
||||
@@ -236,16 +277,16 @@ class TestGetTorrentByIndex:
|
||||
|
||||
class TestAddTorrentToQbittorrent:
|
||||
"""Tests for add_torrent_to_qbittorrent tool.
|
||||
|
||||
|
||||
Note: These tests mock the qBittorrent client because:
|
||||
1. The client requires authentication/session management
|
||||
2. We want to test the tool's logic (memory updates, workflow management)
|
||||
3. The client itself is tested separately in infrastructure tests
|
||||
|
||||
|
||||
This is acceptable mocking because we're testing the TOOL logic, not the client.
|
||||
"""
|
||||
|
||||
@patch('agent.tools.api.qbittorrent_client')
|
||||
@patch("agent.tools.api.qbittorrent_client")
|
||||
def test_success(self, mock_client, memory):
|
||||
"""Should add torrent successfully and update memory."""
|
||||
mock_client.add_torrent.return_value = True
|
||||
@@ -257,7 +298,7 @@ class TestAddTorrentToQbittorrent:
|
||||
# Verify client was called correctly
|
||||
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
|
||||
|
||||
@patch('agent.tools.api.qbittorrent_client')
|
||||
@patch("agent.tools.api.qbittorrent_client")
|
||||
def test_adds_to_active_downloads(self, mock_client, memory_with_search_results):
|
||||
"""Should add to active downloads on success."""
|
||||
mock_client.add_torrent.return_value = True
|
||||
@@ -267,9 +308,12 @@ class TestAddTorrentToQbittorrent:
|
||||
# Test memory update logic
|
||||
mem = get_memory()
|
||||
assert len(mem.episodic.active_downloads) == 1
|
||||
assert mem.episodic.active_downloads[0]["name"] == "Inception.2010.1080p.BluRay.x264"
|
||||
assert (
|
||||
mem.episodic.active_downloads[0]["name"]
|
||||
== "Inception.2010.1080p.BluRay.x264"
|
||||
)
|
||||
|
||||
@patch('agent.tools.api.qbittorrent_client')
|
||||
@patch("agent.tools.api.qbittorrent_client")
|
||||
def test_sets_topic_and_ends_workflow(self, mock_client, memory):
|
||||
"""Should set topic and end workflow."""
|
||||
mock_client.add_torrent.return_value = True
|
||||
@@ -282,10 +326,11 @@ class TestAddTorrentToQbittorrent:
|
||||
assert mem.stm.current_topic == "downloading"
|
||||
assert mem.stm.current_workflow is None
|
||||
|
||||
@patch('agent.tools.api.qbittorrent_client')
|
||||
@patch("agent.tools.api.qbittorrent_client")
|
||||
def test_error_handling(self, mock_client, memory):
|
||||
"""Should handle client errors correctly."""
|
||||
from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError
|
||||
|
||||
mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed")
|
||||
|
||||
result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...")
|
||||
@@ -296,7 +341,7 @@ class TestAddTorrentToQbittorrent:
|
||||
|
||||
class TestAddTorrentByIndex:
|
||||
"""Tests for add_torrent_by_index tool.
|
||||
|
||||
|
||||
These tests verify the tool's logic:
|
||||
- Getting torrent from memory by index
|
||||
- Extracting magnet link
|
||||
@@ -304,7 +349,7 @@ class TestAddTorrentByIndex:
|
||||
- Error handling for edge cases
|
||||
"""
|
||||
|
||||
@patch('agent.tools.api.qbittorrent_client')
|
||||
@patch("agent.tools.api.qbittorrent_client")
|
||||
def test_success(self, mock_client, memory_with_search_results):
|
||||
"""Should get torrent by index and add it."""
|
||||
mock_client.add_torrent.return_value = True
|
||||
@@ -317,7 +362,7 @@ class TestAddTorrentByIndex:
|
||||
# Verify correct magnet was extracted and used
|
||||
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
|
||||
|
||||
@patch('agent.tools.api.qbittorrent_client')
|
||||
@patch("agent.tools.api.qbittorrent_client")
|
||||
def test_uses_correct_magnet(self, mock_client, memory_with_search_results):
|
||||
"""Should extract correct magnet from index."""
|
||||
mock_client.add_torrent.return_value = True
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Edge case tests for tools."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools import api as api_tools
|
||||
from agent.tools import filesystem as fs_tools
|
||||
@@ -15,7 +16,10 @@ class TestFindTorrentEdgeCases:
|
||||
def test_empty_query(self, mock_use_case_class, memory):
|
||||
"""Should handle empty query."""
|
||||
mock_response = Mock()
|
||||
mock_response.to_dict.return_value = {"status": "error", "error": "invalid_query"}
|
||||
mock_response.to_dict.return_value = {
|
||||
"status": "error",
|
||||
"error": "invalid_query",
|
||||
}
|
||||
mock_use_case = Mock()
|
||||
mock_use_case.execute.return_value = mock_response
|
||||
mock_use_case_class.return_value = mock_use_case
|
||||
@@ -28,7 +32,11 @@ class TestFindTorrentEdgeCases:
|
||||
def test_very_long_query(self, mock_use_case_class, memory):
|
||||
"""Should handle very long query."""
|
||||
mock_response = Mock()
|
||||
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
|
||||
mock_response.to_dict.return_value = {
|
||||
"status": "ok",
|
||||
"torrents": [],
|
||||
"count": 0,
|
||||
}
|
||||
mock_use_case = Mock()
|
||||
mock_use_case.execute.return_value = mock_response
|
||||
mock_use_case_class.return_value = mock_use_case
|
||||
@@ -43,7 +51,11 @@ class TestFindTorrentEdgeCases:
|
||||
def test_special_characters_in_query(self, mock_use_case_class, memory):
|
||||
"""Should handle special characters in query."""
|
||||
mock_response = Mock()
|
||||
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
|
||||
mock_response.to_dict.return_value = {
|
||||
"status": "ok",
|
||||
"torrents": [],
|
||||
"count": 0,
|
||||
}
|
||||
mock_use_case = Mock()
|
||||
mock_use_case.execute.return_value = mock_response
|
||||
mock_use_case_class.return_value = mock_use_case
|
||||
@@ -57,7 +69,11 @@ class TestFindTorrentEdgeCases:
|
||||
def test_unicode_query(self, mock_use_case_class, memory):
|
||||
"""Should handle unicode in query."""
|
||||
mock_response = Mock()
|
||||
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0}
|
||||
mock_response.to_dict.return_value = {
|
||||
"status": "ok",
|
||||
"torrents": [],
|
||||
"count": 0,
|
||||
}
|
||||
mock_use_case = Mock()
|
||||
mock_use_case.execute.return_value = mock_response
|
||||
mock_use_case_class.return_value = mock_use_case
|
||||
@@ -161,7 +177,10 @@ class TestAddTorrentEdgeCases:
|
||||
def test_empty_magnet_link(self, mock_use_case_class, memory):
|
||||
"""Should handle empty magnet link."""
|
||||
mock_response = Mock()
|
||||
mock_response.to_dict.return_value = {"status": "error", "error": "empty_magnet"}
|
||||
mock_response.to_dict.return_value = {
|
||||
"status": "error",
|
||||
"error": "empty_magnet",
|
||||
}
|
||||
mock_use_case = Mock()
|
||||
mock_use_case.execute.return_value = mock_response
|
||||
mock_use_case_class.return_value = mock_use_case
|
||||
@@ -326,7 +345,10 @@ class TestFilesystemEdgeCases:
|
||||
for attempt in attempts:
|
||||
result = fs_tools.list_folder("download", attempt)
|
||||
# Should either be forbidden or not found
|
||||
assert result.get("error") in ["forbidden", "not_found", None] or result.get("status") == "ok"
|
||||
assert (
|
||||
result.get("error") in ["forbidden", "not_found", None]
|
||||
or result.get("status") == "ok"
|
||||
)
|
||||
|
||||
def test_path_with_null_byte(self, memory, real_folder):
|
||||
"""Should block null byte injection."""
|
||||
|
||||
Reference in New Issue
Block a user