当前位置: 首页 > news >正文

强化学习1.3 深度学习交叉熵方法

学习资料链接

在本节中,将交叉熵方法(CEM)实现扩展到神经网络,训练一个多层神经网络,来解决简单的连续状态空间游戏问题。
在这里插入图片描述

老规矩,初始化环境。确保在 Google Colab 或远程服务器上也能顺利运行需要图形界面的强化学习环境,不会因为没有显示器而崩溃。

import sys, os
if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):!wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash!touch .setup_complete# This code creates a virtual display to draw game images on.
# It will have no effect if your machine has a monitor.
if type(os.environ.get("DISPLAY")) is not str or len(os.environ.get("DISPLAY")) == 0:!bash ../xvfb startos.environ['DISPLAY'] = ':1'

安装依赖库

# Install gymnasium if you didn't
!pip install "gymnasium[toy_text,classic_control]"

确认 CartPole 环境能跑、能画、维度已知,为后续训练代码做准备。

import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline# if you see "<classname> has no attribute .env", remove .env or update gym
env = gym.make("CartPole-v0", render_mode="rgb_array").envenv.reset()
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]plt.imshow(env.render())
print("state vector dim =", state_dim)
print("n_actions =", n_actions)env.close()

共有四个状态两个动作:
在这里插入图片描述

直接用 sklearn 的 MLPClassifier 当策略网络,具体流程如下:

  • 训练
    把“看到的状态→专家动作”喂进去,一次梯度更新,让网络越来越像专家。

  • 推理
    输入一批状态,返回每行是该状态下各个动作的概率分布,形状 [batch, n_actions]。

后面用交叉熵方法选动作时,就按这个概率采样。

from sklearn.neural_network import MLPClassifier# 创建一个小型的前馈网络,还没训练,权重随机。
agent = MLPClassifier(hidden_layer_sizes=(20, 20),activation="tanh",
)# initialize agent to the dimension of state space and number of actions
# 这里用假数据跑一遍,只为了让网络内部把权重矩阵、标签二值化器等结构建好,不指望它学到任何东西。
agent.partial_fit([env.reset()[0]] * n_actions,   # Xrange(n_actions),                # yclasses=range(n_actions))        # 显式告诉模型一共 n_actions 类

让当前神经网络策略完整玩一局游戏,并收集训练数据

对每一步:

  1. 把当前状态 s 喂给 agent,拿到动作概率向量;
  2. 按概率随机采样一个动作 a(探索);
  3. 执行 a,拿到新状态、奖励、是否结束;
  4. 把 (s, a) 存进列表,累加奖励;

循环直到游戏结束或超时。

def generate_session(env, agent, t_max=1000):"""Play a single game using agent neural network.Terminate when game finishes or after :t_max: steps"""states, actions = [], []total_reward = 0s, _ = env.reset()for t in range(t_max):# use agent to predict a vector of action probabilities for state :s:# probs = <YOUR CODE>probs = agent.predict_proba(np.array([s]))[0]   # 返回形状 (n_actions,)assert probs.shape == (env.action_space.n,), "make sure probabilities are a vector (hint: np.reshape)"# use the probabilities you predicted to pick an action# sample proportionally to the probabilities, don't just take the most likely action# a = <YOUR CODE>a = np.random.choice(env.action_space.n, p=probs)# ^-- hint: try np.random.choicenew_s, r, terminated, truncated, _ = env.step(a)# record sessions like you did beforestates.append(s)actions.append(a)total_reward += rs = new_sif terminated or truncated:breakreturn states, actions, total_reward

进行测试

dummy_states, dummy_actions, dummy_reward = generate_session(env, agent, t_max=5)
print("states:", np.stack(dummy_states))
print("actions:", dummy_actions)
print("reward:", dummy_reward)

在这里插入图片描述

深度交叉熵方法(Deep CEM)的流程和 CEM 一模一样,直接复制代码就行;唯一区别是状态从‘整数索引’变成了‘float32 向量’,所以神经网络输入层要接 state_dim 维实数,而不是 one-hot 或查表。


def select_elites(states_batch, actions_batch, rewards_batch, percentile=50):"""Select states and actions from games that have rewards >= percentile:param states_batch: list of lists of states, states_batch[session_i][t]:param actions_batch: list of lists of actions, actions_batch[session_i][t]:param rewards_batch: list of rewards, rewards_batch[session_i]:returns: elite_states,elite_actions, both 1D lists of states and respective actions from elite sessionsPlease return elite states and actions in their original order[i.e. sorted by session number and timestep within session]If you are confused, see examples below. Please don't assume that states are integers(they will become different later)."""# <YOUR CODE: copy-paste your implementation from the previous notebook>reward_threshold = np.percentile(rewards_batch, percentile)elite_states = []elite_actions = []for session_idx, total_r in enumerate(rewards_batch):if total_r >= reward_threshold:elite_states.extend(states_batch[session_idx])elite_actions.extend(actions_batch[session_idx])return elite_states, elite_actions

开始训练

可视化代码

from IPython.display import clear_outputdef show_progress(rewards_batch, log, percentile, reward_range=[-990, +10]):"""A convenience function that displays training progress.No cool math here, just charts."""mean_reward = np.mean(rewards_batch)threshold = np.percentile(rewards_batch, percentile)log.append([mean_reward, threshold])clear_output(True)print("mean reward = %.3f, threshold=%.3f" % (mean_reward, threshold))plt.figure(figsize=[8, 4])plt.subplot(1, 2, 1)plt.plot(list(zip(*log))[0], label="Mean rewards")plt.plot(list(zip(*log))[1], label="Reward thresholds")plt.legend()plt.grid()plt.subplot(1, 2, 2)plt.hist(rewards_batch, range=reward_range)plt.vlines([np.percentile(rewards_batch, percentile)],[0],[100],label="percentile",color="red",)plt.legend()plt.grid()plt.show()

训练循环

n_sessions = 100
percentile = 70
log = []for i in range(100):# generate new sessions# sessions = [ <YOUR CODE: generate a list of n_sessions new sessions> ]sessions = [generate_session(env, agent, t_max=1000) for _ in range(n_sessions)]# states_batch, actions_batch, rewards_batch = map(np.array, zip(*sessions))states_batch, actions_batch, rewards_batch = zip(*sessions)rewards_batch = np.array(rewards_batch)          # 只有回报是一维数组# elite_states, elite_actions = <YOUR CODE: select elite actions just like before>elite_states, elite_actions = select_elites(states_batch, actions_batch, rewards_batch, percentile)# <YOUR CODE: partial_fit agent to predict elite_actions(y) from elite_states(X)>agent.partial_fit(elite_states, elite_actions)show_progress(rewards_batch, log, percentile, reward_range=[0, np.max(rewards_batch)])if np.mean(rewards_batch) > 190:print("You Win! You may stop training now via KeyboardInterrupt.")

结果展示
在这里插入图片描述
用 RecordVideo 包装环境,把每一局游戏录成视频并保存到 ./videos 文件夹

# Record sessionsfrom gymnasium.wrappers import RecordVideowith RecordVideo(env=gym.make("CartPole-v0", render_mode="rgb_array"),video_folder="./videos",episode_trigger=lambda episode_number: True,
) as env_monitor:sessions = [generate_session(env_monitor, agent) for _ in range(100)]

