当前位置: 首页 > news >正文

强化学习PPO/DDPG算法学习记录

强化学习PPO算法详解

核心思想

  • 直接学习一个策略函数pi(a|s), 在状态s下要输出的动作a的概率分布(离散情况下是每个action的概率,连续情况下是mean和std指定的高斯分布)
  • 策略梯度算法如果更新的步长太大,一次更新就能毁掉整个策略,所以通过一个裁剪函数,防止新策略和旧策略差距太大,保证“小幅且安全”。
  • 多步更新:在收集一批数据后,用小批量数据对策略进行多次epochs更新,提高样本效率。

参考代码:

这个PPO实现代码有什么问题?里面好像没有用到critic网络的结果:import torch
import torch.nn as nn
import torch.optim as optimclass PPO(nn.Module):def __init__(self, state_dim, action_dim):super(PPO, self).__init__()self.actor = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, action_dim),nn.Softmax(dim=-1))self.critic = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, 1))def forward(self, state):return self.actor(state), self.critic(state)class PPOAgent:def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2):self.ppo = PPO(state_dim, action_dim)self.optimizer = optim.Adam(self.ppo.parameters(), lr=lr)self.gamma = gammaself.epsilon = epsilondef update(self, states, actions, rewards, next_states, dones):states = torch.FloatTensor(states)actions = torch.LongTensor(actions)rewards = torch.FloatTensor(rewards)next_states = torch.FloatTensor(next_states)dones = torch.FloatTensor(dones)# 1. 首先,用当前的策略网络(不计算梯度)计算旧概率 old_probswith torch.no_grad():old_probs, old_state_values = self.ppo(states)old_probs = old_probs.gather(1, actions.unsqueeze(1)).squeeze(1)# 计算下一个状态的价值_, next_state_values = self.ppo(next_states)# 计算价值目标:如果done了,next_state_value就是0value_targets = rewards + self.gamma * next_state_values.squeeze(1) * (1 - dones)# 计算优势函数 A(s,a) = value_target - old_state_valueadvantages = value_targets - old_state_values.squeeze(1)# 通常会对advantages进行标准化,以减少方差advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)for _ in range(10):  # 多次更新# 2. 用当前策略网络计算新概率和状态价值new_probs, state_values = self.ppo(states)new_probs = new_probs.gather(1, actions.unsqueeze(1)).squeeze(1)# 3. 计算重要性采样比率ratio = new_probs / old_probs# 4. 计算Clipped Surrogate Losssurr1 = ratio * advantagessurr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantagesactor_loss = -torch.min(surr1, surr2).mean()# 5. 计算Critic Loss (MSE between value_targets and current value estimates)critic_loss = nn.MSELoss()(state_values.squeeze(1), value_targets)# 6. 总损失loss = actor_loss + 0.5 * critic_lossself.optimizer.zero_grad()loss.backward()self.optimizer.step()def get_action(self, state):state = torch.FloatTensor(state)probs, _ = self.ppo(state)return torch.multinomial(probs, 1).item()

PPO是在线策略,输出的是概率,更新稳健,是策略网络的集大成者。

DDPG算法详解

核心改进

相比于DQN,DDPG的核心改进在于:

  • DQN的动作空间是离散的,例如上下左右开火等,而DDPG的动作空间是连续的
  • DQN输出的是Q值,然后选择最大的对应的动作,DDPG直接输出动作
  • DQN是value-based,而DDPG是Policy- based
  • DQN通常只有一个Q网络,DDPG要有Actor和critic两个网络

一句话总结:DQN是为了解决离散控制问题,DDPG主要是针对连续控制领域的,是DQN和策略网络的结合。
参考代码(原文章:Deep Reinforcement Learning (DRL) 算法在 PyTorch 中的实现与应用):

