509-Spring AI Alibaba Graph Parallel Stream Node 示例

本案例基于Spring AI Alibaba Graph框架实现,展示了在利用Graph搭建工作流时,如何对节点进行并行处理以加快执行效率,同时实现每个节点在对AI模型调用时能实时输出结果。
1. 案例目标
我们将创建一个Web应用,演示如何利用Spring AI Alibaba Graph框架实现并行流式处理:
- 并行流式处理 (
/graph/parallel-stream/expand-translate):通过Graph框架,实现查询扩展和翻译两个节点的并行处理,同时支持实时流式输出结果。 - 结果合并 (
MergeResultsNode):当并行节点都处理完成后,由合并节点对结果进行合并处理。
2. 技术栈与核心依赖
- Spring Boot 3.x
- Spring AI Alibaba (用于对接阿里云DashScope通义大模型)
- Spring AI Alibaba Graph (用于构建并行处理流程)
- Maven (项目构建工具)
在 pom.xml 中,你需要引入以下核心依赖:
<dependencies><!-- Spring AI Alibaba 核心启动器,集成 DashScope --><dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-starter-dashscope</artifactId></dependency><!-- Spring AI Chat Client --><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-autoconfigure-model-chat-client</artifactId></dependency><!-- Spring AI Alibaba Graph 核心库 --><dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-graph-core</artifactId><version>1.0.0.3</version></dependency><!-- Spring Web 用于构建 RESTful API --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency>
</dependencies>3. 项目配置
在 src/main/resources/application.yml 文件中,配置你的DashScope API Key。
server:port: 8080
spring:application:name: parallel-stream-nodeai:dashscope:api-key: ${AI_DASHSCOPE_API_KEY}chat:options:model: qwen-max重要提示:请将 AI_DASHSCOPE_API_KEY 环境变量设置为你从阿里云获取的有效API Key。
4. Graph流程设计
本示例实现了一个并行流式处理的Graph流程,包含以下节点:
- ExpanderNode:扩展节点,AI模型流式输出扩展文本。
- TranslateNode:翻译节点,AI模型流式输出翻译文本。
- MergeResultsNode:合并节点,当扩展节点和翻译节点都处理完成后,对结果进行合并。
4.1 Graph流程配置
在 GraphConfiguration.java 中,我们定义了Graph的结构和节点关系:
@Bean
public StateGraph parallelStreamGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException {KeyStrategyFactory keyStrategyFactory = () -> {HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();// 用户输入keyStrategyHashMap.put("query", new ReplaceStrategy());keyStrategyHashMap.put("expander_number", new ReplaceStrategy());keyStrategyHashMap.put("expander_content", new ReplaceStrategy());keyStrategyHashMap.put("translate_language", new ReplaceStrategy());keyStrategyHashMap.put("translate_content", new ReplaceStrategy());keyStrategyHashMap.put("merge_result", new ReplaceStrategy());return keyStrategyHashMap;};Map<String, NodeStatus> node2Status = new HashMap<>();StateGraph stateGraph = new StateGraph(keyStrategyFactory).addNode(ExpanderNode.NODE_NAME, node_async(new ExpanderNode(chatClientBuilder, node2Status))).addNode(TranslateNode.NODE_NAME, node_async(new TranslateNode(chatClientBuilder, node2Status))).addNode(MergeResultsNode.NODE_NAME, node_async(new MergeResultsNode(node2Status))).addEdge(StateGraph.START, TranslateNode.NODE_NAME).addEdge(StateGraph.START, ExpanderNode.NODE_NAME).addEdge(TranslateNode.NODE_NAME, MergeResultsNode.NODE_NAME).addEdge(ExpanderNode.NODE_NAME, MergeResultsNode.NODE_NAME).addEdge(MergeResultsNode.NODE_NAME, StateGraph.END);// 添加 PlantUML 打印GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML,"expander flow");logger.info("\n=== expander UML Flow ===");logger.info(representation.content());logger.info("==================================\n");return stateGraph;
}5. 节点实现
5.1 ExpanderNode
扩展节点,AI模型流式输出扩展文本。
public class ExpanderNode implements NodeAction {private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("You are an expert at information retrieval and search optimization.\n" +"Your task is to generate {number} different versions of the given query.\n\n" +"Each variant must cover different perspectives or aspects of the topic,\n" +"while maintaining the core intent of the original query. The goal is to\n" +"expand the search space and improve the chances of finding relevant information.\n\n" +"Do not explain your choices or add any other text.\n" +"Provide the query variants separated by newlines.\n\n" +"Original query: {query}\n\n" +"Query variants:\n");private final ChatClient chatClient;private final Integer NUMBER = 3;private final Map<String, NodeStatus> node2Status;public static final String NODE_NAME = "expander";public ExpanderNode(ChatClient.Builder chatClientBuilder, Map<String, NodeStatus> node2Status) {this.chatClient = chatClientBuilder.build();this.node2Status = node2Status;}@Overridepublic Map<String, Object> apply(OverAllState state) {node2Status.put(NODE_NAME, NodeStatus.RUNNING);String query = state.value("query", "");Integer expanderNumber = state.value("expander_number", this.NUMBER);Flux<ChatResponse> chatResponseFlux = this.chatClient.prompt().user((user) -> user.text(DEFAULT_PROMPT_TEMPLATE.getTemplate()).param("number", expanderNumber).param("query", query)).stream().chatResponse();AsyncGenerator<? extends NodeOutput> generator = StreamingChatGenerator.builder().startingNode("expander_llm_stream").startingState(state).mapResult(response -> {String text = response.getResult().getOutput().getText();List<String> queryVariants = Arrays.asList(text.split("\n"));node2Status.put(NODE_NAME, NodeStatus.COMPLETED);return Map.of("expander_content", queryVariants);}).build(chatResponseFlux);return Map.of("expander_content", generator);}
}5.2 TranslateNode
翻译节点,AI模型流式输出翻译文本。
public class TranslateNode implements NodeAction {private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("Given a user query, translate it to {targetLanguage}.\n" +"If the query is already in {targetLanguage}, return it unchanged.\n" +"If you don't know the language of the query, return it unchanged.\n" +"Do not add explanations nor any other text.\n\n" +"Original query: {query}\n\n" +"Translated query:\n");private final ChatClient chatClient;private final String TARGET_LANGUAGE = "English";private final Map<String, NodeStatus> node2Status;public static final String NODE_NAME = "translate";public TranslateNode(ChatClient.Builder chatClientBuilder, Map<String, NodeStatus> node2Status) {this.chatClient = chatClientBuilder.build();this.node2Status = node2Status;}@Overridepublic Map<String, Object> apply(OverAllState state) {node2Status.put(NODE_NAME, NodeStatus.RUNNING);String query = state.value("query", "");String targetLanguage = state.value("translate_language", TARGET_LANGUAGE);Flux<ChatResponse> chatResponseFlux = this.chatClient.prompt().user((user) -> user.text(DEFAULT_PROMPT_TEMPLATE.getTemplate()).param("targetLanguage", targetLanguage).param("query", query)).stream().chatResponse();AsyncGenerator<? extends NodeOutput> generator = StreamingChatGenerator.builder().startingNode("translate_llm_stream").startingState(state).mapResult(response -> {String text = response.getResult().getOutput().getText();node2Status.put(NODE_NAME, NodeStatus.COMPLETED);assert text != null;return Map.of("translate_content", text);}).build(chatResponseFlux);return Map.of("translate_content", generator);}
}5.3 MergeResultsNode
合并节点,当扩展节点和翻译节点都处理完成后,对结果进行合并。
private class MergeResultsNode implements NodeAction {public static final String NODE_NAME = "merge";private final Map<String, NodeStatus> node2Status;public MergeResultsNode(Map<String, NodeStatus> node2Status) {this.node2Status = node2Status;}@Overridepublic Map<String, Object> apply(OverAllState state) {if (!isDone(node2Status)) {return Map.of();}Object expanderContent = state.value("expander_content").orElse("unknown");String translateContent = (String) state.value("translate_content").orElse("");return Map.of("merge_result", Map.of("expander_content", expanderContent,"translate_content", translateContent));}private boolean isDone(Map<String, NodeStatus> node2Status) {return node2Status.get(ExpanderNode.NODE_NAME) == NodeStatus.COMPLETED&& node2Status.get(TranslateNode.NODE_NAME) == NodeStatus.COMPLETED;}
}5.4 NodeStatus
节点状态枚举类,用于跟踪节点执行状态。
public enum NodeStatus {RUNNING("running", "运行中"),COMPLETED("completed", "已完成"),FAILED("failed", "失败");String code;String desc;NodeStatus(String running, String desc) {this.code = running;this.desc = desc;}
}6. 控制器实现
在 GraphStreamController.java 中,我们实现了RESTful API接口,用于触发Graph流程:
@RestController
@RequestMapping("/graph/parallel-stream")
public class GraphStreamController {private static final Logger logger = LoggerFactory.getLogger(GraphStreamController.class);private final CompiledGraph compiledGraph;public GraphStreamController(@Qualifier("parallelStreamGraph")StateGraph stateGraph) throws GraphStateException {this.compiledGraph = stateGraph.compile();}@GetMapping(value = "/expand-translate", produces = MediaType.TEXT_EVENT_STREAM_VALUE)public Flux<ServerSentEvent<String>> expand(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?", required = false) String query,@RequestParam(value = "expander_number", defaultValue = "3", required = false) Integer expanderNumber,@RequestParam(value = "translate_language", defaultValue = "english", required = false) String translateLanguage,@RequestParam(value = "thread_id", defaultValue = "yingzi", required = false) String threadId) throws GraphRunnerException {RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();Map<String, Object> objectMap = new HashMap<>();objectMap.put("query", query);objectMap.put("expander_number", expanderNumber);objectMap.put("translate_language", translateLanguage);GraphProcess graphProcess = new GraphProcess(this.compiledGraph);Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();AsyncGenerator<NodeOutput> resultFuture = compiledGraph.stream(objectMap, runnableConfig);graphProcess.processStream(resultFuture, sink);return sink.asFlux().doOnCancel(() -> logger.info("Client disconnected from stream")).doOnError(e -> logger.error("Error occurred during streaming", e));}
}7. 流处理实现
在 GraphProcess.java 中,我们实现了流式处理逻辑,用于处理Graph执行过程中的事件流:
public class GraphProcess {private static final Logger logger = LoggerFactory.getLogger(GraphProcess.class);private final ExecutorService executor = Executors.newSingleThreadExecutor();private CompiledGraph compiledGraph;public GraphProcess(CompiledGraph compiledGraph) {this.compiledGraph = compiledGraph;}public void processStream(AsyncGenerator<NodeOutput> generator, Sinks.Many<ServerSentEvent<String>> sink) {executor.submit(() -> {generator.forEachAsync(output -> {try {logger.info("output = {}", output);String nodeName = output.node();String content;if (output instanceof StreamingOutput streamingOutput) {content = JSON.toJSONString(Map.of(nodeName, streamingOutput.chunk()));} else {JSONObject nodeOutput = new JSONObject();nodeOutput.put("data", output.state().data());nodeOutput.put("node", nodeName);content = JSON.toJSONString(nodeOutput);}sink.tryEmitNext(ServerSentEvent.builder(content).build());} catch (Exception e) {throw new CompletionException(e);}}).thenAccept(v -> {// 正常完成sink.tryEmitComplete();}).exceptionally(e -> {sink.tryEmitError(e);return null;});});}
}8. 运行与测试
- 启动应用:运行你的Spring Boot主程序。
- 使用浏览器或API工具(如Postman, curl)进行测试。
测试:并行流式处理
访问以下URL,触发并行流式处理流程:
http://localhost:8080/graph/parallel-stream/expand-translate?query=你好,很高兴认识你,能简单介绍一下自己吗?&expander_number=3&translate_language=english&thread_id=test123预期响应(流式输出):
你将收到Server-Sent Events (SSE)格式的流式响应,包含以下内容:
- 扩展节点的流式输出结果(多个版本的查询变体)
- 翻译节点的流式输出结果(翻译后的文本)
- 合并节点的最终结果(包含扩展和翻译的完整结果)
9. 实现思路与扩展建议
实现思路
本案例的核心思想是"并行处理与流式输出"。我们通过以下方式实现了高效的处理流程:
- 并行执行:扩展节点和翻译节点可以同时执行,互不依赖,提高处理效率。
- 状态跟踪:通过NodeStatus跟踪每个节点的执行状态,确保所有节点完成后再进行合并。
- 流式输出:使用StreamingChatGenerator实现AI模型调用的实时流式输出,提升用户体验。
- 结果合并:合并节点等待所有并行节点完成后,对结果进行统一处理。
扩展建议
- 增加更多并行节点:可以根据业务需求,添加更多的并行处理节点,如摘要生成、情感分析等。
- 动态节点选择:根据输入内容或用户配置,动态选择需要执行的并行节点组合。
- 错误处理与重试:为并行节点添加完善的错误处理和重试机制,提高系统稳定性。
- 性能监控:添加各节点执行时间的监控,分析性能瓶颈,优化处理流程。
- 结果缓存:对于相同输入的重复请求,可以缓存处理结果,避免重复计算。
