当前位置: 首页 > news >正文

强化学习入门:从零开始实现Dueling DQN

在之前的文章中,已经依次介绍了

  • 《深度强化学习入门:从零开始实现DQN》
  • 《深度强化学习入门:从零开始实现DDQN》

并通过gymnasium中的 CartPole 环境从零开始实现了智能体的训练与测试。本文将继续这一系列,聚焦于 Dueling DQN(对抗网络结构),它是对 DQN 的进一步改进,主要解决 DQN 在价值估计上的冗余与模糊问题。


一、核心问题:传统 DQN 的价值模糊性

在标准 DQN 中,网络直接输出每个动作的 Q 值,这带来两个问题:

  1. 状态价值与动作价值混淆

    • 例如在自动驾驶场景:
      • 状态价值(V):当前道路的安全性(所有动作共享)
      • 动作优势(A):左转/右转的额外收益
    • DQN 无法区分这两类信息
  2. 冗余学习

    # 当状态价值相同时,DQN 仍需为每个动作重复学习:
    Q(直行) = 道路安全值 + 直行优势
    Q(左转) = 道路安全值 + 左转优势
    Q(右转) = 道路安全值 + 右转优势
    

    → 网络浪费容量学习重复的状态特征


二、Dueling DQN 的提出

Dueling DQN 的核心思想是:将 Q 值分解为状态价值和动作优势两部分

数学公式:

Q(s,a)=V(s)+(A(s,a)−1∣A∣∑a′A(s,a′))Q(s,a) = V(s) + \Big(A(s,a) - \frac{1}{|\mathcal{A}|}\sum_{a'} A(s,a')\Big) Q(s,a)=V(s)+(A(s,a)A1aA(s,a))

其中:

  • V(s)V(s)V(s):状态价值函数(标量)
  • A(s,a)A(s,a)A(s,a):动作优势函数(向量)

为什么需要减去均值?

如果直接写成:

Q(s,a)=V(s)+A(s,a)Q(s,a) = V(s) + A(s,a) Q(s,a)=V(s)+A(s,a)

就会导致 VVVAAA 之间存在无穷多解,学习过程不稳定。例如:

  • V′(s)=V(s)+cV'(s) = V(s) + cV(s)=V(s)+c
  • A′(s,a)=A(s,a)−cA'(s,a) = A(s,a) - cA(s,a)=A(s,a)c

它们可以得到相同的 Q(s,a)Q(s,a)Q(s,a)
解决办法:在 A(s,a)A(s,a)A(s,a) 上做归一化,减去平均值或最大值。
Dueling dqn的详细内容可以看我之前的文章 深度强化学习Dueling DQN,本文重点分享代码。


代码结构(参考自腾讯开悟平台)

📦 项目根目录
├── 📂 agent_dueling_dqn        # Dueling DQN智能体核心模块
│   ├── 📂 algorithm            # 算法实现目录
│      └── 📄 __init__.py
│      └── 📄 algorithm.py      # 算法核心逻辑,包含经验回放、采样、更新网络等方法
│   ├── 📂 conf                 # 配置管理目录
│      └── 📄 __init__.py
│      └── 📄 conf.py           # 参数配置**文件,集中管理模型结构、训练参数、路径等,便于调参
│   ├── 📂 feature              # 特征处理目录
│      └── 📄 __init__.py
│      └── 📄 monitor.py        # 训练过程监控模块,负责记录奖励、损失等指标并可实现可视化
│      └── 📄 processor.py      # 数据预处理模块,负责对环境状态进行标准化、转换Tensor等操作
│   ├── 📂 model                # 神经网络模型目录
│      └── 📄 __init__.py
│      └── 📄 model.py          # 网络模型定义文件
│   ├── 📂 workflow             # 工作流目录
│      └── 📄 __init__.py
│      └── 📄 train_workflow.py # 训练工作流,封装了与环境交互、训练循环等完整流程
│   ├── 📄 __init__.py
│   └── 📄 agent.py             # 智能体接口,提供选择动作、学习、保存/加载模型等功能
├── 📂 env                      # 环境管理目录
│   ├── 📄 __init__.py
│   └── 📄 envManager.py        # 环境管理器,封装Gymnasium环境的创建、重置、步进等操作
└── 📄 train_test.py            # 主程序入口,用于启动训练或测试模式的脚本

