当前位置: 首页 > news >正文

【强化学习】DQN 算法

目录

一、引言

二、CartPole 环境

三、DQN

(一)经验回放

(二)目标网络

四、DQN 代码实践

五、总结


一、引言

在 Q-learning 算法中,我们以矩阵的方式建立了一张存储每个状态下所有动作Q值的表格。表格中的每一个动作价值Q(s,a)表示在状态s下选择动作a然后继续遵循某一策略预期能够得到的期望回报。然而,这种用表格存储动作价值的做法只在环境的状态和动作都是离散的,并且空间都比较小的情况下适用,我们之前进行代码实战的几个环境都是如此(如悬崖漫步)。当状态或者动作数量非常大的时候,这种做法就不适用了。例如,当状态是一张 RGB 图像时,假设图像大小是210\times 160\times 3,此时一共有256^{(210\times 160\times 3)}种状态,在计算机中存储这个数量级的Q值表格是不现实的。更甚者,当状态或者动作连续的时候,就有无限个状态动作对,我们更加无法使用这种表格形式来记录各个状态动作对的Q值。

对于这种情况,我们需要用函数拟合的方法来估计Q值,即将这个复杂的Q值表格视作数据,使用一个参数化的函数Q_\theta来拟合这些数据。很显然,这种函数拟合的方法存在一定的精度损失,因此被称为近似方法。我们今天要介绍的 DQN 算法便可以用来解决连续状态下离散动作的问题。

二、CartPole 环境

以图中所示的所示的车杆(CartPole)环境为例,它的状态值就是连续的,动作值是离散的。

在车杆环境中,有一辆小车,智能体的任务是通过左右移动保持车上的杆竖直,若杆的倾斜度数过大,或者车子离初始位置左右的偏离程度过大,或者坚持时间到达 200 帧,则游戏结束。智能体的状态是一个维数为 4 的向量,每一维都是连续的,其动作是离散的,动作空间大小为 2,详情参见表 1 和表 2。在游戏中每坚持一帧,智能体能获得分数为 1 的奖励,坚持时间越长,则最后的分数越高,坚持 200 帧即可获得最高的分数。

表1 CartPole环境的状态空间

维度意义最小值最大值
0车的位置-2.42.4
1车的速度-InfInf
2杆的角度~ -41.8°~ 41.8°
3杆尖端的速度-InfInf

表2 CartPole环境的动作空间

标号动作
0向左移动小车
1向右移动小车

三、DQN

现在我们想在类似车杆的环境中得到动作价值函数Q(s,a),由于状态每一维度的值都是连续的,无法使用表格记录,因此一个常见的解决方法便是使用函数拟合(function approximation)的思想。由于神经网络具有强大的表达能力,因此我们可以用一个神经网络来表示函数Q。若动作是连续(无限)的,神经网络的输入是状态s和动作a,然后输出一个标量,表示在状态s下采取动作a能获得的价值。若动作是离散(有限)的,除了可以采取动作连续情况下的做法,我们还可以只将状态s输入到神经网络中,使其同时输出每一个动作的Q值。通常 DQN(以及 Q-learning)只能处理动作离散的情况,因为在函数Q的更新过程中有max_a这一操作。假设神经网络用来拟合函数Q的参数是\omega,即每一个状态s下所有可能动作aQ值我们都能表示为Q_\omega (s,a)。我们将用于拟合函数Q函数的神经网络称为Q 网络,如图所示。

那么 Q 网络的损失函数是什么呢?我们先来回顾一下 Q-learning 的更新规则:

上述公式用时序差分(temporal difference,TD)学习目标来增量式更新Q(s,a),也就是说要使Q(s,a)和 TD 目标靠近。于是,对于一组数据,我们可以很自然地将 Q 网络的损失函数构造为均方误差的形式:

至此,我们就可以将 Q-learning 扩展到神经网络形式——深度 Q 网络(deep Q network,DQN)算法。由于 DQN 是离线策略算法,因此我们在收集数据的时候可以使用一个\epsilon-贪婪策略来平衡探索与利用,将收集到的数据存储起来,在后续的训练中使用。DQN 中还有两个非常重要的模块——经验回放目标网络,它们能够帮助 DQN 取得稳定、出色的性能。

(一)经验回放

在一般的有监督学习中,假设训练数据是独立同分布的,我们每次训练神经网络的时候从训练数据中随机采样一个或若干个数据来进行梯度下降,随着学习的不断进行,每一个训练数据会被使用多次。在原来的 Q-learning 算法中,每一个数据只会用来更新一次Q值。为了更好地将 Q-learning 和深度神经网络结合,DQN 算法采用了经验回放(experience replay)方法,具体做法为维护一个回放缓冲区,将每次从环境中采样得到的四元组数据(状态、动作、奖励、下一状态)存储到回放缓冲区中,训练 Q 网络的时候再从回放缓冲区中随机采样若干数据来进行训练。这么做可以起到以下两个作用。