import torch
import torch.nn as nn
import torch.optim as optimclass Actor(nn.Module):def __init__(self, state_dim, action_dim, max_action):super(Actor, self).__init__()self.fc1 = nn.Linear(state_dim, 400)self.fc2 = nn.Linear(400, 300)self.fc3 = nn.Linear(300, action_dim)self.max_action = max_actiondef forward(self, state):a = torch.relu(self.fc1(state))a = torch.relu(self.fc2(a))return self.max_action * torch.tanh(self.fc3(a))class Critic(nn.Module):def __init__(self, state_dim, action_dim):super(Critic, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, 400)self.fc2 = nn.Linear(400, 300)self.fc3 = nn.Linear(300, 1)def forward(self, state, action):q = torch.cat([state, action], 1)q = torch.relu(self.fc1(q))q = torch.relu(self.fc2(q))return self.fc3(q)class DDPGAgent:def __init__(self, state_dim, action_dim, max_action, lr=1e-4, gamma=0.99, tau=0.001):self.actor = Actor(state_dim, action_dim, max_action)self.actor_target = Actor(state_dim, action_dim, max_action)self.actor_target.load_state_dict(self.actor.state_dict())self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)self.critic = Critic(state_dim, action_dim)self.critic_target = Critic(state_dim, action_dim)self.critic_target.load_state_dict(self.critic.state_dict())self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)self.gamma = gammaself.tau = taudef select_action(self, state):state = torch.FloatTensor(state.reshape(1, -1))return self.actor(state).cpu().data.numpy().flatten()def update(self, replay_buffer, batch_size=100):# 从经验回放中采样state, action, next_state, reward, done = replay_buffer.sample(batch_size)# 计算目标Q值target_Q = self.critic_target(next_state, self.actor_target(next_state))target_Q = reward + (1 - done) * self.gamma * target_Q.detach()# 更新Criticcurrent_Q = self.critic(state, action)critic_loss = nn.MSELoss()(current_Q, target_Q)self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()# 更新Actoractor_loss = -self.critic(state, self.actor(state)).mean()self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 软更新目标网络for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
http://www.dtcms.com/a/363550.html

相关文章:

  • 图像编码之摄像机的H264 分块编码的含义是什么,以分块编码(tile)192X192为例子说明,好处与缺点分别是什么
  • Day19(前端:JavaScript基础阶段)
  • Linux笔记14——shell编程基础-8
  • 解决戴尔笔记本电脑键盘按键部分失灵
  • 未来工厂雏形:基于Three.js的自主演进式数字孪生系统设计
  • Qwen3-Reranker-0.6B 模型结构
  • Coze平台指南(2):开发环境的搭建与配置
  • Cisco FMC利用sftp Server拷贝文件方法
  • Ubuntu中配置JMmeter工具
  • 从零开始:用代码解析区块链的核心工作原理
  • Ubuntu 24.04 服务器配置MySQL 8.0.42 三节点集群(一主两从架构)安装部署配置教程
  • 软件设计师——软件工程学习笔记
  • 矩阵scaling预处理介绍
  • AI代码生成神器终极对决:CodeLlama vs StarCoder vs Codex,谁才是开发者的「最佳拍档」?
  • STM32CUBEMX配置LAN8720a实现UDP通信
  • 【C++游记】红黑树
  • 嵌入式C语言之链表冒泡排序
  • Java基础第9天总结(可变参数、Collections、斗地主)
  • 深入浅出数据库事务:从原理到实践,解决 Spring 事务与外部进程冲突问题
  • github下载的文件内容类似文件哈希和存储路径原因
  • Kafka 分层存储(Tiered Storage)从 0 到 1 的配置、调优与避坑
  • Vue3 实现自定义指令点击空白区域关闭下拉框
  • 【51单片机】【protues仿真】 基于51单片机智能电子秤系统
  • 工业界实战之数据存储格式与精度
  • 嵌入式解谜日志-网络编程
  • 浏览器面试题及详细答案 88道(56-66)
  • MySQL查询limit 0,100和limit 10000000,100有什么区别?
  • 敏捷规模化管理工具实战指南:如何实现跨团队依赖可视化?
  • 数据库驱动改造加密姓名手机号证件号邮箱敏感信息
  • web自动化测试(selenium)