当前位置: 首页 > news >正文

langchain的简单应用案例---(2)使用Memory实现一个带记忆的对话机器人

这是一个比较旧的版本 用来练手

from langchain.prompts import ChatPromptTemplate
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParserfrom DoubaoLLM import DoubaoLLM  # 注意导入路径llm = DoubaoLLM()# 定义提示模板
prompt = ChatPromptTemplate.from_messages([("system", "你是一个 helpful 的助手。请根据对话历史回答问题。"),("placeholder", "{chat_history}"),  # 用于填充对话历史("human", "{input}"),  # 用户当前输入]
)# 初始化对话记忆
history = ChatMessageHistory()
memory = ConversationBufferMemory(memory_key="chat_history",return_messages=True,chat_memory=history,  # 显式指定聊天记忆存储
)# 构建对话链(替代 LLMChain)
def load_memory(_):return memory.load_memory_variables({})["chat_history"]chain = (RunnablePassthrough.assign(chat_history=load_memory)| prompt| llm| StrOutputParser()
)# 对话交互函数
def chat_with_bot(user_input):# 获取模型响应response = chain.invoke({"input": user_input})# 更新记忆memory.save_context({"input": user_input}, {"output": response})return response# 测试对话
if __name__ == "__main__":print(chat_with_bot("你好,我叫小明"))print(chat_with_bot("我刚才告诉你我的名字了吗?"))  # 应该能记住名字print(chat_with_bot("我叫什么名字?"))  # 验证记忆功能

继承BaseLanguageModel是为了后续更好使用chain链

import requests
from typing import Any, List, Mapping, Optional, Sequence, Dict, Union
from pydantic import Field, ConfigDict
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import (BaseMessage,HumanMessage,AIMessage,SystemMessage,get_buffer_string,AnyMessage,
)
from langchain_core.outputs import LLMResult, Generation
from langchain_core.prompt_values import PromptValue, StringPromptValue, ChatPromptValue
from langchain_core.callbacks import Callbacks
from langchain_core.tools import BaseTool
from langchain_community.agent_toolkits.load_tools import load_tools
from langgraph.prebuilt import create_react_agentAPI_KEY = "***************"
API_NAME = "***************"
API_URL = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
MODEL = "doubao-seed-1-6-flash-250715"class DoubaoLLM(BaseLanguageModel):# 遵循 BaseLanguageModel 抽象类规范,可直接用于集成豆包 API 与 LangChain/LangGraph 生态api_key: str = API_KEYmodel: str = MODELtemperature: float = Field(0.0, description="温度参数,控制输出随机性")timeout: int = Field(30, description="API调用超时时间")api_url: str = API_URL# 工具相关配置tools: List[BaseTool] = Field(default_factory=list, description="绑定的工具列表")model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")def __init__(self, **data: Any):super().__init__(**data)if not self.api_key:raise ValueError("必须提供有效的豆包API密钥")def _llm_type(self) -> str:return "doubao"def _identifying_params(self) -> Mapping[str, Any]:"""用于标识模型的参数,用于缓存和日志"""return {"model": self.model,"temperature": self.temperature,"api_url": self.api_url,}def _format_messages(self, messages: List[BaseMessage]) -> List[Dict[str, str]]:"""转换LangChain消息格式为豆包API格式"""formatted = []for msg in messages:if isinstance(msg, HumanMessage):formatted.append({"role": "user", "content": msg.content})elif isinstance(msg, AIMessage):formatted.append({"role": "assistant", "content": msg.content})elif isinstance(msg, SystemMessage):formatted.append({"role": "system", "content": msg.content})return formatteddef _call_api(self, payload: Dict[str, Any]) -> Dict[str, Any]:"""调用豆包API并返回响应"""headers = {"Content-Type": "application/json","Authorization": f"Bearer {self.api_key}",}try:# print(f"API调用promnt:  {payload}")response = requests.post(self.api_url, headers=headers, json=payload, timeout=self.timeout)response.raise_for_status()# print(f"API调用成功: 结果: {response.text}")return response.json()except requests.exceptions.RequestException as e:raise RuntimeError(f"API调用失败: {str(e)}") from edef generate_prompt(self,prompts: List[PromptValue],stop: Optional[List[str]] = None,callbacks: Callbacks = None,**kwargs: Any,) -> LLMResult:"""批量处理提示并生成结果(核心方法)"""generations = []llm_output = {"model": self.model, "usage": {}}for prompt in prompts:# 转换提示为消息列表if isinstance(prompt, StringPromptValue):messages = [HumanMessage(content=prompt.to_string())]elif isinstance(prompt, ChatPromptValue):messages = prompt.to_messages()else:raise ValueError(f"不支持的Prompt类型: {type(prompt)}")# 添加工具说明(如果有绑定工具)if self.tools:tool_desc = self._format_tools_for_prompt()system_msg = SystemMessage(content=f"可用工具:\n{tool_desc}")messages.insert(0, system_msg)# 调用APIpayload = {"model": self.model,"messages": self._format_messages(messages),"temperature": self.temperature,"stop": stop or [],}payload.update(kwargs)response = self._call_api(payload)# 解析响应if "choices" in response and response["choices"]:content = response["choices"][0]["message"]["content"]generations.append([Generation(text=content)])llm_output["usage"] = response.get("usage", {})else:raise ValueError("API返回无效响应")return LLMResult(generations=generations, llm_output=llm_output)async def agenerate_prompt(self,prompts: List[PromptValue],stop: Optional[List[str]] = None,callbacks: Callbacks = None,**kwargs: Any,) -> LLMResult:"""异步批量处理提示(简化实现,生产环境建议用aiohttp)"""return self.generate_prompt(prompts, stop, callbacks, **kwargs)def predict(self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any) -> str:"""处理文本输入并返回结果"""prompt = StringPromptValue(text=text)result = self.generate_prompt([prompt], stop=list(stop) if stop else None, **kwargs)return result.generations[0][0].textdef predict_messages(self,messages: List[BaseMessage],*,stop: Optional[Sequence[str]] = None,**kwargs: Any,) -> BaseMessage:"""处理消息列表并返回结果"""prompt = ChatPromptValue(messages=messages)result = self.generate_prompt([prompt], stop=list(stop) if stop else None, **kwargs)return AIMessage(content=result.generations[0][0].text)async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any) -> str:return self.predict(text, stop=stop, **kwargs)async def apredict_messages(self,messages: List[BaseMessage],*,stop: Optional[Sequence[str]] = None,**kwargs: Any,) -> BaseMessage:return self.predict_messages(messages, stop=stop, **kwargs)# 修复:调整invoke方法参数,符合Runnable接口规范def invoke(self,input: Union[PromptValue, str, Sequence[BaseMessage]],config: Optional[Mapping[str, Any]] = None,  # 新增config参数**kwargs: Any,) -> Union[str, BaseMessage]:"""实现Runnable接口的invoke方法,处理各种输入类型"""# 忽略config参数(如需使用可在此处添加逻辑)if isinstance(input, str):# 处理字符串输入return self.predict(input, **kwargs)elif isinstance(input, PromptValue):# 处理PromptValue输入result = self.generate_prompt([input], **kwargs)return result.generations[0][0].textelif isinstance(input, Sequence) and all(isinstance(x, BaseMessage) for x in input):# 处理消息序列输入return self.predict_messages(list(input), **kwargs)else:raise ValueError(f"不支持的输入类型: {type(input)}")# 工具绑定相关方法def bind_tools(self, tools: List[BaseTool]) -> "DoubaoLLM":"""绑定工具到模型实例"""return self.model_copy(update={"tools": tools})def _format_tools_for_prompt(self) -> str:"""格式化工具描述为提示文本"""tool_descriptions = []for tool in self.tools:tool_descriptions.append(f"- {tool.name}: {tool.description}\n  输入格式: {tool.args}")return "\n".join(tool_descriptions)# 令牌计算方法def get_num_tokens_from_messages(self,messages: List[BaseMessage],tools: Optional[Sequence] = None,) -> int:"""计算消息的令牌数量(基于GPT-2令牌器)"""return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)# 使用示例
if __name__ == "__main__":# 初始化豆包LLMdoubao_llm = DoubaoLLM()# 加载计算器工具tools = load_tools(["llm-math"], llm=doubao_llm)# 创建带工具的智能体agent = create_react_agent(model=doubao_llm.bind_tools(tools),tools=tools,prompt="你是数学助手,必须使用计算器工具解决数学问题",)#  这个测试并没有跑通# 执行测试try:result = agent.invoke({"messages": [HumanMessage(content="3的4次方加5的平方等于多少?")]})print("结果:", result["messages"][-1].content)print("结果:", result["messages"])except Exception as e:print("错误:", str(e))
http://www.dtcms.com/a/350391.html

相关文章:

  • 工作记录 2015-10-29
  • 销售额和营业收入的区别在哪?哪个值应该更大一些?
  • 新项目,如何做成本估算?
  • 本地缓存与 Redis 缓存的区别与实际应用
  • 【OpenAI】ChatGPT-4o-latest 真正的多模态、长文本模型的详细介绍+API的使用教程!
  • 2025软件测试面试题(持续更新)
  • 07-JUnit测试
  • ubuntu 卡到登录页面进不去--实测
  • 陪护系统有哪些功能?
  • 高并发内存池(4)-TLS:Thread Local Storage
  • Vue.nextTick讲解
  • kubectl 客户端访问 Kubernetes API Server 不通的原因排查与解决办法
  • 800G时代!全场景光模块矩阵解锁数据中心超高速未来
  • AR眼镜赋能矿业冶金数字化转型
  • Wireshark笔记-DHCP流程与数据包解析
  • Linux驱动开发笔记(七)——并发与竞争(上)——原子操作
  • SQLite 全面指南与常用操作
  • 没有AI背景的团队如何快速进行AI开发
  • expdp导出dmp到本地
  • docker 安装配置 redis
  • PDF处理控件Spire.PDF系列教程:在 C# 中实现 PDF 与字节数组的互转
  • 2025年06月 Python(二级)真题解析#中国电子学会#全国青少年软件编程等级考试
  • synchronized关键字的底层原理
  • 蘑兔音乐:创作好搭子
  • 嵌入式C语言进阶:深入理解sizeof操作符的精妙用法
  • 隧道监测实训模型
  • 讲解 JavaScript 中的深拷贝和浅拷贝
  • PyPI 是什么?
  • CCleaner中文版:强大的系统优化与隐私保护工具,支持清理磁盘、注册表和卸载软件
  • `mysql_query()` 数据库查询函数