当前位置: 首页 > news >正文

从代码学习深度强化学习 - DDPG PyTorch版

文章目录

  • 前言
  • DDPG 算法简介
  • 环境介绍:倒立摆
  • 代码实践
    • 1. 辅助函数与环境设置
    • 2. 核心组件 (Actor, Critic, ReplayBuffer)
    • 3. DDPG 智能体
    • 4. 训练循环
    • 5. 运行代码与结果展示
  • 总结


前言

欢迎来到深度强化学习(DRL)的世界!如果您对如何让AI在模拟环境中(比如游戏或机器人控制)学习复杂任务感到好奇,那么您来对地方了。深度强化学习结合了深度学习强大的感知能力和强化学习决策能力,使其能够解决传统方法难以应对的复杂问题。

在众多DRL算法中,DDPG(Deep Deterministic Policy Gradient)是一个非常经典且强大的算法,尤其擅长处理连续动作空间的任务,例如控制机器臂的关节角度或赛车的方向盘。

本篇博客旨在通过一个完整的PyTorch代码实例,带您深入浅出地理解DDPG算法的内部工作原理。我们将一起从零开始,构建DDPG智能体的每一个核心组件,并在经典的“倒立摆(Inverted Pendulum)”环境中进行训练和测试。无论您是DRL初学者还是希望加深对DDPG理解的实践者,相信都能从中获益。

完整代码:下载链接


DDPG 算法简介

DDPG(Deep Deterministic Policy Gradient)主要是为连续动作空间设计的强化学习算法,因此它不直接适用于离散动作空间。DDPG使用的是确定性策略(deterministic policy),这意味着它输出的是一个确定的动作值,而不是动作的概率分布。

与DDPG相对的是随机性策略(stochastic policy),它会输出一个动作的概率分布,然后从这个分布中采样一个动作。对于离散动作空间的问题,我们通常建议使用例如 DQN(Deep Q-Network)或 A3C(Asynchronous Advantage Actor-Critic)等算法。

DDPG属于一种**Actor-Critic(演员-评论家)**框架的算法。它的核心思想是:

  • Actor (演员):一个策略网络,负责根据当前的状态(state)选择一个最佳的动作(action)。
  • Critic (评论家):一个Q值网络,负责评估演员选择的动作在当前状态下有多好,即预测未来的总回报(Q值)。

通过这种方式,Critic会“指导”Actor如何调整策略以输出能获得更高Q值的动作,从而实现学习。

在这里插入图片描述

图1:DDPG 中的 Actor 网络和 Critic 网络,以倒立摆环境为例

DDPG巧妙地融合了DQN算法中的两个关键技术,以提高训练的稳定性和效率:

  1. 经验回放(Replay Buffer): DDPG是一个**离策略(off-policy)**算法。它将智能体与环境交互的经验(状态、动作、奖励、下一状态)存储在一个经验池中。在训练时,它会从池中随机采样一批数据进行学习,这打破了数据之间的相关性,使得训练更加稳定。
  2. 目标网络(Target Networks): DDPG为Actor和Critic网络都创建了一个对应的“目标网络”。在计算目标Q值时,使用的是目标网络,这为学习提供了一个稳定的目标,避免了Q值估计的剧烈波动。目标网络的更新采用**软更新(soft update)**方式,即每次只将主网络的一小部分权重更新到目标网络,公式为:target_weights = τ * local_weights + (1 - τ) * target_weights。这与DQN中每隔一段时间直接“硬”复制权重的方式不同,使得更新过程更加平滑。

为了在确定性策略中实现探索(exploration),DDPG在Actor网络输出的动作上加入了一些随机噪声,从而让智能体有机会尝试当前策略之外的动作,发现可能更优的策略。

环境介绍:倒立摆

在本次实践中,我们选择OpenAI Gym中的经典控制环境——倒立摆(Inverted Pendulum)
在这里插入图片描述

图2:Pendulum环境示意图

这个任务的目标非常直观:通过施加一个力矩(torque),让一个初始位置随机的摆杆尽可能地保持竖直向上。

这是一个典型的连续控制问题,因为它的状态和动作都是连续的:

  • 状态空间 (State Space): 一个3维的连续向量,包含了摆杆角度的余弦、正弦以及角速度。
    标号名称最小值最大值
    0cos θ-1.01.0
    1sin θ-1.01.0
    2θ̇-8.08.0
  • 动作空间 (Action Space): 一个1维的连续值,代表施加的力矩。
    标号动作最小值最大值
    0力矩-2.02.0

环境会根据摆杆的状态(离竖直位置越近、摆动越慢)和所施加的力矩大小来计算每一步的奖励。我们的智能体需要学习一个最优策略来最大化整个回合的累积奖励。

代码实践

接下来,让我们一步步构建和训练我们的DDPG智能体。代码注释中包含了对每个变量和操作的详细解释,包括其维度和作用。

1. 辅助函数与环境设置

首先,我们定义一个用于平滑奖励曲线的辅助函数,并设置好我们的Gym环境。平滑处理可以帮助我们更清晰地观察智能体学习的总体趋势。

