512-Spring AI Alibaba 字段分类分级 Graph 示例

本案例演示如何使用 Spring AI Alibaba 的 Graph 功能,实现一个字段分类分级的智能工作流系统。该系统通过多个节点协同工作,实现敏感词检测、字段分类、人工审核等功能。
1. 案例目标
我们将构建一个包含以下核心功能的字段分类分级系统:
- 敏感词检测:自动检测输入字段是否包含敏感词,决定后续处理流程。
- 字段分类:基于知识库对字段进行智能分类和分级。
- 人工审核:对分类结果进行人工审核,可以批准或拒绝AI的分类结果。
- 结果保存:将最终分类结果保存到数据库中。
2. 技术栈与核心依赖
- Spring Boot 3.x
- Spring AI Alibaba Graph (用于构建智能工作流)
- MyBatis-Plus (用于数据库操作)
- MySQL (数据存储)
- OpenAI API (通过DashScope兼容模式调用通义大模型)
在 pom.xml 中,核心依赖包括:
<dependencies><!-- Spring Web 用于构建 RESTful API --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><!-- OpenAI 模型支持 --><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-starter-model-openai</artifactId></dependency><!-- Spring AI Alibaba Graph 核心 --><dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-graph-core</artifactId><version>1.0.0.3</version></dependency><!-- MyBatis-Plus Starter --><dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.6</version></dependency><!-- MySQL Connector --><dependency><groupId>com.mysql</groupId><artifactId>mysql-connector-j</artifactId><version>8.2.0</version></dependency>
</dependencies>3. 项目配置
在 src/main/resources/application.yml 文件中,配置数据库连接和AI模型参数:
server:port: 8080spring:datasource:driver-class-name: com.mysql.cj.jdbc.Driverurl: jdbc:mysql://127.0.0.1:3306/test?characterEncoding=utf8&autoReconnect=true&useUnicode=true&useSSL=false&serverTimezone=UTCusername: your_usernamepassword: your_passwordapplication:name: spring-ai-alibaba-graph-secai:openai:api-key: your_api_keybase-url: https://dashscope.aliyuncs.com/compatible-modeembedding:options:model: text-embedding-v1chat:options:model: qwen-maxmybatis-plus:configuration:map-underscore-to-camel-case: true4. 系统架构
本系统基于 Spring AI Alibaba Graph 构建,包含以下核心组件:
4.1 Graph 工作流
系统工作流由多个节点组成,通过状态图(StateGraph)定义节点间的流转关系:
@Configuration
public class SecGraphBuilder {@Beanpublic StateGraph secGraph() {// 构建状态图return StateGraph.builder(OverAllState.class).addNode("sensitive", node_async(new SensitiveWordDecNode())).addNode("answer", node_async(new AnswerNode())).addNode("clft", node_async(new ClftNode())).addNode("human", node_async(new HumanFeedbackNode())).addNode("saveTool", node_async(new ToolNode(List.of(new FieldSaveTool())))).addEdge(START, "sensitive").addEdge("answer", END).addConditionalEdges("sensitive", AsyncEdgeAction.of(new SensitiveDispatcher()),Map.of("yes", "answer", "no", "clft")).addConditionalEdges("clft", AsyncEdgeAction.of(new ClftDispatcher()),Map.of("yes", "human", "no", "saveTool")).addConditionalEdges("human", AsyncEdgeAction.of(new HumanFeedbackDispatcher()),Map.of("saveTool", "saveTool", "clft", "clft")).addEdge("saveTool", END).build();}
}4.2 节点(Nodes)
系统包含以下核心节点:
- SensitiveWordDecNode:敏感词检测节点,判断输入字段是否包含敏感词。
- AnswerNode:回答节点,当检测到敏感词时给出相应回答。
- ClftNode:分类节点,对字段进行分类和分级。
- HumanFeedbackNode:人工反馈节点,处理人工审核结果。
- ToolNode:工具节点,执行字段保存操作。
4.3 调度器(Dispatchers)
调度器负责根据节点执行结果决定下一个流转的节点:
- SensitiveDispatcher:根据敏感词检测结果决定流向。
- ClftDispatcher:根据分类结果决定是否需要人工审核。
- HumanFeedbackDispatcher:根据人工反馈结果决定下一步操作。
4.4 数据存储
系统使用 MySQL 数据库存储字段分类结果,包含以下实体:
- Field:字段实体,包含字段名、分类、级别和推理过程。
- FieldMapper:数据访问层接口。
- IFieldService:服务层接口,提供字段相关业务逻辑。
5. 核心代码实现
5.1 主应用类
@SpringBootApplication
@Slf4j
public class Application {public static void main(String[] args) {SpringApplication.run(Application.class, args);}@BeanCommandLineRunner vectorIngestRunner(@Value("${rag.source:classpath:rag/rag_friendly_classification.txt}") Resource ragSource,EmbeddingModel embeddingModel,@Qualifier("classificationVectorStore") VectorStore classificationVectorStore) {return args -> {logger.info("🔄 正在向量化加载分类分级知识库...");var chunks = new TokenTextSplitter().transform(new TextReader(ragSource).read());classificationVectorStore.write(chunks);};}@Beanpublic VectorStore classificationVectorStore(EmbeddingModel embeddingModel) {return SimpleVectorStore.builder(embeddingModel).build();}@Beanpublic ChatMemory chatMemory() {return MessageWindowChatMemory.builder().build();}
}5.2 控制器实现
@RestController
@RequestMapping("/sec/graph")
@Slf4j
public class SecGraphController {private final CompiledGraph compiledGraph;public SecGraphController(@Qualifier("secGraph") StateGraph stateGraph) throws GraphStateException {SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();this.compiledGraph = stateGraph.compile(CompileConfig.builder().saverConfig(saverConfig).interruptBefore("human").build());}@GetMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)public Flux<ServerSentEvent<String>> simpleChat(@RequestParam("fieldName") String fieldName,@RequestParam(value = "thread_id", defaultValue = "yhong", required = false) String threadId) throws Exception {RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();GraphProcess graphProcess = new GraphProcess(this.compiledGraph);Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();AsyncGenerator<NodeOutput> resultFuture = compiledGraph.stream(Map.of("field", fieldName), runnableConfig);graphProcess.processStream(resultFuture, sink);return sink.asFlux().doOnCancel(() -> log.info("Client disconnected from stream")).doOnError(e -> log.error("Error occurred during streaming", e));}@GetMapping(value = "/resume", produces = MediaType.TEXT_EVENT_STREAM_VALUE)public Flux<ServerSentEvent<String>> resume(@RequestParam(value = "thread_id", defaultValue = "yhong", required = false) String threadId,@RequestParam(value = "feed_back", defaultValue = "true", required = false) boolean feedBack,@RequestParam(value = "feedback_reason", defaultValue = "", required = false) String humanReason) throws GraphRunnerException {RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();StateSnapshot stateSnapshot = this.compiledGraph.getState(runnableConfig);OverAllState state = stateSnapshot.state();state.withResume();Map<String, Object> objectMap = new HashMap<>();objectMap.put("feed_back", feedBack);objectMap.put("feedback_reason", humanReason);state.withHumanFeedback(new OverAllState.HumanFeedback(objectMap, "feed_back"));Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();GraphProcess graphProcess = new GraphProcess(this.compiledGraph);AsyncGenerator<NodeOutput> resultFuture = compiledGraph.streamFromInitialNode(state, runnableConfig);graphProcess.processStream(resultFuture, sink);return sink.asFlux().doOnCancel(() -> log.info("Client disconnected from stream")).doOnError(e -> log.error("Error occurred during streaming", e));}
}5.3 敏感词检测节点
public class SensitiveWordDecNode implements NodeAction {@Overridepublic Map<String, Object> apply(OverAllState state) {String field = state.value("field");boolean isSensitive = checkSensitiveWords(field);Map<String, Object> result = new HashMap<>();result.put("is_sensitive", isSensitive);if (isSensitive) {result.put("sensitive_reason", "检测到敏感词: " + field);}return result;}private boolean checkSensitiveWords(String field) {// 实现敏感词检测逻辑return false; // 简化示例}
}5.4 分类节点
public class ClftNode implements NodeAction {private final ChatClient.Builder chatClientBuilder;private final VectorStore vectorStore;public ClftNode(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {this.chatClientBuilder = chatClientBuilder;this.vectorStore = vectorStore;}@Overridepublic Map<String, Object> apply(OverAllState state) {String field = state.value("field");// 使用RAG检索相关分类知识List<Document> similarDocs = vectorStore.similaritySearch(SearchRequest.query(field).withTopK(3));// 构建提示词String context = similarDocs.stream().map(Document::getContent).collect(Collectors.joining("\n\n"));// 调用大模型进行分类String classification = chatClientBuilder.build().prompt().user(u -> u.text("请根据以下上下文信息,对字段 '{field}' 进行分类和分级。\n\n上下文信息:\n{context}").param("field", field).param("context", context)).call().content();// 解析分类结果Map<String, Object> result = parseClassificationResult(classification);return result;}private Map<String, Object> parseClassificationResult(String classification) {// 解析大模型返回的分类结果Map<String, Object> result = new HashMap<>();// 简化示例,实际应根据大模型返回格式进行解析result.put("classification", "个人信息");result.put("level", "高");result.put("reasoning", "该字段涉及用户个人身份信息");result.put("need_human_review", true);return result;}
}5.5 人工反馈节点
public class HumanFeedbackNode implements NodeAction {@Overridepublic Map<String, Object> apply(OverAllState state) {if (state.humanFeedback() == null || !state.humanFeedback().isResume()) {throw new GraphRunnerException("等待人工反馈...");}Map<String, Object> feedbackData = state.humanFeedback().data();boolean isApproved = (boolean) feedbackData.get("feed_back");String feedbackReason = (String) feedbackData.get("feedback_reason");Map<String, Object> result = new HashMap<>();result.put("human_next_node", isApproved ? "saveTool" : "clft");result.put("feedback_reason", feedbackReason);result.put("feedback", isApproved ? "approved" : "rejected");return result;}
}5.6 字段保存工具
public class FieldSaveTool implements ToolCallback {private final IFieldService fieldService;private final ObjectMapper objectMapper;public FieldSaveTool(IFieldService fieldService, ObjectMapper objectMapper) {this.fieldService = fieldService;this.objectMapper = objectMapper;}@Overridepublic String getName() {return "save_field_classification";}@Overridepublic String getDescription() {return "保存字段分类分级信息";}@Overridepublic String getInputSchema() {return "{\"type\":\"object\",\"properties\":{\"fieldName\":{\"type\":\"string\",\"description\":\"字段名称\"},\"classification\":{\"type\":\"string\",\"description\":\"分类\"},\"level\":{\"type\":\"string\",\"description\":\"级别\"},\"reasoning\":{\"type\":\"string\",\"description\":\"推理过程\"}},\"required\":[\"fieldName\",\"classification\",\"level\",\"reasoning\"]}";}@Overridepublic String call(String toolInput) {try {JsonNode jsonNode = objectMapper.readTree(toolInput);Field field = new Field();field.setFieldName(jsonNode.get("fieldName").asText());field.setClassification(jsonNode.get("classification").asText());field.setLevel(jsonNode.get("level").asText());field.setReasoning(jsonNode.get("reasoning").asText());boolean success = fieldService.save(field);return success ? "字段分类分级信息保存成功" : "字段分类分级信息保存失败";} catch (Exception e) {return "保存失败: " + e.getMessage();}}
}6. 运行与测试
- 启动应用:运行 Spring Boot 主程序。
- 使用浏览器或 API 工具进行测试。
测试 1:字段分类流程
访问以下 URL,对字段"用户姓名"进行分类:
http://localhost:8080/sec/graph/chat?fieldName=用户姓名预期响应:系统将返回流式响应,展示整个分类流程的执行过程。
测试 2:人工审核流程
当系统检测到需要人工审核时,工作流会在人工反馈节点暂停。此时可以通过以下接口恢复流程:
http://localhost:8080/sec/graph/resume?thread_id=yhong&feed_back=true&feedback_reason=分类准确参数说明:
thread_id:工作流线程ID,用于标识特定的工作流实例。feed_back:是否批准分类结果,true表示批准,false表示拒绝。feedback_reason:人工审核的理由或说明。
7. 实现思路与扩展建议
实现思路
本案例的核心思想是"基于Graph的工作流编排"。通过将复杂的业务流程拆分为多个节点,并定义节点间的流转关系,实现灵活可扩展的业务流程管理。这使得:
- 流程可视化:通过状态图可以清晰地看到整个业务流程的执行路径。
- 节点复用:每个节点都是独立的组件,可以在不同流程中复用。
- 人机协作:通过人工反馈节点实现人机协作,提高系统可靠性。
- 状态持久化:通过检查点机制实现工作流状态的持久化,支持长时间运行的业务流程。
扩展建议
- 多模态支持:扩展系统以支持图片、音频等多模态字段的分类。
- 规则引擎集成:集成规则引擎,实现更复杂的业务规则判断。
- 分布式执行:将节点执行分布到多个服务实例,提高系统吞吐量。
- 流程监控:增加流程执行监控和告警机制,及时发现和处理异常情况。
- 批量处理:支持批量字段分类,提高处理效率。
