578 lines
19 KiB
Python
578 lines
19 KiB
Python
"""
|
|
Memory - Unified management of 3 memory types.
|
|
|
|
Architecture:
|
|
- LTM (Long-Term Memory): Configuration, library, preferences - Persistent
|
|
- STM (Short-Term Memory): Conversation, current workflow - Volatile
|
|
- Episodic Memory: Search results, transient states - Very volatile
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# LONG-TERM MEMORY (LTM) - Persistent
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class LongTermMemory:
|
|
"""
|
|
Long-term memory - Persistent and static.
|
|
|
|
Stores:
|
|
- User configuration (folders, URLs)
|
|
- Preferences (quality, languages)
|
|
- Library (owned movies/TV shows)
|
|
- Followed shows (watchlist)
|
|
"""
|
|
|
|
# Folder and service configuration
|
|
config: dict[str, str] = field(default_factory=dict)
|
|
|
|
# User preferences
|
|
preferences: dict[str, Any] = field(
|
|
default_factory=lambda: {
|
|
"preferred_quality": "1080p",
|
|
"preferred_languages": ["en", "fr"],
|
|
"auto_organize": False,
|
|
"naming_format": "{title}.{year}.{quality}",
|
|
}
|
|
)
|
|
|
|
# Library of owned media
|
|
library: dict[str, list[dict]] = field(
|
|
default_factory=lambda: {"movies": [], "tv_shows": []}
|
|
)
|
|
|
|
# Followed shows (watchlist)
|
|
following: list[dict] = field(default_factory=list)
|
|
|
|
def get_config(self, key: str, default: Any = None) -> Any:
|
|
"""Get a configuration value."""
|
|
return self.config.get(key, default)
|
|
|
|
def set_config(self, key: str, value: Any) -> None:
|
|
"""Set a configuration value."""
|
|
self.config[key] = value
|
|
logger.debug(f"LTM: Set config {key}")
|
|
|
|
def has_config(self, key: str) -> bool:
|
|
"""Check if a configuration exists."""
|
|
return key in self.config and self.config[key] is not None
|
|
|
|
def add_to_library(self, media_type: str, media: dict) -> None:
|
|
"""Add a media item to the library."""
|
|
if media_type not in self.library:
|
|
self.library[media_type] = []
|
|
|
|
# Avoid duplicates by imdb_id
|
|
existing_ids = [m.get("imdb_id") for m in self.library[media_type]]
|
|
if media.get("imdb_id") not in existing_ids:
|
|
media["added_at"] = datetime.now().isoformat()
|
|
self.library[media_type].append(media)
|
|
logger.info(f"LTM: Added {media.get('title')} to {media_type}")
|
|
|
|
def get_library(self, media_type: str) -> list[dict]:
|
|
"""Get the library for a media type."""
|
|
return self.library.get(media_type, [])
|
|
|
|
def follow_show(self, show: dict) -> None:
|
|
"""Add a show to the watchlist."""
|
|
existing_ids = [s.get("imdb_id") for s in self.following]
|
|
if show.get("imdb_id") not in existing_ids:
|
|
show["followed_at"] = datetime.now().isoformat()
|
|
self.following.append(show)
|
|
logger.info(f"LTM: Now following {show.get('title')}")
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary for serialization."""
|
|
return {
|
|
"config": self.config,
|
|
"preferences": self.preferences,
|
|
"library": self.library,
|
|
"following": self.following,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> "LongTermMemory":
|
|
"""Create an instance from a dictionary."""
|
|
return cls(
|
|
config=data.get("config", {}),
|
|
preferences=data.get(
|
|
"preferences",
|
|
{
|
|
"preferred_quality": "1080p",
|
|
"preferred_languages": ["en", "fr"],
|
|
"auto_organize": False,
|
|
"naming_format": "{title}.{year}.{quality}",
|
|
},
|
|
),
|
|
library=data.get("library", {"movies": [], "tv_shows": []}),
|
|
following=data.get("following", []),
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# SHORT-TERM MEMORY (STM) - Conversation
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class ShortTermMemory:
|
|
"""
|
|
Short-term memory - Volatile and conversational.
|
|
|
|
Stores:
|
|
- Current conversation history
|
|
- Current workflow (what we're doing)
|
|
- Extracted entities from conversation
|
|
- Current discussion topic
|
|
"""
|
|
|
|
# Conversation message history
|
|
conversation_history: list[dict[str, str]] = field(default_factory=list)
|
|
|
|
# Current workflow
|
|
current_workflow: dict | None = None
|
|
|
|
# Extracted entities (title, year, requested quality, etc.)
|
|
extracted_entities: dict[str, Any] = field(default_factory=dict)
|
|
|
|
# Current conversation topic
|
|
current_topic: str | None = None
|
|
|
|
# Conversation language
|
|
language: str = "en"
|
|
|
|
# History message limit
|
|
max_history: int = 20
|
|
|
|
def add_message(self, role: str, content: str) -> None:
|
|
"""Add a message to history."""
|
|
self.conversation_history.append(
|
|
{"role": role, "content": content, "timestamp": datetime.now().isoformat()}
|
|
)
|
|
# Keep only the last N messages
|
|
if len(self.conversation_history) > self.max_history:
|
|
self.conversation_history = self.conversation_history[-self.max_history :]
|
|
logger.debug(f"STM: Added {role} message")
|
|
|
|
def get_recent_history(self, n: int = 10) -> list[dict]:
|
|
"""Get the last N messages."""
|
|
return self.conversation_history[-n:]
|
|
|
|
def start_workflow(self, workflow_type: str, target: dict) -> None:
|
|
"""Start a new workflow."""
|
|
self.current_workflow = {
|
|
"type": workflow_type,
|
|
"target": target,
|
|
"stage": "started",
|
|
"started_at": datetime.now().isoformat(),
|
|
}
|
|
logger.info(f"STM: Started workflow '{workflow_type}'")
|
|
|
|
def update_workflow_stage(self, stage: str) -> None:
|
|
"""Update the workflow stage."""
|
|
if self.current_workflow:
|
|
self.current_workflow["stage"] = stage
|
|
logger.debug(f"STM: Workflow stage -> {stage}")
|
|
|
|
def end_workflow(self) -> None:
|
|
"""End the current workflow."""
|
|
if self.current_workflow:
|
|
logger.info(f"STM: Ended workflow '{self.current_workflow.get('type')}'")
|
|
self.current_workflow = None
|
|
|
|
def set_entity(self, key: str, value: Any) -> None:
|
|
"""Store an extracted entity."""
|
|
self.extracted_entities[key] = value
|
|
logger.debug(f"STM: Set entity {key}={value}")
|
|
|
|
def get_entity(self, key: str, default: Any = None) -> Any:
|
|
"""Get an extracted entity."""
|
|
return self.extracted_entities.get(key, default)
|
|
|
|
def clear_entities(self) -> None:
|
|
"""Clear extracted entities."""
|
|
self.extracted_entities = {}
|
|
|
|
def set_topic(self, topic: str) -> None:
|
|
"""Set the current topic."""
|
|
self.current_topic = topic
|
|
logger.debug(f"STM: Topic -> {topic}")
|
|
|
|
def set_language(self, language: str) -> None:
|
|
"""Set the conversation language."""
|
|
self.language = language
|
|
logger.debug(f"STM: Language -> {language}")
|
|
|
|
def clear(self) -> None:
|
|
"""Reset short-term memory."""
|
|
self.conversation_history = []
|
|
self.current_workflow = None
|
|
self.extracted_entities = {}
|
|
self.current_topic = None
|
|
self.language = "en"
|
|
logger.info("STM: Cleared")
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"conversation_history": self.conversation_history,
|
|
"current_workflow": self.current_workflow,
|
|
"extracted_entities": self.extracted_entities,
|
|
"current_topic": self.current_topic,
|
|
"language": self.language,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# EPISODIC MEMORY - Transient states
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class EpisodicMemory:
|
|
"""
|
|
Episodic/sensory memory - Temporary and event-driven.
|
|
|
|
Stores:
|
|
- Last search results
|
|
- Active downloads
|
|
- Recent errors
|
|
- Pending questions awaiting user response
|
|
- Background events
|
|
"""
|
|
|
|
# Last search results
|
|
last_search_results: dict | None = None
|
|
|
|
# Active downloads
|
|
active_downloads: list[dict] = field(default_factory=list)
|
|
|
|
# Recent errors
|
|
recent_errors: list[dict] = field(default_factory=list)
|
|
|
|
# Pending question awaiting user response
|
|
pending_question: dict | None = None
|
|
|
|
# Background events (download complete, new files, etc.)
|
|
background_events: list[dict] = field(default_factory=list)
|
|
|
|
# Limits for errors/events kept
|
|
max_errors: int = 5
|
|
max_events: int = 10
|
|
|
|
def store_search_results(
|
|
self, query: str, results: list[dict], search_type: str = "torrent"
|
|
) -> None:
|
|
"""
|
|
Store search results with index.
|
|
|
|
Args:
|
|
query: The search query
|
|
results: List of results
|
|
search_type: Type of search (torrent, movie, tvshow)
|
|
"""
|
|
self.last_search_results = {
|
|
"query": query,
|
|
"type": search_type,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"results": [{"index": i + 1, **r} for i, r in enumerate(results)],
|
|
}
|
|
logger.info(f"Episodic: Stored {len(results)} search results for '{query}'")
|
|
|
|
def get_result_by_index(self, index: int) -> dict | None:
|
|
"""
|
|
Get a result by its number (1-indexed).
|
|
|
|
Args:
|
|
index: Result number (1, 2, 3, ...)
|
|
|
|
Returns:
|
|
The result or None if not found
|
|
"""
|
|
if not self.last_search_results:
|
|
logger.warning("Episodic: No search results stored")
|
|
return None
|
|
|
|
for result in self.last_search_results.get("results", []):
|
|
if result.get("index") == index:
|
|
return result
|
|
|
|
logger.warning(f"Episodic: Result #{index} not found")
|
|
return None
|
|
|
|
def get_search_results(self) -> dict | None:
|
|
"""Get the last search results."""
|
|
return self.last_search_results
|
|
|
|
def clear_search_results(self) -> None:
|
|
"""Clear search results."""
|
|
self.last_search_results = None
|
|
|
|
def add_active_download(self, download: dict) -> None:
|
|
"""Add an active download."""
|
|
download["started_at"] = datetime.now().isoformat()
|
|
self.active_downloads.append(download)
|
|
logger.info(f"Episodic: Added download '{download.get('name')}'")
|
|
|
|
def update_download_progress(
|
|
self, task_id: str, progress: int, status: str = "downloading"
|
|
) -> None:
|
|
"""Update download progress."""
|
|
for dl in self.active_downloads:
|
|
if dl.get("task_id") == task_id:
|
|
dl["progress"] = progress
|
|
dl["status"] = status
|
|
dl["updated_at"] = datetime.now().isoformat()
|
|
break
|
|
|
|
def complete_download(self, task_id: str, file_path: str) -> dict | None:
|
|
"""Mark a download as complete and remove it."""
|
|
for i, dl in enumerate(self.active_downloads):
|
|
if dl.get("task_id") == task_id:
|
|
completed = self.active_downloads.pop(i)
|
|
completed["status"] = "completed"
|
|
completed["file_path"] = file_path
|
|
completed["completed_at"] = datetime.now().isoformat()
|
|
|
|
# Add a background event
|
|
self.add_background_event(
|
|
"download_complete",
|
|
{"name": completed.get("name"), "file_path": file_path},
|
|
)
|
|
|
|
logger.info(f"Episodic: Download completed '{completed.get('name')}'")
|
|
return completed
|
|
return None
|
|
|
|
def get_active_downloads(self) -> list[dict]:
|
|
"""Get active downloads."""
|
|
return self.active_downloads
|
|
|
|
def add_error(self, action: str, error: str, context: dict | None = None) -> None:
|
|
"""Record a recent error."""
|
|
self.recent_errors.append(
|
|
{
|
|
"timestamp": datetime.now().isoformat(),
|
|
"action": action,
|
|
"error": error,
|
|
"context": context or {},
|
|
}
|
|
)
|
|
# Keep only the last N errors
|
|
self.recent_errors = self.recent_errors[-self.max_errors :]
|
|
logger.warning(f"Episodic: Error in '{action}': {error}")
|
|
|
|
def get_recent_errors(self) -> list[dict]:
|
|
"""Get recent errors."""
|
|
return self.recent_errors
|
|
|
|
def set_pending_question(
|
|
self,
|
|
question: str,
|
|
options: list[dict],
|
|
context: dict,
|
|
question_type: str = "choice",
|
|
) -> None:
|
|
"""
|
|
Record a question awaiting user response.
|
|
|
|
Args:
|
|
question: The question asked
|
|
options: List of possible options
|
|
context: Question context
|
|
question_type: Type of question (choice, confirmation, input)
|
|
"""
|
|
self.pending_question = {
|
|
"type": question_type,
|
|
"question": question,
|
|
"options": options,
|
|
"context": context,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
logger.info(f"Episodic: Pending question set ({question_type})")
|
|
|
|
def get_pending_question(self) -> dict | None:
|
|
"""Get the pending question."""
|
|
return self.pending_question
|
|
|
|
def resolve_pending_question(self, answer_index: int | None = None) -> dict | None:
|
|
"""
|
|
Resolve the pending question and return the chosen option.
|
|
|
|
Args:
|
|
answer_index: Answer index (1-indexed) or None to cancel
|
|
|
|
Returns:
|
|
The chosen option or None
|
|
"""
|
|
if not self.pending_question:
|
|
return None
|
|
|
|
result = None
|
|
if answer_index is not None and self.pending_question.get("options"):
|
|
for opt in self.pending_question["options"]:
|
|
if opt.get("index") == answer_index:
|
|
result = opt
|
|
break
|
|
|
|
self.pending_question = None
|
|
logger.info("Episodic: Pending question resolved")
|
|
return result
|
|
|
|
def add_background_event(self, event_type: str, data: dict) -> None:
|
|
"""Add a background event."""
|
|
self.background_events.append(
|
|
{
|
|
"type": event_type,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"data": data,
|
|
"read": False,
|
|
}
|
|
)
|
|
# Keep only the last N events
|
|
self.background_events = self.background_events[-self.max_events :]
|
|
logger.info(f"Episodic: Background event '{event_type}'")
|
|
|
|
def get_unread_events(self) -> list[dict]:
|
|
"""Get unread events and mark them as read."""
|
|
unread = [e for e in self.background_events if not e.get("read")]
|
|
for e in self.background_events:
|
|
e["read"] = True
|
|
return unread
|
|
|
|
def clear(self) -> None:
|
|
"""Reset episodic memory."""
|
|
self.last_search_results = None
|
|
self.active_downloads = []
|
|
self.recent_errors = []
|
|
self.pending_question = None
|
|
self.background_events = []
|
|
logger.info("Episodic: Cleared")
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"last_search_results": self.last_search_results,
|
|
"active_downloads": self.active_downloads,
|
|
"recent_errors": self.recent_errors,
|
|
"pending_question": self.pending_question,
|
|
"background_events": self.background_events,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# MEMORY MANAGER - Unified manager
|
|
# =============================================================================
|
|
|
|
|
|
class Memory:
|
|
"""
|
|
Unified manager for the 3 memory types.
|
|
|
|
Usage:
|
|
memory = Memory("memory_data")
|
|
memory.ltm.set_config("download_folder", "/path")
|
|
memory.stm.add_message("user", "Hello")
|
|
memory.episodic.store_search_results("query", results)
|
|
memory.save()
|
|
"""
|
|
|
|
def __init__(self, storage_dir: str = "memory_data"):
|
|
"""
|
|
Initialize the memory.
|
|
|
|
Args:
|
|
storage_dir: Directory for persistent storage
|
|
"""
|
|
self.storage_dir = Path(storage_dir)
|
|
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.ltm_file = self.storage_dir / "ltm.json"
|
|
|
|
# Initialize the 3 memory types
|
|
self.ltm = self._load_ltm()
|
|
self.stm = ShortTermMemory()
|
|
self.episodic = EpisodicMemory()
|
|
|
|
logger.info(f"Memory initialized (storage: {storage_dir})")
|
|
|
|
def _load_ltm(self) -> LongTermMemory:
|
|
"""Load LTM from file."""
|
|
if self.ltm_file.exists():
|
|
try:
|
|
data = json.loads(self.ltm_file.read_text(encoding="utf-8"))
|
|
logger.info("LTM loaded from file")
|
|
return LongTermMemory.from_dict(data)
|
|
except (OSError, json.JSONDecodeError) as e:
|
|
logger.warning(f"Could not load LTM: {e}")
|
|
return LongTermMemory()
|
|
|
|
def save(self) -> None:
|
|
"""Save LTM (the only persistent memory)."""
|
|
try:
|
|
self.ltm_file.write_text(
|
|
json.dumps(self.ltm.to_dict(), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
logger.debug("LTM saved to file")
|
|
except OSError as e:
|
|
logger.error(f"Failed to save LTM: {e}")
|
|
raise
|
|
|
|
def get_context_for_prompt(self) -> dict:
|
|
"""
|
|
Generate context to include in the system prompt.
|
|
|
|
Returns:
|
|
Dictionary with relevant context from all 3 memories
|
|
"""
|
|
return {
|
|
"config": self.ltm.config,
|
|
"preferences": self.ltm.preferences,
|
|
"current_workflow": self.stm.current_workflow,
|
|
"current_topic": self.stm.current_topic,
|
|
"extracted_entities": self.stm.extracted_entities,
|
|
"last_search": {
|
|
"query": (
|
|
self.episodic.last_search_results.get("query")
|
|
if self.episodic.last_search_results
|
|
else None
|
|
),
|
|
"result_count": (
|
|
len(self.episodic.last_search_results.get("results", []))
|
|
if self.episodic.last_search_results
|
|
else 0
|
|
),
|
|
},
|
|
"active_downloads_count": len(self.episodic.active_downloads),
|
|
"pending_question": self.episodic.pending_question is not None,
|
|
"unread_events": len(
|
|
[e for e in self.episodic.background_events if not e.get("read")]
|
|
),
|
|
}
|
|
|
|
def get_full_state(self) -> dict:
|
|
"""Return the full state of all 3 memories (for debug)."""
|
|
return {
|
|
"ltm": self.ltm.to_dict(),
|
|
"stm": self.stm.to_dict(),
|
|
"episodic": self.episodic.to_dict(),
|
|
}
|
|
|
|
def clear_session(self) -> None:
|
|
"""Clear session memories (STM + Episodic)."""
|
|
self.stm.clear()
|
|
self.episodic.clear()
|
|
logger.info("Session memories cleared")
|