当前位置: 首页 > news >正文

从代码学习深度强化学习 - SAC PyTorch版

文章目录

  • 前言
  • SAC处理连续动作空间问题 (Pendulum-v1)
    • 核心代码实现
      • **工具函数与环境初始化**
      • **ReplayBuffer、网络结构与SAC算法**
      • **训练与结果**
  • SAC处理离散动作空间问题 (CartPole-v1)
    • 核心代码实现
      • **工具函数与环境初始化**
      • **ReplayBuffer、网络结构与SAC算法 (离散版)**
      • **训练与结果**
  • 总结


前言

在深度强化学习(DRL)的探索之旅中,我们不断寻求更高效、更稳定的算法来应对日益复杂的决策问题。传统的在线策略算法(On-policy)如A2C、PPO等,虽然在很多场景下表现优异,但其采样效率低下的问题也限制了它们在某些现实世界任务中的应用,尤其是在那些与环境交互成本高昂的场景中。

因此,离线策略(Off-policy)算法应运而生,它们能够利用历史数据(Replay Buffer)进行学习,极大地提高了数据利用率和学习效率。在众多离线策略算法中,Soft Actor-Critic(SAC)算法以其出色的稳定性和卓越的性能脱颖而出。

正如上图所述,与同为离线策略算法的DDPG相比,SAC在训练稳定性和收敛性方面表现更佳,对超参数的敏感度也更低。 SAC的前身是Soft Q-learning,它们都属于最大熵强化学习的范畴,即在最大化累积奖励的同时,也最大化策略的熵,从而鼓励智能体进行更充分的探索。 与Soft Q-learning不同,SAC引入了一个显式的策略函数(Actor),从而优雅地解决了在连续动作空间中求解困难的问题。 SAC学习的是一个随机策略,这使得它能够探索多模态的最优策略,并在复杂的环境中表现出更强的鲁棒性。

本篇博客将通过两个PyTorch实现的SAC代码示例,带您深入理解SAC算法的精髓。我们将分别探讨其在连续动作空间离散动作空间中的具体实现,并通过代码解析,让您对策略网络、价值网络、经验回放、软更新以及核心的熵正则化等概念有更直观的认识。

无论您是强化学习的初学者,还是希望深入了解SAC算法的实践者,相信通过本文的代码学习之旅,您都将有所收获。

完整代码:下载链接


SAC处理连续动作空间问题 (Pendulum-v1)

在连续控制任务中,SAC通过学习一个随机策略,输出动作的正态分布的均值和标准差,从而实现对连续动作的探索和决策。我们将以OpenAI Gym中的经典环境Pendulum-v1为例,这是一个典型的连续控制问题,智能体的目标是利用有限的力矩将摆杆竖立起来。

核心代码实现

以下是SAC算法在Pendulum-v1环境下的完整PyTorch实现。代码涵盖了工具函数、环境初始化、核心网络结构(ReplayBuffer、策略网络、Q值网络)、SAC算法主类以及训练和可视化的全过程。

工具函数与环境初始化

首先,我们定义一个moving_average函数用于平滑训练过程中的奖励曲线,以便更好地观察训练趋势。然后,我们初始化Pendulum-v1环境。

# utils"""
强化学习工具函数集
包含数据平滑处理功能
"""import torch
import numpy as np
def moving_average(data, window_size):"""计算移动平均值,用于平滑奖励曲线该函数通过滑动窗口的方式对时间序列数据进行平滑处理,可以有效减少数据中的噪声,使曲线更加平滑美观。常用于强化学习中对训练过程的奖励曲线进行可视化优化。参数:data (list): 原始数据序列,维度: [num_episodes]包含需要平滑处理的数值数据(如每轮训练的奖励值)window_size (int): 移动窗口大小,维度: 标量决定了平滑程度,窗口越大平滑效果越明显但也会导致更多的数据点丢失返回:list: 移动平均后的数据,维度: [len(data) - window_size + 1]返回的数据长度会比原数据少 window_size - 1 个元素这是因为需要足够的数据点来计算第一个移动平均值示例:>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]  # 维度: [10]>>> smoothed = moving_average(data, 3)       # window_size = 3>>> print(smoothed)  # 输出: [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]  维度: [8]"""# 边界检查:如果数据长度小于窗口大小,直接返回原数据# 这种情况下无法计算移动平均值# data维度: [num_episodes], window_size维度: 标量if len(data) < window_size:return data# 初始化移动平均值列表# moving_avg维度: 最终为[len(data) - window_size + 1]moving_avg = []# 遍历数据,计算每个窗口的移动平均值# i的取值范围: 0 到 len(data) - window_size# 循环次数: len(data) - window_size + 1# 每次循环处理一个滑动窗口位置for i in range(len(data) - window_size + 1):# 提取当前窗口内的数据切片# window_data维度: [window_size]# 包含从索引i开始的连续window_size个元素# 例如:当i=0, window_size=3时,提取data[0:3]window_data = data[i:i + window_size]# 计算当前窗口内数据的算术平均值# np.mean(window_data)维度: 标量# 将平均值添加到结果列表中moving_avg.append(np.mean(window_data))# 返回移动平均后的数据列表# moving_avg维度: [len(data) - window_size + 1]return moving_avg
"""
强化学习环境初始化模块
用于创建和配置OpenAI Gym环境
"""import gym  # OpenAI Gym库,提供标准化的强化学习环境接口
import numpy as np  # 数值计算库,用于处理多维数组和数学运算# 定义环境名称
# env_name维度: 字符串标量
# 'Pendulum-v1'是一个连续控制任务,倒立摆环境
# 状态空间: 3维连续空间 (cos(theta), sin(theta), thetadot)
# 动作空间: 1维连续空间,范围[-2.0, 2.0]
env_name = 'Pendulum-v1'# 创建强化学习环境实例
# env维度: gym.Env对象
# 包含完整的环境状态、动作空间、奖励函数等信息
# 该环境支持reset()、step()、render()、close()等标准方法
env = gym.make(env_name)

