Spring AI alibaba对话上下文持久化数据库
本文为个人学习笔记整理,仅供交流参考,非专业教学资料,内容请自行甄别。
文章目录
- 前言
- 一、对话上下文持久化数据库
- 2.1、add
- 2.2、get
- 2.3、clear
前言
本篇介绍Spring AI alibaba中,将对话上下文记忆持久化到数据库的实现。ChatMemory
是Spring AI提供的接口,定义了聊天对话历史记录的存储的规范。
它的默认实现是InMemoryChatMemory
,即将对话上下文保存到内存中,伴随着服务器的重启,记录会丢失。
在ChatMemory
中定义的三个方法:
- add:将一组消息添加到指定对话的记忆中。
conversationId
参数是对话的唯一标识,用于区分不同用户或不同会话的上下文。List<Message>
代表了要添加的消息列表。Message有三种不同的类型,分别是UserMessage(用户角色)
,AssistantMessage(助手角色)
和SystemMessage(系统角色)
实际场景中,当用户发送新消息或 AI 生成响应后,会通过此方法将这些消息存入记忆,确保后续对话能参考历史内容。 - get:从指定对话的记忆中获取最近的lastN条消息。因为在调用chatClient时,通常需要指定保留的上下文记忆条目。每次调用大模型都是需要消耗token的,如果保留的记忆条数过多,成本也会增加。实际实现中,可能会对消息数量做限制(避免上下文过长),lastN参数就是为了灵活控制上下文长度。
- clear:清空指定对话的所有记忆消息。
一、对话上下文持久化数据库
如果需要将对话上下文持久化数据库,需要自己写一个类,实现ChatMemory
接口。在进行存储之前,首先需要考虑到表结构的设计,首先需要记录contextId ,作为会话的标识,一次对话中的contextId都是相同的。还需要记录message
,即消息的正文。这里采用了blob的格式,因为存入数据库的消息正文需要序列化。
create table if not exists `ai-agent`.context_memory_record
(id bigint not null comment '主键Id'primary key,contextId varchar(255) null comment '上下文Id',message blob null comment '消息',createTime datetime default CURRENT_TIMESTAMP not null comment '创建时间'
)comment '上下文历史消息记录表';
而序列化的选择,采用kryo
。不同的Message的实现,其属性也是不同的:
如果采用JSON进行反序列化,那么每次还需要获取到消息的类型,并且进行分支判断处理,无法统一进行处理。所以采用kryo
的方案,Kryo 是带类型的二进制格式,序列化时会自动将对象的类信息(如类名、类型标识)写入二进制数据中,反序列化时,Kryo 可以根据二进制数据中包含的类型信息,直接还原出原始对象(可能是 A 或 B),无需显式指定目标类型。
需要引入依赖(还需要引入mybatis-plus 和 mysql的依赖):
<dependency><groupId>com.esotericsoftware</groupId><artifactId>kryo</artifactId><version>5.6.2</version></dependency>
在自定义的类中,进行初始化操作:
编写序列化和反序列化的方法:
2.1、add
重写add方法:
整体思路是首先根据conversationId
从数据库中查询结果,并且反序列化,如果是第一次操作,那么结果为空,最终会执行插入到数据库的操作,后续则是查询出前一次的上下文,然后进行追加,执行根据conversationId
更新数据库的操作。
2.2、get
重写get方法:
这里参数中的lastN,是在调用chatClient
时指定的:
2.3、clear
完整代码:
package org.ragdollcat.secondaiagent.chatmemory;import cn.hutool.core.util.ObjUtil;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.objenesis.strategy.StdInstantiatorStrategy;
import org.ragdollcat.secondaiagent.model.ContextMemoryRecord;
import org.ragdollcat.secondaiagent.service.ContextMemoryRecordService;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Component;import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.List;/*** 基于数据库的对话记忆持久化*/
@Slf4j
@Component
public class DbBasedChatMemory implements ChatMemory {@Resourceprivate ContextMemoryRecordService contextMemoryRecordService;private static final Kryo kryo = new Kryo();static {kryo.setRegistrationRequired(false);// 设置实例化策略kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());}@Overridepublic void add(String conversationId, List<Message> messages) {log.info("conversationId:{},messages:{}", conversationId, messages);//根据conversationId从数据库中查询结果,并且反序列化List<Message> conversationMessages = getOrCreateConversation(conversationId);//将本次结果进行追加conversationMessages.addAll(messages);//重新存入数据库saveConversation(conversationId, conversationMessages);}@Overridepublic List<Message> get(String conversationId, int lastN) {List<Message> allMessages = getOrCreateConversation(conversationId);return allMessages.stream().skip(Math.max(0, allMessages.size() - lastN)).toList();}@Overridepublic void clear(String conversationId) {ContextMemoryRecord contextMemoryRecord = getConversationDB(conversationId);if (ObjUtil.isNotEmpty(contextMemoryRecord)){contextMemoryRecordService.removeById(contextMemoryRecord.getId());}}private List<Message> getOrCreateConversation(String conversationId) {ContextMemoryRecord memoryRecord = getConversationDB(conversationId);List<Message> messages = new ArrayList<>();if (ObjUtil.isNotEmpty(memoryRecord) && memoryRecord.getMessage() != null) {// 从数据库记录中获取字节数组byte[] messageBytes = memoryRecord.getMessage();// 使用Kryo反序列化为List<Message>messages = deserializeMessages(messageBytes);}return messages;}/*** 根据Id查询上下文对象** @param conversationId* @return*/private ContextMemoryRecord getConversationDB(String conversationId) {return contextMemoryRecordService.getOne(new QueryWrapper<>(ContextMemoryRecord.class).eq("contextId", conversationId));}private void saveConversation(String conversationId, List<Message> conversationMessages) {ContextMemoryRecord conversationDB = getConversationDB(conversationId);//新增if (ObjUtil.isEmpty(conversationDB)){ContextMemoryRecord memoryRecord = new ContextMemoryRecord();memoryRecord.setContextId(conversationId);//将消息重新序列化memoryRecord.setMessage(serializeMessages(conversationMessages));contextMemoryRecordService.save(memoryRecord);}//更新else {ContextMemoryRecord memoryRecord = new ContextMemoryRecord();memoryRecord.setContextId(conversationId);//将消息重新序列化memoryRecord.setMessage(serializeMessages(conversationMessages));//根据contextId字段更新contextMemoryRecordService.update(memoryRecord,new QueryWrapper<>(ContextMemoryRecord.class).eq("contextId",conversationId));}}/*** 反序列化字节数组为List<Message>*/private List<Message> deserializeMessages(byte[] data) {try (Input input = new Input(new ByteArrayInputStream(data))) {// 反序列化为List<Message>return kryo.readObject(input, ArrayList.class);} catch (Exception e) {// 处理反序列化异常,例如日志记录和返回空列表log.error("Failed to deserialize messages", e);return new ArrayList<>();}}// 对应的序列化方法(当你需要保存消息到数据库时使用)private byte[] serializeMessages(List<Message> messages) {if (messages == null || messages.isEmpty()) {return null;}try (ByteArrayOutputStream baos = new ByteArrayOutputStream();Output output = new Output(baos)) {kryo.writeObject(output, messages);output.flush();return baos.toByteArray();} catch (Exception e) {log.error("Failed to serialize messages", e);return null;}}}
用户第一次提问,conversationId下的记录为空,就执行序列化然后新增的操作。
助手进行回答,查询到第一次相同conversationId的记录,进行追加,然后根据conversationId更新数据库。