infra: reorganized repo
This commit is contained in:
241
tests/test_memory.py
Normal file
241
tests/test_memory.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Tests for the Memory system."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.persistence import (
|
||||
EpisodicMemory,
|
||||
LongTermMemory,
|
||||
Memory,
|
||||
ShortTermMemory,
|
||||
get_memory,
|
||||
has_memory,
|
||||
init_memory,
|
||||
)
|
||||
from infrastructure.persistence.context import _memory_ctx
|
||||
|
||||
|
||||
def is_iso_format(s: str) -> bool:
|
||||
"""Helper to check if a string is a valid ISO 8601 timestamp."""
|
||||
if not isinstance(s, str):
|
||||
return False
|
||||
try:
|
||||
# Attempt to parse the string as an ISO 8601 timestamp
|
||||
datetime.fromisoformat(s.replace("Z", "+00:00"))
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
class TestLongTermMemory:
|
||||
"""Tests for LongTermMemory."""
|
||||
|
||||
def test_default_values(self):
|
||||
ltm = LongTermMemory()
|
||||
assert ltm.config == {}
|
||||
assert ltm.preferences["preferred_quality"] == "1080p"
|
||||
assert "en" in ltm.preferences["preferred_languages"]
|
||||
assert ltm.library == {"movies": [], "tv_shows": []}
|
||||
assert ltm.following == []
|
||||
|
||||
def test_set_and_get_config(self):
|
||||
ltm = LongTermMemory()
|
||||
ltm.set_config("download_folder", "/path/to/downloads")
|
||||
assert ltm.get_config("download_folder") == "/path/to/downloads"
|
||||
|
||||
def test_get_config_default(self):
|
||||
ltm = LongTermMemory()
|
||||
assert ltm.get_config("nonexistent") is None
|
||||
assert ltm.get_config("nonexistent", "default") == "default"
|
||||
|
||||
def test_has_config(self):
|
||||
ltm = LongTermMemory()
|
||||
assert not ltm.has_config("download_folder")
|
||||
ltm.set_config("download_folder", "/path")
|
||||
assert ltm.has_config("download_folder")
|
||||
|
||||
def test_has_config_none_value(self):
|
||||
ltm = LongTermMemory()
|
||||
ltm.config["key"] = None
|
||||
assert not ltm.has_config("key")
|
||||
|
||||
def test_add_to_library(self):
|
||||
ltm = LongTermMemory()
|
||||
movie = {"imdb_id": "tt1375666", "title": "Inception"}
|
||||
ltm.add_to_library("movies", movie)
|
||||
assert len(ltm.library["movies"]) == 1
|
||||
assert ltm.library["movies"][0]["title"] == "Inception"
|
||||
assert is_iso_format(ltm.library["movies"][0].get("added_at"))
|
||||
|
||||
def test_add_to_library_no_duplicates(self):
|
||||
ltm = LongTermMemory()
|
||||
movie = {"imdb_id": "tt1375666", "title": "Inception"}
|
||||
ltm.add_to_library("movies", movie)
|
||||
ltm.add_to_library("movies", movie)
|
||||
assert len(ltm.library["movies"]) == 1
|
||||
|
||||
def test_add_to_library_new_type(self):
|
||||
ltm = LongTermMemory()
|
||||
subtitle = {"imdb_id": "tt1375666", "language": "en"}
|
||||
ltm.add_to_library("subtitles", subtitle)
|
||||
assert "subtitles" in ltm.library
|
||||
assert len(ltm.library["subtitles"]) == 1
|
||||
|
||||
def test_get_library(self):
|
||||
ltm = LongTermMemory()
|
||||
ltm.add_to_library("movies", {"imdb_id": "tt1", "title": "Movie 1"})
|
||||
ltm.add_to_library("movies", {"imdb_id": "tt2", "title": "Movie 2"})
|
||||
movies = ltm.get_library("movies")
|
||||
assert len(movies) == 2
|
||||
|
||||
def test_get_library_empty(self):
|
||||
ltm = LongTermMemory()
|
||||
assert ltm.get_library("unknown") == []
|
||||
|
||||
def test_follow_show(self):
|
||||
ltm = LongTermMemory()
|
||||
show = {"imdb_id": "tt0944947", "title": "Game of Thrones"}
|
||||
ltm.follow_show(show)
|
||||
assert len(ltm.following) == 1
|
||||
assert ltm.following[0]["title"] == "Game of Thrones"
|
||||
assert is_iso_format(ltm.following[0].get("followed_at"))
|
||||
|
||||
def test_follow_show_no_duplicates(self):
|
||||
ltm = LongTermMemory()
|
||||
show = {"imdb_id": "tt0944947", "title": "Game of Thrones"}
|
||||
ltm.follow_show(show)
|
||||
ltm.follow_show(show)
|
||||
assert len(ltm.following) == 1
|
||||
|
||||
def test_to_dict(self):
|
||||
ltm = LongTermMemory()
|
||||
ltm.set_config("key", "value")
|
||||
data = ltm.to_dict()
|
||||
assert "config" in data
|
||||
assert data["config"]["key"] == "value"
|
||||
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"config": {"download_folder": "/downloads"},
|
||||
"preferences": {"preferred_quality": "4K"},
|
||||
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
|
||||
"following": [],
|
||||
}
|
||||
ltm = LongTermMemory.from_dict(data)
|
||||
assert ltm.get_config("download_folder") == "/downloads"
|
||||
assert ltm.preferences["preferred_quality"] == "4K"
|
||||
assert len(ltm.library["movies"]) == 1
|
||||
|
||||
|
||||
class TestShortTermMemory:
|
||||
"""Tests for ShortTermMemory."""
|
||||
|
||||
def test_default_values(self):
|
||||
stm = ShortTermMemory()
|
||||
assert stm.conversation_history == []
|
||||
assert stm.current_workflow is None
|
||||
assert stm.extracted_entities == {}
|
||||
assert stm.current_topic is None
|
||||
assert stm.language == "en"
|
||||
|
||||
def test_add_message(self):
|
||||
stm = ShortTermMemory()
|
||||
stm.add_message("user", "Hello")
|
||||
assert len(stm.conversation_history) == 1
|
||||
assert is_iso_format(stm.conversation_history[0].get("timestamp"))
|
||||
|
||||
def test_add_message_max_history(self):
|
||||
stm = ShortTermMemory(max_history=5)
|
||||
for i in range(10):
|
||||
stm.add_message("user", f"Message {i}")
|
||||
assert len(stm.conversation_history) == 5
|
||||
assert stm.conversation_history[0]["content"] == "Message 5"
|
||||
|
||||
def test_language_management(self):
|
||||
stm = ShortTermMemory()
|
||||
assert stm.language == "en"
|
||||
stm.set_language("fr")
|
||||
assert stm.language == "fr"
|
||||
stm.clear()
|
||||
assert stm.language == "en"
|
||||
|
||||
def test_clear(self):
|
||||
stm = ShortTermMemory()
|
||||
stm.add_message("user", "Hello")
|
||||
stm.set_language("fr")
|
||||
stm.clear()
|
||||
assert stm.conversation_history == []
|
||||
assert stm.language == "en"
|
||||
|
||||
|
||||
class TestEpisodicMemory:
|
||||
"""Tests for EpisodicMemory."""
|
||||
|
||||
def test_add_error(self):
|
||||
episodic = EpisodicMemory()
|
||||
episodic.add_error("find_torrent", "API timeout")
|
||||
assert len(episodic.recent_errors) == 1
|
||||
assert is_iso_format(episodic.recent_errors[0].get("timestamp"))
|
||||
|
||||
def test_add_error_max_limit(self):
|
||||
episodic = EpisodicMemory(max_errors=3)
|
||||
for i in range(5):
|
||||
episodic.add_error("action", f"Error {i}")
|
||||
assert len(episodic.recent_errors) == 3
|
||||
error_messages = [e["error"] for e in episodic.recent_errors]
|
||||
assert error_messages == ["Error 2", "Error 3", "Error 4"]
|
||||
|
||||
def test_store_search_results(self):
|
||||
episodic = EpisodicMemory()
|
||||
episodic.store_search_results("test query", [])
|
||||
assert is_iso_format(episodic.last_search_results.get("timestamp"))
|
||||
|
||||
def test_get_result_by_index(self):
|
||||
episodic = EpisodicMemory()
|
||||
results = [{"name": "Result 1"}, {"name": "Result 2"}]
|
||||
episodic.store_search_results("query", results)
|
||||
result = episodic.get_result_by_index(2)
|
||||
assert result is not None
|
||||
assert result["name"] == "Result 2"
|
||||
|
||||
|
||||
class TestMemory:
|
||||
"""Tests for the Memory manager."""
|
||||
|
||||
def test_init_creates_directories(self, temp_dir):
|
||||
storage = temp_dir / "memory_data"
|
||||
Memory(storage_dir=str(storage))
|
||||
assert storage.exists()
|
||||
|
||||
def test_save_and_load_ltm(self, temp_dir):
|
||||
storage = str(temp_dir)
|
||||
memory = Memory(storage_dir=storage)
|
||||
memory.ltm.set_config("test_key", "test_value")
|
||||
memory.save()
|
||||
new_memory = Memory(storage_dir=storage)
|
||||
assert new_memory.ltm.get_config("test_key") == "test_value"
|
||||
|
||||
def test_clear_session(self, memory):
|
||||
memory.ltm.set_config("key", "value")
|
||||
memory.stm.add_message("user", "Hello")
|
||||
memory.episodic.add_error("action", "error")
|
||||
memory.clear_session()
|
||||
assert memory.ltm.get_config("key") == "value"
|
||||
assert memory.stm.conversation_history == []
|
||||
assert memory.episodic.recent_errors == []
|
||||
|
||||
|
||||
class TestMemoryContext:
|
||||
"""Tests for memory context functions."""
|
||||
|
||||
def test_get_memory_not_initialized(self):
|
||||
_memory_ctx.set(None)
|
||||
with pytest.raises(RuntimeError, match="Memory not initialized"):
|
||||
get_memory()
|
||||
|
||||
def test_init_memory(self, temp_dir):
|
||||
_memory_ctx.set(None)
|
||||
memory = init_memory(str(temp_dir))
|
||||
assert has_memory()
|
||||
assert get_memory() is memory
|
||||
Reference in New Issue
Block a user