当前位置: 首页 > news >正文

使用深度Q网络(DQN)算法实现游戏AI

使用深度Q网络(DQN)算法实现游戏AI

目录

  1. 引言
  2. DQN算法理论基础
  3. 环境设置
  4. DQN模型实现
  5. 经验回放机制
  6. 训练过程
  7. 实验结果与分析
  8. 优化与改进
  9. 结论
  10. 参考文献

引言

深度Q网络(Deep Q-Network, DQN)是DeepMind在2013年提出的一种将深度学习与强化学习相结合的算法,它通过使用深度神经网络来近似Q值函数,从而能够处理高维状态空间的问题。DQN的出现标志着深度强化学习领域的重大突破,为后续诸多算法如Double DQN、Dueling DQN等奠定了基础。

在本报告中,我们将完整实现DQN算法,并使用OpenAI Gym中的CartPole和Atari游戏环境进行测试。我们将详细解释算法的每个组件,包括Q网络架构、经验回放、目标网络等,并提供完整的代码实现。

DQN算法理论基础

强化学习基础

强化学习是机器学习的一个分支,它关注的是智能体如何在一系列行动中通过与环境交互来学习最优策略。核心概念包括:

  • 状态(State):环境的当前情况
  • 动作(Action):智能体可以执行的操作
  • 奖励(Reward):环境对智能体动作的反馈
  • 策略(Policy):从状态到动作的映射
  • 价值函数(Value Function):评估状态或状态-动作对的长期价值

Q学习

Q学习是一种基于值函数的强化学习算法,它通过学习一个动作价值函数Q(s,a)来找到最优策略。Q函数表示在状态s下执行动作a,然后遵循策略π所能获得的期望累积奖励。

Q学习的更新公式为:
Q(s_t, a_t) ← Q(s_t, a_t) + α[r_t + γ max_a Q(s_{t+1}, a) - Q(s_t, a_t)]

其中α是学习率,γ是折扣因子。

深度Q网络

当状态空间很大或连续时,传统的表格型Q学习不再适用。DQN使用深度神经网络来近似Q函数,参数为θ的神经网络Q(s,a;θ)用于估计Q值。

DQN的两个关键创新:

  1. 经验回放(Experience Replay):存储智能体的经验(s_t, a_t, r_t, s_{t+1})到一个回放缓冲区,然后从中随机采样进行训练,打破数据间的相关性。
  2. 目标网络(Target Network):使用一个独立的网络来计算目标Q值,该网络的参数定期从主网络复制,提高训练稳定性。

环境设置

首先,我们需要安装必要的依赖包:

pip install gym
pip install gym[atari]
pip install gym[accept-rom-license]
pip install opencv-python
pip install tensorflow

或者使用PyTorch:

pip install torch torchvision

我们将使用TensorFlow作为深度学习框架,但代码也可以轻松转换为PyTorch实现。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import gym
import random
from collections import deque
import matplotlib.pyplot as plt
import cv2

DQN模型实现

Q网络架构

对于简单的环境如CartPole,我们可以使用全连接网络:

def create_q_network_simple(state_dim, action_dim):model = keras.Sequential([layers.Dense(24, activation='relu', input_shape=state_dim),layers.Dense(24, activation='relu'),layers.Dense(action_dim, activation='linear')])return model

对于Atari游戏,我们需要处理图像输入,使用卷积神经网络:

def create_q_network_atari(input_shape, action_dim):# 输入形状为(84, 84, 4) - 4帧堆叠的84x84灰度图像model = keras.Sequential([layers.Conv2D(32, 8, strides=4, activation='relu', input_shape=input_shape),layers.Conv2D(64, 4, strides=2, activation='relu'),layers.Conv2D(64, 3, strides=1, activation='relu'),layers.Flatten(),layers.Dense(512, activation='relu'),layers.Dense(action_dim, activation='linear')])return model

预处理函数

对于Atari游戏,我们需要对图像进行预处理:

def preprocess_atari(frame):# 转换为灰度图frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)# 调整大小为84x84frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)return framedef stack_frames(stacked_frames, frame, is_new_episode):frame = preprocess_atari(frame)if is_new_episode:# 清除堆叠的帧stacked_frames = deque([np.zeros((84, 84), dtype=np.int) for i in range(4)], maxlen=4)# 重复第一帧4次for i in range(4):stacked_frames.append(frame)else:stacked_frames.append(frame)# 沿着通道维度堆叠stacked_state = np.stack(stacked_frames, axis=2)return stacked_state, stacked_frames