(1)使样本满足独立假设。在 MDP 中交互采样得到的数据本身不满足独立假设,因为这一时刻的状态和上一时刻的状态有关。非独立同分布的数据对训练神经网络有很大的影响,会使神经网络拟合到最近训练的数据上。采用经验回放可以打破样本之间的相关性,让其满足独立假设。

(2)提高样本效率。每一个样本可以被使用多次,十分适合深度神经网络的梯度学习。

(二)目标网络

DQN 算法最终更新的目标是让Q_\omega (s,a)逼近,由于 TD 误差目标本身就包含神经网络的输出,因此在更新网络参数的同时目标也在不断地改变,这非常容易造成神经网络训练的不稳定性。为了解决这一问题,DQN 便使用了目标网络(target network)的思想:既然训练过程中 Q 网络的不断更新会导致目标不断发生改变,不如暂时先将 TD 目标中的 Q 网络固定住。为了实现这一思想,我们需要利用两套 Q 网络。

(1)原来的训练网络Q_\omega (s,a),用于计算原来的损失函数中的Q_\omega (s,a)Q_{\omega ^-}(s,a)项,并且使用正常梯度下降方法来进行更新。

(2) 目标网络,用于计算原先损失函数中的项,其中\omega ^-表示目标网络中的参数。如果两套网络的参数随时保持一致,则仍为原先不够稳定的算法。为了让更新目标更稳定,目标网络并不会每一步都更新。具体而言,目标网络使用训练网络的一套较旧的参数,训练网络Q_\omega (s,a)在训练中的每一步都会更新,而目标网络的参数每隔C步才会与训练网络同步一次,即\omega ^-\leftarrow \omega。这样做使得目标网络相对于训练网络更加稳定。

综上所述,DQN 算法的具体流程如下:

四、DQN 代码实践

接下来,我们就正式进入 DQN 算法的代码实践环节。我们采用的测试环境是 CartPole-v0,其状态空间相对简单,只有 4 个变量,因此网络结构的设计也相对简单:采用一层 128 个神经元的全连接并以 ReLU 作为激活函数。当遇到更复杂的诸如以图像作为输入的环境时,我们可以考虑采用深度卷积神经网络。

从 DQN 算法开始,我们先实现rl_utils库,它包含一些函数,如绘制移动平均曲线、计算优势函数等,不同的算法可以一起使用这些函数。

rl_utils.py中的Python代码如下:

from tqdm import tqdm
import numpy as np
import torch
import collections
import randomclass ReplayBuffer:def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity) def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): transitions = random.sample(self.buffer, batch_size)state, action, reward, next_state, done = zip(*transitions)return np.array(state), action, reward, np.array(next_state), done def size(self): return len(self.buffer)def moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0)) middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))def train_on_policy_agent(env, agent, num_episodes):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_listdef train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += rewardif replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_listdef compute_advantage(gamma, lmbda, td_delta):td_delta = td_delta.detach().numpy()advantage_list = []advantage = 0.0for delta in td_delta[::-1]:advantage = gamma * lmbda * advantage + deltaadvantage_list.append(advantage)advantage_list.reverse()return torch.tensor(advantage_list, dtype=torch.float)

DQN 的Python代码如下:

