"""Memory service handles memory retrieval, scoring, and prompt injection.""" from datetime import datetime, timezone from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models import ChildProfile, MemoryItem, StoryUniverse logger = get_logger(__name__) class MemoryType: """记忆类型常量及配置。""" # 基础类型 RECENT_STORY = "recent_story" FAVORITE_CHARACTER = "favorite_character" SCARY_ELEMENT = "scary_element" VOCABULARY_GROWTH = "vocabulary_growth" EMOTIONAL_HIGHLIGHT = "emotional_highlight" # Phase 1 新增类型 READING_PREFERENCE = "reading_preference" # 阅读偏好 MILESTONE = "milestone" # 里程碑事件 SKILL_MASTERED = "skill_mastered" # 掌握的技能 # 类型配置: (默认权重, 默认TTL天数, 描述) CONFIG = { RECENT_STORY: (1.0, 30, "最近阅读的故事"), FAVORITE_CHARACTER: (1.5, None, "喜欢的角色"), # None = 永久 SCARY_ELEMENT: (2.0, None, "回避的元素"), # 高权重,永久有效 VOCABULARY_GROWTH: (0.8, 90, "词汇积累"), EMOTIONAL_HIGHLIGHT: (1.2, 60, "情感高光"), READING_PREFERENCE: (1.0, None, "阅读偏好"), MILESTONE: (1.5, None, "里程碑事件"), SKILL_MASTERED: (1.0, 180, "掌握的技能"), } @classmethod def get_default_weight(cls, memory_type: str) -> float: """获取类型的默认权重。""" config = cls.CONFIG.get(memory_type) return config[0] if config else 1.0 @classmethod def get_default_ttl(cls, memory_type: str) -> int | None: """获取类型的默认 TTL 天数。""" config = cls.CONFIG.get(memory_type) return config[1] if config else None def _decay_factor(days: float) -> float: """计算时间衰减因子。""" if days <= 7: return 1.0 if days <= 30: return 0.7 if days <= 90: return 0.4 return 0.2 async def build_enhanced_memory_context( profile_id: str | None, universe_id: str | None, db: AsyncSession, ) -> str | None: """构建增强版记忆上下文(自然语言 Prompt)。""" if not profile_id and not universe_id: return None context_parts: list[str] = [] # 1. 基础档案 (Identity Layer) if profile_id: profile = await db.scalar(select(ChildProfile).where(ChildProfile.id == profile_id)) if profile: context_parts.append(f"【目标读者】\n姓名:{profile.name}") if profile.age: context_parts.append(f"年龄:{profile.age}岁") if profile.interests: context_parts.append(f"兴趣爱好:{'、'.join(profile.interests)}") if profile.growth_themes: context_parts.append(f"当前成长关注点:{'、'.join(profile.growth_themes)}") context_parts.append("") # 空行 # 2. 故事宇宙 (Universe Layer) if universe_id: universe = await db.scalar(select(StoryUniverse).where(StoryUniverse.id == universe_id)) if universe: context_parts.append("【故事宇宙设定】") context_parts.append(f"世界观:{universe.name}") # 主角 protagonist = universe.protagonist or {} p_desc = f"{protagonist.get('name', '主角')} ({protagonist.get('personality', '')})" context_parts.append(f"主角设定:{p_desc}") # 常驻角色 if universe.recurring_characters: chars = [f"{c.get('name')} ({c.get('type')})" for c in universe.recurring_characters if isinstance(c, dict)] context_parts.append(f"已知伙伴:{'、'.join(chars)}") # 成就 if universe.achievements: badges = [str(a.get('type')) for a in universe.achievements if isinstance(a, dict)] if badges: context_parts.append(f"已获荣誉:{'、'.join(badges[:5])}") context_parts.append("") # 3. 动态记忆 (Working Memory) if profile_id: memories = await _fetch_scored_memories(profile_id, universe_id, db) if memories: memory_text = _format_memories_to_prompt(memories) if memory_text: context_parts.append("【关键记忆回忆】(请在故事中自然地融入或致敬以下元素)") context_parts.append(memory_text) return "\n".join(context_parts) async def _fetch_scored_memories( profile_id: str, universe_id: str | None, db: AsyncSession, limit: int = 8 ) -> list[MemoryItem]: """获取并评分记忆项,返回 Top N。""" query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id) if universe_id: query = query.where( (MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None)) ) # 取最近 50 条进行评分 query = query.order_by(MemoryItem.last_used_at.desc(), MemoryItem.created_at.desc()).limit(50) result = await db.execute(query) items = result.scalars().all() scored: list[tuple[float, MemoryItem]] = [] now = datetime.now(timezone.utc) for item in items: reference = item.last_used_at or item.created_at or now delta_days = max((now - reference).total_seconds() / 86400, 0) if item.ttl_days and delta_days > item.ttl_days: continue score = (item.base_weight or 1.0) * _decay_factor(delta_days) if score <= 0.1: # 忽略低权重 continue scored.append((score, item)) scored.sort(key=lambda x: x[0], reverse=True) return [item for _, item in scored[:limit]] def _format_memories_to_prompt(memories: list[MemoryItem]) -> str: """将记忆项转换为自然语言指令。""" lines = [] # 分类处理 recent_stories = [] favorites = [] scary = [] vocab = [] for m in memories: if m.type == MemoryType.RECENT_STORY: recent_stories.append(m) elif m.type == MemoryType.FAVORITE_CHARACTER: favorites.append(m) elif m.type == MemoryType.SCARY_ELEMENT: scary.append(m) elif m.type == MemoryType.VOCABULARY_GROWTH: vocab.append(m) # 1. 喜欢的角色 if favorites: names = [] for m in favorites: val = m.value if isinstance(val, dict): names.append(f"{val.get('name')} ({val.get('description', '')})") if names: lines.append(f"- 孩子特别喜欢这些角色,可以让他们客串出场:{', '.join(names)}") # 2. 避雷区 if scary: items = [] for m in scary: val = m.value if isinstance(val, dict): items.append(val.get('keyword', '')) elif isinstance(val, str): items.append(val) if items: lines.append(f"- 【注意禁止】不要出现以下让孩子害怕的元素:{', '.join(items)}") # 3. 近期故事 (取最近 2 个) if recent_stories: lines.append("- 近期经历(可作为彩蛋提及):") for m in recent_stories[:2]: val = m.value if isinstance(val, dict): title = val.get('title', '未知故事') lines.append(f" * 之前读过《{title}》") # 4. 词汇积累 if vocab: words = [] for m in vocab: val = m.value if isinstance(val, dict): words.append(val.get('word')) if words: lines.append(f"- 已掌握词汇(可适当复现以巩固):{', '.join([w for w in words if w])}") return "\n".join(lines) async def prune_expired_memories(db: AsyncSession) -> int: """清理过期的记忆项。 Returns: 删除的记录数量 """ from sqlalchemy import delete now = datetime.now(timezone.utc) # 查找所有设置了 TTL 的项目 stmt = select(MemoryItem).where(MemoryItem.ttl_days.is_not(None)) result = await db.execute(stmt) candidates = result.scalars().all() to_delete_ids = [] for item in candidates: if not item.ttl_days: continue reference = item.last_used_at or item.created_at or now delta_days = (now - reference).total_seconds() / 86400 if delta_days > item.ttl_days: to_delete_ids.append(item.id) if not to_delete_ids: return 0 delete_stmt = delete(MemoryItem).where(MemoryItem.id.in_(to_delete_ids)) await db.execute(delete_stmt) await db.commit() logger.info("memory_pruned", count=len(to_delete_ids)) return len(to_delete_ids) async def create_memory( db: AsyncSession, profile_id: str, memory_type: str, value: dict, universe_id: str | None = None, weight: float | None = None, ttl_days: int | None = None, ) -> MemoryItem: """创建新的记忆项。 Args: db: 数据库会话 profile_id: 孩子档案 ID memory_type: 记忆类型 (使用 MemoryType 常量) value: 记忆内容 (JSON 格式) universe_id: 可选,关联的故事宇宙 ID weight: 可选,权重 (默认使用类型配置) ttl_days: 可选,过期天数 (默认使用类型配置) Returns: 创建的 MemoryItem """ memory = MemoryItem( child_profile_id=profile_id, universe_id=universe_id, type=memory_type, value=value, base_weight=weight or MemoryType.get_default_weight(memory_type), ttl_days=ttl_days if ttl_days is not None else MemoryType.get_default_ttl(memory_type), ) db.add(memory) await db.commit() await db.refresh(memory) logger.info( "memory_created", memory_id=memory.id, profile_id=profile_id, type=memory_type, ) return memory async def update_memory_usage(db: AsyncSession, memory_id: str) -> None: """更新记忆的最后使用时间。 Args: db: 数据库会话 memory_id: 记忆项 ID """ result = await db.execute(select(MemoryItem).where(MemoryItem.id == memory_id)) memory = result.scalar_one_or_none() if memory: memory.last_used_at = datetime.now(timezone.utc) await db.commit() logger.debug("memory_usage_updated", memory_id=memory_id) async def get_profile_memories( db: AsyncSession, profile_id: str, memory_type: str | None = None, universe_id: str | None = None, limit: int = 50, ) -> list[MemoryItem]: """获取档案的记忆列表。 Args: db: 数据库会话 profile_id: 孩子档案 ID memory_type: 可选,按类型筛选 universe_id: 可选,按宇宙筛选 limit: 返回数量限制 Returns: MemoryItem 列表 """ query = select(MemoryItem).where(MemoryItem.child_profile_id == profile_id) if memory_type: query = query.where(MemoryItem.type == memory_type) if universe_id: query = query.where( (MemoryItem.universe_id == universe_id) | (MemoryItem.universe_id.is_(None)) ) query = query.order_by(MemoryItem.created_at.desc()).limit(limit) result = await db.execute(query) return list(result.scalars().all()) async def create_story_memory( db: AsyncSession, profile_id: str, story_id: int, title: str, summary: str | None = None, keywords: list[str] | None = None, universe_id: str | None = None, ) -> MemoryItem: """为故事创建记忆项。 这是一个便捷函数,专门用于在故事阅读后创建 recent_story 类型的记忆。 Args: db: 数据库会话 profile_id: 孩子档案 ID story_id: 故事 ID title: 故事标题 summary: 故事梗概 keywords: 关键词列表 universe_id: 可选,关联的故事宇宙 ID Returns: 创建的 MemoryItem """ value = { "story_id": story_id, "title": title, "summary": summary or "", "keywords": keywords or [], } return await create_memory( db=db, profile_id=profile_id, memory_type=MemoryType.RECENT_STORY, value=value, universe_id=universe_id, ) async def create_character_memory( db: AsyncSession, profile_id: str, name: str, description: str | None = None, source_story_id: int | None = None, affinity_score: float = 1.0, universe_id: str | None = None, ) -> MemoryItem: """为喜欢的角色创建记忆项。 Args: db: 数据库会话 profile_id: 孩子档案 ID name: 角色名称 description: 角色描述 source_story_id: 来源故事 ID affinity_score: 喜爱程度 (0.0-1.0) universe_id: 可选,关联的故事宇宙 ID Returns: 创建的 MemoryItem """ value = { "name": name, "description": description or "", "source_story_id": source_story_id, "affinity_score": min(1.0, max(0.0, affinity_score)), } return await create_memory( db=db, profile_id=profile_id, memory_type=MemoryType.FAVORITE_CHARACTER, value=value, universe_id=universe_id, ) async def create_scary_element_memory( db: AsyncSession, profile_id: str, keyword: str, category: str = "other", source_story_id: int | None = None, ) -> MemoryItem: """为回避元素创建记忆项。 Args: db: 数据库会话 profile_id: 孩子档案 ID keyword: 回避的关键词 category: 分类 (creature/scene/action/other) source_story_id: 来源故事 ID Returns: 创建的 MemoryItem """ value = { "keyword": keyword, "category": category, "source_story_id": source_story_id, } return await create_memory( db=db, profile_id=profile_id, memory_type=MemoryType.SCARY_ELEMENT, value=value, )