ReplayBuffer、网络结构与SAC算法

这部分代码是SAC算法的核心。

  • ReplayBuffer: 经验回放池,用于存储和采样智能体的经验数据,打破数据相关性,提高学习效率。
  • PolicyNetContinuous: 策略网络(Actor),输入状态,输出动作分布的均值和标准差。这里使用了重参数化技巧(Reparameterization Trick),使得从策略分布中采样的过程可导,从而能够利用梯度进行端到端的训练。动作经过tanh函数激活并缩放到环境的动作边界内。
  • QValueNetContinuous: Q值网络(Critic),输入状态和动作,输出对应的Q值。SAC采用了双Q网络的技巧,即构建两个结构相同的Q网络,在计算目标Q值时取两者的较小值,以缓解Q值过高估计的问题。
  • SACContinuous: SAC算法的主类,整合了上述所有网络和组件。它实现了动作选择、目标Q值计算、网络参数的软更新以及策略和价值网络的更新逻辑。其中,温度参数α的学习和更新是SAC的核心之一,它通过最大化熵的目标自动调整,平衡探索与利用。
"""
SAC (Soft Actor-Critic) 算法实现
用于连续动作空间的强化学习智能体
"""import torch  # PyTorch深度学习框架
import torch.nn as nn  # 神经网络模块
import torch.nn.functional as F  # 神经网络功能函数
import numpy as np  # 数值计算库
import random  # 随机数生成库
import collections  # 集合数据类型模块
from torch.distributions import Normal  # 正态分布类class ReplayBuffer:"""经验回放缓冲区类用于存储和采样智能体的经验数据"""def __init__(self, capacity):"""初始化经验回放缓冲区参数:capacity (int): 缓冲区容量,维度: 标量"""# 使用双端队列作为缓冲区存储结构# self.buffer维度: deque,最大长度为capacity# 存储格式: (state, action, reward, next_state, done)self.buffer = collections.deque(maxlen=capacity)def add(self, state, action, reward, next_state, done):"""向缓冲区添加一条经验参数:state (np.array): 当前状态,维度: [state_dim]action (float): 执行的动作,维度: 标量reward (float): 获得的奖励,维度: 标量next_state (np.array): 下一个状态,维度: [state_dim]done (bool): 是否结束,维度: 标量布尔值"""# 将经验元组添加到缓冲区# 元组维度: (state[state_dim], action[1], reward[1], next_state[state_dim], done[1])self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):"""从缓冲区随机采样一批经验参数:batch_size (int): 批次大小,维度: 标量返回:tuple: 包含状态、动作、奖励、下一状态、完成标志的元组state (np.array): 状态批次,维度: [batch_size, state_dim]action (tuple): 动作批次,维度: [batch_size]reward (tuple): 奖励批次,维度: [batch_size]next_state (np.array): 下一状态批次,维度: [batch_size, state_dim]done (tuple): 完成标志批次,维度: [batch_size]"""# 随机采样batch_size个经验# transitions维度: list,长度为batch_sizetransitions = random.sample(self.buffer, batch_size)# 将经验元组解包并转置# 每个元素的维度: state[batch_size个state_dim], action[batch_size], 等等state, action, reward, next_state, done = zip(*transitions)# 将状态转换为numpy数组便于后续处理# state维度: [batch_size, state_dim]# next_state维度: [batch_size, state_dim]return np.array(state), action, reward, np.array(next_state), donedef size(self):"""返回缓冲区当前大小返回:int: 缓冲区大小,维度: 标量"""return len(self.buffer)class PolicyNetContinuous(torch.nn.Module):"""连续动作空间的策略网络输出动作的均值和标准差,用于生成随机策略"""def __init__(self, state_dim, hidden_dim, action_dim, action_bound):"""初始化策略网络参数:state_dim (int): 状态空间维度,维度: 标量hidden_dim (int): 隐藏层维度,维度: 标量action_dim (int): 动作空间维度,维度: 标量action_bound (float): 动作边界值,维度: 标量"""super(PolicyNetContinuous, self).__init__()# 第一个全连接层:状态到隐藏层# 输入维度: [batch_size, state_dim]# 输出维度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 输出动作均值的全连接层# 输入维度: [batch_size, hidden_dim]# 输出维度: [batch_size, action_dim]self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)# 输出动作标准差的全连接层# 输入维度: [batch_size, hidden_dim]# 输出维度: [batch_size, action_dim]self.fc_std = torch.nn.Linear(hidden_dim, action_dim)# 动作边界值,用于缩放tanh输出# action_bound维度: 标量self.action_bound = action_bounddef forward(self, x):"""前向传播函数参数:x (torch.Tensor): 输入状态,维度: [batch_size, state_dim]返回:tuple: 包含动作和对数概率的元组action (torch.Tensor): 输出动作,维度: [batch_size, action_dim]log_prob (torch.Tensor): 动作对数概率,维度: [batch_size, action_dim]"""# 第一层激活# x维度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 计算动作均值# mu维度: [batch_size, action_dim]mu = self.fc_mu(x)# 计算动作标准差,使用softplus确保为正值# std维度: [batch_size, action_dim]std = F.softplus(self.fc_std(x))# 创建正态分布对象# dist维度: Normal分布对象,参数维度均为[batch_size, action_dim]dist = Normal(mu, std)# 重参数化采样,确保梯度可以反向传播# normal_sample维度: [batch_size, action_dim]normal_sample = dist.rsample()  # rsample()是重参数化采样"""重参数化采样是一种用于训练神经网络生成模型(Generative Models)的技术,特别是在概率编码器-解码器框架中常见,例如变分自编码器(Variational Autoencoder,VAE)。这种技术的目的是将采样过程通过神经网络的可导操作,使得模型可以被端到端地训练。在普通的采样过程中,由于采样操作本身是不可导的,传统的梯度下降方法无法直接用于训练神经网络。为了解决这个问题,引入了重参数化技巧。`dist.rsample()` 是重参数化采样的一部分。这里的重参数化指的是将采样操作重新参数化为可导的操作,使得梯度能够通过网络反向传播。通过这种方式,可以有效地训练生成模型,尤其是概率生成模型。在正态分布的情况下,传统的采样操作是直接从标准正态分布中抽取样本,然后通过线性变换得到最终的样本。而重参数化采样则通过在标准正态分布上进行采样,并通过神经网络产生的均值和标准差进行变换,使得采样操作变为可导的。这有助于在训练过程中通过梯度下降来优化网络参数。"""# 计算采样点的对数概率密度# log_prob维度: [batch_size, action_dim]log_prob = dist.log_prob(normal_sample)# 使用tanh函数将动作限制在[-1, 1]范围内# action维度: [batch_size, action_dim]action = torch.tanh(normal_sample)# 计算tanh_normal分布的对数概率密度# 根据变换的雅可比行列式调整概率密度# 避免数值不稳定性,添加小常数1e-7# log_prob维度: [batch_size, action_dim]log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)# 将动作缩放到实际的动作边界范围# action维度: [batch_size, action_dim]action = action * self.action_boundreturn action, log_probclass QValueNetContinuous(torch.nn.Module):"""连续动作空间的Q值网络输入状态和动作,输出对应的Q值"""def __init__(self, state_dim, hidden_dim, action_dim):"""初始化Q值网络参数:state_dim (int): 状态空间维度,维度: 标量hidden_dim (int): 隐藏层维度,维度: 标量action_dim (int): 动作空间维度,维度: 标量"""super(QValueNetContinuous, self).__init__()# 第一个全连接层:拼接状态和动作后的输入层# 输入维度: [batch_size, state_dim + action_dim]# 输出维度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)# 第二个隐藏层# 输入维度: [batch_size, hidden_dim]# 输出维度: [batch_size, hidden_dim]self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)# 输出层:输出Q值# 输入维度: [batch_size, hidden_dim]# 输出维度: [batch_size, 1]self.fc_out = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):"""前向传播函数参数:x (torch.Tensor): 输入状态,维度: [batch_size, state_dim]a (torch.Tensor): 输入动作,维度: [batch_size, action_dim]返回:torch.Tensor: Q值,维度: [batch_size, 1]"""# 将状态和动作拼接作为网络输入# cat维度: [batch_size, state_dim + action_dim]cat = torch.cat([x, a], dim=1)# 第一层激活# x维度: [batch_size, state_dim + action_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(cat))# 第二层激活# x维度: [batch_size, hidden_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc2(x))# 输出Q值# 返回值维度: [batch_size, 1]return self.fc_out(x)class SACContinuous:"""SAC (Soft Actor-Critic) 算法实现类处理连续动作空间的强化学习问题SAC 使用两个 Critic 网络来使 Actor 的训练更稳定,而这两个 Critic 网络在训练时则各自需要一个目标价值网络。因此,SAC 算法一共用到 5 个网络,分别是一个策略网络、两个价值网络和两个目标价值网络。"""def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,device):"""初始化SAC算法参数:state_dim (int): 状态空间维度,维度: 标量hidden_dim (int): 隐藏层维度,维度: 标量action_dim (int): 动作空间维度,维度: 标量action_bound (float): 动作边界值,维度: 标量actor_lr (float): 策略网络学习率,维度: 标量critic_lr (float): 价值网络学习率,维度: 标量alpha_lr (float): 温度参数学习率,维度: 标量target_entropy (float): 目标熵值,维度: 标量tau (float): 软更新参数,维度: 标量gamma (float): 折扣因子,维度: 标量device (torch.device): 计算设备,维度: 设备对象"""# 策略网络:输出动作分布# self.actor维度: PolicyNetContinuous对象self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim,action_bound).to(device)# 第一个Q网络:评估状态-动作价值# self.critic_1维度: QValueNetContinuous对象self.critic_1 = QValueNetContinuous(state_dim, hidden_dim,action_dim).to(device)# 第二个Q网络:评估状态-动作价值# self.critic_2维度: QValueNetContinuous对象self.critic_2 = QValueNetContinuous(state_dim, hidden_dim,action_dim).to(device)# 第一个目标Q网络:用于计算目标Q值# self.target_critic_1维度: QValueNetContinuous对象self.target_critic_1 = QValueNetContinuous(state_dim,hidden_dim, action_dim).to(device)# 第二个目标Q网络:用于计算目标Q值# self.target_critic_2维度: QValueNetContinuous对象self.target_critic_2 = QValueNetContinuous(state_dim,hidden_dim, action_dim).to(device)# 令目标Q网络的初始参数和Q网络一样self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())# 策略网络优化器# self.actor_optimizer维度: torch.optim.Adam对象self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr
http://www.dtcms.com/a/287086.html

