当前位置: 首页 > news >正文

Language Agent Tree Search (1)

代码

import math
from collections import deque
from typing import Optional

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage

from pydantic import BaseModel, Field


class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 0-10 on the quality of the candidate response.",
        gte=0,
        lte=10,
    )
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )

    def as_message(self):
        return HumanMessage(
            content=f"Reasoning: {self.reflections}\nScore: {self.score}"
        )

    @property
    def normalized_score(self) -> float:
        return self.score / 10.0


class Node:
    def __init__(
        self,
        messages: list[BaseMessage],
        reflection: Reflection,
        parent: Optional["Node"] = None,
    ):
        self.messages = messages
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self.depth = parent.depth + 1 if parent is not None else 1
        self._is_solved = reflection.found_solution if reflection else False
        if self._is_solved:
            self._mark_tree_as_solved()
        self.backpropagate(reflection.normalized_score)

    def __repr__(self) -> str:
        return (
            f"<Node value={self.value}, visits={self.visits},"
            f" solution={self.messages} reflection={self.reflection}/>"
        )

    @property
    def is_solved(self):
        """If any solutions exist, we can end the search."""
        return self._is_solved

    @property
    def is_terminal(self):
        return not self.children

    @property
    def best_child_score(self):
        """Return the child with the highest value."""
        if not self.children:
            return None
        return max(self.children, key=lambda child: int(child.is_solved) * child.value)

    @property
    def height(self) -> int:
        """Check for how far we've rolled out the tree."""
        if self.children:
            return 1 + max([child.height for child in self.children])
        return 1

    def upper_confidence_bound(self, exploration_weight=1.0):
        """Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
        if self.parent is None:
            raise ValueError("Cannot obtain UCT from root node")
        if self.visits == 0:
            return self.value
        # Encourages exploitation of high-value trajectories
        average_reward = self.value / self.visits
        # Encourages exploration of less-visited trajectories
        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
        return average_reward + exploration_weight * exploration_term

    def backpropagate(self, reward: float):
        """Update the score of this node and its parents."""
        node = self
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            node = node.parent

    def get_messages(self, include_reflections: bool = True):
        if include_reflections:
            return self.messages + [self.reflection.as_message()]
        return self.messages

    def get_trajectory(self, include_reflections: bool = True) -> list[BaseMessage]:
        """Get messages representing this search branch."""
        messages = []
        node = self
        while node:
            messages.extend(
                node.get_messages(include_reflections=include_reflections)[::-1]
            )
            node = node.parent
        # Reverse the final back-tracked trajectory to return in the correct order
        return messages[::-1]  # root solution, reflection, child 1, ...

    def _get_all_children(self):
        all_nodes = []
        nodes = deque()
        nodes.append(self)
        while nodes:
            node = nodes.popleft()
            all_nodes.extend(node.children)
            for n in node.children:
                nodes.append(n)
        return all_nodes

    def get_best_solution(self):
        """Return the best solution from within the current sub-tree."""
        all_nodes = [self] + self._get_all_children()
        best_node = max(
            all_nodes,
            # We filter out all non-terminal, non-solution trajectories
            key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
        )
        return best_node

    def _mark_tree_as_solved(self):
        parent = self.parent
        while parent:
            parent._is_solved = True
            parent = parent.parent

代码解释

这段代码实现了一种搜索树(用于强化学习或AI决策树)来评估对话响应的质量。核心思想是 模拟蒙特卡洛树搜索(MCTS) 的方式,通过 探索和回溯传播(backpropagation) 来优化搜索路径。


1. Reflection 类

用于存储对某个回答的评价,包括:

  • reflections:对回答的评论(是否充分、冗余等)。
  • score(0-10):评分,数值越高表示质量越好。
  • found_solution(布尔值):该回答是否完全解决了问题。
  • as_message():转换为 HumanMessage(LangChain 消息类型)。
  • normalized_score():将评分归一化到 [0,1]

2. Node 类

用于构建搜索树节点,每个节点包含:

  • messages:当前对话消息列表(BaseMessage 类型)。
  • reflection:当前回答的评价(Reflection 对象)。
  • parent:父节点(如果存在)。
  • children:子节点列表。
  • value:当前节点的分数,初始化时等于 normalized_score
  • visits:当前节点被访问的次数。
  • depth:当前节点的深度(从根节点算起)。
  • _is_solved:表示当前节点是否是完全解决问题的答案
主要方法
  1. backpropagate(reward: float)

    • 递归回溯更新当前节点及所有父节点的分值,类似 MCTS 反向传播:
    • 公式:
      在这里插入图片描述
  2. upper_confidence_bound(exploration_weight=1.0)

    • UCT 公式(用于探索 vs 开发)
      在这里插入图片描述

    • 第一项(平均奖励):倾向于选择评分高的路径(开发)。

    • 第二项(探索因子):鼓励探索访问次数少的路径(探索)。

  3. get_best_solution()

    • 返回当前子树中评分最高的终端节点(已解决问题的最佳解)。
  4. get_trajectory(include_reflections=True)

    • 返回该搜索路径上的所有消息(对话+评分)
  5. _mark_tree_as_solved()

    • 向上递归标记整棵树为“已解决”,优化搜索。

代码应用场景

  1. 对话评估

    • 该代码可用于自动评估 AI 回应质量,比如 LangChain 的 AI 代理,判断哪个回答更好。
  2. 强化学习中的搜索树

    • 类似 MCTS,在决策过程中利用探索(exploration)开发(exploitation) 策略寻找最优解。
  3. 优化搜索路径

    • 通过 backpropagate()upper_confidence_bound(),可以优化AI的回答策略,找到更合适的答案。

总结

这段代码本质上是一个基于蒙特卡洛树搜索(MCTS)+ 评分反馈的搜索树,可以用于强化学习、智能对话、任务规划等领域。它利用反向传播更新节点价值,结合UCT 策略来平衡探索和利用,最终找到最优解答

示例

示例1

import random
from langchain_core.messages import HumanMessage

# 创建初始根节点的反思对象
root_reflection = Reflection(reflections="Starting point", score=5, found_solution=False)

# 创建根节点(无父节点)
root_node = Node(messages=[HumanMessage(content="What is the capital of France?")], reflection=root_reflection)

# 模拟不同回答的候选节点
candidate_responses = [
    ("Paris", Reflection(reflections="Correct and concise", score=10, found_solution=True)),
    ("France is a country in Europe.", Reflection(reflections="Too general, lacks specificity", score=4, found_solution=False)),
    ("I don't know.", Reflection(reflections="Unhelpful response", score=1, found_solution=False)),
    ("Paris, the capital of France, is known for the Eiffel Tower.", Reflection(reflections="Correct, but verbose", score=8, found_solution=True)),
]

# 生成子节点并添加到搜索树
for content, reflection in candidate_responses:
    child_node = Node(messages=[HumanMessage(content=content)], reflection=reflection, parent=root_node)
    root_node.children.append(child_node)

# 选择最佳答案(基于 MCTS 评估)
best_solution = root_node.get_best_solution()

# 输出最佳解
print("Best Solution Found:")
print(f"Message: {best_solution.messages[0].content}")
print(f"Score: {best_solution.reflection.score}")
print(f"Reflections: {best_solution.reflection.reflections}")

Best Solution Found:
Message: Paris
Score: 10
Reflections: Correct and concise

示例2

import random

# 创建根节点的消息和评估
root_messages = [HumanMessage(content="How do I fine-tune a YOLO model?")]
root_reflection = Reflection(reflections="The answer provides basic steps.", score=6, found_solution=False)

# 创建根节点
root = Node(messages=root_messages, reflection=root_reflection)

# 创建子节点
child_messages_1 = root.get_messages() + [AIMessage(content="First, you need a labeled dataset.")]
child_reflection_1 = Reflection(reflections="Good start but lacks details.", score=7, found_solution=False)
child_1 = Node(messages=child_messages_1, reflection=child_reflection_1, parent=root)
root.children.append(child_1)

child_messages_2 = root.get_messages() + [AIMessage(content="Use transfer learning to fine-tune.")]
child_reflection_2 = Reflection(reflections="More helpful but still incomplete.", score=8, found_solution=False)
child_2 = Node(messages=child_messages_2, reflection=child_reflection_2, parent=root)
root.children.append(child_2)

# 创建终极解答的节点
final_messages = child_2.get_messages() + [AIMessage(content="Train the model using PyTorch with MMYOLO.")]
final_reflection = Reflection(reflections="Complete answer with code examples.", score=10, found_solution=True)
final_node = Node(messages=final_messages, reflection=final_reflection, parent=child_2)
child_2.children.append(final_node)

# 获取搜索轨迹
print("搜索轨迹:")
for msg in final_node.get_trajectory():
    print(f"{msg.__class__.__name__}: {msg.content}")

# 计算最佳解
best_solution = root.get_best_solution()
print("\n最佳解答:")
for msg in best_solution.get_messages():
    print(f"{msg.__class__.__name__}: {msg.content}")

# 计算 UCT 评分
if child_1.visits > 0:
    print("\nChild 1 UCT Score:", child_1.upper_confidence_bound())
if child_2.visits > 0:
    print("Child 2 UCT Score:", child_2.upper_confidence_bound())

# 查看树的深度
print("\n搜索树深度:", root.height)

搜索轨迹:
HumanMessage: How do I fine-tune a YOLO model?
HumanMessage: Reasoning: The answer provides basic steps.
Score: 6
HumanMessage: How do I fine-tune a YOLO model?
HumanMessage: Reasoning: The answer provides basic steps.
Score: 6
AIMessage: Use transfer learning to fine-tune.
HumanMessage: Reasoning: More helpful but still incomplete.
Score: 8
HumanMessage: How do I fine-tune a YOLO model?
HumanMessage: Reasoning: The answer provides basic steps.
Score: 6
AIMessage: Use transfer learning to fine-tune.
HumanMessage: Reasoning: More helpful but still incomplete.
Score: 8
AIMessage: Train the model using PyTorch with MMYOLO.
HumanMessage: Reasoning: Complete answer with code examples.
Score: 10

最佳解答:
HumanMessage: How do I fine-tune a YOLO model?
HumanMessage: Reasoning: The answer provides basic steps.
Score: 6
AIMessage: Use transfer learning to fine-tune.
HumanMessage: Reasoning: More helpful but still incomplete.
Score: 8
AIMessage: Train the model using PyTorch with MMYOLO.
HumanMessage: Reasoning: Complete answer with code examples.
Score: 10

Child 1 UCT Score: 1.8774100225154746
Child 2 UCT Score: 1.2825546111576978

搜索树深度: 3

参考链接:https://langchain-ai.github.io/langgraph/tutorials/lats/lats/

相关文章:

  • 春招中护网面试题库
  • Github 2025-03-06 Go开源项目日报 Top10
  • C语言:怎样将一个结构体数据全部清零
  • 【AI深度学习基础】Pandas完全指南进阶篇:解锁高效数据处理高阶技能 (含完整代码)
  • 【VBA】WPS/PPT设置标题字体
  • 50.日常算法
  • 算法进阶——枚举
  • Java-servlet(三)Java-servlet-Web环境搭建(下)详细讲解利用maven和tomcat搭建Java-servlet环境
  • Python 错误和异常处理:守护程序的稳定运行
  • 无耳 Solon v3.1.0 全新发布(可全面替换 Java Spring 生态)
  • 18k star,取代Navicat!一款集成了AI功能的数据库管理工具!
  • 带触屏笔记本关闭屏幕触控方法
  • redis测评
  • 深度求索(DeepSeek)开源周技术全景与行业影响研究报告
  • 信息安全之构建FTP服务器证书
  • Golang集成企业微信接收消息服务处理URL接口(验证、解密)信息(GoFly快速开发框架)
  • 基于vue3的刻度尺组件
  • 什么是全栈?
  • 【人工智能技术发展路径:从符号学习到深度学习的演进】
  • 大模型FunctionCall-知识整理
  • 网站开发方式有哪些/seo网站关键词优化费用
  • 优秀网站设计的标准/在百度上怎么发布广告
  • 自己做的网站突然打不开/网络推广企业
  • 上海网站公司电话/公司网站建设费
  • 网站开发公用头部/广告推广投放平台
  • 做鸡蛋期货看什么网站/网站推广方案模板