当前位置: 首页 > news >正文

【零基础学AI】第35讲:策略梯度方法 - 连续控制任务实战

在这里插入图片描述

本节课你将学到

  • 理解策略梯度与值函数的本质区别
  • 掌握REINFORCE和Actor-Critic算法
  • 使用PyTorch实现连续动作空间控制
  • 训练机械臂控制AI智能体

开始之前

环境要求

  • Python 3.8+
  • PyTorch 2.0+
  • MuJoCo物理引擎 (需要许可证)
  • Gymnasium[mujoco]
  • 必须使用GPU加速训练

前置知识

  • 深度Q网络(第34讲)
  • 概率与统计基础
  • 梯度下降原理(第23讲)

核心概念

策略梯度 vs 值函数方法

特性值函数方法 (如DQN)策略梯度方法
输出动作价值Q(s,a)动作概率分布π(a
动作空间离散有限连续/离散皆可
探索方式ε-greedy通过策略本身随机性
策略更新间接通过值函数直接优化策略参数

策略梯度定理

∇J(θ) = 𝔼[∇logπ(a|s) * Q(s,a)]

  • J(θ):策略性能指标
  • π(a|s):策略选择动作的概率
  • Q(s,a):动作价值函数
状态s
策略π
动作a
环境
奖励r+新状态s'
计算梯度
更新策略

连续动作空间处理

使用高斯分布参数化策略:

  • 均值μ:由神经网络输出
  • 标准差σ:可学习参数或固定值
  • 动作采样:a ~ N(μ, σ²)

代码实战

1. 环境配置(MuJoCo)

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal# 创建机械臂控制环境
env = gym.make('Reacher-v4', render_mode='rgb_array')
print("观测空间:", env.observation_space)
print("动作空间:", env.action_space)# 环境测试
state, _ = env.reset()
for _ in range(50):action = env.action_space.sample()  # 随机动作state, reward, done, _, _ = env.step(action)if done:break
env.close()# ⚠️ 常见错误1:MuJoCo许可证问题
# 解决方案:
# 1. 获取教育许可证(mujoco.org)
# 2. 使用替代环境如PyBullet

2. 策略网络实现

class PolicyNetwork(nn.Module):def __init__(self, state_dim, action_dim, hidden_size=256):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_size)self.fc2 = nn.Linear(hidden_size, hidden_size)# 输出均值和标准差self.mean_head = nn.Linear(hidden_size, action_dim)self.log_std_head = nn.Linear(hidden_size, action_dim)# 初始化参数self.log_std_min = -20self.log_std_max = 2def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))mean = self.mean_head(x)log_std = self.log_std_head(x)log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)return mean, log_stddef act(self, state, deterministic=False):state = torch.FloatTensor(state).unsqueeze(0)mean, log_std = self.forward(state)std = log_std.exp()dist = Normal(mean, std)if deterministic:action = meanelse:action = dist.sample()# 限制动作范围action = torch.tanh(action)  # [-1, 1]return action.detach().numpy()[0]# 测试网络
policy = PolicyNetwork(env.observation_space.shape[0], env.action_space.shape[0])
test_action = policy.act(state)
print("生成的动作:", test_action)# ⚠️ 常见错误2:动作范围不匹配
# 确保输出动作在环境要求的范围内(如tanh激活)

3. REINFORCE算法实现

class REINFORCE:def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99):self.policy = PolicyNetwork(state_dim, action_dim)self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)self.gamma = gammaself.saved_log_probs = []self.rewards = []def select_action(self, state):state = torch.FloatTensor(state)mean, log_std = self.policy(state)std = log_std.exp()dist = Normal(mean, std)action = dist.sample()log_prob = dist.log_prob(action).sum()self.saved_log_probs.append(log_prob)return torch.tanh(action).detach().numpy()def update(self):R = 0policy_loss = []returns = []# 计算每个时间步的回报for r in reversed(self.rewards):R = r + self.gamma * Rreturns.insert(0, R)returns = torch.tensor(returns)returns = (returns - returns.mean()) / (returns.std() + 1e-9)for log_prob, R in zip(self.saved_log_probs, returns):policy_loss.append(-log_prob * R)self.optimizer.zero_grad()policy_loss = torch.stack(policy_loss).sum()policy_loss.backward()self.optimizer.step()# 清空缓存del self.rewards[:]del self.saved_log_probs[:]# 初始化
agent = REINFORCE(env.observation_space.shape[0],env.action_space.shape[0])

4. 训练循环

def train(num_episodes=1000):scores = []for i_episode in range(1, num_episodes+1):state, _ = env.reset()episode_reward = 0while True:action = agent.select_action(state)state, reward, done, _, _ = env.step(action)agent.rewards.append(reward)episode_reward += rewardif done:breakagent.update()scores.append(episode_reward)# 打印训练进度if i_episode % 10 == 0:avg_score = np.mean(scores[-10:])print(f"Episode {i_episode}, Avg Reward: {avg_score:.1f}")# 可视化演示if avg_score > -20:  # 当策略较好时展示demo_episode(agent.policy)return scoresdef demo_episode(policy, max_steps=200):state, _ = env.reset()frames = []for _ in range(max_steps):frames.append(env.render())action = policy.act(state, deterministic=True)state, _, done, _, _ = env.step(action)if done:breakenv.close()# 保存为GIFsave_frames_as_gif(frames, path='./output/reacher_demo.gif')# 开始训练
scores = train(num_episodes=500)

