AI Agent开发实战(四):记忆系统让Agent记住一切

7次阅读
没有评论

AI Agent开发实战(四):记忆系统让Agent记住一切

一、开场:健忘的Agent不靠谱

大家好,我是老金。

之前我们实现了会话管理,但有个问题:

  • 会话结束后记忆就消失了
  • 无法记住用户的偏好
  • 无法从历史中学习

真正的智能Agent需要长期记忆

  • 记住用户是谁
  • 记住用户的偏好
  • 记住过去的交互
  • 能够检索相关记忆

今天我们实现Agent的记忆系统。

二、记忆系统架构

2.1 记忆类型

┌─────────────────────────────────────────────────────────┐
│                  Agent记忆系统架构                       │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  ┌─────────────────────────────────────────────────┐   │
│  │              工作记忆(Working Memory)          │   │
│  │  • 当前对话上下文                                │   │
│  │  • 临时状态                                      │   │
│  │  • 容量有限(通常4-20条消息)                    │   │
│  └─────────────────────────────────────────────────┘   │
│                         ↓                               │
│  ┌─────────────────────────────────────────────────┐   │
│  │              短期记忆(Short-term Memory)       │   │
│  │  • 最近N轮对话                                   │   │
│  │  • 会话级别                                      │   │
│  │  • 自动过期                                      │   │
│  └─────────────────────────────────────────────────┘   │
│                         ↓                               │
│  ┌─────────────────────────────────────────────────┐   │
│  │              长期记忆(Long-term Memory)        │   │
│  │  • 用户画像                                      │   │
│  │  • 重要事件                                      │   │
│  │  • 知识库                                        │   │
│  │  • 向量存储                                      │   │
│  └─────────────────────────────────────────────────┘   │
│                                                         │
└─────────────────────────────────────────────────────────┘

2.2 记忆抽象

# src/memory/base.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
from datetime import datetime

class MemoryItem(BaseModel):
    """记忆条目"""
    id: str
    content: str
    metadata: Dict[str, Any] = {}
    created_at: datetime = datetime.now()
    importance: float = 0.5  # 重要性 0-1
    access_count: int = 0    # 访问次数

class BaseMemory(ABC):
    """记忆基类"""

    @abstractmethod
    async def add(self, item: MemoryItem) -> str:
        """添加记忆"""
        pass

    @abstractmethod
    async def get(self, id: str) -> Optional[MemoryItem]:
        """获取记忆"""
        pass

    @abstractmethod
    async def search(self, query: str, top_k: int = 5) -> List[MemoryItem]:
        """搜索记忆"""
        pass

    @abstractmethod
    async def delete(self, id: str) -> bool:
        """删除记忆"""
        pass

    @abstractmethod
    async def clear(self):
        """清空记忆"""
        pass

三、工作记忆实现

3.1 工作记忆

# src/memory/working_memory.py
from .base import BaseMemory, MemoryItem
from typing import List, Optional
from collections import deque
import uuid

class WorkingMemory(BaseMemory):
    """工作记忆(固定大小的滑动窗口)"""

    def __init__(self, max_size: int = 10):
        self.max_size = max_size
        self._items: deque = deque(maxlen=max_size)

    async def add(self, item: MemoryItem) -> str:
        """添加记忆"""
        if not item.id:
            item.id = str(uuid.uuid4())
        self._items.append(item)
        return item.id

    async def get(self, id: str) -> Optional[MemoryItem]:
        """获取记忆"""
        for item in self._items:
            if item.id == id:
                item.access_count += 1
                return item
        return None

    async def search(self, query: str, top_k: int = 5) -> List[MemoryItem]:
        """搜索(简单包含匹配)"""
        results = []
        for item in self._items:
            if query.lower() in item.content.lower():
                results.append(item)
                if len(results) >= top_k:
                    break
        return results

    async def delete(self, id: str) -> bool:
        """删除记忆"""
        for i, item in enumerate(self._items):
            if item.id == id:
                del self._items[i]
                return True
        return False

    async def clear(self):
        """清空"""
        self._items.clear()

    def get_all(self) -> List[MemoryItem]:
        """获取所有"""
        return list(self._items)

    def get_context(self) -> str:
        """获取上下文字符串"""
        return "n".join([f"{item.metadata.get('role', 'unknown')}: {item.content}" 
                         for item in self._items])

