当前位置: 首页 > news >正文

从代码学习深度强化学习 - 模仿学习 PyTorch版

文章目录

  • 前言
  • 专家:PPO 智能体训练
    • 1. 训练 PPO 的辅助函数
    • 2. 环境设置
    • 3. PPO 智能体定义
    • 4. PPO 训练主流程
    • 5. 实例化与训练
  • 方法一:行为克隆 (Behavior Cloning, BC)
    • 1. 理论简介
    • 2. 采样专家数据
    • 3. BC 智能体实现与训练
  • 方法二:生成对抗模仿学习 (GAIL)
    • 1. 理论简介
    • 2. 判别器实现
    • 3. GAIL 算法实现
    • 4. GAIL 训练主流程
  • 总结

前言

在传统的强化学习(RL)框架中,奖励函数(Reward Function)的设计是至关重要的环节。它像一座灯塔,指引着智能体(Agent)在复杂的环境中学习最优策略。然而,在许多现实世界的复杂任务中,设计一个精确、有效且能避免智能体“钻空子”的奖励函数,本身就是一项巨大的挑战。例如,我们如何为自动驾驶汽车定义一个完美的奖励函数?仅仅奖励“到达目的地”可能导致鲁莽驾驶,而加入“遵守交规”、“保持平稳”等规则又会使函数变得异常复杂,稍有不慎就可能产生意想不到的负面效果。

这时候,模仿学习(Imitation Learning, IL) 提供了一个全新的视角。与其绞尽脑汁地设计奖励函数,我们不如直接向“专家”学习。假设我们有一个专家(可以是一个人类操作员,也可以是一个已经训练好的RL模型),它能够为任务提供一系列高质量的演示(demonstrations)。模仿学习的目标就是让智能体通过学习这些专家的“状态-动作”数据,来复现甚至超越专家的行为策略,而整个过程可以完全不依赖于环境提供的奖励信号。

目前,主流的模仿学习方法大致可以分为三类:

  1. 行为克隆 (Behavior Cloning, BC):最简单直接的方法,将模仿学习问题转化为一个监督学习问题。
  2. 逆强化学习 (Inverse Reinforcement Learning, IRL):尝试从专家的行为中反推出其背后潜在的奖励函数。
  3. 生成对抗模仿学习 (Generative Adversarial Imitation Learning, GAIL):借鉴了生成对抗网络(GAN)的思想,通过一个判别器来间接指导策略的学习。

本篇博客将聚焦于行为克隆(BC)和生成对抗模仿学习(GAIL),通过详细的代码实践,带领大家深入理解这两种主流模仿学习算法的原理与实现。我们将使用 PyTorch 框架,并在经典的 CartPole-v1 环境中完成所有实验。

完整代码:下载链接

在开始模仿之前,我们首先需要一个“专家”。因此,我们的第一步是使用强大的 PPO 算法训练一个表现优异的专家智能体。

专家:PPO 智能体训练

PPO (Proximal Policy Optimization) 是一种非常流行且效果稳健的强化学习算法。我们将首先用它来解决 CartPole-v1 问题,训练出一个能够持续获得高分的专家模型。这个模型将为我们后续的模仿学习算法提供高质量的专家数据。

1. 训练 PPO 的辅助函数

这里我们定义了两个工具函数:compute_advantage 用于计算广义优势估计(GAE),这是 PPO 算法稳定性的关键;moving_average 用于平滑奖励曲线,方便我们观察训练趋势。

