【Spring AI 1.0.0】Spring AI 1.0.0框架快速入门(4)——Chat Memory(聊天记录)
Spring AI框架快速入门
- Spring AI 系列文章:
- 一、前言
- 二、源码解析
- 三、自定义MYSQL持久化ChatMemoryRepository
- 3.1 自定义结构化数据库
- 3.2 序列化方式
- 四、持久化到Redis
- 五、总结
Spring AI 系列文章:
【Spring AI 1.0.0】Spring AI 1.0.0框架快速入门(1)——Chat Client API
【Spring AI 1.0.0】Spring AI 1.0.0框架快速入门(2)——Prompt(提示词)
【Spring AI 1.0.0】Spring AI 1.0.0框架快速入门(3)——Structured Output Converter(结构化输出转换器)
【Spring AI 1.0.0】Spring AI 1.0.0框架快速入门(4)——Chat Memory(聊天记录)
一、前言
前一篇文章《【Spring AI 1.0.0】Spring AI 1.0.0框架快速入门(3)——Structured Output Converter(结构化输出转换器)》》中,介绍了Structured Output Converter的基本用法,这篇文章说说什么是聊天记录。
在大型语言模型(LLMs)的架构设计中,其无状态特性导致模型无法保留历史交互信息,这在需要持续维护对话上下文的应用场景中构成了显著限制。针对这一技术瓶颈,Spring AI框架创新性地提出了聊天内存功能,通过实现交互信息的持久化存储与动态检索机制,有效解决了多轮对话中的状态保持问题。
二、源码解析
ChatMemory
和ChatMemoryRepository
是对话记忆的核心接口
ChatMemory
用于短期对话状态管理,它是基于内存的实时操作,有一个实现类MessageWindowChatMemory
即活动窗口机制,用于保留最近的N条信息。
ChatMemoryRepository
接口用于长期对话管理,即持久化能力,默认是InMemoryChatMemoryRepository
内存缓存,也可以自定方法实现ChatMemoryRepository
来持久化到JDBC、MongoDB 等
InMemoryChatMemoryRepository
的源码,其实就是通过一个ConcurrentHashMap 来维护对话信息,key 是对话 conversationId(相当于房间号),value 是该对话 id 对应的消息列表。
三、自定义MYSQL持久化ChatMemoryRepository
持久化本质是将数据存储到MYSQL中,由于ChatMemoryRepository
返回的消息是List<Message>
类型,Message是一个接口,虽然实现他的接口不多,实现起来还是有一定的复杂度,最主要的问题是消息和文本的转换。保存消息时,需要将Message对象转换成文本对象;读取消息时,需要将文本对象转换成Message对象。也就是序列化和反序列化。
序列化通常采用json,实际并不容易:
- 要持久化的 Message 是一个接口,有很多种不同的子类实现(比如 UserMessage、SystemMessage 等)
- 每种子类所拥有的字段都不一样,结构不统一
- 子类没有无参构造函数,而且没有实现 Serializable 序列化接口
在这里有两个方案:
- 方案一: 自己结构化数据库,使用结构化的数据库然后自己手动创建Message的实现对象来序列化.
- 方案二: 使用序列化库来实现,这里分别尝试了jackson和Kryo序列化库,最终选择了Kryo序列化库,其可以动态注册,减少代码量.
3.1 自定义结构化数据库
maven:
<dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-starter-model-deepseek</artifactId>
</dependency><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-starter-model-zhipuai</artifactId>
</dependency>
<dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId><scope>provided</scope>
</dependency><!-- JdbcChatMemoryRepository 是一个内置实现,使用 JDBC 在关系数据库中存储消息。
它开箱即用地支持多个数据库,适合需要持久存储聊天内存的应用程序。-->
<dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-starter-model-chat-memory-repository-jdbc</artifactId>
</dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-webflux</artifactId>
</dependency><!-- MySQL 驱动 -->
<dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>8.0.32</version>
</dependency><!-- https://mvnrepository.com/artifact/com.baomidou/mybatis-plus-boot-starter -->
<dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-spring-boot3-starter</artifactId><version>3.5.12</version>
</dependency><!-- 3.5.9及以上版本想使用mybatis plus分页配置需要单独引入-->
<dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-jsqlparser</artifactId><version>3.5.12</version> <!-- 确保版本和 MyBatis Plus 主包一致 -->
</dependency>
SQL
CREATE TABLE ai_chat_memory (id BIGINT AUTO_INCREMENT PRIMARY KEY,conversation_id VARCHAR(255) NOT NULL comment '会话id',type VARCHAR(20) NOT NULL comment '消息类型',content TEXT NOT NULL comment '消息内容',create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',update_time TIMESTAMP default CURRENT_TIMESTAMP not null on update CURRENT_TIMESTAMP comment '更新时间',is_delete tinyint default 0 not null comment '是否删除',INDEX idx_conv (conversation_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
配置:
#https://doc.spring4all.com/spring-ai/reference/api/chat/deepseek-chat.html# deepseek ai
spring.ai.chat.client.enabled=false
spring.ai.deepseek.api-key=your api key
#spring.ai.deepseek.base-url=https://api.deepseek.com
spring.ai.deepseek.chat.options.model=deepseek-chat
#spring.ai.deepseek.chat.options.model=deepseek-reasoner
spring.ai.deepseek.chat.options.temperature=0.8# zhipu ai
spring.ai.zhipuai.api-key=49dadd9c9d504acbb60580f6d53cf30b.vlX0Fp67MTwxdZ5i
spring.ai.zhipuai.base-url=https://open.bigmodel.cn/api/paas
#spring.ai.zhipuai.image.options.model=cogview-3
#spring.ai.zhipuai.image.options.model=embedding-2
spring.ai.zhipuai.chat.options.model=glm-4v-flashlogging.level.org.springframework.ai.chat.client.advisor=DEBUG# mysql
spring.datasource.url=jdbc:mysql://localhost:3307/ai?useUnicode=true&characterEncoding=utf-8&useSSL=false&serverTimezone=UTC
spring.datasource.username=root
spring.datasource.password=123456
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver# mybatis-plus
mybatis-plus.configuration.map-underscore-to-camel-case=false
mybatis-plus.configuration.log-impl=org.apache.ibatis.logging.stdout.StdOutImpl
# 全局逻辑删除的实体字段名
mybatis-plus.global-config.db-config.logic-delete-field=isDelete
# 逻辑已删除值(默认为 1)
mybatis-plus.global-config.db-config.logic-delete-value=1
# 逻辑未删除值(默认为 0)
mybatis-plus.global-config.db-config.logic-not-delete-value=0
实体类:
import TableName(value ="ai_chat_memory")
@Data
public class AiChatMemory implements Serializable {/****/@TableId(type = IdType.AUTO)private Long id;/*** 会话id*/@TableField("conversation_id")private String conversationId;/*** 消息类型*/@TableField("type")private String type;/*** 消息内容*/@TableField("content")private String content;/*** 创建时间*/@TableField("create_time")private Date createTime;/*** 更新时间*/@TableField("update_time")private Date updateTime;/*** 是否删除*/@TableLogic@TableField("is_delete")private Integer isDelete
}
ChatMemoryRepository实现:
@Service
public class MyBatisPlusChatMemoryRepository implements ChatMemoryRepository {@Resourceprivate AiChatMemoryMapper mapper;@Overridepublic List<String> findConversationIds() {// 查询所有会话的IDLambdaQueryWrapper<AiChatMemory> lqw = new LambdaQueryWrapper<>();return mapper.selectList(lqw).stream().map(AiChatMemory::getConversationId).distinct().collect(Collectors.toList());}@Overridepublic List<Message> findByConversationId(String conversationId) {LambdaQueryWrapper<AiChatMemory> lqw = new LambdaQueryWrapper<>();lqw.eq(AiChatMemory::getConversationId, conversationId);lqw.orderByDesc(AiChatMemory::getCreateTime);List<AiChatMemory> aiChatMemories = mapper.selectList(lqw);List<Message> messages = new ArrayList<>();for (AiChatMemory aiChatMemory : aiChatMemories) {String type = aiChatMemory.getType();switch (type) {case "user" -> messages.add(new UserMessage(aiChatMemory.getContent()));case "assistant" -> messages.add(new AssistantMessage(aiChatMemory.getContent()));case "system" -> messages.add(new SystemMessage(aiChatMemory.getContent()));default -> throw new IllegalArgumentException("Unknown message type: " + type);}}return messages;}@Overridepublic void saveAll(String conversationId, List<Message> messages) {List<AiChatMemory> list = new ArrayList<>();messages.stream().forEach(message -> {AiChatMemory aiChatMemory = new AiChatMemory();aiChatMemory.setConversationId(conversationId);aiChatMemory.setType(message.getMessageType().getValue());aiChatMemory.setContent(message.getText());list.add(aiChatMemory);});mapper.insertBatch(list);}@Overridepublic void deleteByConversationId(String conversationId) {// 删除指定会话的所有消息LambdaQueryWrapper<AiChatMemory> wrapper = new LambdaQueryWrapper<>();wrapper.eq(AiChatMemory::getConversationId, conversationId);mapper.delete(wrapper);}public List<AiChatMemory> findAiChatMemoryList(String conversationId) {LambdaQueryWrapper<AiChatMemory> lqw = new LambdaQueryWrapper<>();lqw.eq(AiChatMemory::getConversationId, conversationId);lqw.orderByDesc(AiChatMemory::getCreateTime);return mapper.selectList(lqw);}
}
核心代码实现:
@Slf4j
@Service
public class ChatService {@Resource(name = "zhiPuAiChatClient")private ChatClient chatClient;@Resourceprivate MyBatisPlusChatMemoryRepository chatMemoryRepository;private final Map<String, ChatMemory> memoryMap = new ConcurrentHashMap<>();public Flux<String> chat(String conversationId, String message) {ChatMemory chatMemory = this.getMemory(conversationId);// 添加用户消息UserMessage userMessage = new UserMessage(message);chatMemory.add(conversationId, userMessage);chatMemoryRepository.saveAll(conversationId, List.of(userMessage));// 构建包含上下文的PromptList<Message> messages = chatMemory.get(conversationId);// 调用ChatClient的流式接口Flux<String> responseStream = chatClient.prompt(new Prompt(messages)).stream().content();// 使用StringBuilder累积完整响应StringBuilder fullResponse = new StringBuilder();return responseStream// 每收到一个流片段就追加到StringBuilder.doOnNext(chunk -> fullResponse.append(chunk))// 当流完成时,异步保存完整响应.doOnComplete(() -> {// 使用异步线程执行保存操作,避免阻塞流处理Mono.fromRunnable(() -> {String output = fullResponse.toString();log.warn("AI Response: {}", output);AssistantMessage assistantMessage = new AssistantMessage(output);chatMemory.add(conversationId, assistantMessage);chatMemoryRepository.saveAll(conversationId, List.of(assistantMessage));}).subscribeOn(Schedulers.boundedElastic()).subscribe(); // 订阅以触发异步执行});}public List<AiChatMemory> getHistory(String conversationId) {return chatMemoryRepository.findAiChatMemoryList(conversationId);}private ChatMemory getMemory(String conversationId) {return memoryMap.computeIfAbsent(conversationId, id -> MessageWindowChatMemory.builder().build());}
}
上述代码中使用ConcurrentHashMap()
来管理当前的会话记录,key为conversationId,value为ChatMemory,将当前消息和历史消息整合成Prompt
,然后通过chatClient.prompt()
来调用大模型,使用stream()
,来流式返回。最后通过StringBuilder fullResponse
来异步接收完整的答案,最后保存到数据库中。
接口:
@RestController
@RequiredArgsConstructor
@RequestMapping("/api/chat")
public class ChatMemoryController {@Resourceprivate ChatService chatService;@GetMapping("/history")public List<AiChatMemory> history(@RequestParam String conversationId) {return chatService.getHistory(conversationId);}@GetMapping(value = "/testMysqlChatMemory", produces = "text/html;charset=UTF-8")Flux<String> testMysqlChatMemory(@RequestParam String conversationId, @RequestParam String message) {return chatService.chat(conversationId, message);}
}
执行结果:
在浏览器中执行:
http://localhost:8080/api/chat/testMysqlChatMemory?conversationId=test-sessonId-001&message=分章节写一本小说,先写第一章的内容,以穿越之我为女帝稳江山为题目
可以看到应答是流式出来的
表中有了这一条数据
然后接着执行:
http://localhost:8080/api/chat/testMysqlChatMemory?conversationId=test-sessonId-001&message=继续写第二章
数据库表中新增数据:
3.2 序列化方式
先创建数据库表:
DROP TABLE IF EXISTS logger;
create table logger
(id varchar(255) not null,userId bigint not null,message text not null,time datetime default CURRENT_TIMESTAMP not null
);DROP TABLE IF EXISTS request;
create table request
(id varchar(255) not null,userId bigint not null,name varchar(255) not null
);#会话DROP TABLE IF EXISTS user;
create table user
(id bigint not nullprimary key,name varchar(255) not null,status tinyint not null comment '用户身份
0 - 无ai权限
1 - 有ai权限'
);
引入相关依赖:
<!-- 自定义持久化的序列化库-->
<dependency><groupId>com.esotericsoftware</groupId><artifactId>kryo</artifactId><version>5.6.2</version>
</dependency>
创建序列化工具类:
@Component
public class MessageSerializer {// ⚠️ 静态 Kryo 实例(线程不安全,建议改用局部实例)private static final Kryo kryo = new Kryo();static {kryo.setRegistrationRequired(false);// 设置实例化策略(需确保兼容所有 Message 实现类)kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());}/*** 使用 Kryo 将 Message 序列化为 Base64 字符串*/public static String serialize(Message message) {try (ByteArrayOutputStream baos = new ByteArrayOutputStream();Output output = new Output(baos)) {kryo.writeClassAndObject(output, message); // ⚠️ 依赖动态注册和实例化策略output.flush();return Base64.getEncoder().encodeToString(baos.toByteArray());} catch (IOException e) {throw new RuntimeException("序列化失败", e);}}/*** 使用 Kryo 将 Base64 字符串反序列化为 Message 对象*/public static Message deserialize(String base64) {try (ByteArrayInputStream bais = new ByteArrayInputStream(Base64.getDecoder().decode(base64));Input input = new Input(bais)) {return (Message) kryo.readClassAndObject(input); // ⚠️ 依赖动态注册和实例化策略} catch (IOException e) {throw new RuntimeException("反序列化失败", e);}}
}
创建自定义的数据库持久化ChatMemoryRepository:
@Service
@Slf4j
public class SerializerMethodChatMemory implements ChatMemoryRepository {@Resourceprivate LoggerMapper loggerMapper;@Overridepublic List<String> findConversationIds() {QueryWrapper<Logger> wrapper = new QueryWrapper<>();List<Logger> loggerList = loggerMapper.selectList(wrapper);return loggerList.stream().map(Logger::getId).distinct().collect(Collectors.toList());}@Overridepublic List<Message> findByConversationId(String conversationId) {Long userId = parseUserId(conversationId);QueryWrapper<Logger> wrapper = new QueryWrapper<>();wrapper.eq("id", conversationId).eq("userId", userId) // 添加用户 ID 过滤.orderByDesc("time"); // 按时间倒序List<Logger> loggerList = loggerMapper.selectList(wrapper);List<Message> messages = new ArrayList<>();for (Logger logger : loggerList) {messages.add(MessageSerializer.deserialize(logger.getMessage()));}return messages;}/*** 添加多条数据到数据库中** @param conversationId* @param messages*/@Overridepublic void saveAll(String conversationId, List<Message> messages) {Long userId = parseUserId(conversationId);List<Logger> loggerList = new ArrayList<>();for (Message message : messages) {Logger logger = new Logger();logger.setId(conversationId);logger.setUserId(userId);logger.setTime(LocalDateTime.now());logger.setMessage(MessageSerializer.serialize(message));loggerList.add(logger);}loggerMapper.insert(loggerList);}@Overridepublic void deleteByConversationId(String conversationId) {Long userId = parseUserId(conversationId);QueryWrapper<Logger> loggerQueryWrapper = new QueryWrapper<>();loggerQueryWrapper.eq("id", conversationId);loggerMapper.deleteById(loggerQueryWrapper);}// 从 conversationId 解析用户 ID(格式:chat-{userId})private long parseUserId(String conversationId) {String[] parts = conversationId.split("-");if (parts.length == 2 && "chat".equals(parts[0])) {return Long.parseLong(parts[1]);}throw new IllegalArgumentException("无效的 conversationId 格式: " + conversationId);}public List<Logger> findAiChatMemoryList(String conversationId) {LambdaQueryWrapper<Logger> lqw = new LambdaQueryWrapper<>();lqw.eq(Logger::getUserId, parseUserId(conversationId));lqw.orderByDesc(Logger::getTime);return loggerMapper.selectList(lqw);}
}
核心接口和上面类似,我们直接执行来看效果:
浏览器执行:
http://localhost:8080/api/chat/test/serializer?conversationId=chat-000000001&message=解说一下成龙的一部电影
查表可知内容数据被序列化成了特殊字符
调用查询接口
http://localhost:8080/api/chat/history/serializer?conversationId=chat-000000001
表中的数据又被序列化成了Message对象
四、持久化到Redis
上述代码中,我们讲了通过Kryo
来序列化List<Message>
对象,这同样可以适用于redis中
引入redis依赖
<!-- Redis -->
<dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
配置文件:
spring.data.redis.port=6379
spring.data.redis.host=localhost
spring.data.redis.database=0
redis配置类:
@Configuration
public class RedisTemplateConfig {@Beanpublic RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {RedisTemplate<String, Object> template = new RedisTemplate<>();template.setConnectionFactory(connectionFactory);template.setKeySerializer(RedisSerializer.string());return template;}
}
实现自定义RedisChatMemory
@Service
@Slf4j
public class RedisChatMemory implements ChatMemoryRepository {@Resourceprivate RedisTemplate<String, Object> redisTemplate;// 用于存储所有对话ID的键private static final String CONVERSATION_IDS_KEY = "ALL_CONVERSATION_IDS";// 用于JSON序列化/反序列化private static final ObjectMapper objectMapper = new ObjectMapper();@Overridepublic List<String> findConversationIds() {// 从Redis的集合中获取所有对话IDSet<Object> members = redisTemplate.opsForSet().members(CONVERSATION_IDS_KEY);if (members == null || members.isEmpty()) {return Collections.emptyList();}return members.stream().map(Object::toString).collect(Collectors.toList());}@Overridepublic List<Message> findByConversationId(String conversationId) {return getFromRedis(conversationId);}@Overridepublic void saveAll(String conversationId, List<Message> messages) {if (messages.isEmpty()) {return;}List<Message> messageList = getFromRedis(conversationId);messageList.addAll(messages);// 保存消息列表setToRedis(conversationId, messageList);// 将对话ID添加到集合中(自动去重)redisTemplate.opsForSet().add(CONVERSATION_IDS_KEY, conversationId);}@Overridepublic void deleteByConversationId(String conversationId) {// 删除对话消息redisTemplate.delete(conversationId);// 从集合中移除对话IDredisTemplate.opsForSet().remove(CONVERSATION_IDS_KEY, conversationId);}/*** 从Redis获取数据工具方法* @param conversationId* @return*/private List<Message> getFromRedis(String conversationId) {Object obj = redisTemplate.opsForValue().get(conversationId);List<Message> messageList = new ArrayList<>();if (obj != null) {try {// 将obj转换为List<String>List<String> messageJsons = objectMapper.convertValue(obj, new TypeReference<List<String>>() {});// 逐个反序列化为Message对象for (String json : messageJsons) {Message message = MessageSerializer.deserialize(json);messageList.add(message);}} catch (IllegalArgumentException e) {log.error("Failed to convert Redis value to List<String>", e);}}return messageList;}/*** 将数据存入Redis工具方法* @param conversationId* @param messages*/private void setToRedis(String conversationId,List<Message> messages){List<String> stringList = new ArrayList<>();for (Message message : messages) {String serialize = MessageSerializer.serialize(message);stringList.add(serialize);}redisTemplate.opsForValue().set(conversationId,stringList);}
}
核心代码还是比较简单,使用CONVERSATION_IDS_KEY
来存所有的会话id,使用Kryo来序列化数据
来看执行结果:
http://localhost:8080/api/chat/test/redis?conversationId=test-sessonId-001&message=讲一下你最喜欢的小说
redis中的数据:
获取结果看看:
又被反序列化成List<Message>
对象
五、总结
这篇文章深入讲解了Spring Ai 1.0.0中Chat Memory(聊天记录)功能,介绍了如何上下文流式交互应答,介绍了通过实现ChatMemoryRepository
自定义持久化数据到MySQL和Redis,下篇文章,介绍Spring Ai 中 工具调用
创作不易,不妨点赞、收藏、关注支持一下,各位的支持就是我创作的最大动力❤️