网络结构实现

相较于标准 DQN,Dueling DQN 只是在网络的最后增加了 两条分支:一个估计 V(s)V(s)V(s),一个估计 A(s,a)A(s,a)A(s,a)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DuelingDQN(nn.Module):def __init__(self, state_dim, action_dim):super().__init__()# 共享特征提取层self.feature_layer = nn.Sequential(nn.Linear(state_dim, 128),nn.ReLU(),nn.Linear(128, 128),nn.ReLU())# 状态价值分支 V(s)self.V_branch = nn.Sequential(nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 1))# 动作优势分支 A(s,a)self.A_branch = nn.Sequential(nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, action_dim))def forward(self, x):features = self.feature_layer(x)V = self.V_branch(features)  # [batch_size, 1]A = self.A_branch(features)  # [batch_size, action_dim]# 组合 Q 值:V + (A - mean(A))Q = V + (A - A.mean(dim=1, keepdim=True))return Q

只需在原 DQN 框架中替换网络部分,其他流程(经验回放、训练循环、目标网络更新)均保持不变。


完整代码请看这里


文章转载自:

http://zkDH3oxg.djcbt.cn
http://oJo9SyBr.djcbt.cn
http://qKBYk1Ga.djcbt.cn
http://Ps2ygwcB.djcbt.cn
http://2gI8qXKw.djcbt.cn
http://JlVn7tq4.djcbt.cn
http://6DhA7q11.djcbt.cn
http://MA531hHI.djcbt.cn
http://TVOmDJic.djcbt.cn
http://o1BUEBWZ.djcbt.cn
http://nr81Dt3D.djcbt.cn
http://h1Bq6H8S.djcbt.cn
http://KgPRazW4.djcbt.cn
http://OtMjKmof.djcbt.cn
http://bnN492Ar.djcbt.cn
http://09ZOncWh.djcbt.cn
http://2PJMuErW.djcbt.cn
http://O0XDz088.djcbt.cn
http://O9G9NORw.djcbt.cn
http://eMfHuy3B.djcbt.cn
http://vsz4jPI9.djcbt.cn
http://DztHwlma.djcbt.cn
http://BEXqbsnP.djcbt.cn
http://l69Mt4pJ.djcbt.cn
http://07uQc3hU.djcbt.cn
http://WSbWOiFG.djcbt.cn
http://yNRFXa8r.djcbt.cn
http://B0IOoUGS.djcbt.cn
http://vFczMSk6.djcbt.cn
http://H8AEAeTd.djcbt.cn
http://www.dtcms.com/a/372008.html

相关文章:

  • 做事总是三分钟热度怎么办
  • 图像形态学
  • C++运算符重载——函数调用运算符 ()
  • 分布式系统——分布式数据库的高扩展性保证
  • C++ 并发编程:异步任务
  • 四、神经网络的学习(中)
  • OPENPPP2 —— IP标准校验和算法深度剖析:从原理到SSE2优化实现
  • 梅花易数:从入门到精通
  • 计算机⽹络及TCP⽹络应⽤程序开发
  • 单点登录1(SSO知识点)
  • 嵌入式学习---(ARM)
  • 嵌入式学习day44-硬件—ARM体系架构
  • 《数据结构全解析:栈(数组实现)》
  • Linux系统资源监控脚本
  • PHP中各种超全局变量使用的过程
  • C++-类型转换
  • [GDOUCTF 2023]doublegame
  • 系统资源监控与邮件告警
  • 1706.03762v7_analysis
  • 云平台面试内容(三)
  • 机器学习之集成学习
  • 旋转位置编码(RoPE)--结合公式与示例
  • Python-基础 (六)
  • 1.12 Memory Profiler Package - Summary
  • 【面试题】C++系列(一)
  • Hadoop(九)
  • 关于npm的钩子函数
  • 旋转数字矩阵 od
  • Matlab:基于遗传算法优化 PID 控制器的完整实现与解析
  • JBoltAI需求分析大师:基于SpringBoot的大模型智能需求文档生成解决方案