聊聊Spring AI Alibaba的PlantUMLGenerator
序
本文主要研究一下Spring AI Alibaba的PlantUMLGenerator
DiagramGenerator
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/DiagramGenerator.java
public abstract class DiagramGenerator {public enum CallStyle {DEFAULT, START, END, CONDITIONAL, PARALLEL}public record Context(StringBuilder sb, String title, boolean printConditionalEdge, boolean isSubGraph) {static Builder builder() {return new Builder();}static public class Builder {String title;boolean printConditionalEdge;boolean IsSubGraph;private Builder() {}public Builder title(String title) {this.title = title;return this;}public Builder printConditionalEdge(boolean value) {this.printConditionalEdge = value;return this;}public Builder isSubGraph(boolean value) {this.IsSubGraph = value;return this;}public Context build() {return new Context(new StringBuilder(), title, printConditionalEdge, IsSubGraph);}}/*** Converts a given title string to snake_case format by replacing all* non-alphanumeric characters with underscores.* @return the snake_case formatted string*/public String titleToSnakeCase() {return title.replaceAll("[^a-zA-Z0-9]", "_");}/*** Returns a string representation of this object by returning the string built in* {@link #sb}.* @return a string representation of this object.*/@Overridepublic String toString() {return sb.toString();}}/*** Appends a header to the output based on the provided context.* @param ctx The {@link Context} containing the information needed for appending the* header.*/protected abstract void appendHeader(Context ctx);/*** Appends a footer to the content.* @param ctx Context object containing the necessary information.*/protected abstract void appendFooter(Context ctx);/*** This method is an abstract method that must be implemented by subclasses. It is* used to initiate a communication call between two parties identified by their phone* numbers.* @param ctx The current context in which the call is being made.* @param from The phone number of the caller.* @param to The phone number of the recipient.*/protected abstract void call(Context ctx, String from, String to, CallStyle style);/*** Abstract method that must be implemented by subclasses to handle the logic of* making a call.* @param ctx The context in which the call is being made.* @param from The phone number of the caller.* @param to The phone number of the recipient.* @param description A brief description of the call.*/protected abstract void call(Context ctx, String from, String to, String description, CallStyle style);/*** Declares a conditional element in the configuration or template. This method is* used to mark the start of a conditional section based on the provided {@code name}.* It takes a {@code Context} object that may contain additional parameters necessary* for the declaration, and a {@code name} which identifies the type or key associated* with the conditional section.* @param ctx The context containing contextual information needed for the* declaration.* @param name The name of the conditional section to be declared.*/protected abstract void declareConditionalStart(Context ctx, String name);/*** Declares a node in the specified context with the given name.* @param ctx the context in which to declare the node {@code @literal (not null)}* @param name the name of the node to be declared* {@code @literal (not null, not empty)}*/protected abstract void declareNode(Context ctx, String name);/*** Declares a conditional edge in the context with a specified ordinal.* @param ctx the context* @param ordinal the ordinal value*/protected abstract void declareConditionalEdge(Context ctx, int ordinal);/*** Comment a line in the given context.* @param ctx The context in which the line is to be commented.* @param yesOrNo Whether the line should be uncommented ({@literal true}) or* commented ({@literal false}).*/protected abstract void commentLine(Context ctx, boolean yesOrNo);/*** Generate a textual representation of the given graph.* @param nodes the state graph nodes used to generate the context, which must not be* null* @param edges the state graph edges used to generate the context, which must not be* null* @param title The title of the graph.* @param printConditionalEdge Whether to print the conditional edge condition.* @return A string representation of the graph.*/public final String generate(StateGraph.Nodes nodes, StateGraph.Edges edges, String title,boolean printConditionalEdge) {return generate(nodes, edges,Context.builder().title(title).isSubGraph(false).printConditionalEdge(printConditionalEdge).build()).toString();}/*** Generates a context based on the given state graph.* @param nodes the state graph nodes used to generate the context, which must not be* null* @param edges the state graph edges used to generate the context, which must not be* null* @param ctx the initial context, which must not be null* @return the generated context, which will not be null*/protected final Context generate(StateGraph.Nodes nodes, StateGraph.Edges edges, Context ctx) {appendHeader(ctx);for (var n : nodes.elements) {if (n instanceof SubGraphNode subGraphNode) {@SuppressWarnings("unchecked")var subGraph = (StateGraph) subGraphNode.subGraph();Context subgraphCtx = generate(subGraph.nodes, subGraph.edges,Context.builder().title(n.id()).printConditionalEdge(ctx.printConditionalEdge).isSubGraph(true).build());ctx.sb().append(subgraphCtx);}else {declareNode(ctx, n.id());}}final int[] conditionalEdgeCount = { 0 };edges.elements.stream().filter(e -> !Objects.equals(e.sourceId(), START)).filter(e -> !e.isParallel()).forEach(e -> {if (e.target().value() != null) {conditionalEdgeCount[0] += 1;commentLine(ctx, !ctx.printConditionalEdge());declareConditionalEdge(ctx, conditionalEdgeCount[0]);}});var edgeStart = edges.elements.stream().filter(e -> Objects.equals(e.sourceId(), START)).findFirst().orElseThrow();if (edgeStart.isParallel()) {edgeStart.targets().forEach(target -> {call(ctx, START, target.id(), CallStyle.START);});}else if (edgeStart.target().id() != null) {call(ctx, START, edgeStart.target().id(), CallStyle.START);}else if (edgeStart.target().value() != null) {String conditionName = "startcondition";commentLine(ctx, !ctx.printConditionalEdge());declareConditionalStart(ctx, conditionName);edgeCondition(ctx, edgeStart.target().value(), START, conditionName);}conditionalEdgeCount[0] = 0; // resetedges.elements.stream().filter(e -> !Objects.equals(e.sourceId(), START)).forEach(v -> {if (v.isParallel()) {v.targets().forEach(target -> {call(ctx, v.sourceId(), target.id(), CallStyle.PARALLEL);});}else if (v.target().id() != null) {call(ctx, v.sourceId(), v.target().id(), CallStyle.DEFAULT);}else if (v.target().value() != null) {conditionalEdgeCount[0] += 1;String conditionName = format("condition%d", conditionalEdgeCount[0]);edgeCondition(ctx, v.targets().get(0).value(), v.sourceId(), conditionName);}});appendFooter(ctx);return ctx;}/*** Evaluates an edge condition based on the given context and condition.* @param ctx the current context used for evaluation* @param condition the condition to be evaluated* @param k a string identifier for the condition* @param conditionName the name of the condition being processed*/private void edgeCondition(Context ctx, EdgeCondition condition, String k, String conditionName) {commentLine(ctx, !ctx.printConditionalEdge());call(ctx, k, conditionName, CallStyle.CONDITIONAL);condition.mappings().forEach((cond, to) -> {commentLine(ctx, !ctx.printConditionalEdge());call(ctx, conditionName, to, cond, CallStyle.CONDITIONAL);commentLine(ctx, ctx.printConditionalEdge());call(ctx, k, to, cond, CallStyle.CONDITIONAL);});}}
DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。
PlantUMLGenerator
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/diagram/PlantUMLGenerator.java
public class PlantUMLGenerator extends DiagramGenerator {@Overrideprotected void appendHeader(Context ctx) {if (ctx.isSubGraph()) {ctx.sb().append(format("rectangle %s [ {{\ntitle \"%s\"\n", ctx.title(), ctx.title())).append(format("circle \" \" as %s\n", START)).append(format("circle exit as %s\n", END));}else {ctx.sb().append(format("@startuml %s\n", ctx.titleToSnakeCase())).append("skinparam usecaseFontSize 14\n").append("skinparam usecaseStereotypeFontSize 12\n").append("skinparam hexagonFontSize 14\n").append("skinparam hexagonStereotypeFontSize 12\n").append(format("title \"%s\"\n", ctx.title())).append("footer\n\n").append("powered by spring-ai-alibaba\n").append("end footer\n").append(format("circle start<<input>> as %s\n", START)).append(format("circle stop as %s\n", END));}}@Overrideprotected void appendFooter(Context ctx) {if (ctx.isSubGraph()) {ctx.sb().append("\n}} ]\n");}else {ctx.sb().append("@enduml\n");}}@Overrideprotected void call(Context ctx, String from, String to, CallStyle style) {ctx.sb().append(switch (style) {case CONDITIONAL -> format("\"%s\" .down.> \"%s\"\n", from, to);default -> format("\"%s\" -down-> \"%s\"\n", from, to);});}@Overrideprotected void call(Context ctx, String from, String to, String description, CallStyle style) {ctx.sb().append(switch (style) {case CONDITIONAL -> format("\"%s\" .down.> \"%s\": \"%s\"\n", from, to, description);default -> format("\"%s\" -down-> \"%s\": \"%s\"\n", from, to, description);});}@Overrideprotected void declareConditionalStart(Context ctx, String name) {ctx.sb().append(format("hexagon \"check state\" as %s<<Condition>>\n", name));}@Overrideprotected void declareNode(Context ctx, String name) {ctx.sb().append(format("usecase \"%s\"<<Node>>\n", name));}@Overrideprotected void declareConditionalEdge(Context ctx, int ordinal) {ctx.sb().append(format("hexagon \"check state\" as condition%d<<Condition>>\n", ordinal));}@Overrideprotected void commentLine(Context ctx, boolean yesOrNo) {if (yesOrNo)ctx.sb().append("'");}}
PlantUMLGenerator实现了DiagramGenerator的抽象方法
StateGraph
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/StateGraph.java
/*** Represents a state graph with nodes and edges.**/
public class StateGraph {public static String END = "__END__";public static String START = "__START__";final Nodes nodes = new Nodes();final Edges edges = new Edges();private OverAllState overAllState;private String name;public OverAllState getOverAllState() {return overAllState;}public StateGraph setOverAllState(OverAllState overAllState) {this.overAllState = overAllState;return this;}private final PlainTextStateSerializer stateSerializer;//....../*** Instantiates a new State graph.* @param overAllState the over all state* @param plainTextStateSerializer the plain text state serializer*/public StateGraph(OverAllState overAllState, PlainTextStateSerializer plainTextStateSerializer) {this.overAllState = overAllState;this.stateSerializer = plainTextStateSerializer;}public StateGraph(String name, OverAllState overAllState) {this.name = name;this.overAllState = overAllState;this.stateSerializer = new GsonSerializer();}/*** Instantiates a new State graph.* @param overAllState the over all state*/public StateGraph(OverAllState overAllState) {this.overAllState = overAllState;this.stateSerializer = new GsonSerializer();}public StateGraph(String name, AgentStateFactory<OverAllState> factory) {this.name = name;this.overAllState = factory.apply(Map.of());this.stateSerializer = new GsonSerializer2(factory);}public StateGraph(AgentStateFactory<OverAllState> factory) {this.overAllState = factory.apply(Map.of());this.stateSerializer = new GsonSerializer2(factory);}/*** Instantiates a new State graph.*/public StateGraph() {this.stateSerializer = new GsonSerializer();}public String getName() {return name;}/*** Key strategies map.* @return the map*/public Map<String, KeyStrategy> keyStrategies() {return overAllState.keyStrategies();}/*** Gets state serializer.* @return the state serializer*/public StateSerializer getStateSerializer() {return stateSerializer;}/*** Gets state factory.* @return the state factory*/public final AgentStateFactory<OverAllState> getStateFactory() {return stateSerializer.stateFactory();}/*** /** Adds a node to the graph.* @param id the identifier of the node* @param action the action to be performed by the node* @throws GraphStateException if the node identifier is invalid or the node already* exists*/public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException {return addNode(id, AsyncNodeActionWithConfig.of(action));}/*** @param id the identifier of the node* @param actionWithConfig the action to be performed by the node* @return this* @throws GraphStateException if the node identifier is invalid or the node already* exists*/public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException {Node node = new Node(id, (config) -> actionWithConfig);return addNode(id, node);}/*** @param id the identifier of the node* @param node the node to be added* @return this* @throws GraphStateException if the node identifier is invalid or the node already* exists*/public StateGraph addNode(String id, Node node) throws GraphStateException {if (Objects.equals(node.id(), END)) {throw Errors.invalidNodeIdentifier.exception(END);}if (!Objects.equals(node.id(), id)) {throw Errors.invalidNodeIdentifier.exception(node.id(), id);}if (nodes.elements.contains(node)) {throw Errors.duplicateNodeError.exception(id);}nodes.elements.add(node);return this;}/*** Adds a subgraph to the state graph by creating a node with the specified* identifier. This implies that Subgraph share the same state with parent graph* @param id the identifier of the node representing the subgraph* @param subGraph the compiled subgraph to be added* @return this state graph instance* @throws GraphStateException if the node identifier is invalid or the node already* exists*/public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException {if (Objects.equals(id, END)) {throw Errors.invalidNodeIdentifier.exception(END);}var node = new SubCompiledGraphNode(id, subGraph);if (nodes.elements.contains(node)) {throw Errors.duplicateNodeError.exception(id);}nodes.elements.add(node);return this;}/*** Adds a subgraph to the state graph by creating a node with the specified* identifier. This implies that Subgraph share the same state with parent graph* @param id the identifier of the node representing the subgraph* @param subGraph the subgraph to be added. it will be compiled on compilation of the* parent* @return this state graph instance* @throws GraphStateException if the node identifier is invalid or the node already* exists*/public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException {if (Objects.equals(id, END)) {throw Errors.invalidNodeIdentifier.exception(END);}subGraph.validateGraph();OverAllState subGraphOverAllState = subGraph.getOverAllState();OverAllState superOverAllState = getOverAllState();if (subGraphOverAllState != null) {Map<String, KeyStrategy> strategies = subGraphOverAllState.keyStrategies();for (Map.Entry<String, KeyStrategy> strategyEntry : strategies.entrySet()) {if (!superOverAllState.containStrategy(strategyEntry.getKey())) {superOverAllState.registerKeyAndStrategy(strategyEntry.getKey(), strategyEntry.getValue());}}}subGraph.setOverAllState(getOverAllState());var node = new SubStateGraphNode(id, subGraph);if (nodes.elements.contains(node)) {throw Errors.duplicateNodeError.exception(id);}nodes.elements.add(node);return this;}/*** Adds an edge to the graph.* @param sourceId the identifier of the source node* @param targetId the identifier of the target node* @throws GraphStateException if the edge identifier is invalid or the edge already* exists*/public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException {if (Objects.equals(sourceId, END)) {throw Errors.invalidEdgeIdentifier.exception(END);}// if (Objects.equals(sourceId, START)) {// this.entryPoint = new EdgeValue<>(targetId);// return this;// }var newEdge = new Edge(sourceId, new EdgeValue(targetId));int index = edges.elements.indexOf(newEdge);if (index >= 0) {var newTargets = new ArrayList<>(edges.elements.get(index).targets());newTargets.add(newEdge.target());edges.elements.set(index, new Edge(sourceId, newTargets));}else {edges.elements.add(newEdge);}return this;}/*** Adds conditional edges to the graph.* @param sourceId the identifier of the source node* @param condition the condition to determine the target node* @param mappings the mappings of conditions to target nodes* @throws GraphStateException if the edge identifier is invalid, the mappings are* empty, or the edge already exists*/public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings)throws GraphStateException {if (Objects.equals(sourceId, END)) {throw Errors.invalidEdgeIdentifier.exception(END);}if (mappings == null || mappings.isEmpty()) {throw Errors.edgeMappingIsEmpty.exception(sourceId);}var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings)));if (edges.elements.contains(newEdge)) {throw Errors.duplicateConditionalEdgeError.exception(sourceId);}else {edges.elements.add(newEdge);}return this;}void validateGraph() throws GraphStateException {var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception);edgeStart.validate(nodes);for (Edge edge : edges.elements) {edge.validate(nodes);}}/*** Compiles the state graph into a compiled graph.* @param config the compile configuration* @return a compiled graph* @throws GraphStateException if there are errors related to the graph state*/public CompiledGraph compile(CompileConfig config) throws GraphStateException {Objects.requireNonNull(config, "config cannot be null");validateGraph();return new CompiledGraph(this, config);}/*** Compiles the state graph into a compiled graph.* @return a compiled graph* @throws GraphStateException if there are errors related to the graph state*/public CompiledGraph compile() throws GraphStateException {SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();return compile(CompileConfig.builder().plainTextStateSerializer(new JacksonSerializer()).saverConfig(saverConfig).build());}/*** Generates a drawable graph representation of the state graph.* @param type the type of graph representation to generate* @param title the title of the graph* @param printConditionalEdges whether to print conditional edges* @return a diagram code of the state graph*/public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {String content = type.generator.generate(nodes, edges, title, printConditionalEdges);return new GraphRepresentation(type, content);}/*** Generates a drawable graph representation of the state graph.* @param type the type of graph representation to generate* @param title the title of the graph* @return a diagram code of the state graph*/public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {String content = type.generator.generate(nodes, edges, title, true);return new GraphRepresentation(type, content);}public GraphRepresentation getGraph(GraphRepresentation.Type type) {String content = type.generator.generate(nodes, edges, name, true);return new GraphRepresentation(type, content);}//......
}
StateGraph提供了addNode、addEdge、addConditionalEdges等方法,其中getGraph方法根据指定GraphRepresentation.Type的DiagramGenerator来生成状态图
示例
@Testpublic void testGraph() throws GraphStateException {OverAllState overAllState = getOverAllState();StateGraph workflow = new StateGraph(overAllState).addNode("agent_1", node_async(state -> {System.out.println("agent_1");return Map.of("messages", "message1");})).addNode("agent_2", node_async(state -> {System.out.println("agent_2");return Map.of("messages", new String[] { "message2" });})).addNode("agent_3", node_async(state -> {System.out.println("agent_3");List<String> messages = Optional.ofNullable(state.value("messages").get()).filter(List.class::isInstance).map(List.class::cast).orElse(new ArrayList<>());int steps = messages.size() + 1;return Map.of("messages", "message3", "steps", steps);})).addEdge("agent_1", "agent_2").addEdge("agent_2", "agent_3").addEdge(StateGraph.START, "agent_1").addEdge("agent_3", StateGraph.END);GraphRepresentation representation = workflow.getGraph(GraphRepresentation.Type.PLANTUML, "demo");System.out.println(representation.content());}
输出如下:
@startuml demo
skinparam usecaseFontSize 14
skinparam usecaseStereotypeFontSize 12
skinparam hexagonFontSize 14
skinparam hexagonStereotypeFontSize 12
title "demo"
footerpowered by spring-ai-alibaba
end footer
circle start<<input>> as __START__
circle stop as __END__
usecase "agent_1"<<Node>>
usecase "agent_2"<<Node>>
usecase "agent_3"<<Node>>
"__START__" -down-> "agent_1"
"agent_1" -down-> "agent_2"
"agent_2" -down-> "agent_3"
"agent_3" -down-> "__END__"
@enduml
小结
DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。PlantUMLGenerator继承了DiagramGenerator,根据plantUML语法实现了抽象方法。
doc
- spring-ai-alibaba