经验回放机制

经验回放是DQN的关键组件,它通过存储和随机采样经验来打破数据间的相关性:

class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)def store(self, state, action, reward, next_state, done):experience = (state, action, reward, next_state, done)self.buffer.append(experience)def sample(self, batch_size):batch = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = map(np.array, zip(*batch))return states, actions, rewards, next_states, donesdef size(self):return len(self.buffer)

训练过程

DQN智能体类

class DQNAgent:def __init__(self, state_dim, action_dim, learning_rate=0.001, gamma=0.99,epsilon=1.0,epsilon_min=0.01,epsilon_decay=0.995,buffer_capacity=10000,batch_size=64,target_update_freq=1000,is_atari=False):self.state_dim = state_dimself.action_dim = action_dimself.gamma = gammaself.epsilon = epsilonself.epsilon_min = epsilon_minself.epsilon_decay = epsilon_decayself.batch_size = batch_sizeself.target_update_freq = target_update_freqself.is_atari = is_atari# 创建Q网络和目标网络if is_atari:self.q_network = create_q_network_atari(state_dim, action_dim)self.target_network = create_q_network_atari(state_dim, action_dim)else:self.q_network = create_q_network_simple(state_dim, action_dim)self.target_network = create_q_network_simple(state_dim, action_dim)self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate)# 初始化目标网络参数self.target_network.set_weights(self.q_network.get_weights())# 创建经验回放缓冲区self.replay_buffer = ReplayBuffer(buffer_capacity)# 跟踪训练步数self.train_step = 0def act(self, state, training=True):if training and np.random.rand() <= self.epsilon:return random.randrange(self.action_dim)state = np.expand_dims(state, axis=0)q_values = self.q_network.predict(state, verbose=0)return np.argmax(q_values[0])def update_epsilon(self):self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)def train(self):if self.replay_buffer.size() < self.batch_size:return 0# 从回放缓冲区采样states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)# 计算当前Q值current_q = self.q_network(states)# 计算目标Q值next_q = self.target_network(next_states)max_next_q = np.max(next_q, axis=1)target_q = current_q.numpy()for i in range(self.batch_size):if dones[i]:target_q[i, actions[i]] = rewards[i]else:target_q[i, actions[i]] = rewards[i] + self.gamma * max_next_q[i]# 训练网络with tf.GradientTape() as tape:q_values = self.q_network(states)loss = keras.losses.MSE(target_q, q_values)grads = tape.gradient(loss, self.q_network.trainable_variables)self.optimizer.apply_gradients(zip(grads, self.q_network.trainable_variables))# 更新epsilonself.update_epsilon()# 定期更新目标网络self.train_step += 1if self.train_step % self.target_update_freq == 0:self.target_network.set_weights(self.q_network.get_weights())return loss.numpy()def save(self, filepath):self.q_network.save_weights(filepath)def load(self, filepath):self.q_network.load_weights(filepath)self.target_network.set_weights(self.q_network.get_weights())

训练循环

def train_dqn(env_name, is_atari=False, num_episodes=1000, max_steps=1000, render=False, save_freq=100, model_path='dqn_model.h5'):# 创建环境env = gym.make(env_name)if is_atari:# Atari环境需要特殊的包装env = gym.wrappers.AtariPreprocessing(env, scale_obs=True)state_dim = (84, 84, 4)stacked_frames = deque([np.zeros((84, 84), dtype=np.int) for i in range(4)], maxlen=4)else:state_dim = env.observation_space.shapeif len(state_dim) == 0:  # 离散观察空间state_dim = (1,)action_dim = env.action_space.n# 创建智能体agent = DQNAgent(state_dim, action_dim, is_atari=is_atari)# 记录奖励和损失episode_rewards = []episode_losses = []for episode in range(num_episodes):state = env.reset()if is_atari:# 预处理初始状态state = preprocess_atari(state)# 堆叠帧for _ in range(4):stacked_frames.append(state)state = np.stack(stacked_frames, axis=2)total_reward = 0total_loss = 0step_count = 0for step in range(max_steps):if render:env.render()# 选择动作action = agent.act(state)# 执行动作next_state, reward, done, _ = env.step(action)if is_atari:# 预处理下一状态next_state = preprocess_atari(next_state)stacked_frames.append(next_state)next_state = np.stack(stacked_frames, axis=2)# 存储经验agent.replay_buffer.store(state, action, reward, next_state, done)# 训练网络loss = agent.train()if loss:total_loss += lossstate = next_statetotal_reward += rewardstep_count += 1if done:break# 记录统计信息episode_rewards.append(total_reward)if step_count > 0:episode_losses.append(total_loss / step_count)else:episode_losses.append(0)# 打印进度if episode % 10 == 0:avg_reward = np.mean(episode_rewards[-10:])print(f"Episode: {episode}, Reward: {total_reward}, Avg Reward (last 10): {avg_reward:.2f}, Epsilon: {agent.epsilon:.3f}")# 保存模型if episode % save_freq == 0:agent.save(model_path)# 关闭环境env.close()# 绘制结果plt.figure(figsize=(12, 6))plt.subplot(1, 2, 1)plt.plot(episode_rewards)plt.title('Episode Rewards')plt.xlabel('Episode')plt.ylabel('Total Reward')plt.subplot(1, 2, 2)plt.plot(episode_losses)plt.title('Episode Losses')plt.xlabel('Episode')plt.ylabel('Average Loss')plt.tight_layout()plt.savefig('training_results.png')plt.show()return agent, episode_rewards, episode_losses

