使用 Spring AI Alibaba Graph 实现工作流
1 依赖
<dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-starter-dashscope</artifactId><version>1.0.0.2</version>
</dependency><dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-graph-core</artifactId><version>1.0.0.2</version>
</dependency><!--接入兼容OpenAI api的大模型--><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-openai</artifactId></dependency>
2 实战
以实现一个简单的模型调用工作流为例,阐述工作流实现步骤。
2.1 实现业务节点 (NodeAction)
通用模型调用节点。
@Slf4j
public class CommonModelCallNode implements NodeAction {private ChatClient chatClient;public CommonModelCallNode(ChatClient chatClient) {this.chatClient = chatClient;}@Overridepublic Map<String, Object> apply(OverAllState state) {// 调用AI模型,生成查询String result = null;try {result = chatClient.prompt().user(u -> u.text(state.value("query", ""))).call().content();} catch (Exception e) {log.error("CommonModelCallNode error", e);HashMap<String, Object> resultMap = new HashMap<>();resultMap.put("errorMsg", e.getMessage());return resultMap;}// 将结果放入Map,该Map会根据状态策略合并到OverAllState中HashMap<String, Object> resultMap = new HashMap<>();resultMap.put("commonModelCallResult", result);return resultMap;}
}
2.2 创建工作流 (StateGraph)
创建工作流时,需要定义状态策略。
@Configuration
public class GraphConfig {@Resourceprivate ChatClient chatClient4VolcesDoubaoV1;/*** 定义状态策略* ps:定义在全局状态 OverAllState 中,各个字段的更新策略。*/@Bean("overAllStateFactory4Demo1")public OverAllStateFactory overAllStateFactory4Demo1() {return () -> {OverAllState state = new OverAllState();// 入参必须在OverAllState中注册,否则无法传递给NodeActionstate.registerKeyAndStrategy("query", new ReplaceStrategy()); // 新值替换旧值state.registerKeyAndStrategy("commonModelCallResult", new ReplaceStrategy());state.registerKeyAndStrategy("errorMsg", new ReplaceStrategy());return state;};}/*** 创建工作流*/@Bean("stateGraph4Demo1")public StateGraph stateGraph4Demo1(@Qualifier("overAllStateFactory4Demo1") OverAllStateFactory factory) throws GraphStateException {return new StateGraph("stateGraph4Demo1", factory).addNode("commonModelCallNode", node_async(new CommonModelCallNode(chatClient4VolcesDoubaoV1))) // 添加业务节点.addEdge(StateGraph.START, "commonModelCallNode") // 从开始节点连接到业务节点.addEdge("commonModelCallNode", StateGraph.END); // 从业务节点连接到结束节点}}
2.3 编译工作流并执行
创建一个 http 接口,执行工作流。
@Slf4j
@RestController
@RequestMapping("/graph")
public class SimpleGraphController {@Resourceprivate StateGraph stateGraph4Demo1;@GetMapping("/demo1")@ResponseBodypublic BaseResponse<Map<String, Object>> demo1(@RequestParam(name = "query") String query,@RequestParam(name = "threadId", defaultValue = "test-thread") String threadId) {long startTime = System.currentTimeMillis();log.info("======>>> SimpleGraphController demo1 start. query: {}", query);if (StringUtils.isBlank(query)) {return BaseResponse.error("query 不能为空");}// 准备初始输入参数Map<String, Object> inputMap = new HashMap<>();inputMap.put("query", query);// 配置执行参数,如线程ID(用于会话追踪)RunnableConfig config = RunnableConfig.builder().threadId(threadId).build();// 执行Graph并获取最终状态CompiledGraph compiledGraph;try {compiledGraph = stateGraph4Demo1.compile();} catch (GraphStateException e) {log.error("Graph compile error: {}", e.getMessage());return BaseResponse.error("Graph compile error");}Optional<OverAllState> resultState = compiledGraph.invoke(inputMap, config);Map<String, Object> resultMap = resultState.map(OverAllState::data).orElse(new HashMap<>());log.info("======>>> SimpleGraphController demo1 end. cost:{}ms, query:{}", System.currentTimeMillis() - startTime, query);return BaseResponse.success(resultMap);}
}
2.5 打印工作流
可以打印工作流流程图。包括 PLANTUML 和 MERMAID 两种格式的语言。
public class GraphPrintTest extends BaseTest{@Resourceprivate StateGraph stateGraph4Demo1;@Testpublic void testGraphPrint() {System.out.println(stateGraph4Demo1.getGraph(GraphRepresentation.Type.PLANTUML, "Demo1工作流", true).content());System.out.println("=========================");System.out.println(stateGraph4Demo1.getGraph(GraphRepresentation.Type.MERMAID, "Demo1工作流", true).content());Assert.assertTrue(true);}}@RunWith(SpringRunner.class)
@SpringBootTest(classes = SpringAiDemoApplication.class)
public class BaseTest {
}
打印的流程图如下所示: