D3QN + 优先经验回放(PER)实现全解析:从数据树到训练循环(附伪代码与流程图)
D3QN + 优先经验回放(PER)实现全解析:从数据树到训练循环(附伪代码与流程图)
这是一篇完整的工程向教程,目标是让读者在“不依赖额外资料”的情况下,从零搭建一套可运行、可解释、可扩展的 D3QN+PER 训练系统。全文分三大部分:
- SumTree(数据树):以 O(log N) 支持“按权取样”和“局部更新”,解决加权随机采样的效率瓶颈;
- 优先经验回放池(PER):用 TD 误差驱动优先级、以分段采样保证覆盖、以重要性采样权重(IS weights)抵消偏置;
- D3QN(Dueling + Double):以价值/优势分流提升 Q 估计效率,以双网络降低过估计偏差,并与 PER 的权重、优先级更新闭环联动。
附:文末给出“可直接照抄”的伪代码,正文中则穿插给出实现细节、数值稳定性建议与典型坑点;五张分页流程图 d3qn_pseudocode_Page1.png
~ d3qn_pseudocode_Page5.png
用于对照阅读。
一、SumTree:把“按权采样”降到 O(log N)
1.1 问题背景
在强化学习中,我们常常需要“按优先级”从经验池中抽样。设经验池大小为 NNN,每个样本有非负权重(或优先级)pip_ipi,其被抽到的概率应满足:
P(i)=pi∑k=1Npk. P(i) = \frac{p_i}{\sum_{k=1}^{N} p_k}. P(i)=∑k=1Npkpi.
朴素地做“累积和 + 线性扫描”复杂度是 O(N)\mathcal{O}(N)O(N),每次抽样太慢;每次更新 pip_ipi 后重建累积和也慢。我们需要一种既能快速更新、又能快速查询前缀和的数据结构。SumTree(也称 Segment Tree 的简化变体)正是为此而生。
1.2 树的布局与索引映射
SumTree 是完全二叉树,用数组存储:长度约 2⋅capacity−12\cdot \text{capacity}-12⋅capacity−1。最底层的叶子区间存放每个样本的优先级 pip_ipi,上方的非叶子结点存放其两个子结点之和。根结点存放整棵树的总和,即 ∑ipi\sum_i p_i∑ipi。
- 叶子到样本的映射:
data_idx = tree_idx - capacity + 1
; - 从根向下查找:若查询值 vvv 小于等于左子树和,则走左子树;否则减去左子树和并走右子树;如此直到叶子。
这样,给定区间 [0,total][0, \text{total}][0,total] 上均匀随机的 vvv,就能在 O(logN)\mathcal{O}(\log N)O(logN) 时间内定位被抽中的叶子(样本)。
1.3 三个核心操作
- Add(x, p):把新样本 x 写入当前指针位置对应的叶子,并设置优先级 p,然后自底向上回写父结点之和。指针环形前进支持覆盖最老样本。
- Update(i, p):已知叶子下标 i,用 Δ=pnew−pold\Delta = p_{new} - p_{old}Δ=pnew−pold 自底向上把所有祖先的值加上 Δ\DeltaΔ,保持区间和一致。
- GetLeaf(v):从根开始二分查找,直到叶子;返回
(tree_idx, priority, data_idx)
。这里 v 是在 [0,total][0, \text{total}][0,total] 的随机数,或分段内的随机数(见 PER 的均衡分段采样)。
复杂度:三个操作均为 O(logN)\mathcal{O}(\log N)O(logN);空间复杂度约为 2N2N2N(常数很小)。
1.4 示例(一步步走到目标叶子)
设 capacity=8,共 8 个叶子。若根结点总和为 SSS,我们先采样 v∼U(0,S)v\sim U(0,S)v∼U(0,S),比较 vvv 与左孩子和 SLS_LSL:
- 若 v≤SLv\le S_Lv≤SL,继续到左孩子;
- 否则 v←v−SLv\leftarrow v-S_Lv←v−S