完整项目

项目结构:

lesson_35_policy_gradient/
├── envs/
│   ├── mujoco_wrapper.py  # 环境预处理
│   └── utils.py           # 辅助函数
├── networks/
│   ├── policy.py          # 策略网络
│   └── value.py           # 价值网络(AC用)
├── agents/
│   ├── reinforce.py       # REINFORCE算法
│   └── actor_critic.py    # Actor-Critic算法
├── configs/
│   └── hyperparams.yaml   # 超参数配置
├── train.py               # 训练脚本
├── eval.py                # 评估脚本
├── requirements.txt       # 依赖列表
└── README.md              # 项目说明

requirements.txt

gymnasium[mujoco]==0.28.1
torch==2.0.1
numpy==1.24.3
matplotlib==3.7.1
imageio==2.31.1
mujoco==2.3.3

configs/hyperparams.yaml

# REINFORCE参数
policy_lr: 0.0003
gamma: 0.99
hidden_size: 256# 环境参数
env_name: "Reacher-v4"
max_episode_steps: 200# 训练参数
num_episodes: 1000
log_interval: 10

运行效果

训练过程输出

Episode 10, Avg Reward: -45.2
Episode 20, Avg Reward: -32.5
...
Episode 500, Avg Reward: -12.3

常见问题

Q1: 训练初期奖励极低且不提升

解决方案:

  1. 增大初始探索(增大策略输出的标准差)
  2. 使用课程学习(从简单任务开始)
  3. 检查奖励函数设计是否合理

Q2: 策略收敛到局部最优

可能原因:

  1. 学习率过高导致震荡(尝试减小到0.0001)
  2. 探索不足(保持适度的策略随机性)
  3. 网络容量不足(增加隐藏层大小)

Q3: 如何应用到真实机器人?

调整建议:

  1. 使用仿真到真实迁移技术(Sim2Real)
  2. 添加域随机化(随机化物理参数)
  3. 考虑安全约束(限制动作幅度)

课后练习

  1. Actor-Critic改进
    实现带基线函数的Actor-Critic算法,比较与REINFORCE的性能差异

  2. PPO算法实现
    实现近端策略优化(PPO)的clip目标函数,提高训练稳定性

  3. 多任务学习
    修改网络结构使同一策略能完成多个MuJoCo任务

  4. 分层策略
    实现高层策略和底层控制器的分层强化学习架构


扩展阅读

  1. Policy Gradient经典论文
  2. OpenAI Spinning Up教程
  3. DeepMind控制论文集

下节预告:第36讲将深入探讨多智能体强化学习,实现博弈对抗AI!

http://www.dtcms.com/a/272659.html

相关文章:

  • Swift 图论实战:DFS 算法解锁 LeetCode 323 连通分量个数
  • 快速搭建服务器,fetch请求从服务器获取数据
  • ReentrantLock 与 Synchronized 的区别
  • 给MySQL做定时备份,一天3次
  • method_name字段是什么
  • 单片机基础(STM32-DAY2(GPIO))
  • Linux驱动06 --- UDP
  • 飞书AI技术体系
  • web 系统对接飞书三方登录完整步骤实战使用示例
  • 低温冷启动 高温热启动
  • OpenCV 图像进阶处理:特征提取与车牌识别深度解析
  • 醋酸镨:闪亮的稀土宝藏,掀开科技应用新篇章
  • Spring IoC 如何注入一些简单的值(比如配置文件里的字符串、数字)?
  • 【文献阅读】Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data
  • MyBatis 使用教程及插件开发
  • 自动驾驶环境感知:天气数据采集与融合技术实战
  • AI-Sphere-Butler项目语音切换数字人管家形象功能老是开发不成功。
  • Oracle 数据库管理与维护实战指南(用户权限、备份恢复、性能调优)
  • 深度学习与图像处理案例 │ 基于深度学习的自动驾驶小车
  • GitHub上优秀的开源播放器项目介绍及优劣对比
  • 申请注册苹果iOS企业级开发者证书需要公司拥有什么规模条件
  • Nacos的基本功能以及使用Feign进行微服务间的通信
  • 【网络编程】 TCP 协议栈的知识汇总
  • ZW3D 二次开发-创建圆柱体
  • Qt cannot find C:\WINDOWS\TEMP\cctVBBgu: Invalid argument
  • QT5使用cmakelists引入Qt5Xlsx库并使用
  • 达梦数据库不兼容 SQL_NO_CACHE 报错解决方案
  • C++交叉编译工具链制作以及QT交叉编译环境配置
  • 生产环境CI/CD流水线构建与优化实践指南
  • 医院多部门协同构建知识库-指南库-预测模型三维网络路径研究