AI Agent开发实战(四):记忆系统让Agent记住一切
一、开场:健忘的Agent不靠谱
大家好,我是老金。
之前我们实现了会话管理,但有个问题:
- 会话结束后记忆就消失了
- 无法记住用户的偏好
- 无法从历史中学习
真正的智能Agent需要长期记忆:
- 记住用户是谁
- 记住用户的偏好
- 记住过去的交互
- 能够检索相关记忆
今天我们实现Agent的记忆系统。
二、记忆系统架构
2.1 记忆类型
┌─────────────────────────────────────────────────────────┐
│ Agent记忆系统架构 │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 工作记忆(Working Memory) │ │
│ │ • 当前对话上下文 │ │
│ │ • 临时状态 │ │
│ │ • 容量有限(通常4-20条消息) │ │
│ └─────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 短期记忆(Short-term Memory) │ │
│ │ • 最近N轮对话 │ │
│ │ • 会话级别 │ │
│ │ • 自动过期 │ │
│ └─────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 长期记忆(Long-term Memory) │ │
│ │ • 用户画像 │ │
│ │ • 重要事件 │ │
│ │ • 知识库 │ │
│ │ • 向量存储 │ │
│ └─────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘
2.2 记忆抽象
# src/memory/base.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
from datetime import datetime
class MemoryItem(BaseModel):
"""记忆条目"""
id: str
content: str
metadata: Dict[str, Any] = {}
created_at: datetime = datetime.now()
importance: float = 0.5 # 重要性 0-1
access_count: int = 0 # 访问次数
class BaseMemory(ABC):
"""记忆基类"""
@abstractmethod
async def add(self, item: MemoryItem) -> str:
"""添加记忆"""
pass
@abstractmethod
async def get(self, id: str) -> Optional[MemoryItem]:
"""获取记忆"""
pass
@abstractmethod
async def search(self, query: str, top_k: int = 5) -> List[MemoryItem]:
"""搜索记忆"""
pass
@abstractmethod
async def delete(self, id: str) -> bool:
"""删除记忆"""
pass
@abstractmethod
async def clear(self):
"""清空记忆"""
pass
三、工作记忆实现
3.1 工作记忆
# src/memory/working_memory.py
from .base import BaseMemory, MemoryItem
from typing import List, Optional
from collections import deque
import uuid
class WorkingMemory(BaseMemory):
"""工作记忆(固定大小的滑动窗口)"""
def __init__(self, max_size: int = 10):
self.max_size = max_size
self._items: deque = deque(maxlen=max_size)
async def add(self, item: MemoryItem) -> str:
"""添加记忆"""
if not item.id:
item.id = str(uuid.uuid4())
self._items.append(item)
return item.id
async def get(self, id: str) -> Optional[MemoryItem]:
"""获取记忆"""
for item in self._items:
if item.id == id:
item.access_count += 1
return item
return None
async def search(self, query: str, top_k: int = 5) -> List[MemoryItem]:
"""搜索(简单包含匹配)"""
results = []
for item in self._items:
if query.lower() in item.content.lower():
results.append(item)
if len(results) >= top_k:
break
return results
async def delete(self, id: str) -> bool:
"""删除记忆"""
for i, item in enumerate(self._items):
if item.id == id:
del self._items[i]
return True
return False
async def clear(self):
"""清空"""
self._items.clear()
def get_all(self) -> List[MemoryItem]:
"""获取所有"""
return list(self._items)
def get_context(self) -> str:
"""获取上下文字符串"""
return "n".join([f"{item.metadata.get('role', 'unknown')}: {item.content}"
for item in self._items])
四、向量记忆实现
4.1 向量存储
# src/memory/vector_memory.py
from .base import BaseMemory, MemoryItem
from typing import List, Optional
import chromadb
from chromadb.config import Settings
import uuid
class VectorMemory(BaseMemory):
"""向量记忆(使用ChromaDB)"""
def __init__(
self,
collection_name: str = "agent_memory",
persist_directory: str = "./data/memory",
embedding_function = None
):
# 初始化ChromaDB
self.client = chromadb.PersistentClient(path=persist_directory)
# 创建集合
self.collection = self.client.get_or_create_collection(
name=collection_name,
embedding_function=embedding_function
)
async def add(self, item: MemoryItem) -> str:
"""添加记忆"""
if not item.id:
item.id = str(uuid.uuid4())
self.collection.add(
ids=[item.id],
documents=[item.content],
metadatas=[{
**item.metadata,
"importance": item.importance,
"access_count": item.access_count,
"created_at": item.created_at.isoformat()
}]
)
return item.id
async def get(self, id: str) -> Optional[MemoryItem]:
"""获取记忆"""
result = self.collection.get(ids=[id])
if result['ids']:
return MemoryItem(
id=result['ids'][0],
content=result['documents'][0],
metadata=result['metadatas'][0]
)
return None
async def search(self, query: str, top_k: int = 5) -> List[MemoryItem]:
"""语义搜索"""
results = self.collection.query(
query_texts=[query],
n_results=top_k
)
items = []
for i, id in enumerate(results['ids'][0]):
items.append(MemoryItem(
id=id,
content=results['documents'][0][i],
metadata=results['metadatas'][0][i]
))
return items
async def delete(self, id: str) -> bool:
"""删除记忆"""
try:
self.collection.delete(ids=[id])
return True
except:
return False
async def clear(self):
"""清空"""
# 获取所有ID
all_ids = self.collection.get()['ids']
if all_ids:
self.collection.delete(ids=all_ids)
async def update_access_count(self, id: str):
"""更新访问次数"""
item = await self.get(id)
if item:
item.access_count += 1
await self.delete(id)
await self.add(item)
4.2 使用FAISS的替代方案
# src/memory/faiss_memory.py
from .base import BaseMemory, MemoryItem
from typing import List, Optional
import faiss
import numpy as np
from openai import OpenAI
import uuid
import pickle
import os
class FAISSMemory(BaseMemory):
"""使用FAISS的向量记忆"""
def __init__(
self,
dimension: int = 1536, # OpenAI embedding维度
index_path: str = "./data/memory/faiss.index"
):
self.dimension = dimension
self.index_path = index_path
# 初始化FAISS索引
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
with open(index_path + ".meta", "rb") as f:
self.metadata = pickle.load(f)
else:
self.index = faiss.IndexFlatIP(dimension) # 内积相似度
self.metadata = {}
# OpenAI客户端
self.openai_client = OpenAI()
def _get_embedding(self, text: str) -> np.ndarray:
"""获取文本向量"""
response = self.openai_client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return np.array(response.data[0].embedding, dtype=np.float32)
async def add(self, item: MemoryItem) -> str:
"""添加记忆"""
if not item.id:
item.id = str(uuid.uuid4())
# 获取向量
embedding = self._get_embedding(item.content)
# 添加到索引
self.index.add(np.array([embedding]))
# 保存元数据
self.metadata[item.id] = {
"content": item.content,
"metadata": item.metadata,
"importance": item.importance,
"created_at": item.created_at.isoformat()
}
# 持久化
self._save()
return item.id
async def search(self, query: str, top_k: int = 5) -> List[MemoryItem]:
"""搜索"""
if self.index.ntotal == 0:
return []
# 获取查询向量
query_embedding = self._get_embedding(query)
# 搜索
k = min(top_k, self.index.ntotal)
scores, indices = self.index.search(np.array([query_embedding]), k)
items = []
for i, idx in enumerate(indices[0]):
if idx str:
"""记忆内容"""
item = MemoryItem(
id=str(uuid.uuid4()),
content=content,
metadata={"role": role},
importance=importance
)
# 添加到工作记忆
await self.working_memory.add(item)
# 重要内容存储到长期记忆
if should_store_long_term or importance > 0.7:
await self.long_term_memory.add(item)
return item.id
async def recall(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""回忆相关内容"""
# 先搜索工作记忆
working_results = await self.working_memory.search(query, top_k)
# 再搜索长期记忆
long_term_results = await self.long_term_memory.search(query, top_k)
# 合并去重
all_results = []
seen_ids = set()
for item in working_results:
if item.id not in seen_ids:
all_results.append({
"content": item.content,
"source": "working_memory",
"importance": item.importance
})
seen_ids.add(item.id)
for item in long_term_results:
if item.id not in seen_ids:
all_results.append({
"content": item.content,
"source": "long_term_memory",
"importance": item.importance
})
seen_ids.add(item.id)
return all_results[:top_k]
async def get_context_for_llm(self, query: str = None) -> str:
"""获取用于LLM的上下文"""
context_parts = []
# 工作记忆
working_context = self.working_memory.get_context()
if working_context:
context_parts.append(f"最近对话:n{working_context}")
# 长期记忆
if query:
long_term_memories = await self.long_term_memory.search(query, top_k=3)
if long_term_memories:
long_term_context = "n".join([f"- {m.content}" for m in long_term_memories])
context_parts.append(f"相关记忆:n{long_term_context}")
return "nn".join(context_parts)
async def forget(self, id: str):
"""遗忘"""
await self.working_memory.delete(id)
await self.long_term_memory.delete(id)
async def clear_working_memory(self):
"""清空工作记忆"""
await self.working_memory.clear()
async def consolidate(self):
"""巩固记忆(将重要的工作记忆转移到长期记忆)"""
items = self.working_memory.get_all()
for item in items:
if item.importance > 0.7 or item.access_count > 3:
await self.long_term_memory.add(item)
六、带记忆的Agent
6.1 记忆Agent实现
# src/agents/memory_agent.py
from .tool_agent import ToolAgent
from ..memory.manager import MemoryManager
from ..utils.llm_client import LLMClient
from typing import List, Optional
class MemoryAgent(ToolAgent):
"""带记忆的Agent"""
def __init__(
self,
llm_client: LLMClient,
user_id: str = "default",
**kwargs
):
super().__init__(llm_client, **kwargs)
self.user_id = user_id
self.memory = MemoryManager()
async def chat(self, user_input: str) -> str:
"""对话(带记忆)"""
# 1. 记住用户输入
await self.memory.remember(
content=user_input,
role="user",
importance=self._calculate_importance(user_input)
)
# 2. 获取相关记忆
relevant_memories = await self.memory.recall(user_input, top_k=5)
# 3. 构建上下文
memory_context = ""
if relevant_memories:
memory_context = "相关历史记忆:n" + "n".join([
f"- {m['content']}" for m in relevant_memories
])
# 4. 增强系统提示
enhanced_prompt = f"""{self.system_prompt}
{memory_context}
"""
# 临时更新系统提示
original_prompt = self.system_prompt
self.state.messages[0].content = enhanced_prompt
# 5. 执行对话
response = await self.run(user_input)
# 恢复原始系统提示
self.state.messages[0].content = original_prompt
# 6. 记住回复
await self.memory.remember(
content=response,
role="assistant",
importance=0.3 # 回复重要性较低
)
return response
def _calculate_importance(self, message: str) -> float:
"""计算消息重要性"""
importance = 0.5
# 包含个人信息的更重要
personal_keywords = ["我", "我的", "喜欢", "讨厌", "工作", "名字"]
for keyword in personal_keywords:
if keyword in message:
importance += 0.1
# 问题是中等重要
if "?" in message or "?" in message:
importance += 0.05
# 长消息可能更重要
if len(message) > 100:
importance += 0.1
return min(importance, 1.0)
6.2 使用示例
# examples/memory_agent_demo.py
import asyncio
from src.utils.llm_client import LLMClient
from src.agents.memory_agent import MemoryAgent
async def main():
llm = LLMClient(provider="openai", model="gpt-4-turbo-preview")
agent = MemoryAgent(
llm_client=llm,
user_id="user_001",
system_prompt="你是一个友好的AI助手,能记住用户的偏好和重要信息。"
)
# 多轮对话
conversations = [
"我叫张三,是一名软件工程师",
"我喜欢Python和AI",
"我正在学习AI Agent开发",
"你还记得我是谁吗?",
"我喜欢什么编程语言?",
"我最近在学什么?"
]
for user_input in conversations:
print(f"用户: {user_input}")
response = await agent.chat(user_input)
print(f"Agent: {response}n")
if __name__ == "__main__":
asyncio.run(main())
6.3 运行效果
用户: 我叫张三,是一名软件工程师
Agent: 你好张三!很高兴认识一位软件工程师。有什么我可以帮你的吗?
用户: 我喜欢Python和AI
Agent: Python和AI是非常棒的组合!你主要用Python做哪方面的开发呢?机器学习、深度学习,还是其他方向?
用户: 我正在学习AI Agent开发
Agent: AI Agent开发是目前的热门方向!需要我分享一些学习资源或者解答相关问题吗?
用户: 你还记得我是谁吗?
Agent: 当然记得!你叫张三,是一名软件工程师,你喜欢Python和AI,目前正在学习AI Agent开发。
用户: 我喜欢什么编程语言?
Agent: 你告诉我你喜欢Python。
用户: 我最近在学什么?
Agent: 你最近在学习AI Agent开发。
七、记忆压缩与总结
7.1 记忆压缩
# src/memory/compressor.py
from ..utils.llm_client import LLMClient
from typing import List
from .base import MemoryItem
class MemoryCompressor:
"""记忆压缩器"""
def __init__(self, llm_client: LLMClient):
self.llm = llm_client
async def compress(self, memories: List[MemoryItem]) -> str:
"""压缩多条记忆为摘要"""
if not memories:
return ""
# 合并所有记忆
combined = "n".join([f"- {m.content}" for m in memories])
# 使用LLM生成摘要
prompt = f"""请将以下多条记忆压缩为一条简洁的摘要:
{combined}
要求:
1. 保留关键信息
2. 去除重复内容
3. 保持逻辑连贯
4. 不超过100字
摘要:"""
summary = await self.llm.chat([{"role": "user", "content": prompt}])
return summary
async def summarize_session(self, messages: List[dict]) -> str:
"""总结会话"""
conversation = "n".join([
f"{m['role']}: {m['content']}"
for m in messages
])
prompt = f"""请总结以下对话的关键信息:
{conversation}
总结要点:"""
return await self.llm.chat([{"role": "user", "content": prompt}])
八、最佳实践
8.1 记忆策略
| 策略 | 适用场景 |
|---|---|
| 工作记忆优先 | 当前任务相关 |
| 语义搜索 | 检索相关知识 |
| 时间衰减 | 旧记忆降权 |
| 重要性排序 | 关键信息优先 |
8.2 记忆优化
# 记忆管理最佳实践
class MemoryOptimization:
"""记忆优化"""
@staticmethod
async def optimize(memory_manager: MemoryManager):
"""优化记忆"""
# 1. 清理过期的短期记忆
# 2. 压缩重复记忆
# 3. 合并相似记忆
# 4. 删除低重要性记忆
pass
九、总结
记忆系统要点
- 工作记忆:当前上下文,容量有限
- 长期记忆:向量存储,语义检索
- 记忆管理:统一管理,智能检索
- 记忆压缩:减少冗余,保留关键
下期预告
下一篇:Multi-Agent协作系统——让多个Agent一起工作!
往期回顾
正文完