Spring AI 系列之三十三 - Spring AI Alibaba-Graph框架之人类反馈
之前做个几个大模型的应用,都是使用Python语言,后来有一个项目使用了Java,并使用了Spring AI框架。随着Spring AI不断地完善,最近它发布了1.0正式版,意味着它已经能很好的作为企业级生产环境的使用。对于Java开发者来说真是一个福音,其功能已经能满足基于大模型开发企业级应用。借着这次机会,给大家分享一下Spring AI框架。
注意:由于框架不同版本改造会有些使用的不同,因此本次系列中使用基本框架是 Spring AI-1.0.0,JDK版本使用的是19,Spring-AI-Alibaba-1.0.0.3-SNAPSHOT。
代码参考: https://github.com/forever1986/springai-study
目录
- 1 人类反馈示例
- 1.1 初始化
- 1.2 创建3个节点和边
- 1.3 构建图和演示接口
- 1.4 演示
- 2 底层原理
- 2.1 初始化流程
- 2.2 执行流程
- 2.3 NodeAction 和 EdgeAction
- 3 回顾设计思路
上一章讲解了关于Spring AI Alibaba-Graph框架的基本入门,这里并未展现Spring AI Alibaba-Graph框架的强大之处。这一章通过一个更为复杂的示例来说明Spring AI Alibaba-Graph的支持交互式工作流。
1 人类反馈示例
代码参考lesson26子模块下的graph-human-feedback子模块
示例说明:在实际业务场景中,经常会遇到人类介入的场景,人类的不同操作将影响工作流不同的走向。以下实现一个简单案例:包含三个节点,扩展节点、人类节点、翻译节点
- 扩展节点:AI 模型流式对问题进行扩展输出
- 人类节点:通过对用户的反馈,决定是直接结束,还是接着执行翻译节点。决定参数为feedback,它是true时,进行翻译,false时则直接结束
- 翻译节点:将问题翻译为其他英文
注意:为了演示Graph框架是一个独立的模块,这里使用智谱聊天大模型GLM-4-Flash-250414,而非阿里的千问模型,也就是说它并没有一定和dashscope 模块强依赖
说明:上图就是代码打印的PlantUML格式,将其新建一个PlantUML文件展示的效果
1.1 初始化
1)在lesson26子模块下,新建graph-human-feedback子模块,其pom引入如下:
<dependencies><!-- 引入智谱的model插件 --><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-starter-model-zhipuai</artifactId></dependency><dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-graph-core</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><!-- 需要引入gson插件 --><dependency><groupId>com.google.code.gson</groupId><artifactId>gson</artifactId><version>2.8.6</version></dependency>
</dependencies>
2)新建application.properties配置文件
# 聊天模型
spring.ai.zhipuai.api-key=你的智谱API KEY
spring.ai.zhipuai.chat.options.model=GLM-4-Flash-250414
spring.ai.zhipuai.chat.options.temperature=0.7
1.2 创建3个节点和边
1)ExpanderNode节点:将用户问题通过多角度生成多个问题
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import reactor.core.publisher.Flux;import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;/*** 自定义的ExpanderNode节点:将用户的问题,*/
public class ExpanderNode implements NodeAction {private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("""您是信息检索和搜索优化方面的专家。您的任务是生成给定查询的 {number} 种不同版本。每个变体都必须涵盖该主题的不同视角或方面,同时保持原始查询的核心意图。其目的是扩大搜索范围,并提高找到相关信息的可能性。请勿解释您的选择或添加任何其他文字。请将查询变体以换行的方式分隔展示。原始查询:{query}查询变体:""");private final ChatClient chatClient;private final Integer NUMBER = 3;public ExpanderNode(ChatClient.Builder chatClientBuilder) {this.chatClient = chatClientBuilder.build();}@Overridepublic Map<String, Object> apply(OverAllState state) throws Exception {String query = state.value("query", "");Integer expanderNumber = state.value("expandernumber", this.NUMBER);Flux<String> streamResult = this.chatClient.prompt().user((user) -> user.text(DEFAULTPROMPTTEMPLATE.getTemplate()).param("number", expanderNumber).param("query", query)).stream().content();String result = streamResult.reduce("", (acc, item) -> acc + item).block();List<String> queryVariants = Arrays.asList(result.split("\n"));HashMap<String, Object> resultMap = new HashMap<>();resultMap.put("expandercontent", queryVariants);return resultMap;}
}
2)HumanFeedbackNode节点:实现人类反馈,进行不同跳转
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.NodeAction;import java.util.HashMap;
import java.util.Map;public class HumanFeedbackNode implements NodeAction {@Overridepublic Map<String, Object> apply(OverAllState state) {System.out.println("humanfeedback node is running.");HashMap<String, Object> resultMap = new HashMap<>();String nextStep = StateGraph.END;// 获取OverAllState中humanFeedback参数的值Map<String, Object> feedBackData = state.humanFeedback().data();// 判断如果是true,则将humannextnode设置为TranslateNode节点的ID,如果是flase,则将humannextnode设置为END节点boolean feedback = (boolean) feedBackData.getOrDefault("feedback", true);if (feedback) {nextStep = "translate";}resultMap.put("humannextnode", nextStep);System.out.println("humanfeedback node -> "+ nextStep+" node");return resultMap;}
}
3)TranslateNode节点:将用户问题翻译为其它语言
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import reactor.core.publisher.Flux;import java.util.HashMap;
import java.util.Map;public class TranslateNode implements NodeAction {private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("""对于用户输入的查询,将其翻译成 {targetLanguage}。如果查询已经是 {targetLanguage} 的形式,则无需更改,直接返回。如果不知道查询的语言,则也无需更改。请勿添加解释或任何其他文字。原始查询:{query}原始查询:""");private final ChatClient chatClient;private final String TARGETLANGUAGE= "English"; // 默认英语public TranslateNode(ChatClient.Builder chatClientBuilder) {this.chatClient = chatClientBuilder.build();}@Overridepublic Map<String, Object> apply(OverAllState state) {String query = state.value("query", "");String targetLanguage = state.value("translatelanguage", TARGETLANGUAGE);Flux<String> streamResult = this.chatClient.prompt().user((user) -> user.text(DEFAULTPROMPTTEMPLATE.getTemplate()).param("targetLanguage", targetLanguage).param("query", query)).stream().content();String result = streamResult.reduce("", (acc, item) -> acc + item).block();HashMap<String, Object> resultMap = new HashMap<>();resultMap.put("translatecontent", result);return resultMap;}
}
4)定义HumanFeedbackEdge边:
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.EdgeAction;/*** 条件边*/
public class HumanFeedbackEdge implements EdgeAction {@Overridepublic String apply(OverAllState state) throws Exception {// 获取OverAllState的key=humannextnode的值,这个值在该边的上一个节点HumanFeedbackNode中设置return (String) state.value("humannextnode", StateGraph.END);}}
1.3 构建图和演示接口
1)新建GraphHumanConfiguration 配置类,构建图
import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.demo.lesson26.human.feedback.edge.HumanFeedbackEdge;
import com.demo.lesson26.human.feedback.node.ExpanderNode;
import com.demo.lesson26.human.feedback.node.HumanFeedbackNode;
import com.demo.lesson26.human.feedback.node.TranslateNode;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;import java.util.HashMap;
import java.util.Map;@Configuration
public class GraphHumanConfiguration {@Beanpublic StateGraph humanGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException {// 全局变量的替换策略(ReplaceStrategy为替换,AppendStrategy为追加)KeyStrategyFactory keyStrategyFactory = () -> {HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();// 用户输入keyStrategyHashMap.put("query", new ReplaceStrategy());keyStrategyHashMap.put("threadid", new ReplaceStrategy());keyStrategyHashMap.put("expandernumber", new ReplaceStrategy());keyStrategyHashMap.put("expandercontent", new ReplaceStrategy());// 人类反馈keyStrategyHashMap.put("feedback", new ReplaceStrategy());keyStrategyHashMap.put("humannextnode", new ReplaceStrategy());// 是否需要翻译keyStrategyHashMap.put("translatelanguage", new ReplaceStrategy());keyStrategyHashMap.put("translatecontent", new ReplaceStrategy());return keyStrategyHashMap;};// 构造图StateGraph stateGraph = new StateGraph(keyStrategyFactory)// 节点ExpanderNode.addNode("expander", AsyncNodeAction.node_async(new ExpanderNode(chatClientBuilder)))// 节点TranslateNode.addNode("translate", AsyncNodeAction.node_async(new TranslateNode(chatClientBuilder)))// 节点HumanFeedbackNode.addNode("humanfeedback", AsyncNodeAction.node_async(new HumanFeedbackNode()))// 边:START -> ExpanderNode.addEdge(StateGraph.START, "expander")// 边:ExpanderNode -> HumanFeedbackNode.addEdge("expander", "humanfeedback")// 条件边:参数humanfeedback为true,则HumanFeedbackNode -> TranslateNode; 否则HumanFeedbackNode -> END.addConditionalEdges("humanfeedback", AsyncEdgeAction.edge_async((new HumanFeedbackEdge())), Map.of("translate", "translate", StateGraph.END, StateGraph.END))// 边:TranslateNode -> END.addEdge("translate", StateGraph.END);// 将图打印出来,可以使用 PlantUML 插件查看GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML,"human flow");System.out.println("\n=== expander UML Flow ===");System.out.println(representation.content());System.out.println("==================================\n");return stateGraph;}
}
2)新建GraphHumanController访问示例:这里通过2个接口模拟人类反馈,第二个接口会重新加载之前的会话状态
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;import java.util.HashMap;
import java.util.Map;
import java.util.Optional;@RestController
public class GraphHumanController {private final CompiledGraph compiledGraph;@Autowiredpublic GraphHumanController(@Qualifier("humanGraph") StateGraph stateGraph) throws GraphStateException {SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();this.compiledGraph = stateGraph.compile(CompileConfig.builder().saverConfig(saverConfig).interruptBefore("humanfeedback") // 关键点:在humanfeedback节点前打断流程.build());}@GetMapping("/graph/human/expand")public Map<String, Object> expand(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?", required = false) String query,@RequestParam(value = "expandernumber", defaultValue = "3", required = false) Integer expanderNumber,@RequestParam(value = "threadid", defaultValue = "1", required = false) String threadId) throws GraphRunnerException {RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();Map<String, Object> objectMap = new HashMap<>();objectMap.put("query", query);objectMap.put("expandernumber", expanderNumber);Optional<OverAllState> invoke = compiledGraph.invoke(objectMap,runnableConfig);return invoke.map(OverAllState::data).orElse(new HashMap<>());}@GetMapping("/graph/human/resume")public Map<String, Object> resume(@RequestParam(value = "threadid", defaultValue = "1", required = false) String threadId,@RequestParam(value = "feedback", defaultValue = "true", required = false) boolean feedBack) throws GraphRunnerException {RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();// 重新加载threadid=1的stateSnapshotStateSnapshot stateSnapshot = this.compiledGraph.getState(runnableConfig);OverAllState state = stateSnapshot.state();state.withResume();// 设置resume标志,表示从snapshot开始继续// 添加新的参数feedbackMap<String, Object> objectMap = new HashMap<>();objectMap.put("feedback", feedBack);state.withHumanFeedback(new OverAllState.HumanFeedback(objectMap, ""));// 调用执行,入参是从snapshot中重新加载的OverAllState,并且添加了feedback参数Optional<OverAllState> invoke = compiledGraph.invoke(state,runnableConfig);return invoke.map(OverAllState::data).orElse(new HashMap<>());}}
3)新建Lesson26HumanFeedbackApplication 启动类:
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;@SpringBootApplication
public class Lesson26HumanFeedbackApplication {public static void main(String[] args) {SpringApplication.run(Lesson26HumanFeedbackApplication.class, args);}}
1.4 演示
1)访问地址:http://localhost:8080/graph/human/expand
2)访问地址:http://localhost:8080/graph/human/resume?feedback=true
说明:传入feedback=true,true表示需要翻译。当你改为false,则不会输出translatecontent内容。
但是如果你先执行了true,那么再次访问/resume,则每次返回都是一样的,这是因为/graph/human/resume的代码里每次都是加载snapshot=1的工作流数据,它已经执行结束了。所以你要重新演示false的场景,需要重启服务器或者将threadid改为其它的执行一遍/expand。
2 底层原理
通过上面的示例,来看看Graph的底层原理
2.1 初始化流程
1)如何构建Graph,当调用:stateGraph.compile()方法,可以进去源码看看,源码如下,可以看到最终是新建一个CompiledGraph类
2)进入CompiledGraph类的构造方法,可以看到其将配置转换,并保存节点、边还有中断配置数据,这时候一个CompiledGraph就已经构建完成
2.2 执行流程
1)先从compiledGraph.invoke()方法进入,可以看到通过stream方法构建执行流程
2)看一下stream方法,其将节点都变成一个个AsyncNodeGenerator
下图这个是有snapshot的初始化的情况,也是转换为AsyncNodeGenerator
3)再来看看AsyncNodeGenerator,其最终是next()方法进行执行,其中有两处跟示例相关,一个是判断是否有终止配置,一个是通过evaluateAction()方法执行节点
3)进入evaluateAction方法看看,可以看到其调用的是apply()方法,并更新state并返回数据。同时通过nextNodeId()方法获取下一个执行节点id。
4)进入nextNodeId()方法,可以看到其流程是获取到边,并执行EdgeAction的apply()方法获取到下一个节点id
2.3 NodeAction 和 EdgeAction
1)从上面可以看到每个Node最终执行的是NodeAction的apply方法,该方法返回的数据会加入到OverAllState 中
@FunctionalInterface
public interface NodeAction {Map<String, Object> apply(OverAllState state) throws Exception;}
2)从上面可以看到,EdgeAction决定的是下一个执行节点,因此apply方法返回的是下一个节点id
@FunctionalInterface
public interface EdgeAction {/*** Applies this action to the given agent state.* @param state the agent state* @return a result of the action* @throws Exception if an error occurs during the action*/String apply(OverAllState state) throws Exception;
}
3 回顾设计思路
从上面知道底层原理,设计示例的思路就比较明朗。首先有ExpanderNode和TranslateNode节点是必需的,用于处理真正业务。而需要增加一个HumanFeedbackNode节点用于通过参数设置下一个决定边如何走向。
1)第一个关键点:设置HumanFeedbackNode之前的需要中断
2)第二个关键点:反馈收到后,重启原先的流程snapshot,将反馈放入到state中
结语:本章通过一个更为复杂的示例,演示了Spring AI Alibaba-Graph框架的强大之处,最后分析了Graph的底层原理,让用户可以了解Graph的工作方式,其中着重了解了中断机制以及NodeAction 和 EdgeAction 的作用。下一章还将继续探索Graph的其它功能。
Spring AI系列上一章:《Spring AI 系列之三十二 - Spring AI Alibaba-Graph框架之入门》
Spring AI系列下一章:《Spring AI 系列之三十四 - Spring AI Alibaba-Graph框架之并行执行》