把刚刚录好的 .mp4 视频嵌入到 Notebook 里,直接在浏览器里播放。

# Show video. This may not work in some setups. If it doesn't
# work for you, you can download the videos and view them locally.from pathlib import Path
from base64 import b64encode
from IPython.display import HTMLvideo_paths = sorted([s for s in Path("videos").iterdir() if s.suffix == ".mp4"])
video_path = video_paths[-1]  # You can also try other indicesif "google.colab" in sys.modules:# https://stackoverflow.com/a/57378660/1214547with video_path.open("rb") as fp:mp4 = fp.read()data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
else:data_url = str(video_path)HTML("""
<video width="640" height="480" controls><source src="{}" type="video/mp4">
</video>
""".format(data_url)
)

在这里插入图片描述

http://www.dtcms.com/a/395846.html

相关文章:

  • 用PowerBI的思想解决QuickBI文本无法动态配色问题
  • 逆向解析 1688 商品详情接口:自主构建 Sign 签名算法实战
  • SpringCloud项目阶段六:feign服务降级处理以及基于DFA算法的自管理敏感词审核和tess4j图片文字识别集成
  • 跨行业安全合规文档协同平台:重塑制造企业的质量管理与合规运营新范式
  • 线性代数 · SVD | 奇异值分解命名来历与直观理解
  • Qt 控件与布局
  • TDengine 聚合函数 SPREAD 用户手册
  • 4090 云服务器租赁:高性能与灵活性的算力融合方案​
  • 阿里云服务器ECS上安装anaconda(jupyter)和OpenCV教程
  • CVE-2025–3246 本地提权
  • Chat API和Chat SDK
  • 爱奇艺技术实践:基于 StarRocks 释放天玑买量数据价值
  • 突破传统文本切分桎梏!基于语义理解的智能文档处理革命——AntSK-FileChunk深度技术解析
  • Git常用的使用方法
  • IDEA集成Claude Code (win系统)
  • MySQL执行计划:索引为何失效?如何避免?
  • 【附源码】基于SpringBoot的校园防汛物资管理平台的设计与实现
  • PyTorch 核心工具与模型搭建
  • ARM--时钟管理单元与定时器
  • Unity-动画基础
  • 逻辑回归中的决策边界解析与应用实例
  • 设计模式——结构型模式(下)
  • CANoe中封装SeedKey安全解锁函数的完整指南
  • Vue树选择
  • opencv人脸识别
  • 怿星科技桂林子公司乔迁新址,于山水画中开启研发新篇章
  • 创建者模式:工厂方法模式
  • 【 C/C++ 算法】入门动态规划-----路径问题(以练代学式)
  • 三.上网行为安全
  • k个一组翻转链表