import random
import gymnasium as gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils# 经验回放池
class ReplayBuffer:''' 经验回放池 '''def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity)  # 队列,先进先出def add(self, state, action, reward, next_state, done):  # 将数据加入bufferself.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):  # 从buffer中采样数据transitions = random.sample(self.buffer, batch_size)state, action, reward, next_state, done = zip(*transitions)return np.array(state), action, reward, np.array(next_state), donedef size(self):  # 目前buffer中数据的数量return len(self.buffer)# Q网络
class Qnet(torch.nn.Module):''' 只有一层隐藏层的Q网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(Qnet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))  # 隐藏层使用ReLU激活函数return self.fc2(x)# DQN算法
class DQN:''' DQN算法 '''def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma,epsilon, target_update, device):self.action_dim = action_dimself.q_net = Qnet(state_dim, hidden_dim,self.action_dim).to(device)  # Q网络# 目标网络self.target_q_net = Qnet(state_dim, hidden_dim,self.action_dim).to(device)# 使用Adam优化器self.optimizer = torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)self.gamma = gamma  # 折扣因子self.epsilon = epsilon  # epsilon-贪婪策略self.target_update = target_update  # 目标网络更新频率self.count = 0  # 计数器,记录更新次数self.device = devicedef take_action(self, state):  # epsilon-贪婪策略采取动作if np.random.random() < self.epsilon:action = np.random.randint(self.action_dim)else:state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.q_net(state).argmax().item()return actiondef update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)q_values = self.q_net(states).gather(1, actions)  # Q值# 下个状态的最大Q值max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)  # TD误差目标dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))  # 均方误差损失函数self.optimizer.zero_grad()  # 梯度清零dqn_loss.backward()  # 反向传播self.optimizer.step()if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict())  # 更新目标网络self.count += 1# 训练参数
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")# 环境设置
env_name = 'CartPole-v1'
env = gym.make(env_name)
# 设置随机种子(Gymnasium的种子设置方式)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
env.reset(seed=0)  # 在reset中设置环境种子,替代原env.seed()replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device)return_list = []
for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state, _ = env.reset(seed=0)  # Gymnasium的reset返回(observation, info)terminated, truncated = False, False  # 区分终止和截断while not (terminated or truncated):  # 终止或截断都视为 episode 结束action = agent.take_action(state)# Gymnasium的step返回(observation, reward, terminated, truncated, info)next_state, reward, terminated, truncated, _ = env.step(action)replay_buffer.add(state, action, reward, next_state, terminated or truncated)  # 合并终止状态state = next_stateepisode_return += reward# 当buffer数据足够时进行训练if replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)# 绘制结果
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(env_name))
plt.savefig('DQN on {}_image1.png'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(env_name))
plt.savefig('DQN on {}_image2.png'.format(env_name))
plt.show()

程序运行结果如下:

Iteration 0: 100%|██████████| 50/50 [00:00<00:00, 445.45it/s, episode=50, return=11.300]
Iteration 1: 100%|██████████| 50/50 [00:01<00:00, 49.07it/s, episode=100, return=45.200]
Iteration 2: 100%|██████████| 50/50 [00:13<00:00,  3.68it/s, episode=150, return=198.900]
Iteration 3: 100%|██████████| 50/50 [00:28<00:00,  1.76it/s, episode=200, return=445.800]
Iteration 4: 100%|██████████| 50/50 [00:33<00:00,  1.51it/s, episode=250, return=471.800]
Iteration 5: 100%|██████████| 50/50 [00:26<00:00,  1.87it/s, episode=300, return=435.800]
Iteration 6: 100%|██████████| 50/50 [00:24<00:00,  2.03it/s, episode=350, return=500.000]
Iteration 7: 100%|██████████| 50/50 [00:26<00:00,  1.86it/s, episode=400, return=486.300]
Iteration 8: 100%|██████████| 50/50 [00:20<00:00,  2.41it/s, episode=450, return=360.600]
Iteration 9: 100%|██████████| 50/50 [00:19<00:00,  2.52it/s, episode=500, return=245.600]


五、总结

本文介绍了深度Q网络(DQN)算法在连续状态空间下的应用。针对传统Q-learning无法处理连续状态的问题,DQN采用神经网络拟合Q值函数,并引入经验回放和目标网络两大关键技术来提升稳定性。通过CartPole平衡杆环境的实验验证,DQN能够有效学习连续状态下的最优策略。实验结果显示,经过500回合训练后,智能体能够在环境中获得接近满分的表现,证明了DQN在处理连续状态空间问题上的有效性。文章详细阐述了算法原理、网络架构和具体实现代码,为理解深度强化学习提供了一个典型范例。

http://www.dtcms.com/a/593003.html

相关文章:

  • 大模型-详解 Vision Transformer (ViT) (2
  • 学习react第一天
  • 2025年电子会计档案管理软件深度介绍及厂商推荐
  • io_uring 避坑指南
  • (附源码)基于Spring boot的校园志愿服务管理系统的设计与实现
  • deepseek回答 如何用deepseek训练出一个我的思路
  • 3ds Max材质高清参数设置:10分钟提升渲染真实感
  • MyBatis 插件
  • 甘肃省城乡住房建设厅网站首页微商软件自助商城
  • 一文掌握,kanass安装与配置
  • C# ASP.NET MVC 数据验证实战:View 层双保险(Html.ValidationMessageFor + jQuery Validate)
  • 工信部 网站 邮箱内容管理系统做网站
  • arcgis用累计值进行分级
  • 生理学实验系统 生理学实验系统软件 集成化生物信号采集与处理系统生物信号采集处理系统 生理机能实验处理系统
  • 环境变量与程序地址空间
  • Node.js的主要应用场景和简单例子
  • 做视频解析网站是犯法的么360优化大师
  • 大网站cn域名淘宝店铺装修模板免费下载
  • VBA即用型代码手册:利用函数保存为PDF文件UseFunctionSaveAsPDF
  • JPA 的说明和使用
  • MyBatis使用LocalDateTime会报错
  • web网页开发,在线财务管理系统,基于Idea,html,css,jQuery,java,ssm,mysql。
  • 2025汉化idea创建JSP项目
  • 如何高效处理日常 PDF 文档?
  • LeetCode 2342.数位和相等数对的最大和
  • 企业网站建设需了解什么软文投放平台有哪些?
  • pink老师html5+css3day07
  • 各个手机芯片型号
  • [Qt学习笔记]Qt5.15.2版本安装及调整组件
  • C语言--文件读写函数的使用,对文件读写知识有了更深的了解。