SpringAI从入门到精通 (2)
“不积跬步,无以至千里。”
Spring AI 进阶功能实现
章节二:对话上下文管理
在实际应用中,我们需要保持多轮对话的上下文,让 AI 能够"记住"之前的对话内容。
1. 创建会话管理器
创建 ConversationMemory.java
:
package com.example.springaiproject.service;import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.stereotype.Component;import java.util.*;
import java.util.concurrent.ConcurrentHashMap;/*** 会话记忆管理器*/
@Component
public class ConversationMemory {// 使用 Map 存储每个会话的消息历史private final Map<String, List<Message>> conversationHistory = new ConcurrentHashMap<>();// 最大保留消息数量private static final int MAX_MESSAGES = 10;/*** 添加用户消息*/public void addUserMessage(String sessionId, String content) {List<Message> messages = conversationHistory.computeIfAbsent(sessionId, k -> new ArrayList<>());messages.add(new UserMessage(content));trimMessages(sessionId);}/*** 添加 AI 回复消息*/public void addAssistantMessage(String sessionId, String content) {List<Message> messages = conversationHistory.get(sessionId);if (messages != null) {messages.add(new AssistantMessage(content));trimMessages(sessionId);}}/*** 获取会话历史*/public List<Message> getHistory(String sessionId) {return conversationHistory.getOrDefault(sessionId, new ArrayList<>());}/*** 清空会话历史*/public void clearHistory(String sessionId) {conversationHistory.remove(sessionId);}/*** 限制消息数量,避免上下文过长*/private void trimMessages(String sessionId) {List<Message> messages = conversationHistory.get(sessionId);if (messages != null && messages.size() > MAX_MESSAGES) {conversationHistory.put(sessionId, new ArrayList<>(messages.subList(messages.size() - MAX_MESSAGES, messages.size())));}}
}
2. 更新 AIService
修改 AIService.java
,添加上下文管理功能:
package com.example.springaiproject.service;import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Service;import java.util.List;@Service
public class AIService {private final ChatClient chatClient;private final ConversationMemory conversationMemory;public AIService(ChatClient.Builder chatClientBuilder, ConversationMemory conversationMemory) {this.chatClient = chatClientBuilder.build();this.conversationMemory = conversationMemory;}/*** 无上下文的简单对话*/public String generateText(String prompt) {return this.chatClient.prompt().user(prompt).call().content();}/*** 带上下文的对话*/public String chat(String sessionId, String userMessage) {// 添加用户消息到历史conversationMemory.addUserMessage(sessionId, userMessage);// 获取历史消息List<Message> history = conversationMemory.getHistory(sessionId);// 发送请求(包含历史上下文)String response = this.chatClient.prompt().messages(history).call().content();// 保存 AI 回复到历史conversationMemory.addAssistantMessage(sessionId, response);return response;}/*** 清空对话历史*/public void clearHistory(String sessionId) {conversationMemory.clearHistory(sessionId);}
}
3. 更新 Controller
在 AIController.java
中添加新的端点:
package com.example.springaiproject.controller;import com.example.springaiproject.service.AIService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;import java.util.Map;@RestController
@RequestMapping("/api/ai")
public class AIController {private final AIService aiService;@Autowiredpublic AIController(AIService aiService) {this.aiService = aiService;}@PostMapping("/generate")public String generateText(@RequestBody String prompt) {return aiService.generateText(prompt);}/*** 带会话的聊天接口*/@PostMapping("/chat")public Map<String, String> chat(@RequestParam(defaultValue = "default") String sessionId,@RequestBody Map<String, String> request) {String message = request.get("message");String response = aiService.chat(sessionId, message);return Map.of("response", response, "sessionId", sessionId);}/*** 清空会话历史*/@DeleteMapping("/chat/{sessionId}")public Map<String, String> clearHistory(@PathVariable String sessionId) {aiService.clearHistory(sessionId);return Map.of("message", "History cleared for session: " + sessionId);}
}
4. 测试对话上下文
# 第一轮对话
curl -X POST "http://localhost:8080/api/ai/chat?sessionId=user123" \-H "Content-Type: application/json" \-d '{"message": "我的名字是张三"}'# 第二轮对话(AI 会记住你的名字)
curl -X POST "http://localhost:8080/api/ai/chat?sessionId=user123" \-H "Content-Type: application/json" \-d '{"message": "你还记得我的名字吗?"}'# 清空历史
curl -X DELETE "http://localhost:8080/api/ai/chat/user123"
章节三:检索增强生成(RAG)
RAG 允许 AI 从外部知识库中检索相关信息,生成更准确的答案。
1. 添加依赖
在 pom.xml
中添加向量存储和文档处理依赖:
<dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-pdf-document-reader</artifactId>
</dependency>
<dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
</dependency>
<!-- 或使用简单的内存向量存储 -->
<dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-transformers-spring-boot-starter</artifactId>
</dependency>
2. 配置向量存储
更新 application.yml
:
spring:ai:vectorstore:simple:enabled: trueembedding:transformer:enabled: true
3. 创建文档加载器
创建 DocumentLoader.java
:
package com.example.springaiproject.service;import org.springframework.ai.document.Document;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;import java.util.List;/*** 文档加载和向量化服务*/
@Service
public class DocumentLoader {private final VectorStore vectorStore;public DocumentLoader(VectorStore vectorStore) {this.vectorStore = vectorStore;}/*** 加载文档到向量存储*/public void loadDocument(Resource resource) {// 读取文档TextReader textReader = new TextReader(resource);List<Document> documents = textReader.get();// 分割文档为小块TokenTextSplitter textSplitter = new TokenTextSplitter();List<Document> splitDocuments = textSplitter.apply(documents);// 存储到向量数据库vectorStore.add(splitDocuments);}/*** 从文本加载*/public void loadText(String content, Map<String, Object> metadata) {Document document = new Document(content, metadata);TokenTextSplitter textSplitter = new TokenTextSplitter();List<Document> splitDocuments = textSplitter.apply(List.of(document));vectorStore.add(splitDocuments);}
}
4. 创建 RAG 服务
创建 RAGService.java
:
package com.example.springaiproject.service;import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;/*** RAG 检索增强生成服务*/
@Service
public class RAGService {private final ChatClient chatClient;private final VectorStore vectorStore;public RAGService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {this.vectorStore = vectorStore;this.chatClient = chatClientBuilder.defaultAdvisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())).build();}/*** 基于知识库回答问题*/public String askWithContext(String question) {return this.chatClient.prompt().user(question).call().content();}/*** 手动检索相关文档并回答*/public String askWithManualRetrieval(String question, int topK) {// 检索相关文档var similarDocuments = vectorStore.similaritySearch(SearchRequest.query(question).withTopK(topK));// 构建上下文StringBuilder context = new StringBuilder("参考以下信息回答问题:\n\n");similarDocuments.forEach(doc -> {context.append(doc.getContent()).append("\n\n");});context.append("问题:").append(question);// 生成回答return this.chatClient.prompt().user(context.toString()).call().content();}
}
5. 添加 RAG Controller
package com.example.springaiproject.controller;import com.example.springaiproject.service.DocumentLoader;
import com.example.springaiproject.service.RAGService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;import java.util.Map;@RestController
@RequestMapping("/api/rag")
public class RAGController {private final RAGService ragService;private final DocumentLoader documentLoader;@Autowiredpublic RAGController(RAGService ragService, DocumentLoader documentLoader) {this.ragService = ragService;this.documentLoader = documentLoader;}/*** 上传知识文档*/@PostMapping("/upload")public Map<String, String> uploadDocument(@RequestBody Map<String, String> request) {String content = request.get("content");documentLoader.loadText(content, Map.of("source", "user-upload"));return Map.of("message", "Document uploaded successfully");}/*** 基于知识库问答*/@PostMapping("/ask")public Map<String, String> ask(@RequestBody Map<String, String> request) {String question = request.get("question");String answer = ragService.askWithContext(question);return Map.of("answer", answer);}
}
6. 测试 RAG
# 1. 上传知识文档
curl -X POST "http://localhost:8080/api/rag/upload" \-H "Content-Type: application/json" \-d '{"content": "Spring AI 是 Spring 框架的 AI 集成模块,版本 1.0.1 支持 OpenAI、Azure OpenAI 等多种模型提供商。"}'# 2. 基于知识库提问
curl -X POST "http://localhost:8080/api/rag/ask" \-H "Content-Type: application/json" \-d '{"question": "Spring AI 支持哪些模型提供商?"}'
章节四:多模态输入输出
Spring AI 支持处理图像、音频等多种数据类型。
1. 添加多模态依赖
确保使用支持多模态的模型(如 GPT-4 Vision):
spring:ai:openai:chat:options:model: gpt-4-vision-preview # 支持图像输入
2. 创建多模态服务
创建 MultiModalService.java
:
package com.example.springaiproject.service;import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.model.Media.MimeType;
import org.springframework.stereotype.Service;
import org.springframework.util.MimeTypeUtils;import java.io.IOException;
import java.net.URL;
import java.util.List;/*** 多模态 AI 服务*/
@Service
public class MultiModalService {private final ChatClient chatClient;public MultiModalService(ChatClient.Builder chatClientBuilder) {this.chatClient = chatClientBuilder.build();}/*** 分析图像*/public String analyzeImage(String imageUrl, String prompt) {try {var imageMedia = new Media(MimeTypeUtils.IMAGE_PNG, new URL(imageUrl));var userMessage = new UserMessage(prompt,List.of(imageMedia));return this.chatClient.prompt().messages(userMessage).call().content();} catch (Exception e) {return "Error analyzing image: " + e.getMessage();}}/*** 分析本地图像(Base64)*/public String analyzeImageBase64(String base64Image, String prompt) {try {// 去除 data:image/png;base64, 前缀(如果有)String imageData = base64Image.replaceFirst("^data:image/[^;]+;base64,", "");byte[] imageBytes = java.util.Base64.getDecoder().decode(imageData);var imageResource = new org.springframework.core.io.ByteArrayResource(imageBytes);var imageMedia = new Media(MimeTypeUtils.IMAGE_PNG, imageResource);var userMessage = new UserMessage(prompt,List.of(imageMedia));return this.chatClient.prompt().messages(userMessage).call().content();} catch (Exception e) {return "Error analyzing image: " + e.getMessage();}}/*** 批量分析多张图像*/public String analyzeMultipleImages(List<String> imageUrls, String prompt) {try {List<Media> mediaList = imageUrls.stream().map(url -> {try {return new Media(MimeTypeUtils.IMAGE_PNG, new URL(url));} catch (Exception e) {throw new RuntimeException(e);}}).toList();var userMessage = new UserMessage(prompt, mediaList);return this.chatClient.prompt().messages(userMessage).call().content();} catch (Exception e) {return "Error analyzing images: " + e.getMessage();}}
}
3. 创建多模态 Controller
package com.example.springaiproject.controller;import com.example.springaiproject.service.MultiModalService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;import java.util.Base64;
import java.util.List;
import java.util.Map;@RestController
@RequestMapping("/api/multimodal")
public class MultiModalController {private final MultiModalService multiModalService;@Autowiredpublic MultiModalController(MultiModalService multiModalService) {this.multiModalService = multiModalService;}/*** 分析图像 URL*/@PostMapping("/analyze-image")public Map<String, String> analyzeImage(@RequestBody Map<String, String> request) {String imageUrl = request.get("imageUrl");String prompt = request.getOrDefault("prompt", "请描述这张图片的内容");String result = multiModalService.analyzeImage(imageUrl, prompt);return Map.of("result", result);}/*** 上传并分析图像*/@PostMapping("/upload-image")public Map<String, String> uploadAndAnalyzeImage(@RequestParam("file") MultipartFile file,@RequestParam(defaultValue = "请描述这张图片的内容") String prompt) {try {byte[] bytes = file.getBytes();String base64 = Base64.getEncoder().encodeToString(bytes);String base64WithPrefix = "data:image/png;base64," + base64;String result = multiModalService.analyzeImageBase64(base64WithPrefix, prompt);return Map.of("result", result);} catch (Exception e) {return Map.of("error", e.getMessage());}}/*** 批量分析图像*/@PostMapping("/analyze-multiple")public Map<String, String> analyzeMultipleImages(@RequestBody Map<String, Object> request) {@SuppressWarnings("unchecked")List<String> imageUrls = (List<String>) request.get("imageUrls");String prompt = (String) request.getOrDefault("prompt", "比较这些图片的异同");String result = multiModalService.analyzeMultipleImages(imageUrls, prompt);return Map.of("result", result);}
}
4. 测试多模态功能
# 分析图像 URL
curl -X POST "http://localhost:8080/api/multimodal/analyze-image" \-H "Content-Type: application/json" \-d '{"imageUrl": "https://example.com/image.jpg","prompt": "这张图片里有什么?"}'# 上传图像文件
curl -X POST "http://localhost:8080/api/multimodal/upload-image" \-F "file=@/path/to/image.jpg" \-F "prompt=描述图片中的主要元素"# 批量分析
curl -X POST "http://localhost:8080/api/multimodal/analyze-multiple" \-H "Content-Type: application/json" \-d '{"imageUrls": ["https://example.com/image1.jpg","https://example.com/image2.jpg"],"prompt": "比较这两张图片"}'
完整项目结构
spring-ai-demo/
├── src/main/java/com/example/springaiproject/
│ ├── SpringAiDemoApplication.java
│ ├── controller/
│ │ ├── AIController.java
│ │ ├── RAGController.java
│ │ └── MultiModalController.java
│ └── service/
│ ├── AIService.java
│ ├── ConversationMemory.java
│ ├── DocumentLoader.java
│ ├── RAGService.java
│ └── MultiModalService.java
└── src/main/resources/└── application.yml
总结
通过这三个章节,我们实现了:
- 对话上下文管理:支持多轮对话,AI 能记住之前的交互内容
- RAG 检索增强生成:集成向量数据库,从知识库中检索相关信息
- 多模态处理:支持图像分析、多图对比等视觉任务
这些功能是构建企业级 AI 应用坚实的基础。
🔗 参考资源
- Spring AI RAG 指南
- Spring AI 多模态支持
- 向量数据库对比