langchain的简单应用案例---(2)使用Memory实现一个带记忆的对话机器人
这是一个比较旧的版本 用来练手
from langchain.prompts import ChatPromptTemplate
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParserfrom DoubaoLLM import DoubaoLLM # 注意导入路径llm = DoubaoLLM()# 定义提示模板
prompt = ChatPromptTemplate.from_messages([("system", "你是一个 helpful 的助手。请根据对话历史回答问题。"),("placeholder", "{chat_history}"), # 用于填充对话历史("human", "{input}"), # 用户当前输入]
)# 初始化对话记忆
history = ChatMessageHistory()
memory = ConversationBufferMemory(memory_key="chat_history",return_messages=True,chat_memory=history, # 显式指定聊天记忆存储
)# 构建对话链(替代 LLMChain)
def load_memory(_):return memory.load_memory_variables({})["chat_history"]chain = (RunnablePassthrough.assign(chat_history=load_memory)| prompt| llm| StrOutputParser()
)# 对话交互函数
def chat_with_bot(user_input):# 获取模型响应response = chain.invoke({"input": user_input})# 更新记忆memory.save_context({"input": user_input}, {"output": response})return response# 测试对话
if __name__ == "__main__":print(chat_with_bot("你好,我叫小明"))print(chat_with_bot("我刚才告诉你我的名字了吗?")) # 应该能记住名字print(chat_with_bot("我叫什么名字?")) # 验证记忆功能
继承BaseLanguageModel是为了后续更好使用chain链
import requests
from typing import Any, List, Mapping, Optional, Sequence, Dict, Union
from pydantic import Field, ConfigDict
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import (BaseMessage,HumanMessage,AIMessage,SystemMessage,get_buffer_string,AnyMessage,
)
from langchain_core.outputs import LLMResult, Generation
from langchain_core.prompt_values import PromptValue, StringPromptValue, ChatPromptValue
from langchain_core.callbacks import Callbacks
from langchain_core.tools import BaseTool
from langchain_community.agent_toolkits.load_tools import load_tools
from langgraph.prebuilt import create_react_agentAPI_KEY = "***************"
API_NAME = "***************"
API_URL = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
MODEL = "doubao-seed-1-6-flash-250715"class DoubaoLLM(BaseLanguageModel):# 遵循 BaseLanguageModel 抽象类规范,可直接用于集成豆包 API 与 LangChain/LangGraph 生态api_key: str = API_KEYmodel: str = MODELtemperature: float = Field(0.0, description="温度参数,控制输出随机性")timeout: int = Field(30, description="API调用超时时间")api_url: str = API_URL# 工具相关配置tools: List[BaseTool] = Field(default_factory=list, description="绑定的工具列表")model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")def __init__(self, **data: Any):super().__init__(**data)if not self.api_key:raise ValueError("必须提供有效的豆包API密钥") def _llm_type(self) -> str:return "doubao" def _identifying_params(self) -> Mapping[str, Any]:"""用于标识模型的参数,用于缓存和日志"""return {"model": self.model,"temperature": self.temperature,"api_url": self.api_url,}def _format_messages(self, messages: List[BaseMessage]) -> List[Dict[str, str]]:"""转换LangChain消息格式为豆包API格式"""formatted = []for msg in messages:if isinstance(msg, HumanMessage):formatted.append({"role": "user", "content": msg.content})elif isinstance(msg, AIMessage):formatted.append({"role": "assistant", "content": msg.content})elif isinstance(msg, SystemMessage):formatted.append({"role": "system", "content": msg.content})return formatteddef _call_api(self, payload: Dict[str, Any]) -> Dict[str, Any]:"""调用豆包API并返回响应"""headers = {"Content-Type": "application/json","Authorization": f"Bearer {self.api_key}",}try:# print(f"API调用promnt: {payload}")response = requests.post(self.api_url, headers=headers, json=payload, timeout=self.timeout)response.raise_for_status()# print(f"API调用成功: 结果: {response.text}")return response.json()except requests.exceptions.RequestException as e:raise RuntimeError(f"API调用失败: {str(e)}") from edef generate_prompt(self,prompts: List[PromptValue],stop: Optional[List[str]] = None,callbacks: Callbacks = None,**kwargs: Any,) -> LLMResult:"""批量处理提示并生成结果(核心方法)"""generations = []llm_output = {"model": self.model, "usage": {}}for prompt in prompts:# 转换提示为消息列表if isinstance(prompt, StringPromptValue):messages = [HumanMessage(content=prompt.to_string())]elif isinstance(prompt, ChatPromptValue):messages = prompt.to_messages()else:raise ValueError(f"不支持的Prompt类型: {type(prompt)}")# 添加工具说明(如果有绑定工具)if self.tools:tool_desc = self._format_tools_for_prompt()system_msg = SystemMessage(content=f"可用工具:\n{tool_desc}")messages.insert(0, system_msg)# 调用APIpayload = {"model": self.model,"messages": self._format_messages(messages),"temperature": self.temperature,"stop": stop or [],}payload.update(kwargs)response = self._call_api(payload)# 解析响应if "choices" in response and response["choices"]:content = response["choices"][0]["message"]["content"]generations.append([Generation(text=content)])llm_output["usage"] = response.get("usage", {})else:raise ValueError("API返回无效响应")return LLMResult(generations=generations, llm_output=llm_output)async def agenerate_prompt(self,prompts: List[PromptValue],stop: Optional[List[str]] = None,callbacks: Callbacks = None,**kwargs: Any,) -> LLMResult:"""异步批量处理提示(简化实现,生产环境建议用aiohttp)"""return self.generate_prompt(prompts, stop, callbacks, **kwargs)def predict(self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any) -> str:"""处理文本输入并返回结果"""prompt = StringPromptValue(text=text)result = self.generate_prompt([prompt], stop=list(stop) if stop else None, **kwargs)return result.generations[0][0].textdef predict_messages(self,messages: List[BaseMessage],*,stop: Optional[Sequence[str]] = None,**kwargs: Any,) -> BaseMessage:"""处理消息列表并返回结果"""prompt = ChatPromptValue(messages=messages)result = self.generate_prompt([prompt], stop=list(stop) if stop else None, **kwargs)return AIMessage(content=result.generations[0][0].text)async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any) -> str:return self.predict(text, stop=stop, **kwargs)async def apredict_messages(self,messages: List[BaseMessage],*,stop: Optional[Sequence[str]] = None,**kwargs: Any,) -> BaseMessage:return self.predict_messages(messages, stop=stop, **kwargs)# 修复:调整invoke方法参数,符合Runnable接口规范def invoke(self,input: Union[PromptValue, str, Sequence[BaseMessage]],config: Optional[Mapping[str, Any]] = None, # 新增config参数**kwargs: Any,) -> Union[str, BaseMessage]:"""实现Runnable接口的invoke方法,处理各种输入类型"""# 忽略config参数(如需使用可在此处添加逻辑)if isinstance(input, str):# 处理字符串输入return self.predict(input, **kwargs)elif isinstance(input, PromptValue):# 处理PromptValue输入result = self.generate_prompt([input], **kwargs)return result.generations[0][0].textelif isinstance(input, Sequence) and all(isinstance(x, BaseMessage) for x in input):# 处理消息序列输入return self.predict_messages(list(input), **kwargs)else:raise ValueError(f"不支持的输入类型: {type(input)}")# 工具绑定相关方法def bind_tools(self, tools: List[BaseTool]) -> "DoubaoLLM":"""绑定工具到模型实例"""return self.model_copy(update={"tools": tools})def _format_tools_for_prompt(self) -> str:"""格式化工具描述为提示文本"""tool_descriptions = []for tool in self.tools:tool_descriptions.append(f"- {tool.name}: {tool.description}\n 输入格式: {tool.args}")return "\n".join(tool_descriptions)# 令牌计算方法def get_num_tokens_from_messages(self,messages: List[BaseMessage],tools: Optional[Sequence] = None,) -> int:"""计算消息的令牌数量(基于GPT-2令牌器)"""return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)# 使用示例
if __name__ == "__main__":# 初始化豆包LLMdoubao_llm = DoubaoLLM()# 加载计算器工具tools = load_tools(["llm-math"], llm=doubao_llm)# 创建带工具的智能体agent = create_react_agent(model=doubao_llm.bind_tools(tools),tools=tools,prompt="你是数学助手,必须使用计算器工具解决数学问题",)# 这个测试并没有跑通# 执行测试try:result = agent.invoke({"messages": [HumanMessage(content="3的4次方加5的平方等于多少?")]})print("结果:", result["messages"][-1].content)print("结果:", result["messages"])except Exception as e:print("错误:", str(e))