Language Agent Tree Search (1)
代码
import math
from collections import deque
from typing import Optional
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from pydantic import BaseModel, Field
class Reflection(BaseModel):
reflections: str = Field(
description="The critique and reflections on the sufficiency, superfluency,"
" and general quality of the response"
)
score: int = Field(
description="Score from 0-10 on the quality of the candidate response.",
gte=0,
lte=10,
)
found_solution: bool = Field(
description="Whether the response has fully solved the question or task."
)
def as_message(self):
return HumanMessage(
content=f"Reasoning: {self.reflections}\nScore: {self.score}"
)
@property
def normalized_score(self) -> float:
return self.score / 10.0
class Node:
def __init__(
self,
messages: list[BaseMessage],
reflection: Reflection,
parent: Optional["Node"] = None,
):
self.messages = messages
self.parent = parent
self.children = []
self.value = 0
self.visits = 0
self.reflection = reflection
self.depth = parent.depth + 1 if parent is not None else 1
self._is_solved = reflection.found_solution if reflection else False
if self._is_solved:
self._mark_tree_as_solved()
self.backpropagate(reflection.normalized_score)
def __repr__(self) -> str:
return (
f"<Node value={self.value}, visits={self.visits},"
f" solution={self.messages} reflection={self.reflection}/>"
)
@property
def is_solved(self):
"""If any solutions exist, we can end the search."""
return self._is_solved
@property
def is_terminal(self):
return not self.children
@property
def best_child_score(self):
"""Return the child with the highest value."""
if not self.children:
return None
return max(self.children, key=lambda child: int(child.is_solved) * child.value)
@property
def height(self) -> int:
"""Check for how far we've rolled out the tree."""
if self.children:
return 1 + max([child.height for child in self.children])
return 1
def upper_confidence_bound(self, exploration_weight=1.0):
"""Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
if self.parent is None:
raise ValueError("Cannot obtain UCT from root node")
if self.visits == 0:
return self.value
# Encourages exploitation of high-value trajectories
average_reward = self.value / self.visits
# Encourages exploration of less-visited trajectories
exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
return average_reward + exploration_weight * exploration_term
def backpropagate(self, reward: float):
"""Update the score of this node and its parents."""
node = self
while node:
node.visits += 1
node.value = (node.value * (node.visits - 1) + reward) / node.visits
node = node.parent
def get_messages(self, include_reflections: bool = True):
if include_reflections:
return self.messages + [self.reflection.as_message()]
return self.messages
def get_trajectory(self, include_reflections: bool = True) -> list[BaseMessage]:
"""Get messages representing this search branch."""
messages = []
node = self
while node:
messages.extend(
node.get_messages(include_reflections=include_reflections)[::-1]
)
node = node.parent
# Reverse the final back-tracked trajectory to return in the correct order
return messages[::-1] # root solution, reflection, child 1, ...
def _get_all_children(self):
all_nodes = []
nodes = deque()
nodes.append(self)
while nodes:
node = nodes.popleft()
all_nodes.extend(node.children)
for n in node.children:
nodes.append(n)
return all_nodes
def get_best_solution(self):
"""Return the best solution from within the current sub-tree."""
all_nodes = [self] + self._get_all_children()
best_node = max(
all_nodes,
# We filter out all non-terminal, non-solution trajectories
key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
)
return best_node
def _mark_tree_as_solved(self):
parent = self.parent
while parent:
parent._is_solved = True
parent = parent.parent
代码解释
这段代码实现了一种搜索树(用于强化学习或AI决策树)来评估对话响应的质量。核心思想是 模拟蒙特卡洛树搜索(MCTS) 的方式,通过 探索和回溯传播(backpropagation) 来优化搜索路径。
1. Reflection 类
用于存储对某个回答的评价,包括:
- reflections:对回答的评论(是否充分、冗余等)。
- score(0-10):评分,数值越高表示质量越好。
- found_solution(布尔值):该回答是否完全解决了问题。
as_message()
:转换为HumanMessage
(LangChain 消息类型)。normalized_score()
:将评分归一化到 [0,1]。
2. Node 类
用于构建搜索树节点,每个节点包含:
messages
:当前对话消息列表(BaseMessage 类型)。reflection
:当前回答的评价(Reflection 对象)。parent
:父节点(如果存在)。children
:子节点列表。value
:当前节点的分数,初始化时等于normalized_score
。visits
:当前节点被访问的次数。depth
:当前节点的深度(从根节点算起)。_is_solved
:表示当前节点是否是完全解决问题的答案。
主要方法
-
backpropagate(reward: float)
- 递归回溯更新当前节点及所有父节点的分值,类似 MCTS 反向传播:
- 公式:
-
upper_confidence_bound(exploration_weight=1.0)
-
UCT 公式(用于探索 vs 开发):
-
第一项(平均奖励):倾向于选择评分高的路径(开发)。
-
第二项(探索因子):鼓励探索访问次数少的路径(探索)。
-
-
get_best_solution()
- 返回当前子树中评分最高的终端节点(已解决问题的最佳解)。
-
get_trajectory(include_reflections=True)
- 返回该搜索路径上的所有消息(对话+评分)。
-
_mark_tree_as_solved()
- 向上递归标记整棵树为“已解决”,优化搜索。
代码应用场景
-
对话评估
- 该代码可用于自动评估 AI 回应质量,比如 LangChain 的 AI 代理,判断哪个回答更好。
-
强化学习中的搜索树
- 类似 MCTS,在决策过程中利用探索(exploration) 和 开发(exploitation) 策略寻找最优解。
-
优化搜索路径
- 通过
backpropagate()
和upper_confidence_bound()
,可以优化AI的回答策略,找到更合适的答案。
- 通过
总结
这段代码本质上是一个基于蒙特卡洛树搜索(MCTS)+ 评分反馈的搜索树,可以用于强化学习、智能对话、任务规划等领域。它利用反向传播更新节点价值,结合UCT 策略来平衡探索和利用,最终找到最优解答。
示例
示例1
import random
from langchain_core.messages import HumanMessage
# 创建初始根节点的反思对象
root_reflection = Reflection(reflections="Starting point", score=5, found_solution=False)
# 创建根节点(无父节点)
root_node = Node(messages=[HumanMessage(content="What is the capital of France?")], reflection=root_reflection)
# 模拟不同回答的候选节点
candidate_responses = [
("Paris", Reflection(reflections="Correct and concise", score=10, found_solution=True)),
("France is a country in Europe.", Reflection(reflections="Too general, lacks specificity", score=4, found_solution=False)),
("I don't know.", Reflection(reflections="Unhelpful response", score=1, found_solution=False)),
("Paris, the capital of France, is known for the Eiffel Tower.", Reflection(reflections="Correct, but verbose", score=8, found_solution=True)),
]
# 生成子节点并添加到搜索树
for content, reflection in candidate_responses:
child_node = Node(messages=[HumanMessage(content=content)], reflection=reflection, parent=root_node)
root_node.children.append(child_node)
# 选择最佳答案(基于 MCTS 评估)
best_solution = root_node.get_best_solution()
# 输出最佳解
print("Best Solution Found:")
print(f"Message: {best_solution.messages[0].content}")
print(f"Score: {best_solution.reflection.score}")
print(f"Reflections: {best_solution.reflection.reflections}")
Best Solution Found:
Message: Paris
Score: 10
Reflections: Correct and concise
示例2
import random
# 创建根节点的消息和评估
root_messages = [HumanMessage(content="How do I fine-tune a YOLO model?")]
root_reflection = Reflection(reflections="The answer provides basic steps.", score=6, found_solution=False)
# 创建根节点
root = Node(messages=root_messages, reflection=root_reflection)
# 创建子节点
child_messages_1 = root.get_messages() + [AIMessage(content="First, you need a labeled dataset.")]
child_reflection_1 = Reflection(reflections="Good start but lacks details.", score=7, found_solution=False)
child_1 = Node(messages=child_messages_1, reflection=child_reflection_1, parent=root)
root.children.append(child_1)
child_messages_2 = root.get_messages() + [AIMessage(content="Use transfer learning to fine-tune.")]
child_reflection_2 = Reflection(reflections="More helpful but still incomplete.", score=8, found_solution=False)
child_2 = Node(messages=child_messages_2, reflection=child_reflection_2, parent=root)
root.children.append(child_2)
# 创建终极解答的节点
final_messages = child_2.get_messages() + [AIMessage(content="Train the model using PyTorch with MMYOLO.")]
final_reflection = Reflection(reflections="Complete answer with code examples.", score=10, found_solution=True)
final_node = Node(messages=final_messages, reflection=final_reflection, parent=child_2)
child_2.children.append(final_node)
# 获取搜索轨迹
print("搜索轨迹:")
for msg in final_node.get_trajectory():
print(f"{msg.__class__.__name__}: {msg.content}")
# 计算最佳解
best_solution = root.get_best_solution()
print("\n最佳解答:")
for msg in best_solution.get_messages():
print(f"{msg.__class__.__name__}: {msg.content}")
# 计算 UCT 评分
if child_1.visits > 0:
print("\nChild 1 UCT Score:", child_1.upper_confidence_bound())
if child_2.visits > 0:
print("Child 2 UCT Score:", child_2.upper_confidence_bound())
# 查看树的深度
print("\n搜索树深度:", root.height)
搜索轨迹:
HumanMessage: How do I fine-tune a YOLO model?
HumanMessage: Reasoning: The answer provides basic steps.
Score: 6
HumanMessage: How do I fine-tune a YOLO model?
HumanMessage: Reasoning: The answer provides basic steps.
Score: 6
AIMessage: Use transfer learning to fine-tune.
HumanMessage: Reasoning: More helpful but still incomplete.
Score: 8
HumanMessage: How do I fine-tune a YOLO model?
HumanMessage: Reasoning: The answer provides basic steps.
Score: 6
AIMessage: Use transfer learning to fine-tune.
HumanMessage: Reasoning: More helpful but still incomplete.
Score: 8
AIMessage: Train the model using PyTorch with MMYOLO.
HumanMessage: Reasoning: Complete answer with code examples.
Score: 10
最佳解答:
HumanMessage: How do I fine-tune a YOLO model?
HumanMessage: Reasoning: The answer provides basic steps.
Score: 6
AIMessage: Use transfer learning to fine-tune.
HumanMessage: Reasoning: More helpful but still incomplete.
Score: 8
AIMessage: Train the model using PyTorch with MMYOLO.
HumanMessage: Reasoning: Complete answer with code examples.
Score: 10
Child 1 UCT Score: 1.8774100225154746
Child 2 UCT Score: 1.2825546111576978
搜索树深度: 3
参考链接:https://langchain-ai.github.io/langgraph/tutorials/lats/lats/