相关文章:

  • 消息队列与信号量:System V 进程间通信的基础
  • 【机器学习深度学习】为什么要将模型转换为 GGUF 格式?
  • win10连接鼠标自动关闭触摸板/win10关闭触摸板(笔记本)
  • 路由器的Serial 串口理解
  • 移除debian升级后没用的垃圾
  • 爬虫逆向之JS混淆案例(全国招标公告公示搜索引擎 type__1017逆向)
  • AJAX概述
  • Unity 3D碰撞器
  • C语言—深入理解指针(详)
  • Eureka 和 Nacos
  • 医疗AI与融合数据库的整合:挑战、架构与未来展望(下)
  • Acrobat SDK 核心架构、应用
  • 2025年最新秋招java后端面试八股文+场景题
  • Linux操作系统之线程(三)
  • 动态规划算法的欢乐密码(三):简单多状态DP问题(上)
  • VBA 运用LISTBOX插件,选择多个选项,并将选中的选项回车录入当前选中的单元格
  • 【Linux系统】进程控制
  • 高效适配多分辨率!Unity动态UI缩放工具 Resize Pro 免费分享
  • 用户中心项目实战(springboot+vue快速开发管理系统)
  • Window延迟更新10000天配置方案
  • 【逻辑回归】MAP - Charting Student Math Misunderstandings
  • PostgreSQL ORDER BY 语句详解
  • bash方式启动模型训练
  • tkinter绘制组件(45)——导航栏
  • EP01:【Python 第一弹】基础入门知识
  • 国产电科金仓数据库:融合进化,智领未来
  • C++进阶课程第4期——动态规划
  • FastAPI遇上GraphQL:异步解析器如何让API性能飙升?
  • C++中的list(1)
  • c#中ArrayList和List的常用方法