102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
# agent/parameters.py
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
|
|
@dataclass
|
|
class ParameterSchema:
|
|
"""Describes a required parameter for the agent."""
|
|
|
|
key: str
|
|
description: str
|
|
why_needed: str # Explanation for the AI
|
|
type: str # "string", "number", "object", etc.
|
|
validator: Callable[[Any], bool] | None = None
|
|
default: Any = None
|
|
required: bool = True
|
|
|
|
|
|
# Define all required parameters
|
|
REQUIRED_PARAMETERS = [
|
|
ParameterSchema(
|
|
key="config",
|
|
description="Configuration object containing all folder paths",
|
|
why_needed=(
|
|
"This contains the paths to all important folders:\n"
|
|
"- download_folder: Where downloaded files arrive before being organized\n"
|
|
"- tvshow_folder: Where TV show files are organized and stored\n"
|
|
"- movie_folder: Where movie files are organized and stored\n"
|
|
"- torrent_folder: Where torrent structures are saved for the torrent client"
|
|
),
|
|
type="object",
|
|
validator=lambda x: isinstance(x, dict),
|
|
required=True,
|
|
default={},
|
|
),
|
|
ParameterSchema(
|
|
key="tv_shows",
|
|
description="List of TV shows the user is following",
|
|
why_needed=(
|
|
"This tracks which TV shows you're following. "
|
|
"Each show includes: IMDB ID, title, number of seasons, and status (ongoing or ended)."
|
|
),
|
|
type="array",
|
|
validator=lambda x: isinstance(x, list),
|
|
required=False,
|
|
default=[],
|
|
),
|
|
]
|
|
|
|
|
|
def get_parameter_schema(key: str) -> ParameterSchema | None:
|
|
"""Get schema for a specific parameter."""
|
|
for param in REQUIRED_PARAMETERS:
|
|
if param.key == key:
|
|
return param
|
|
return None
|
|
|
|
|
|
def get_missing_required_parameters(memory_data: dict) -> list[ParameterSchema]:
|
|
"""Get list of required parameters that are missing or None."""
|
|
missing = []
|
|
for param in REQUIRED_PARAMETERS:
|
|
if param.required:
|
|
value = memory_data.get(param.key)
|
|
if value is None:
|
|
missing.append(param)
|
|
return missing
|
|
|
|
|
|
def format_parameters_for_prompt() -> str:
|
|
"""Format parameter descriptions for the AI system prompt."""
|
|
lines = ["REQUIRED PARAMETERS:"]
|
|
for param in REQUIRED_PARAMETERS:
|
|
status = "REQUIRED" if param.required else "OPTIONAL"
|
|
lines.append(f"\n- {param.key} ({status}):")
|
|
lines.append(f" Description: {param.description}")
|
|
lines.append(f" Why needed: {param.why_needed}")
|
|
lines.append(f" Type: {param.type}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def validate_parameter(key: str, value: Any) -> tuple[bool, str | None]:
|
|
"""
|
|
Validate a parameter value against its schema.
|
|
|
|
Returns:
|
|
(is_valid, error_message)
|
|
"""
|
|
schema = get_parameter_schema(key)
|
|
if not schema:
|
|
return True, None # Unknown parameters are allowed
|
|
|
|
if schema.validator:
|
|
try:
|
|
if not schema.validator(value):
|
|
return False, f"Validation failed for {key}"
|
|
except Exception as e:
|
|
return False, f"Validation error for {key}: {str(e)}"
|
|
|
|
return True, None
|