240 lines
8.1 KiB
Python
240 lines
8.1 KiB
Python
"""Tests for the Memory system."""
|
|
import pytest
|
|
from datetime import datetime
|
|
|
|
from alfred.infrastructure.persistence import (
|
|
EpisodicMemory,
|
|
LongTermMemory,
|
|
Memory,
|
|
ShortTermMemory,
|
|
get_memory,
|
|
has_memory,
|
|
init_memory,
|
|
)
|
|
from alfred.infrastructure.persistence.context import _memory_ctx
|
|
|
|
|
|
def is_iso_format(s: str) -> bool:
|
|
"""Helper to check if a string is a valid ISO 8601 timestamp."""
|
|
if not isinstance(s, str):
|
|
return False
|
|
try:
|
|
# Attempt to parse the string as an ISO 8601 timestamp
|
|
datetime.fromisoformat(s.replace("Z", "+00:00"))
|
|
return True
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
|
|
class TestLongTermMemory:
|
|
"""Tests for LongTermMemory."""
|
|
|
|
def test_default_values(self):
|
|
ltm = LongTermMemory()
|
|
assert ltm.config == {}
|
|
assert ltm.preferences["preferred_quality"] == "1080p"
|
|
assert "en" in ltm.preferences["preferred_languages"]
|
|
assert ltm.library == {"movies": [], "tv_shows": []}
|
|
assert ltm.following == []
|
|
|
|
def test_set_and_get_config(self):
|
|
ltm = LongTermMemory()
|
|
ltm.set_config("download_folder", "/path/to/downloads")
|
|
assert ltm.get_config("download_folder") == "/path/to/downloads"
|
|
|
|
def test_get_config_default(self):
|
|
ltm = LongTermMemory()
|
|
assert ltm.get_config("nonexistent") is None
|
|
assert ltm.get_config("nonexistent", "default") == "default"
|
|
|
|
def test_has_config(self):
|
|
ltm = LongTermMemory()
|
|
assert not ltm.has_config("download_folder")
|
|
ltm.set_config("download_folder", "/path")
|
|
assert ltm.has_config("download_folder")
|
|
|
|
def test_has_config_none_value(self):
|
|
ltm = LongTermMemory()
|
|
ltm.config["key"] = None
|
|
assert not ltm.has_config("key")
|
|
|
|
def test_add_to_library(self):
|
|
ltm = LongTermMemory()
|
|
movie = {"imdb_id": "tt1375666", "title": "Inception"}
|
|
ltm.add_to_library("movies", movie)
|
|
assert len(ltm.library["movies"]) == 1
|
|
assert ltm.library["movies"][0]["title"] == "Inception"
|
|
assert is_iso_format(ltm.library["movies"][0].get("added_at"))
|
|
|
|
def test_add_to_library_no_duplicates(self):
|
|
ltm = LongTermMemory()
|
|
movie = {"imdb_id": "tt1375666", "title": "Inception"}
|
|
ltm.add_to_library("movies", movie)
|
|
ltm.add_to_library("movies", movie)
|
|
assert len(ltm.library["movies"]) == 1
|
|
|
|
def test_add_to_library_new_type(self):
|
|
ltm = LongTermMemory()
|
|
subtitle = {"imdb_id": "tt1375666", "language": "en"}
|
|
ltm.add_to_library("subtitles", subtitle)
|
|
assert "subtitles" in ltm.library
|
|
assert len(ltm.library["subtitles"]) == 1
|
|
|
|
def test_get_library(self):
|
|
ltm = LongTermMemory()
|
|
ltm.add_to_library("movies", {"imdb_id": "tt1", "title": "Movie 1"})
|
|
ltm.add_to_library("movies", {"imdb_id": "tt2", "title": "Movie 2"})
|
|
movies = ltm.get_library("movies")
|
|
assert len(movies) == 2
|
|
|
|
def test_get_library_empty(self):
|
|
ltm = LongTermMemory()
|
|
assert ltm.get_library("unknown") == []
|
|
|
|
def test_follow_show(self):
|
|
ltm = LongTermMemory()
|
|
show = {"imdb_id": "tt0944947", "title": "Game of Thrones"}
|
|
ltm.follow_show(show)
|
|
assert len(ltm.following) == 1
|
|
assert ltm.following[0]["title"] == "Game of Thrones"
|
|
assert is_iso_format(ltm.following[0].get("followed_at"))
|
|
|
|
def test_follow_show_no_duplicates(self):
|
|
ltm = LongTermMemory()
|
|
show = {"imdb_id": "tt0944947", "title": "Game of Thrones"}
|
|
ltm.follow_show(show)
|
|
ltm.follow_show(show)
|
|
assert len(ltm.following) == 1
|
|
|
|
def test_to_dict(self):
|
|
ltm = LongTermMemory()
|
|
ltm.set_config("key", "value")
|
|
data = ltm.to_dict()
|
|
assert "config" in data
|
|
assert data["config"]["key"] == "value"
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"config": {"download_folder": "/downloads"},
|
|
"preferences": {"preferred_quality": "4K"},
|
|
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
|
|
"following": [],
|
|
}
|
|
ltm = LongTermMemory.from_dict(data)
|
|
assert ltm.get_config("download_folder") == "/downloads"
|
|
assert ltm.preferences["preferred_quality"] == "4K"
|
|
assert len(ltm.library["movies"]) == 1
|
|
|
|
|
|
class TestShortTermMemory:
|
|
"""Tests for ShortTermMemory."""
|
|
|
|
def test_default_values(self):
|
|
stm = ShortTermMemory()
|
|
assert stm.conversation_history == []
|
|
assert stm.current_workflow is None
|
|
assert stm.extracted_entities == {}
|
|
assert stm.current_topic is None
|
|
assert stm.language == "en"
|
|
|
|
def test_add_message(self):
|
|
stm = ShortTermMemory()
|
|
stm.add_message("user", "Hello")
|
|
assert len(stm.conversation_history) == 1
|
|
assert is_iso_format(stm.conversation_history[0].get("timestamp"))
|
|
|
|
def test_add_message_max_history(self):
|
|
stm = ShortTermMemory(max_history=5)
|
|
for i in range(10):
|
|
stm.add_message("user", f"Message {i}")
|
|
assert len(stm.conversation_history) == 5
|
|
assert stm.conversation_history[0]["content"] == "Message 5"
|
|
|
|
def test_language_management(self):
|
|
stm = ShortTermMemory()
|
|
assert stm.language == "en"
|
|
stm.set_language("fr")
|
|
assert stm.language == "fr"
|
|
stm.clear()
|
|
assert stm.language == "en"
|
|
|
|
def test_clear(self):
|
|
stm = ShortTermMemory()
|
|
stm.add_message("user", "Hello")
|
|
stm.set_language("fr")
|
|
stm.clear()
|
|
assert stm.conversation_history == []
|
|
assert stm.language == "en"
|
|
|
|
|
|
class TestEpisodicMemory:
|
|
"""Tests for EpisodicMemory."""
|
|
|
|
def test_add_error(self):
|
|
episodic = EpisodicMemory()
|
|
episodic.add_error("find_torrent", "API timeout")
|
|
assert len(episodic.recent_errors) == 1
|
|
assert is_iso_format(episodic.recent_errors[0].get("timestamp"))
|
|
|
|
def test_add_error_max_limit(self):
|
|
episodic = EpisodicMemory(max_errors=3)
|
|
for i in range(5):
|
|
episodic.add_error("action", f"Error {i}")
|
|
assert len(episodic.recent_errors) == 3
|
|
error_messages = [e["error"] for e in episodic.recent_errors]
|
|
assert error_messages == ["Error 2", "Error 3", "Error 4"]
|
|
|
|
def test_store_search_results(self):
|
|
episodic = EpisodicMemory()
|
|
episodic.store_search_results("test query", [])
|
|
assert is_iso_format(episodic.last_search_results.get("timestamp"))
|
|
|
|
def test_get_result_by_index(self):
|
|
episodic = EpisodicMemory()
|
|
results = [{"name": "Result 1"}, {"name": "Result 2"}]
|
|
episodic.store_search_results("query", results)
|
|
result = episodic.get_result_by_index(2)
|
|
assert result is not None
|
|
assert result["name"] == "Result 2"
|
|
|
|
|
|
class TestMemory:
|
|
"""Tests for the Memory manager."""
|
|
|
|
def test_init_creates_directories(self, temp_dir):
|
|
storage = temp_dir / "memory_data"
|
|
Memory(storage_dir=str(storage))
|
|
assert storage.exists()
|
|
|
|
def test_save_and_load_ltm(self, temp_dir):
|
|
storage = str(temp_dir)
|
|
memory = Memory(storage_dir=storage)
|
|
memory.ltm.set_config("test_key", "test_value")
|
|
memory.save()
|
|
new_memory = Memory(storage_dir=storage)
|
|
assert new_memory.ltm.get_config("test_key") == "test_value"
|
|
|
|
def test_clear_session(self, memory):
|
|
memory.ltm.set_config("key", "value")
|
|
memory.stm.add_message("user", "Hello")
|
|
memory.episodic.add_error("action", "error")
|
|
memory.clear_session()
|
|
assert memory.ltm.get_config("key") == "value"
|
|
assert memory.stm.conversation_history == []
|
|
assert memory.episodic.recent_errors == []
|
|
|
|
|
|
class TestMemoryContext:
|
|
"""Tests for memory context functions."""
|
|
|
|
def test_get_memory_not_initialized(self):
|
|
_memory_ctx.set(None)
|
|
with pytest.raises(RuntimeError, match="Memory not initialized"):
|
|
get_memory()
|
|
|
|
def test_init_memory(self, temp_dir):
|
|
_memory_ctx.set(None)
|
|
memory = init_memory(str(temp_dir))
|
|
assert has_memory()
|
|
assert get_memory() is memory
|