实验结果与分析

CartPole环境测试

# 训练CartPole环境
cartpole_agent, rewards, losses = train_dqn(env_name='CartPole-v1',is_atari=False,num_episodes=300,max_steps=500,render=False,save_freq=50,model_path='dqn_cartpole.h5'
)

Atari Pong环境测试

# 训练Atari Pong环境
pong_agent, pong_rewards, pong_losses = train_dqn(env_name='PongNoFrameskip-v4',is_atari=True,num_episodes=1000,max_steps=10000,render=False,save_freq=100,model_path='dqn_pong.h5'
)

结果分析

在CartPole环境中,DQN通常能够在100-200个训练周期内学会平衡杆子。我们可以观察到奖励随着训练逐渐增加,最终达到最大可能值(500)。

对于Atari Pong环境,训练过程需要更长时间,通常需要数千个周期。我们可以观察到智能体从完全随机玩耍逐渐进步到能够击败简单AI对手。

优化与改进

Double DQN

Double DQN解决了DQN高估Q值的问题,通过使用主网络选择动作,目标网络评估动作:

class DoubleDQNAgent(DQNAgent):def train(self):if self.replay_buffer.size() < self.batch_size:return 0# 从回放缓冲区采样states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)# 计算当前Q值current_q = self.q_network(states)# 使用Double DQN计算目标Q值next_q_main = self.q_network(next_states)next_actions = np.argmax(next_q_main, axis=1)next_q_target = self.target_network(next_states)max_next_q = next_q_target[np.arange(self.batch_size), next_actions]target_q = current_q.numpy()for i in range(self.batch_size):if dones[i]:target_q[i, actions[i]] = rewards[i]else:target_q[i, actions[i]] = rewards[i] + self.gamma * max_next_q[i]# 训练网络with tf.GradientTape() as tape:q_values = self.q_network(states)loss = keras.losses.MSE(target_q, q_values)grads = tape.gradient(loss, self.q_network.trainable_variables)self.optimizer.apply_gradients(zip(grads, self.q_network.trainable_variables))# 更新epsilonself.update_epsilon()# 定期更新目标网络self.train_step += 1if self.train_step % self.target_update_freq == 0:self.target_network.set_weights(self.q_network.get_weights())return loss.numpy()

Dueling DQN

Dueling DQN将Q值分解为状态值和优势函数:

def create_dueling_q_network_atari(input_shape, action_dim):inputs = layers.Input(shape=input_shape)# 卷积层x = layers.Conv2D(32, 8, strides=4, activation='relu')(inputs)x = layers.Conv2D(64, 4, strides=2, activation='relu')(x)x = layers.Conv2D(64, 3, strides=1, activation='relu')(x)x = layers.Flatten()(x)x = layers.Dense(512, activation='relu')(x)# 分离为价值流和优势流value_stream = layers.Dense(1, activation='linear')(x)advantage_stream = layers.Dense(action_dim, activation='linear')(x)# 合并价值流和优势流q_values = value_stream + (advantage_stream - tf.reduce_mean(advantage_stream, axis=1, keepdims=True))return keras.Model(inputs=inputs, outputs=q_values)class DuelingDQNAgent(DQNAgent):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)# 使用Dueling网络架构if self.is_atari:self.q_network = create_dueling_q_network_atari(self.state_dim, self.action_dim)self.target_network = create_dueling_q_network_atari(self.state_dim, self.action_dim)else:# 简单环境的Dueling网络inputs = layers.Input(shape=self.state_dim)x = layers.Dense(24, activation='relu')(inputs)x = layers.Dense(24, activation='relu')(x)value_stream = layers.Dense(1, activation='linear')(x)advantage_stream = layers.Dense(self.action_dim, activation='linear')(x)q_values = value_stream + (advantage_stream - tf.reduce_mean(advantage_stream, axis=1, keepdims=True))self.q_network = keras.Model(inputs=inputs, outputs=q_values)self.target_network = keras.Model(inputs=inputs, outputs=q_values)# 初始化目标网络参数self.target_network.set_weights(self.q_network.get_weights())

