Files
alfred/tests/test_memory.py
Francwa 9ca31e45e0 feat!: migrate to OpenAI native tool calls and fix circular deps (#fuck-gemini)
- Fix circular dependencies in agent/tools
- Migrate from custom JSON to OpenAI tool calls format
- Add async streaming (step_stream, complete_stream)
- Simplify prompt system and remove token counting
- Add 5 new API endpoints (/health, /v1/models, /api/memory/*)
- Add 3 new tools (get_torrent_by_index, add_torrent_by_index, set_language)
- Fix all 500 tests and add coverage config (80% threshold)
- Add comprehensive docs (README, pytest guide)

BREAKING: LLM interface changed, memory injection via get_memory()
2025-12-06 19:11:05 +01:00

697 lines
22 KiB
Python

"""Tests for the Memory system."""
import json
import pytest
from infrastructure.persistence import (
EpisodicMemory,
LongTermMemory,
Memory,
ShortTermMemory,
get_memory,
has_memory,
init_memory,
set_memory,
)
from infrastructure.persistence.context import _memory_ctx
class TestLongTermMemory:
"""Tests for LongTermMemory."""
def test_default_values(self):
"""LTM should have sensible defaults."""
ltm = LongTermMemory()
assert ltm.config == {}
assert ltm.preferences["preferred_quality"] == "1080p"
assert "en" in ltm.preferences["preferred_languages"]
assert ltm.library == {"movies": [], "tv_shows": []}
assert ltm.following == []
def test_set_and_get_config(self):
"""Should set and retrieve config values."""
ltm = LongTermMemory()
ltm.set_config("download_folder", "/path/to/downloads")
assert ltm.get_config("download_folder") == "/path/to/downloads"
def test_get_config_default(self):
"""Should return default for missing config."""
ltm = LongTermMemory()
assert ltm.get_config("nonexistent") is None
assert ltm.get_config("nonexistent", "default") == "default"
def test_has_config(self):
"""Should check if config exists."""
ltm = LongTermMemory()
assert not ltm.has_config("download_folder")
ltm.set_config("download_folder", "/path")
assert ltm.has_config("download_folder")
def test_has_config_none_value(self):
"""Should return False for None values."""
ltm = LongTermMemory()
ltm.config["key"] = None
assert not ltm.has_config("key")
def test_add_to_library(self):
"""Should add media to library."""
ltm = LongTermMemory()
movie = {"imdb_id": "tt1375666", "title": "Inception"}
ltm.add_to_library("movies", movie)
assert len(ltm.library["movies"]) == 1
assert ltm.library["movies"][0]["title"] == "Inception"
assert "added_at" in ltm.library["movies"][0]
def test_add_to_library_no_duplicates(self):
"""Should not add duplicate media."""
ltm = LongTermMemory()
movie = {"imdb_id": "tt1375666", "title": "Inception"}
ltm.add_to_library("movies", movie)
ltm.add_to_library("movies", movie)
assert len(ltm.library["movies"]) == 1
def test_add_to_library_new_type(self):
"""Should create new media type if not exists."""
ltm = LongTermMemory()
subtitle = {"imdb_id": "tt1375666", "language": "en"}
ltm.add_to_library("subtitles", subtitle)
assert "subtitles" in ltm.library
assert len(ltm.library["subtitles"]) == 1
def test_get_library(self):
"""Should get library for media type."""
ltm = LongTermMemory()
ltm.add_to_library("movies", {"imdb_id": "tt1", "title": "Movie 1"})
ltm.add_to_library("movies", {"imdb_id": "tt2", "title": "Movie 2"})
movies = ltm.get_library("movies")
assert len(movies) == 2
def test_get_library_empty(self):
"""Should return empty list for unknown type."""
ltm = LongTermMemory()
assert ltm.get_library("unknown") == []
def test_follow_show(self):
"""Should add show to following list."""
ltm = LongTermMemory()
show = {"imdb_id": "tt0944947", "title": "Game of Thrones"}
ltm.follow_show(show)
assert len(ltm.following) == 1
assert ltm.following[0]["title"] == "Game of Thrones"
assert "followed_at" in ltm.following[0]
def test_follow_show_no_duplicates(self):
"""Should not follow same show twice."""
ltm = LongTermMemory()
show = {"imdb_id": "tt0944947", "title": "Game of Thrones"}
ltm.follow_show(show)
ltm.follow_show(show)
assert len(ltm.following) == 1
def test_to_dict(self):
"""Should serialize to dict."""
ltm = LongTermMemory()
ltm.set_config("key", "value")
data = ltm.to_dict()
assert "config" in data
assert "preferences" in data
assert "library" in data
assert "following" in data
assert data["config"]["key"] == "value"
def test_from_dict(self):
"""Should deserialize from dict."""
data = {
"config": {"download_folder": "/downloads"},
"preferences": {"preferred_quality": "4K"},
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
"following": [],
}
ltm = LongTermMemory.from_dict(data)
assert ltm.get_config("download_folder") == "/downloads"
assert ltm.preferences["preferred_quality"] == "4K"
assert len(ltm.library["movies"]) == 1
def test_from_dict_missing_keys(self):
"""Should handle missing keys with defaults."""
ltm = LongTermMemory.from_dict({})
assert ltm.config == {}
assert ltm.preferences["preferred_quality"] == "1080p"
class TestShortTermMemory:
"""Tests for ShortTermMemory."""
def test_default_values(self):
"""STM should start empty."""
stm = ShortTermMemory()
assert stm.conversation_history == []
assert stm.current_workflow is None
assert stm.extracted_entities == {}
assert stm.current_topic is None
def test_add_message(self):
"""Should add message to history."""
stm = ShortTermMemory()
stm.add_message("user", "Hello")
assert len(stm.conversation_history) == 1
assert stm.conversation_history[0]["role"] == "user"
assert stm.conversation_history[0]["content"] == "Hello"
assert "timestamp" in stm.conversation_history[0]
def test_add_message_max_history(self):
"""Should limit history to max_history."""
stm = ShortTermMemory()
stm.max_history = 5
for i in range(10):
stm.add_message("user", f"Message {i}")
assert len(stm.conversation_history) == 5
assert stm.conversation_history[0]["content"] == "Message 5"
def test_get_recent_history(self):
"""Should get last N messages."""
stm = ShortTermMemory()
for i in range(10):
stm.add_message("user", f"Message {i}")
recent = stm.get_recent_history(3)
assert len(recent) == 3
assert recent[0]["content"] == "Message 7"
def test_get_recent_history_less_than_n(self):
"""Should return all if less than N messages."""
stm = ShortTermMemory()
stm.add_message("user", "Hello")
stm.add_message("assistant", "Hi")
recent = stm.get_recent_history(10)
assert len(recent) == 2
def test_start_workflow(self):
"""Should start a workflow."""
stm = ShortTermMemory()
stm.start_workflow("download", {"title": "Inception"})
assert stm.current_workflow is not None
assert stm.current_workflow["type"] == "download"
assert stm.current_workflow["target"]["title"] == "Inception"
assert stm.current_workflow["stage"] == "started"
def test_update_workflow_stage(self):
"""Should update workflow stage."""
stm = ShortTermMemory()
stm.start_workflow("download", {"title": "Inception"})
stm.update_workflow_stage("searching")
assert stm.current_workflow["stage"] == "searching"
def test_update_workflow_stage_no_workflow(self):
"""Should do nothing if no workflow."""
stm = ShortTermMemory()
stm.update_workflow_stage("searching") # Should not raise
assert stm.current_workflow is None
def test_end_workflow(self):
"""Should end workflow."""
stm = ShortTermMemory()
stm.start_workflow("download", {"title": "Inception"})
stm.end_workflow()
assert stm.current_workflow is None
def test_set_and_get_entity(self):
"""Should set and get entities."""
stm = ShortTermMemory()
stm.set_entity("movie_title", "Inception")
stm.set_entity("year", 2010)
assert stm.get_entity("movie_title") == "Inception"
assert stm.get_entity("year") == 2010
def test_get_entity_default(self):
"""Should return default for missing entity."""
stm = ShortTermMemory()
assert stm.get_entity("nonexistent") is None
assert stm.get_entity("nonexistent", "default") == "default"
def test_clear_entities(self):
"""Should clear all entities."""
stm = ShortTermMemory()
stm.set_entity("key1", "value1")
stm.set_entity("key2", "value2")
stm.clear_entities()
assert stm.extracted_entities == {}
def test_set_topic(self):
"""Should set current topic."""
stm = ShortTermMemory()
stm.set_topic("searching_movie")
assert stm.current_topic == "searching_movie"
def test_clear(self):
"""Should clear all STM data."""
stm = ShortTermMemory()
stm.add_message("user", "Hello")
stm.start_workflow("download", {})
stm.set_entity("key", "value")
stm.set_topic("topic")
stm.clear()
assert stm.conversation_history == []
assert stm.current_workflow is None
assert stm.extracted_entities == {}
assert stm.current_topic is None
def test_to_dict(self):
"""Should serialize to dict."""
stm = ShortTermMemory()
stm.add_message("user", "Hello")
stm.set_topic("test")
data = stm.to_dict()
assert "conversation_history" in data
assert "current_workflow" in data
assert "extracted_entities" in data
assert "current_topic" in data
class TestEpisodicMemory:
"""Tests for EpisodicMemory."""
def test_default_values(self):
"""Episodic should start empty."""
episodic = EpisodicMemory()
assert episodic.last_search_results is None
assert episodic.active_downloads == []
assert episodic.recent_errors == []
assert episodic.pending_question is None
assert episodic.background_events == []
def test_store_search_results(self):
"""Should store search results with indexes."""
episodic = EpisodicMemory()
results = [
{"name": "Result 1", "seeders": 100},
{"name": "Result 2", "seeders": 50},
]
episodic.store_search_results("test query", results)
assert episodic.last_search_results is not None
assert episodic.last_search_results["query"] == "test query"
assert len(episodic.last_search_results["results"]) == 2
assert episodic.last_search_results["results"][0]["index"] == 1
assert episodic.last_search_results["results"][1]["index"] == 2
def test_get_result_by_index(self):
"""Should get result by 1-based index."""
episodic = EpisodicMemory()
results = [
{"name": "Result 1"},
{"name": "Result 2"},
{"name": "Result 3"},
]
episodic.store_search_results("query", results)
result = episodic.get_result_by_index(2)
assert result is not None
assert result["name"] == "Result 2"
def test_get_result_by_index_not_found(self):
"""Should return None for invalid index."""
episodic = EpisodicMemory()
results = [{"name": "Result 1"}]
episodic.store_search_results("query", results)
assert episodic.get_result_by_index(5) is None
assert episodic.get_result_by_index(0) is None
assert episodic.get_result_by_index(-1) is None
def test_get_result_by_index_no_results(self):
"""Should return None if no search results."""
episodic = EpisodicMemory()
assert episodic.get_result_by_index(1) is None
def test_clear_search_results(self):
"""Should clear search results."""
episodic = EpisodicMemory()
episodic.store_search_results("query", [{"name": "Result"}])
episodic.clear_search_results()
assert episodic.last_search_results is None
def test_add_active_download(self):
"""Should add download with timestamp."""
episodic = EpisodicMemory()
episodic.add_active_download(
{
"task_id": "123",
"name": "Test Movie",
"magnet": "magnet:?xt=...",
}
)
assert len(episodic.active_downloads) == 1
assert episodic.active_downloads[0]["name"] == "Test Movie"
assert "started_at" in episodic.active_downloads[0]
def test_update_download_progress(self):
"""Should update download progress."""
episodic = EpisodicMemory()
episodic.add_active_download({"task_id": "123", "name": "Test"})
episodic.update_download_progress("123", 50, "downloading")
assert episodic.active_downloads[0]["progress"] == 50
assert episodic.active_downloads[0]["status"] == "downloading"
def test_update_download_progress_not_found(self):
"""Should do nothing for unknown task_id."""
episodic = EpisodicMemory()
episodic.add_active_download({"task_id": "123", "name": "Test"})
episodic.update_download_progress("999", 50) # Should not raise
assert episodic.active_downloads[0].get("progress") is None
def test_complete_download(self):
"""Should complete download and add event."""
episodic = EpisodicMemory()
episodic.add_active_download({"task_id": "123", "name": "Test Movie"})
completed = episodic.complete_download("123", "/path/to/file.mkv")
assert len(episodic.active_downloads) == 0
assert completed["status"] == "completed"
assert completed["file_path"] == "/path/to/file.mkv"
assert len(episodic.background_events) == 1
assert episodic.background_events[0]["type"] == "download_complete"
def test_complete_download_not_found(self):
"""Should return None for unknown task_id."""
episodic = EpisodicMemory()
result = episodic.complete_download("999", "/path")
assert result is None
def test_add_error(self):
"""Should add error with timestamp."""
episodic = EpisodicMemory()
episodic.add_error("find_torrent", "API timeout", {"query": "test"})
assert len(episodic.recent_errors) == 1
assert episodic.recent_errors[0]["action"] == "find_torrent"
assert episodic.recent_errors[0]["error"] == "API timeout"
def test_add_error_max_limit(self):
"""Should limit errors to max_errors."""
episodic = EpisodicMemory()
episodic.max_errors = 3
for i in range(5):
episodic.add_error("action", f"Error {i}")
assert len(episodic.recent_errors) == 3
assert episodic.recent_errors[0]["error"] == "Error 2"
def test_set_pending_question(self):
"""Should set pending question."""
episodic = EpisodicMemory()
options = [
{"index": 1, "label": "Option 1"},
{"index": 2, "label": "Option 2"},
]
episodic.set_pending_question(
"Which one?",
options,
{"context": "test"},
"choice",
)
assert episodic.pending_question is not None
assert episodic.pending_question["question"] == "Which one?"
assert len(episodic.pending_question["options"]) == 2
def test_resolve_pending_question(self):
"""Should resolve question and return chosen option."""
episodic = EpisodicMemory()
options = [
{"index": 1, "label": "Option 1"},
{"index": 2, "label": "Option 2"},
]
episodic.set_pending_question("Which?", options, {})
result = episodic.resolve_pending_question(2)
assert result["label"] == "Option 2"
assert episodic.pending_question is None
def test_resolve_pending_question_cancel(self):
"""Should cancel question if no index."""
episodic = EpisodicMemory()
episodic.set_pending_question("Which?", [], {})
result = episodic.resolve_pending_question(None)
assert result is None
assert episodic.pending_question is None
def test_add_background_event(self):
"""Should add background event."""
episodic = EpisodicMemory()
episodic.add_background_event("download_complete", {"name": "Movie"})
assert len(episodic.background_events) == 1
assert episodic.background_events[0]["type"] == "download_complete"
assert episodic.background_events[0]["read"] is False
def test_add_background_event_max_limit(self):
"""Should limit events to max_events."""
episodic = EpisodicMemory()
episodic.max_events = 3
for i in range(5):
episodic.add_background_event("event", {"i": i})
assert len(episodic.background_events) == 3
def test_get_unread_events(self):
"""Should get unread events and mark as read."""
episodic = EpisodicMemory()
episodic.add_background_event("event1", {})
episodic.add_background_event("event2", {})
unread = episodic.get_unread_events()
assert len(unread) == 2
assert all(e["read"] for e in episodic.background_events)
def test_get_unread_events_already_read(self):
"""Should not return already read events."""
episodic = EpisodicMemory()
episodic.add_background_event("event1", {})
episodic.get_unread_events() # Mark as read
episodic.add_background_event("event2", {})
unread = episodic.get_unread_events()
assert len(unread) == 1
assert unread[0]["type"] == "event2"
def test_clear(self):
"""Should clear all episodic data."""
episodic = EpisodicMemory()
episodic.store_search_results("query", [{}])
episodic.add_active_download({"task_id": "1", "name": "Test"})
episodic.add_error("action", "error")
episodic.set_pending_question("?", [], {})
episodic.add_background_event("event", {})
episodic.clear()
assert episodic.last_search_results is None
assert episodic.active_downloads == []
assert episodic.recent_errors == []
assert episodic.pending_question is None
assert episodic.background_events == []
class TestMemory:
"""Tests for the Memory manager."""
def test_init_creates_directories(self, temp_dir):
"""Should create storage directory."""
storage = temp_dir / "memory_data"
memory = Memory(storage_dir=str(storage))
assert storage.exists()
def test_init_loads_existing_ltm(self, temp_dir):
"""Should load existing LTM from file."""
ltm_file = temp_dir / "ltm.json"
ltm_file.write_text(
json.dumps(
{
"config": {"download_folder": "/downloads"},
"preferences": {"preferred_quality": "4K"},
"library": {"movies": []},
"following": [],
}
)
)
memory = Memory(storage_dir=str(temp_dir))
assert memory.ltm.get_config("download_folder") == "/downloads"
assert memory.ltm.preferences["preferred_quality"] == "4K"
def test_init_handles_corrupted_ltm(self, temp_dir):
"""Should handle corrupted LTM file."""
ltm_file = temp_dir / "ltm.json"
ltm_file.write_text("not valid json {{{")
memory = Memory(storage_dir=str(temp_dir))
assert memory.ltm.config == {} # Default values
def test_save(self, temp_dir):
"""Should save LTM to file."""
memory = Memory(storage_dir=str(temp_dir))
memory.ltm.set_config("test_key", "test_value")
memory.save()
ltm_file = temp_dir / "ltm.json"
assert ltm_file.exists()
data = json.loads(ltm_file.read_text())
assert data["config"]["test_key"] == "test_value"
def test_get_context_for_prompt(self, memory_with_search_results):
"""Should generate context for prompt."""
context = memory_with_search_results.get_context_for_prompt()
assert "config" in context
assert "preferences" in context
assert context["last_search"]["query"] == "Inception 1080p"
assert context["last_search"]["result_count"] == 3
def test_get_full_state(self, memory):
"""Should return full state of all memories."""
state = memory.get_full_state()
assert "ltm" in state
assert "stm" in state
assert "episodic" in state
def test_clear_session(self, memory_with_search_results):
"""Should clear STM and Episodic but keep LTM."""
memory_with_search_results.ltm.set_config("key", "value")
memory_with_search_results.stm.add_message("user", "Hello")
memory_with_search_results.clear_session()
assert memory_with_search_results.ltm.get_config("key") == "value"
assert memory_with_search_results.stm.conversation_history == []
assert memory_with_search_results.episodic.last_search_results is None
class TestMemoryContext:
"""Tests for memory context functions."""
def test_init_memory(self, temp_dir):
"""Should initialize and set memory in context."""
_memory_ctx.set(None) # Reset context
memory = init_memory(str(temp_dir))
assert memory is not None
assert has_memory()
assert get_memory() is memory
def test_set_memory(self, temp_dir):
"""Should set existing memory in context."""
_memory_ctx.set(None)
memory = Memory(storage_dir=str(temp_dir))
set_memory(memory)
assert get_memory() is memory
def test_get_memory_not_initialized(self):
"""Should raise if memory not initialized."""
_memory_ctx.set(None)
with pytest.raises(RuntimeError, match="Memory not initialized"):
get_memory()
def test_has_memory(self, temp_dir):
"""Should check if memory is initialized."""
_memory_ctx.set(None)
assert not has_memory()
init_memory(str(temp_dir))
assert has_memory()