# utils
# 根据用拿到的reward和评价网络对下一个状态的估值之和与评价网络对当前状态的估值的TD误差计算优势
# 这个函数的主要目的是使用 GAE 方法计算 Advantage,其中 Advantage 被用于更新策略网络的参数。 
# GAE 通过权衡近期优势和未来优势的重要性,提供了一种更加折中的方法
# utils
# 广义优势估计(Generalized Advantage Estimation,GAE)"""
强化学习工具函数集
包含广义优势估计(GAE)和数据平滑处理功能
"""import torch
import numpy as npdef compute_advantage(gamma, lmbda, td_delta):"""计算广义优势估计(Generalized Advantage Estimation,GAE)GAE是一种在强化学习中用于减少策略梯度方差的技术,通过对时序差分误差进行指数加权平均来估计优势函数,平衡偏差和方差的权衡。参数:gamma (float): 折扣因子,维度: 标量取值范围[0,1],决定未来奖励的重要性lmbda (float): GAE参数,维度: 标量  取值范围[0,1],控制偏差-方差权衡lmbda=0时为TD(0)单步时间差分,lmbda=1时为蒙特卡洛方法用采样到的奖励-状态价值估计td_delta (torch.Tensor): 时序差分误差序列,维度: [时间步数]包含每个时间步的TD误差值返回:torch.Tensor: 广义优势估计值,维度: [时间步数]与输入td_delta维度相同的优势函数估计数学公式:A_t^GAE(γ,λ) = Σ_{l=0}^∞ (γλ)^l * δ_{t+l}其中 δ_t = r_t + γV(s_{t+1}) - V(s_t) 是TD误差"""# 将PyTorch张量转换为NumPy数组进行计算# td_delta维度: [时间步数] -> [时间步数]td_delta = td_delta.detach().numpy() # 因为A用来求g的,需要梯度,防止梯度向下传播# 初始化优势值列表,用于存储每个时间步的优势估计# advantage_list维度: 最终为[时间步数]advantage_list = []# 初始化当前优势值,从序列末尾开始反向计算# advantage维度: 标量advantage = 0.0# 从时间序列末尾开始反向遍历TD误差# 反向计算是因为GAE需要利用未来的信息# delta维度: 标量(td_delta中的单个元素)for delta in td_delta[::-1]:  # [::-1]实现序列反转# GAE递归公式:A_t = δ_t + γλA_{t+1}# gamma * lmbda * advantage: 来自未来时间步的衰减优势值# delta: 当前时间步的TD误差# advantage维度: 标量advantage = gamma * lmbda * advantage + delta# 将计算得到的优势值添加到列表中# advantage_list维度: 逐步增长到[时间步数]advantage_list.append(advantage)# 由于是反向计算,需要将结果列表反转回正确的时间顺序# advantage_list维度: [时间步数](时间顺序已恢复)advantage_list.reverse()# 将NumPy列表转换回PyTorch张量并返回# 返回值维度: [时间步数]return torch.tensor(advantage_list, dtype=torch.float)def moving_average(data, window_size):"""计算移动平均值,用于平滑奖励曲线该函数通过滑动窗口的方式对时间序列数据进行平滑处理,可以有效减少数据中的噪声,使曲线更加平滑美观。常用于强化学习中对训练过程的奖励曲线进行可视化优化。参数:data (list): 原始数据序列,维度: [num_episodes]包含需要平滑处理的数值数据(如每轮训练的奖励值)window_size (int): 移动窗口大小,维度: 标量决定了平滑程度,窗口越大平滑效果越明显但也会导致更多的数据点丢失返回:list: 移动平均后的数据,维度: [len(data) - window_size + 1]返回的数据长度会比原数据少 window_size - 1 个元素这是因为需要足够的数据点来计算第一个移动平均值示例:>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]  # 维度: [10]>>> smoothed = moving_average(data, 3)       # window_size = 3>>> print(smoothed)  # 输出: [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]  维度: [8]"""# 边界检查:如果数据长度小于窗口大小,直接返回原数据# 这种情况下无法计算移动平均值# data维度: [num_episodes], window_size维度: 标量if len(data) < window_size:return data# 初始化移动平均值列表# moving_avg维度: 最终为[len(data) - window_size + 1]moving_avg = []# 遍历数据,计算每个窗口的移动平均值# i的取值范围: 0 到 len(data) - window_size# 循环次数: len(data) - window_size + 1# 每次循环处理一个滑动窗口位置for i in range(len(data) - window_size + 1):# 提取当前窗口内的数据切片# window_data维度: [window_size]# 包含从索引i开始的连续window_size个元素# 例如:当i=0, window_size=3时,提取data[0:3]window_data = data[i:i + window_size]# 计算当前窗口内数据的算术平均值# np.mean(window_data)维度: 标量# 将平均值添加到结果列表中moving_avg.append(np.mean(window_data))# 返回移动平均后的数据列表# moving_avg维度: [len(data) - window_size + 1]return moving_avg

2. 环境设置

我们使用 gym 库来创建 CartPole-v1 环境。