优先级经验回放

优先级经验回放通过根据TD误差的重要性采样经验,提高学习效率:

class PrioritizedReplayBuffer:def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001):self.capacity = capacityself.alpha = alphaself.beta = betaself.beta_increment = beta_incrementself.buffer = []self.priorities = np.zeros(capacity, dtype=np.float32)self.pos = 0self.size = 0def store(self, state, action, reward, next_state, done):max_priority = np.max(self.priorities) if self.size > 0 else 1.0if self.size < self.capacity:self.buffer.append((state, action, reward, next_state, done))self.size += 1else:self.buffer[self.pos] = (state, action, reward, next_state, done)self.priorities[self.pos] = max_priorityself.pos = (self.pos + 1) % self.capacitydef sample(self, batch_size):if self.size == 0:return Nonepriorities = self.priorities[:self.size]probs = priorities ** self.alphaprobs /= np.sum(probs)indices = np.random.choice(self.size, batch_size, p=probs)samples = [self.buffer[i] for i in indices]# 计算重要性采样权重total = self.sizeweights = (total * probs[indices]) ** (-self.beta)weights /= np.max(weights)self.beta = min(1.0, self.beta + self.beta_increment)states, actions, rewards, next_states, dones = map(np.array, zip(*samples))return indices, states, actions, rewards, next_states, dones, weightsdef update_priorities(self, indices, priorities):for idx, priority in zip(indices, priorities):self.priorities[idx] = priority + 1e-5  # 避免零优先级def __len__(self):return self.sizeclass PrioritizedDQNAgent(DQNAgent):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)# 使用优先级经验回放self.replay_buffer = PrioritizedReplayBuffer(kwargs.get('buffer_capacity', 10000))def train(self):if self.replay_buffer.size() < self.batch_size:return 0# 从优先级回放缓冲区采样sample_result = self.replay_buffer.sample(self.batch_size)if sample_result is None:return 0indices, states, actions, rewards, next_states, dones, weights = sample_result# 计算当前Q值current_q = self.q_network(states)# 计算目标Q值next_q = self.target_network(next_states)max_next_q = np.max(next_q, axis=1)target_q = current_q.numpy()# 计算TD误差td_errors = np.zeros(self.batch_size)for i in range(self.batch_size):if dones[i]:target_q[i, actions[i]] = rewards[i]else:target_q[i, actions[i]] = rewards[i] + self.gamma * max_next_q[i]td_errors[i] = abs(target_q[i, actions[i]] - current_q.numpy()[i, actions[i]])# 更新优先级self.replay_buffer.update_priorities(indices, td_errors)# 训练网络(带重要性采样权重)with tf.GradientTape() as tape:q_values = self.q_network(states)# 使用Huber损失losses = keras.losses.Huber()(target_q, q_values)# 应用重要性采样权重weighted_loss = tf.reduce_mean(losses * weights)grads = tape.gradient(weighted_loss, self.q_network.trainable_variables)self.optimizer.apply_gradients(zip(grads, self.q_network.trainable_variables))# 更新epsilonself.update_epsilon()# 定期更新目标网络self.train_step += 1if self.train_step % self.target_update_freq == 0:self.target_network.set_weights(self.q_network.get_weights())return weighted_loss.numpy()

结论

在本报告中,我们完整实现了DQN算法及其多种改进版本。我们从强化学习的基础理论出发,详细解释了DQN的核心组件,包括Q网络、经验回放和目标网络。我们实现了基本的DQN算法,并在CartPole和Atari游戏环境中进行了测试。

此外,我们还实现了三种DQN的改进算法:Double DQN(解决Q值高估问题)、Dueling DQN(分离状态价值和优势函数)和优先级经验回放(提高样本利用效率)。这些改进算法在不同程度上提升了原始DQN的性能和稳定性。

