当前位置: 首页 > news >正文

【学习笔记】pytorch强化学习

https://www.bilibili.com/video/BV1zC411h7B8


文章目录

  • [mcts] 01 mcts 基本概念基本原理(UCB)及两个示例
  • [mcts] 02 mcts from scartch(UCTNode,uct_search, pUCT,树的可视化)


[mcts] 01 mcts 基本概念基本原理(UCB)及两个示例

https://github.com/chunhuizhang/personal_chatgpt/blob/main/tutorials/drl/mcts/mcts_01_intro_bascis.ipynb

  • reference
    • Bandit based Monte-Carlo Planning:http://ggp.stanford.edu/readings/uct.pdf
from IPython.display import SVG, Image
import numpy as np

在这里插入图片描述

C = np.sqrt(2)
# level 1 select
print(7/10 + C*np.sqrt(np.log(21)/10))
print(5/8 + C*np.sqrt(np.log(21)/8))
print(0/3 + C*np.sqrt(np.log(21)/3))
  • MCTS

    • statistical(monte carlo) tree
    • Node:刻画/表示的是 state
    • edge: 刻画的是 action 导致的 state transition
  • Select 选择的是 leaf node(从 leaf node 中选择,)

    • 所谓的 leaf node:就是没有 children 的 node,比如初始状态的 root node 就没有 children;
    • select 的依据是 UCT (UCB1 vs. UCT)
      • The main idea in this paper it to apply a particular bandit algorithm, UCB1 (UCB stands for Upper Confidence Bounds), for rollout-based Monte-Carlo planning. The new algorithm, called UCT (UCB applied to trees) described in Section 2 is called UCT.
    • 随着 tree 的展开及update,后续 select 的过程就是一个 tree traversal 的过程;
  • Expand & Simulate (rollout, random simulate)

    • expand:leaf node 展开其 children,产生children的过程

      • 对于围棋的初始 root 状态,展开就是 19*19=361 个可能的children;
    • simulate:最能体现 monte carlo 思想的步骤

      • 搜索树的每个节点,算法会进行多次随机模拟
      • in order to find a value;
      • 从当前状态出发,按照某种策略(可能是完全随机的,也可能是某种启发式的策略)执行到游戏结束或达到某个深度限制。这些模拟的结果(胜 win、负 lose、平等 draw)被用来估算从当前节点出发的期望得分。
    • 什么时候需要rollout,节点是全新(没有被 simulate)的时候;

      • new node => rollout ( n i = 0 n_i=0 ni=0)
      • old node => expand (已经被update (simulate -> bp)过)
  • Backpropagate:对 node 的更新一直向上传(找其父节点)

    • 基于 rollout 找到的 value,
    • 每对一个node完成simulate,因为涉及到 bp,一直沿着 parent node,更新到 root 节点;
      • select 的过程,UCB 都需要重新计算;

UCB1 ( s i ) = w i n i + C ln ⁡ N i n i , UCB1 ( s i ) = v ˉ i + C ln ⁡ N i n i \begin{split} &\text{UCB1}(s_i)=\frac{w_i}{n_i} + C\sqrt{\frac{\ln N_i}{n_i}}, \quad \\ &\text{UCB1}(s_i)=\bar v_i + C\sqrt{\frac{\ln N_i}{n_i}} \end{split} UCB1(si)=niwi+CnilnNi ,UCB1(si)=vˉi+CnilnNi

  • w i w_i wi: # wins

  • n i n_i ni: # simulations

  • w i n i \frac{w_i}{n_i} niwi:game 中的胜率计算;

  • N i N_i Ni: parent’s # simulations

  • C = 2 C=\sqrt2 C=2

  • double E,exploitation vs. exploration

    • exploration:FOMO,fear of missing out
  • AlphaGo: deep learning + mcts

    • policy net: 输入棋盘状态,输出落点几率( π ( a i ∣ s i ) \pi(a_i|s_i) π(aisi)) => UCB
    • value net: 输入棋盘状态,输出获胜的几率 => simulate

图片来源:https://www.youtube.com/watch?v=UXW2yZndl7U
在这里插入图片描述

在这里插入图片描述

一些奏效的例子:

在这里插入图片描述

在这里插入图片描述

print('left', 20/1 + 2*np.sqrt(np.log(2)/1))
print('right', 10/1 + 2*np.sqrt(np.log(2)/1))
# left 21.665109222315394
# right 11.665109222315396

在这里插入图片描述

# left child tree
print(20/2 + 2*np.sqrt(np.log(3)/2)) # 11.482303807367511
# right child tree
print(10/1 + 2*np.sqrt(np.log(3)/1)) # 12.09629414793641

# select right child 

在这里插入图片描述

# left child tree
print(10 + 2*np.sqrt(np.log(4)/2)) # 11.665109222315396

# right child tree
print(12 + 2*np.sqrt(np.log(4)/2)) # 13.665109222315396

# select right child 

在围棋中的案例

  • 白子/黑子,game,围棋(Go,AlphaGo)
    • 19*19 = 361
    • 博弈树:minimax tree
  • 这里根节点(root)的视角是黑子;
    • 黑子,白子,胜负的换算比较简单:总次数 - 黑子赢次数 = 白子赢的次数

在这里插入图片描述

from graphviz import Digraph
from IPython.display import display
graph = Digraph('mcts')
graph.node('s0', 'w_i/n_i', style='filled')
display(graph)
graph = Digraph('mcts')
graph.node('s0', '0/0', style='filled')
display(graph)

在这里插入图片描述
在这里插入图片描述
定义UCB函数

def ucb(wi, ni, Ni, C=np.sqrt(2)):
    return wi/ni + C*np.sqrt(np.log(Ni)/ni)

print('left', ucb(1, 1, 2))
print('right', ucb(0, 1, 2))

在这里插入图片描述

# level 1 
print('left', ucb(1, 2, 3))
print('right', ucb(0, 1, 3))
# choose left

# level 2,切换成白子赢的次数
print('left', ucb(1, 1, 2))

输出:

left 1.548147073968205
right 1.482303807367511
left 2.177410022515475

在这里插入图片描述

实际决策或者planning的时候,只贪心地考虑胜率,w_i/n_i


[mcts] 02 mcts from scartch(UCTNode,uct_search, pUCT,树的可视化)

https://github.com/chunhuizhang/personal_chatgpt/blob/main/tutorials/drl/mcts/mcts_02_from_scartch.ipynb

  • 补充
    • nodes correspond to states s s s
    • edges refer to actions a a a
      • each edge transfers the environment from its parent state to its child state
        • state transition
    • game tree
      • 交替落子 minimax setting;白子的 v(value) 是黑子的 -v;
        • 当前层黑子(边是黑子的action),下一层的为白子(边是白子的action)
        • 交替落子;
    • UCT => pUCT: Q + U
      • early on the simulation, U dominates (more exploration)
      • but later, Q is more important (less exploration, more exploitation)
    • training & inference
      • training: uct = Q + U(select node)
      • inference: Q(当前状态下的 best move)
  • 参考
    • https://github.com/brilee/python_uct
    • https://www.moderndescartes.com/essays/deep_dive_mcts/
import collections
import numpy as np
import math
from IPython.display import Image
from tqdm.notebook import tqdm

节点与搜索

在这里插入图片描述

  • node: 表示一个 game state,比如围棋里边的局面;

  • root:current state

    • mcts planning 就是决策在 current state 下,如何choose best move;
  • leaf node:terminal node or unexplored node

  • edge:action leading to another node

  • 因为 simulate(rollout/evaluate)完了之后涉及到 bp(反向传播或者回溯),每个 node 除了需要指向 children,还需要维护 parent

name_id = 0

class UCTNode():
    def __init__(self, name, state, action, parent=None):
        self.name = name
        self.state = state
        self.action = action
        
        self.is_expanded = False
        
        # self.parent.child_total_value[self.action]
        # self.parent.child_number_visits[self.action]
        # 指向self
        self.parent = parent  # Optional[UCTNode]
        
        self.children = {}  # Dict[action, UCTNode]
        self.child_priors = np.zeros([362], dtype=np.float32)
        # ti
        self.child_total_value = np.zeros([362], dtype=np.float32)
        # ni
        self.child_number_visits = np.zeros([362], dtype=np.float32)
    
    
    # Ni
    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.action]

    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits[self.action] = value
        
    # ti
    @property
    def total_value(self):
        return self.parent.child_total_value[self.action]

    @total_value.setter
    def total_value(self, value):
        self.parent.child_total_value[self.action] = value

    # pUCT
    # https://courses.cs.washington.edu/courses/cse599i/18wi/resources/lecture19/lecture19.pdf
    def child_Q(self) -> np.ndarray:
        return self.child_total_value / (1 + self.child_number_visits)


    def child_U(self) -> np.ndarray:
        return math.sqrt(self.number_visits) * (
            self.child_priors / (1 + self.child_number_visits))
    
    
    def best_child(self) -> int:
#         print(self.child_Q() + self.child_U())
        return np.argmax(self.child_Q() + self.child_U())
    
    # traversal
    def select_leaf(self):
        current = self
        while current.is_expanded:
            # pUCT
            best_action = current.best_child()
            current = current.maybe_add_child(best_action)
        return current

    def expand(self, child_priors):
        self.is_expanded = True
        self.child_priors = child_priors

    def maybe_add_child(self, action):
        global name_id
        if action not in self.children:
            # 新增 child 节点时,切换 player 身份(白子 => 黑子,黑子 => 白子)
            name_id += 1
            self.children[action] = UCTNode(
                name_id, self.state.play(action), action, parent=self)
        return self.children[action]

    def backup(self, value_estimate: float):
        current = self
        while current.parent is not None:
            current.number_visits += 1
            current.total_value += (value_estimate * self.state.to_play)
            current = current.parent

Q + U

# 黑子白子的交替
# Select 的依据是 UCT:Q+U
# edge:P(child priors)
# node:V(value)
# f_\theta => (p, v)
Image(url='https://www.moderndescartes.com/static/deep_dive_mcts/alphago_uct_diagram.png', width=700)

在这里插入图片描述

  • Ranking = Quality + Uncertainty (Q + U)
    • Quality: exploitation
    • Uncertainty: exploration
      • FOMO(fear of missing out)
      • P from policy network

Q = t i 1 + n i U = ln ⁡ N i × P 1 + n i \begin{split} &Q=\frac{t_i}{1+n_i}\\ &U=\sqrt{\ln N_i}\times \frac{P}{1+n_i} \end{split} Q=1+nitiU=lnNi ×1+niP

定义游戏状态

# 交替落子 minimax setting;白子的 v(value) 是黑子的 -v;
class GameState:
    def __init__(self, to_play=1):
        self.to_play = to_play
    def play(self, action):
        return GameState(to_play=-self.to_play)

策略网络与值网络

  • 结合使用策略网络(Policy network)来指导搜索方向, 并使用价值网络来评估棋局的潜在价值, 可以显著减少搜索树的大小,提高搜索的效率。
    • 策略网络(Policy network)能够从先前的对局中学习到有效的走棋模式和策略,这相当于在搜索过程中加入了大量的“先验知识”(child_priors)。
  • 价值网络(value network)可以给出对当前棋局胜负的直接评估,而不需要到达游戏的终局。这种评估能力对于减少搜索深度、加速决策过程至关重要。
class NeuralNet():
    @classmethod
    def evaluate(self, game_state):
        # return policy_network(state), value_network(state)
        # policy_network(state): return pi(a|s)
        # value_network(state): return v(s)
        return np.random.random([362]), np.random.random()

最后是UCT搜索算法

class DummyNode(object):
    def __init__(self):
        self.parent = None
        self.child_total_value = collections.defaultdict(float)
        self.child_number_visits = collections.defaultdict(float)
def print_tree_level_width(root: UCTNode):
    if not root:
        return
    
    queue = [(root, 0)]  # 初始化队列,元素为 (节点, 层级)
    current_level = 0
    level_nodes = []

    while queue:
        node, level = queue.pop(0)  # 从队列中取出当前节点和它的层级
        # 当进入新的一层时,打印上一层的信息并重置
        if level > current_level:
            print(f"Level {current_level} width: {len(level_nodes)}")
            level_nodes = [f'{node.action}']  # 重置当前层的节点列表
            current_level = level
        else:
            level_nodes.append(f'{node.action}')
        
        # 将当前节点的所有子节点加入队列
        for child in node.children.values():
            queue.append((child, level + 1))
    
    # 打印最后一层的信息
    print(f"Level {current_level} width: {len(level_nodes)}")