import numpy as npdef moving_average(data, window_size):"""计算移动平均值,用于平滑奖励曲线参数:data (list): 原始数据序列 [num_episodes,] - 每个episode的奖励值window_size (int): 移动窗口大小 - 用于平滑的数据点数量返回:list: 移动平均后的数据 [len(data) - window_size + 1,] - 平滑后的奖励序列示例:原始数据: [1, 2, 3, 4, 5], 窗口大小: 3返回: [2.0, 3.0, 4.0] (即 [1+2+3]/3, [2+3+4]/3, [3+4+5]/3)"""# 如果数据长度小于窗口大小,直接返回原数据if len(data) < window_size:return data# 初始化移动平均值列表moving_avg = []  # [len(data) - window_size + 1,]# 滑动窗口计算移动平均值for i in range(len(data) - window_size + 1):# 提取当前窗口内的数据window_data = data[i:i + window_size]  # [window_size,] - 当前窗口的数据# 计算窗口内数据的平均值并添加到结果列表moving_avg.append(np.mean(window_data))  # scalar - 当前窗口的平均值return moving_avg
# 导入必要的库
import gym  # OpenAI Gym强化学习环境库
import numpy as np  # 数值计算库,用于处理数组和矩阵运算# 定义环境名称
env_name = 'Pendulum-v1'  # str - 环境名称,倒立摆连续控制任务# 创建强化学习环境
env = gym.make(env_name)  # gym.Env - Gym环境对象
# env.observation_space: Box(3,) - 观测空间维度为3,包含[cos(theta), sin(theta), thetadot]
# env.action_space: Box(1,) - 动作空间维度为1,连续动作范围[-2.0, 2.0]
# env.reward_range: (-inf, inf) - 奖励范围,理论上无限制

2. 核心组件 (Actor, Critic, ReplayBuffer)

现在,我们来定义DDPG的三个核心组件:经验回放池(ReplayBuffer)、策略网络(Actor)和Q值网络(Critic)。

# 导入必要的库
import collections  # 用于双向队列deque
import random  # 用于随机采样
import numpy as np  # 数值计算库
import torch  # PyTorch深度学习框架
import torch.nn.functional as F  # PyTorch函数库,包含激活函数等# 经验回放缓存
class ReplayBuffer:"""经验回放缓存,用于存储和采样经验数据"""def __init__(self, capacity):"""初始化经验回放缓存参数:capacity (int): 缓存容量 - 最大存储的经验数量"""self.buffer = collections.deque(maxlen=capacity)  # deque - 双向队列,自动维护最大长度def add(self, state, action, reward, next_state, done):"""向经验回放缓存添加一条经验参数:state (array): 当前状态 [state_dim,] - 环境的当前观测action (array): 执行的动作 [action_dim,] - 智能体选择的动作reward (float): 获得的奖励 scalar - 环境反馈的奖励值next_state (array): 下一个状态 [state_dim,] - 执行动作后的新观测done (bool): 是否结束 scalar - 回合是否终止的标志"""self.buffer.append((state, action, reward, next_state, done))  # tuple - 存储一条完整的经验def sample(self, batch_size):"""从经验回放缓存中采样一批经验参数:batch_size (int): 采样批次大小 - 需要采样的经验数量返回:state (ndarray): 状态批次 [batch_size, state_dim] - 采样的状态数组action (ndarray): 动作批次 [batch_size, action_dim] - 采样的动作数组reward (ndarray): 奖励批次 [batch_size,] - 采样的奖励数组next_state (ndarray): 下一状态批次 [batch_size, state_dim] - 采样的下一状态数组done (ndarray): 结束标志批次 [batch_size,] - 采样的结束标志数组"""transitions = random.sample(self.buffer, batch_size)  # list - 随机采样的经验列表state, action, reward, next_state, done = zip(*transitions)  # tuple - 解包经验元组return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)def size(self):"""返回当前经验回放缓存中数据的数量返回:int: 缓存中经验的数量 scalar - 当前存储的经验条数"""return len
http://www.dtcms.com/a/278087.html

相关文章:

  • [Python 基础课程]列表
  • 【DataLoader的使用】
  • 力扣 hot100 Day43
  • Actor-Critic重要性采样原理
  • java valueOf方法
  • 【算法】贪心算法入门
  • SwiftUI 7 新 WebView:金蛇出洞,网页江湖换新天
  • 一些git命令
  • 若依框架集成阿里云OSS实现文件上传优化
  • 对于muduo我自己的理解
  • UniHttp生命周期钩子与公共参数实战:打造智能天气接口客户端
  • flask校园学科竞赛管理系统-计算机毕业设计源码12876
  • SPSSPRO:数据分析市场SaaS挑战者的战略分析
  • JAVA并发——什么是AQS?
  • Mapbox GL初探
  • 【unitrix】 5.0 第二套类型级二进制数基本结构体(types2.rs)
  • 16.使用ResNet网络进行Fashion-Mnist分类
  • css如何同时给元素设置背景和背景图?
  • 每日算法刷题Day47:7.13:leetcode 复习完滑动窗口一章,用时2h30min
  • 说实话,统计分析用Python这5个第三方库就够了
  • AutoLabor-ROS-Python 学习记录——第一章 ROS概述与环境搭建
  • PortsSwiggerLab: SSRF with blacklist-based input filter
  • JS进阶-day1 作用域解构箭头函数
  • Spring AI 项目实战(十六):Spring Boot + AI + 通义万相图像生成工具全栈项目实战(附完整源码)
  • NO.5数据结构串和KMP算法|字符串匹配|主串与模式串|KMP|失配分析|next表
  • pthread_mutex_unlock函数的概念和用法
  • 大规模电商系统分库分表实战经验分享
  • NFSV4锁机制(三)
  • 编程技术杂谈2.0
  • DVWA靶场通关笔记-XSS DOM(High级别)