四、向量记忆实现

4.1 向量存储

# src/memory/vector_memory.py
from .base import BaseMemory, MemoryItem
from typing import List, Optional
import chromadb
from chromadb.config import Settings
import uuid

class VectorMemory(BaseMemory):
    """向量记忆(使用ChromaDB)"""

    def __init__(
        self,
        collection_name: str = "agent_memory",
        persist_directory: str = "./data/memory",
        embedding_function = None
    ):
        # 初始化ChromaDB
        self.client = chromadb.PersistentClient(path=persist_directory)

        # 创建集合
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            embedding_function=embedding_function
        )

    async def add(self, item: MemoryItem) -> str:
        """添加记忆"""
        if not item.id:
            item.id = str(uuid.uuid4())

        self.collection.add(
            ids=[item.id],
            documents=[item.content],
            metadatas=[{
                **item.metadata,
                "importance": item.importance,
                "access_count": item.access_count,
                "created_at": item.created_at.isoformat()
            }]
        )

        return item.id

    async def get(self, id: str) -> Optional[MemoryItem]:
        """获取记忆"""
        result = self.collection.get(ids=[id])

        if result['ids']:
            return MemoryItem(
                id=result['ids'][0],
                content=result['documents'][0],
                metadata=result['metadatas'][0]
            )
        return None

    async def search(self, query: str, top_k: int = 5) -> List[MemoryItem]:
        """语义搜索"""
        results = self.collection.query(
            query_texts=[query],
            n_results=top_k
        )

        items = []
        for i, id in enumerate(results['ids'][0]):
            items.append(MemoryItem(
                id=id,
                content=results['documents'][0][i],
                metadata=results['metadatas'][0][i]
            ))

        return items

    async def delete(self, id: str) -> bool:
        """删除记忆"""
        try:
            self.collection.delete(ids=[id])
            return True
        except:
            return False

    async def clear(self):
        """清空"""
        # 获取所有ID
        all_ids = self.collection.get()['ids']
        if all_ids:
            self.collection.delete(ids=all_ids)

    async def update_access_count(self, id: str):
        """更新访问次数"""
        item = await self.get(id)
        if item:
            item.access_count += 1
            await self.delete(id)
            await self.add(item)

4.2 使用FAISS的替代方案

# src/memory/faiss_memory.py
from .base import BaseMemory, MemoryItem
from typing import List, Optional
import faiss
import numpy as np
from openai import OpenAI
import uuid
import pickle
import os