def UCT_search(state, num_reads):
    # repeated simuations?
    root = UCTNode(0, state, action=None, parent=DummyNode())
    for i in tqdm(range(num_reads)):
        # 每次都是从根节点出发
        leaf = root.select_leaf()
        # child_priors: [0, 1]
        child_priors, value_estimate = NeuralNet().evaluate(leaf.state)
        leaf.expand(child_priors)
        leaf.backup(value_estimate)
#         print(i)
#         print_tree_level_width(root)
    return root, np.argmax(root.child_number_visits)
    # return root, root.best_child()

运行搜索算法:

num_reads = 100000
import time
tick = time.time()
root, _ = UCT_search(GameState(), num_reads)
tock = time.time()
print("Took %s sec to run %s times" % (tock - tick, num_reads))
import resource
print("Consumed %sKB memory" % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
"""
  0%|          | 0/100000 [00:00<?, ?it/s]
Took 5.709271430969238 sec to run 100000 times
Consumed 758408KB memory
"""

打印:

print_tree_level_width(root)

# Level 0 width: 1
# Level 1 width: 360
# Level 2 width: 71329
# Level 3 width: 28310

使用igraph可视化:

# import igraph as ig
# g = ig.Graph(directed=True)

# # 用于跟踪已添加节点的字典
# nodes_dict = {}

# def add_nodes_and_edges(node, parent_id=None):
#     # 添加当前节点(如果尚未添加)
#     if node not in nodes_dict:
#         nodes_dict[node.name] = len(nodes_dict)
#         g.add_vertices(1)
    
#     current_id = nodes_dict[node.name]
    
#     # 添加从父节点到当前节点的边
#     if parent_id is not None:
#         g.add_edges([(parent_id, current_id)])
    
#     # 递归为子节点做同样的处理
#     for child in node.children.values():
#         add_nodes_and_edges(child, current_id)

# # 从根节点开始添加节点和边
# add_nodes_and_edges(root)
# layout = g.layout("tree", root=[0])

# # 设置节点名称
# g.vs["label"] = list(nodes_dict.keys())

# # 可视化
# ig.plot(g, layout=layout, bbox=(300, 300), margin=20)
http://www.dtcms.com/a/113669.html

相关文章:

  • flutter 专题 七十三Flutter打包未签名的ipa
  • Media streaming mental map
  • 马吕斯定律(Malus‘s Law)
  • [Hot 100] 221. 最大正方形 215. 数组中的第K个最大元素 208. 实现 Trie (前缀树) 207. 课程表
  • Nmap全脚本使用指南!NSE脚本全详细教程!Kali Linux教程!(五)
  • 7-12 最长对称子串(PTA)
  • verilog状态机思想编程流水灯
  • VMware 安装 Ubuntu 全流程实战指南:从零搭建到深度优化
  • 医药档案区块链系统
  • 强引用,弱引用,软引用,虚引用,自旋锁,读写锁
  • 基于springboot放松音乐在线播放系统(源码+lw+部署文档+讲解),源码可白嫖!
  • Linux驱动-①电容屏触摸屏②音频③CAN通信
  • client-go如何监听自定义资源
  • 2011-2019年各省地方财政资源勘探电力信息等事务支出数据
  • Jetpack Compose 自定义标题栏终极指南:从基础到高级实战
  • 蓝桥杯2024年第十五届省赛真题-宝石组合
  • BGP路由协议之特殊配置
  • Linux内核slab分配器
  • Linux 系统安装与优化全攻略:打造高效开发环境
  • Airflow量化入门系列:第四章 A股数据处理与存储优化
  • 浅谈StarRocks 常见问题解析
  • (5)模拟后——Leonardo的可视化操作
  • 探秘叁仟智盒设备:智慧城市的智能枢纽
  • Django4.0 快速集成jwt
  • ASP.NET Core Web API 参数传递方式
  • NLP简介及其发展历史
  • docker stack常用命令
  • C#结构体(Struct)深度解析:轻量数据容器与游戏开发应用 (Day 20)
  • pinia-plugin-persist、vuex
  • Spring Boot项目连接MySQL数据库及CRUD操作示例