深度Q网络是深度强化学习领域的基石算法,虽然后续出现了更多先进的算法,但DQN及其变体仍然是理解和入门深度强化学习的重要途径。通过本实现的代码,读者可以深入理解DQN的工作原理,并在此基础上进行进一步的实验和研究。

参考文献

  1. Mnih, V., et al. (2015). Human-level control through deep reinforcement learning. Nature, 518(7540), 529-533.
  2. Van Hasselt, H., Guez, A., & Silver, D. (2016). Deep reinforcement learning with double q-learning. In Proceedings of the AAAI conference on artificial intelligence.
  3. Wang, Z., et al. (2016). Dueling network architectures for deep reinforcement learning. In International conference on machine learning.
  4. Schaul, T., et al. (2015). Prioritized experience replay. arXiv preprint arXiv:1511.05952.
  5. Sutton, R. S., & Barto, A. G. (2018). Reinforcement learning: An introduction. MIT press.

注意:以上代码为简化实现,实际应用中可能需要根据具体环境和问题进行调优和修改。完整训练Atari游戏需要大量计算资源和时间,建议在GPU环境下运行。


文章转载自:

http://R4NDCbVd.jgttx.cn
http://H9aHfZXk.jgttx.cn
http://lxaGcajI.jgttx.cn
http://uImgehZ2.jgttx.cn
http://eGYzcU8q.jgttx.cn
http://iB9RW9yE.jgttx.cn
http://JuCv4W2a.jgttx.cn
http://4CuDlvKe.jgttx.cn
http://NLJXz6cL.jgttx.cn
http://x559g9DX.jgttx.cn
http://R79e6Xz7.jgttx.cn
http://7JFzOomO.jgttx.cn
http://4FI4m2Va.jgttx.cn
http://4ymw3zxm.jgttx.cn
http://SvHilrDE.jgttx.cn
http://E8XoQdQE.jgttx.cn
http://crqS4JNN.jgttx.cn
http://lQFlK0by.jgttx.cn
http://zMyh70M4.jgttx.cn
http://I1OPFOWl.jgttx.cn
http://8VX3szuv.jgttx.cn
http://Ed7R7Tgk.jgttx.cn
http://FVuKpEDS.jgttx.cn
http://jUwZJxko.jgttx.cn
http://IL5m8HxA.jgttx.cn
http://itgtileS.jgttx.cn
http://vbwMaBjb.jgttx.cn
http://kk7zH8Kg.jgttx.cn
http://4rsqy9Cg.jgttx.cn
http://jJRak4Ux.jgttx.cn
http://www.dtcms.com/a/368973.html

相关文章:

  • 深度学习优化框架(DeepSpeed)
  • Java 8 终于要被淘汰了!带你速通 Java 8~24 新特性 | 又能跟面试官吹牛皮了
  • 操作系统重点
  • 安全运维-云计算系统安全
  • HTML 各种标签的使用说明书
  • BYOFF (Bring Your Own Formatting Function)解析(80)
  • MySQL源码部署(rhel7)
  • HashMap多线程下的循环链表问题
  • 企业微信AI怎么用?食品集团靠它砍掉50%低效操作,答案就是选对企业微信服务商
  • 企业微信AI怎么用才高效?3大功能+5个实操场景,实测效率提升50%
  • Arduino Nano33 BLESense Rev2【室内空气质量检测语音识别蓝牙调光台灯】
  • 无人机小目标检测新SOTA:MASF-YOLO重磅开源,多模块协同助力精度飞跃
  • 本地 Docker 环境 Solr 配置 SSL 证书
  • SQL中TRUNCATE vs. DELETE 命令对比
  • RequestContextFilter介绍
  • [密码学实战](GBT 15843.3)基于SM2数字签名的实体鉴别实现完整源码(四十九)
  • 【UE】 实现指向性菲涅尔 常用于圆柱体的特殊菲涅尔
  • 标签系统的架构设计与实现
  • 卫星在轨光压计算详解
  • 摄像头模块的种类:按结构分类
  • 第8篇:决策树与随机森林:从零实现到调参实战
  • 迁移学习-ResNet
  • CentOS安装或升级protoc
  • 【QT 5.12.12 下载 Windows 版本】
  • 多语言Qt Linguist
  • 2025年- H118-Lc86. 分隔链表(链表)--Java版
  • 快速了解迁移学习
  • 【HEMCO第一期】用户教程
  • SVT-AV1编码器中实现WPP依赖管理核心调度
  • Qt---JSON处理体系