#env
"""
强化学习环境初始化模块
用于创建和配置OpenAI Gym环境
"""import gym# 环境配置
# 定义要使用的强化学习环境名称
# CartPole-v1是经典的平衡杆控制问题:
# - 状态空间:4维连续空间(车位置、车速度、杆角度、杆角速度)
# - 动作空间:2维离散空间(向左推车、向右推车)
# - 目标:保持杆子平衡尽可能长的时间
# env_name维度: 标量(字符串)
env_name = 'CartPole-v1'# 创建强化学习环境实例
# gym.make()函数根据环境名称创建对应的环境对象
# 该环境对象包含了状态空间、动作空间、奖励函数等定义
# env维度: gym.Env对象(包含状态空间[4]和动作空间[2]的环境实例)
# env.observation_space.shape: (4,) - 观测状态维度
# env.action_space.n: 2 - 离散动作数量
env = gym.make(env_name)

3. PPO 智能体定义

PPO 采用 Actor-Critic 架构。PolicyNet (Actor) 负责输出动作概率,ValueNet (Critic) 负责评估状态的价值。PPO 类将这两者整合起来,并实现了 PPO 的核心更新逻辑。

# 智能体,主要实现网络定义,网络更新,更新算法,根据状态做出动作"""
PPO(Proximal Policy Optimization)算法实现
包含策略网络、价值网络和PPO智能体的完整定义
"""import torch
import torch.nn.functional as F
import numpy as npclass PolicyNet(torch.nn.Module):"""策略网络(Actor Network)用于输出动作概率分布,指导智能体如何选择动作"""def __init__(self, state_dim, hidden_dim, action_dim):"""初始化策略网络参数:state_dim (int): 状态空间维度,维度: 标量对于CartPole-v1环境,state_dim=4hidden_dim (int): 隐藏层神经元数量,维度: 标量控制网络的表达能力action_dim (int): 动作空间维度,维度: 标量对于CartPole-v1环境,action_dim=2"""super(PolicyNet, self).__init__()# 第一层全连接层:状态输入 -> 隐藏层# 输入维度: [batch_size, state_dim] -> 输出维度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二层全连接层:隐藏层 -> 动作概率# 输入维度: [batch_size, hidden_dim] -> 输出维度: [batch_size, action_dim]self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):"""前向传播过程参数:x (torch.Tensor): 输入状态,维度: [batch_size, state_dim]返回:torch.Tensor: 动作概率分布,维度: [batch_size, action_dim]每行为一个状态对应的动作概率分布,概率和为1"""# 第一层 + ReLU激活函数# x维度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 第二层 + Softmax激活函数,输出概率分布# x维度: [batch_size, hidden_dim] -> [batch_size, action_dim]# dim=1表示在第1维(动作维度)上进行softmax,确保每行概率和为1return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):"""价值网络(Critic Network)用于估计状态价值函数V(s),评估当前状态的好坏"""def __init__(self, state_dim, hidden_dim):"""初始化价值网络参数:state_dim (int): 状态空间维度,维度: 标量对于CartPole-v1环境,state_dim=4hidden_dim (int): 隐藏层神经元数量,维度: 标量控制网络的表达能力"""super(ValueNet, self).__init__()# 第一层全连接层:状态输入 -> 隐藏层# 输入维度: [batch_size, state_dim] -> 输出维度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二层全连接层:隐藏层 -> 状态价值(标量)# 输入维度: [batch_size, hidden_dim] -> 输出维度: [batch_size, 1]self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):"""前向传播过程参数:x (torch.Tensor): 输入状态,维度: [batch_size, state_dim]返回:torch.Tensor: 状态价值估计,维度: [batch_size, 1]每行为一个状态对应的价值估计"""# 第一层 + ReLU激活函数# x维度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 第二层,输出状态价值(无激活函数,可以输出负值)# x维度: [batch_size, hidden_dim] -> [batch_size, 1]return self.fc2(x)class PPO:"""PPO(Proximal Policy Optimization)算法实现采用截断方式防止策略更新过大,确保训练稳定性"""def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):"""初始化PPO智能体参数:state_dim (int): 状态空间维度,维度: 标量hidden_dim (int): 隐藏层神经元数量,维度: 标量action_dim (int): 动作空间维度,维度: 标量actor_lr (float): Actor网络学习率,维度: 标量critic_lr (float): Critic网络学习率,维度: 标量lmbda (float): GAE参数λ,维度: 标量,取值范围[0,1]epochs (int): 每次更新的训练轮数,维度: 标量eps (float): PPO截断参数ε,维度: 标量,通常取0.1-0.3gamma (float): 折扣因子γ,维度: 标量,取值范围[0,1]device (torch.device): 计算设备(CPU或GPU),维度: 标量"""# 初始化Actor网络(策略网络)# 网络参数维度:fc1权重[state_dim, hidden_dim], fc2权重[hidden_dim, action_dim]self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)# 初始化Critic网络(价值网络)# 网络参数维度:fc1权重[state_dim, hidden_dim], fc2权重[hidden_dim, 1]self.critic = ValueNet(state_dim, hidden_dim).to(device)# 初始化Actor网络优化器# 优化器管理Actor网络的所有参数self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)# 初始化Critic网络优化器# 优化器管理Critic网络的所有参数self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)# 存储算法超参数self.gamma = gamma      # 折扣因子,维度: 标量self.lmbda = lmbda      # GAE参数,维度: 标量self.epochs = epochs    # 一条序列的数据用来训练的轮数,维度: 标量self.eps = eps          # PPO中截断范围的参数,维度: 标量self.device = device    # 计算设备,维度: 标量def take_action(self, state):"""根据当前状态选择动作参数:state (list/np.array): 当前状态,维度: [state_dim]例如CartPole-v1中为[位置, 速度, 角度, 角速度]返回:int: 选择的动作索引,维度: 标量对于CartPole-v1,返回0(向左)或1(向右)"""# 将状态转换为批次格式,添加batch维度# state维度: [state_dim] -> [1, state_dim]state = np.array([state])# 转换为PyTorch张量并移动到指定设备# state维度: [1, state_dim]state = torch.tensor(state, dtype=torch.float).to(self.device)# 通过Actor网络获取动作概率分布# probs维度: [1, action_dim],每个元素表示对应动作的概率probs = self.actor(state)# 创建分类分布对象,用于从概率分布中采样# action_dist: 分类分布对象,基于probs概率分布action_dist = torch.distributions.Categorical(probs)# 从概率分布中采样一个动作# action维度: [1],包含采样得到的动作索引action = action_dist.sample()# 返回动作的具体数值(去除张量包装)# 返回值维度: 标量(整数)return action.item()def update(self, transition_dict):"""使用收集的经验数据更新Actor和Critic网络参数:transition_dict (dict): 包含经验数据的字典,包含以下键值对:'states': 状态序列,维度: [序列长度, state_dim]'actions': 动作序列,维度: [序列长度]'rewards': 奖励序列,维度: [序列长度]'next_states': 下一状态序列,维度: [序列长度, state_dim]'dones': 终止标志序列,维度: [序列长度]"""# 提取并转换状态数据# states维度: [序列长度, state_dim]states = np.array(transition_dict['states'])states = torch.tensor(states, dtype=torch.float).to(self.device)# 提取并转换动作数据,调整为列向量格式# actions维度: [序列长度] -> [序列长度, 1]actions 
http://www.dtcms.com/a/317672.html

相关文章:

  • 【数据库】MySQL详解:关系型数据库的王者
  • MySQL和Navicat Premium的安装
  • stm32项目(22)——基于stm32的智能病房监护系统
  • Python面试题及详细答案150道(01-15) -- 基础语法篇
  • 代数——第6章——对称性(Michael Artin)
  • vue3 find 数组查找方法
  • CPP网络编程-异步sever
  • FPGA学习笔记——VGA彩条显示
  • python:非常流行和重要的Python机器学习库scikit-learn 介绍
  • STM32学习笔记3-GPIO输入部分
  • WMS及UI渲染底层原理学习
  • 【STM32 LWIP配置】STM32H723ZG + Ethernet +LWIP 配置 cubemx
  • 无人机图传的得力助手:5G 便携式多卡高清视频融合终端的协同应用
  • 中宇联5G云宽带+4G路由器:解锁企业办公高效协同与门店体验升级
  • 图解 Claude Code 子智能体 Sub-agent
  • [ java GUI ] 图形用户界面
  • 【软考系统架构设计师备考笔记4】 - 英语语法一篇通
  • ctfshow_vip题目限免-----SVN漏洞,git泄露
  • Git Cherry-Pick 指南
  • Leetcode——菜鸟笔记1
  • Git 分支管理:从新开发分支迁移为主分支的完整指南
  • 鸿蒙app 开发中 全局弹窗类的封装 基于PromptAction
  • C#之基础语法
  • 机器学习之朴素贝叶斯
  • Suno API V5模型 php源码 —— 使用灵感模式进行出创作
  • 基于PHP的论坛社交网站系统开发与设计
  • 排序算法详解
  • 媒体资产管理系统和OCR文字识别的结合
  • Ethereum: L1 与 L2 的安全纽带, Rollups 技术下的协作与区别全解析
  • 解决启动docker报错Cannot connect to the Docker daemon问题