class FAISSMemory(BaseMemory):
    """使用FAISS的向量记忆"""

    def __init__(
        self,
        dimension: int = 1536,  # OpenAI embedding维度
        index_path: str = "./data/memory/faiss.index"
    ):
        self.dimension = dimension
        self.index_path = index_path

        # 初始化FAISS索引
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            with open(index_path + ".meta", "rb") as f:
                self.metadata = pickle.load(f)
        else:
            self.index = faiss.IndexFlatIP(dimension)  # 内积相似度
            self.metadata = {}

        # OpenAI客户端
        self.openai_client = OpenAI()

    def _get_embedding(self, text: str) -> np.ndarray:
        """获取文本向量"""
        response = self.openai_client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return np.array(response.data[0].embedding, dtype=np.float32)

    async def add(self, item: MemoryItem) -> str:
        """添加记忆"""
        if not item.id:
            item.id = str(uuid.uuid4())

        # 获取向量
        embedding = self._get_embedding(item.content)

        # 添加到索引
        self.index.add(np.array([embedding]))

        # 保存元数据
        self.metadata[item.id] = {
            "content": item.content,
            "metadata": item.metadata,
            "importance": item.importance,
            "created_at": item.created_at.isoformat()
        }

        # 持久化
        self._save()

        return item.id

    async def search(self, query: str, top_k: int = 5) -> List[MemoryItem]:
        """搜索"""
        if self.index.ntotal == 0:
            return []

        # 获取查询向量
        query_embedding = self._get_embedding(query)

        # 搜索
        k = min(top_k, self.index.ntotal)
        scores, indices = self.index.search(np.array([query_embedding]), k)

        items = []
        for i, idx in enumerate(indices[0]):
            if idx  str:
        """记忆内容"""
        item = MemoryItem(
            id=str(uuid.uuid4()),
            content=content,
            metadata={"role": role},
            importance=importance
        )

        # 添加到工作记忆
        await self.working_memory.add(item)

        # 重要内容存储到长期记忆
        if should_store_long_term or importance > 0.7:
            await self.long_term_memory.add(item)

        return item.id

    async def recall(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """回忆相关内容"""
        # 先搜索工作记忆
        working_results = await self.working_memory.search(query, top_k)

        # 再搜索长期记忆
        long_term_results = await self.long_term_memory.search(query, top_k)

        # 合并去重
        all_results = []
        seen_ids = set()

        for item in working_results:
            if item.id not in seen_ids:
                all_results.append({
                    "content": item.content,
                    "source": "working_memory",
                    "importance": item.importance
                })
                seen_ids.add(item.id)

        for item in long_term_results:
            if item.id not in seen_ids:
                all_results.append({
                    "content": item.content,
                    "source": "long_term_memory",
                    "importance": item.importance
                })
                seen_ids.add(item.id)

        return all_results[:top_k]

    async def get_context_for_llm(self, query: str = None) -> str:
        """获取用于LLM的上下文"""
        context_parts = []

        # 工作记忆
        working_context = self.working_memory.get_context()
        if working_context:
            context_parts.append(f"最近对话:n{working_context}")

        # 长期记忆
        if query:
            long_term_memories = await self.long_term_memory.search(query, top_k=3)
            if long_term_memories:
                long_term_context = "n".join([f"- {m.content}" for m in long_term_memories])
                context_parts.append(f"相关记忆:n{long_term_context}")

        return "nn".join(context_parts)

    async def forget(self, id: str):
        """遗忘"""
        await self.working_memory.delete(id)
        await self.long_term_memory.delete(id)

    async def clear_working_memory(self):
        """清空工作记忆"""
        await self.working_memory.clear()

    async def consolidate(self):
        """巩固记忆(将重要的工作记忆转移到长期记忆)"""
        items = self.working_memory.get_all()
        for item in items:
            if item.importance > 0.7 or item.access_count > 3:
                await self.long_term_memory.add(item)

六、带记忆的Agent

6.1 记忆Agent实现

# src/agents/memory_agent.py
from .tool_agent import ToolAgent
from ..memory.manager import MemoryManager
from ..utils.llm_client import LLMClient
from typing import List, Optional

class MemoryAgent(ToolAgent):
    """带记忆的Agent"""

    def __init__(
        self,
        llm_client: LLMClient,
        user_id: str = "default",
        **kwargs
    ):
        super().__init__(llm_client, **kwargs)
        self.user_id = user_id
        self.memory = MemoryManager()

    async def chat(self, user_input: str) -> str:
        """对话(带记忆)"""
        # 1. 记住用户输入
        await self.memory.remember(
            content=user_input,
            role="user",
            importance=self._calculate_importance(user_input)
        )

        # 2. 获取相关记忆
        relevant_memories = await self.memory.recall(user_input, top_k=5)

        # 3. 构建上下文
        memory_context = ""
        if relevant_memories:
            memory_context = "相关历史记忆:n" + "n".join([
                f"- {m['content']}" for m in relevant_memories
            ])

        # 4. 增强系统提示
        enhanced_prompt = f"""{self.system_prompt}

{memory_context}
"""

        # 临时更新系统提示
        original_prompt = self.system_prompt
        self.state.messages[0].content = enhanced_prompt

        # 5. 执行对话
        response = await self.run(user_input)

        # 恢复原始系统提示
        self.state.messages[0].content = original_prompt

        # 6. 记住回复
        await self.memory.remember(
            content=response,
            role="assistant",
            importance=0.3  # 回复重要性较低
        )

        return response

    def _calculate_importance(self, message: str) -> float:
        """计算消息重要性"""
        importance = 0.5

        # 包含个人信息的更重要
        personal_keywords = ["我", "我的", "喜欢", "讨厌", "工作", "名字"]
        for keyword in personal_keywords:
            if keyword in message:
                importance += 0.1

        # 问题是中等重要
        if "?" in message or "?" in message:
            importance += 0.05

        # 长消息可能更重要
        if len(message) > 100:
            importance += 0.1

        return min(importance, 1.0)

6.2 使用示例

# examples/memory_agent_demo.py
import asyncio
from src.utils.llm_client import LLMClient
from src.agents.memory_agent import MemoryAgent

async def main():
    llm = LLMClient(provider="openai", model="gpt-4-turbo-preview")
    agent = MemoryAgent(
        llm_client=llm,
        user_id="user_001",
        system_prompt="你是一个友好的AI助手,能记住用户的偏好和重要信息。"
    )

    # 多轮对话
    conversations = [
        "我叫张三,是一名软件工程师",
        "我喜欢Python和AI",
        "我正在学习AI Agent开发",
        "你还记得我是谁吗?",
        "我喜欢什么编程语言?",
        "我最近在学什么?"
    ]

    for user_input in conversations:
        print(f"用户: {user_input}")
        response = await agent.chat(user_input)
        print(f"Agent: {response}n")

if __name__ == "__main__":
    asyncio.run(main())

6.3 运行效果

用户: 我叫张三,是一名软件工程师
Agent: 你好张三!很高兴认识一位软件工程师。有什么我可以帮你的吗?

用户: 我喜欢Python和AI
Agent: Python和AI是非常棒的组合!你主要用Python做哪方面的开发呢?机器学习、深度学习,还是其他方向?

用户: 我正在学习AI Agent开发
Agent: AI Agent开发是目前的热门方向!需要我分享一些学习资源或者解答相关问题吗?

用户: 你还记得我是谁吗?
Agent: 当然记得!你叫张三,是一名软件工程师,你喜欢Python和AI,目前正在学习AI Agent开发。

用户: 我喜欢什么编程语言?
Agent: 你告诉我你喜欢Python。

用户: 我最近在学什么?
Agent: 你最近在学习AI Agent开发。

七、记忆压缩与总结

7.1 记忆压缩

# src/memory/compressor.py
from ..utils.llm_client import LLMClient
from typing import List
from .base import MemoryItem

class MemoryCompressor:
    """记忆压缩器"""

    def __init__(self, llm_client: LLMClient):
        self.llm = llm_client

    async def compress(self, memories: List[MemoryItem]) -> str:
        """压缩多条记忆为摘要"""
        if not memories:
            return ""

        # 合并所有记忆
        combined = "n".join([f"- {m.content}" for m in memories])

        # 使用LLM生成摘要
        prompt = f"""请将以下多条记忆压缩为一条简洁的摘要:

{combined}

要求:
1. 保留关键信息
2. 去除重复内容
3. 保持逻辑连贯
4. 不超过100字

摘要:"""

        summary = await self.llm.chat([{"role": "user", "content": prompt}])
        return summary

    async def summarize_session(self, messages: List[dict]) -> str:
        """总结会话"""
        conversation = "n".join([
            f"{m['role']}: {m['content']}" 
            for m in messages
        ])

        prompt = f"""请总结以下对话的关键信息:

{conversation}

总结要点:"""

        return await self.llm.chat([{"role": "user", "content": prompt}])

八、最佳实践

8.1 记忆策略

策略 适用场景
工作记忆优先 当前任务相关
语义搜索 检索相关知识
时间衰减 旧记忆降权
重要性排序 关键信息优先

8.2 记忆优化

# 记忆管理最佳实践
class MemoryOptimization:
    """记忆优化"""

    @staticmethod
    async def optimize(memory_manager: MemoryManager):
        """优化记忆"""
        # 1. 清理过期的短期记忆
        # 2. 压缩重复记忆
        # 3. 合并相似记忆
        # 4. 删除低重要性记忆
        pass

九、总结

记忆系统要点

  1. 工作记忆:当前上下文,容量有限
  2. 长期记忆:向量存储,语义检索
  3. 记忆管理:统一管理,智能检索
  4. 记忆压缩:减少冗余,保留关键

下期预告

下一篇:Multi-Agent协作系统——让多个Agent一起工作!


往期回顾

正文完
 0
技术老金
版权声明:本站原创文章,由 技术老金 于2026-04-01发表,共计11144字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)