当前位置: 首页 > news >正文

结合MAML算法元强化学习

这篇文章介绍了一个基于元强化学习(Meta-RL)的Atari游戏智能体框架。该框架采用改进的Actor-Critic网络结构,结合MAML元学习算法,实现了在多个Atari游戏任务上的快速适应能力。系统包含四大核心组件:1)改进的Actor-Critic网络,支持参数共享和元学习;2)MAML元学习器,实现内层任务适应和外层元更新;3)专业的Atari游戏预处理器;4)完整的元强化学习训练框架。实验表明,该框架能在Breakout等游戏中有效学习,并展示了向Pong等新游戏的迁移学习潜力。文章详细介绍了网络架构、训练流程和实现细节,为深度强化学习在复杂游戏环境中的应用提供了实用方案。

import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np
from collections import deque
import random
import cv2
import os
import time
from typing import List, Tuple, Dict, Any
import warnings# 忽略特定警告
warnings.filterwarnings("ignore", category=UserWarning, module="gymnasium")class ImprovedActorCritic(nn.Module):"""改进的Actor-Critic网络,支持元学习"""def __init__(self, state_dim: Tuple[int, int, int], action_dim: int, hidden_dim: int = 512):"""初始化Actor-Critic网络参数:state_dim: 状态维度 (通道, 高度, 宽度)action_dim: 动作空间维度hidden_dim: 隐藏层维度"""super().__init__()self.state_dim = state_dimself.action_dim = action_dim# 共享的特征提取层self.feature_net = nn.Sequential(nn.Conv2d(state_dim[0], 32, kernel_size=8, stride=4),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=4, stride=2),nn.ReLU(),nn.Conv2d(64, 64, kernel_size=3, stride=1),nn.ReLU(),nn.Flatten())# 测试特征维度with torch.no_grad():dummy_input = torch.zeros(1, *state_dim)feature_dim = self.feature_net(dummy_input).shape[1]# Actor和Critic头self.actor = nn.Sequential(nn.Linear(feature_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, action_dim))self.critic = nn.Sequential(nn.Linear(feature_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1))# 初始化权重self.apply(self._init_weights)def _init_weights(self, module):"""初始化网络权重"""if isinstance(module, (nn.Linear, nn.Conv2d)):nn.init.orthogonal_(module.weight, gain=np.sqrt(2))nn.init.constant_(module.bias, 0.0)def forward(self, x: torch.Tensor, params: Dict[str, torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:"""前向传播参数:x: 输入状态params: 自定义参数(用于元学习)返回:action_logits: 动作logitsvalue: 状态价值"""if params is None:features = self.feature_net(x)action_logits = self.actor(features)value = self.critic(features)else:# 使用提供的参数进行计算(用于元学习)features = self._apply_func_with_params(x, self.feature_net, params, 'feature_net')action_logits = self._apply_func_with_params(features, self.actor, params, 'actor')value = self._apply_func_with_params(features, self.critic, params, 'critic')return action_logits, valuedef _apply_func_with_params(self, x: torch.Tensor, module: nn.Module, params: Dict[str, torch.Tensor], prefix: str) -> torch.Tensor:"""使用给定的参数应用模块计算参数:x: 输入张量module: 要应用的模块params: 参数字典prefix: 参数前缀返回:计算结果"""# 简化实现:实际应用中需要根据网络结构递归处理if isinstance(module, nn.Sequential):result = xfor i, layer in enumerate(module):layer_name = f"{prefix}.{i}"if isinstance(layer, nn.Linear):weight = params[f"{layer_name}.weight"]bias = params[f"{layer_name}.bias"]result = torch.relu(torch.nn.functional.linear(result, weight, bias))elif isinstance(layer, nn.Conv2d):weight = params[f"{layer_name}.weight"]bias = params.get(f"{layer_name}.bias", None)result = torch.relu(torch.nn.functional.conv2d(result, weight, bias, stride=layer.stride, padding=layer.padding))else:result = layer(result)return resultreturn module(x)def get_params_dict(self) -> Dict[str, torch.Tensor]:"""获取当前参数的字典"""return {name: param.clone() for name, param in self.named_parameters()}class MAMLMetaLearner:"""MAML元学习器"""def __init__(self, model: nn.Module, inner_lr: float = 0.01, meta_lr: float = 0.001):"""初始化MAML元学习器参数:model: 要优化的模型inner_lr: 内层学习率meta_lr: 元学习率"""self.model = modelself.inner_lr = inner_lrself.meta_optimizer = optim.Adam(model.parameters(), lr=meta_lr)def inner_update(self, task_experiences: List[Tuple], num_steps: int = 1) -> Dict[str, torch.Tensor]:"""内层更新:在单个任务上快速适应参数:task_experiences: 任务经验列表num_steps: 内层更新步数返回:适应后的参数"""# 克隆当前参数作为快速参数的起点fast_parameters = {k: v.clone() for k, v in self.model.named_parameters()}for step in range(num_steps):task_loss = 0# 随机采样一批经验batch = random.sample(task_experiences, min(32, len(task_experiences)))# 计算任务特定损失for experience in batch:states, actions, rewards, next_states, dones = experience# 计算策略损失action_logits, values = self.model(states, fast_parameters)probs = torch.softmax(action_logits, dim=-1)log_probs = torch.log_softmax(action_logits, dim=-1)# 计算优势函数with torch.no_grad():_, next_values = self.model(next_states, fast_parameters)targets = rewards + 0.99 * next_values * (1 - dones)advantages = targets - values# 策略梯度损失policy_loss = -(log_probs.gather(1, actions) * advantages.detach()).mean()# 价值函数损失value_loss = nn.MSELoss()(values, targets)# 熵正则化entropy = -(probs * log_probs).sum(-1).mean()total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropytask_loss += total_loss# 在内层循环中计算梯度并更新快速参数grads = torch.autograd.grad(task_loss, list(fast_parameters.values()), create_graph=True, allow_unused=True)# 更新快速参数for (name, param), grad in zip(fast_parameters.items(), grads):if grad is not None:fast_parameters[name] = param - self.inner_lr * gradreturn fast_parametersdef meta_update(self, tasks_batch: List[List[Tuple]]) -> float:"""外层元更新参数:tasks_batch: 任务批次,每个任务包含一组经验返回:元损失值"""meta_loss = 0for task_experiences in tasks_batch:# 内层适应adapted_params = self.inner_update(task_experiences)# 在适应后的参数上计算元损失task_meta_loss = 0# 使用不同的经验批次计算元损失if len(task_experiences) > 1:meta_batch = random.sample(task_experiences, min(16, len(task_experiences)))else:meta_batch = task_experiencesfor experience in meta_batch:states, actions, rewards, next_states, dones = experience# 使用适应后的参数计算损失action_logits, values = self.model(states, adapted_params)probs = torch.softmax(action_logits, dim=-1)log_probs = torch.log_softmax(action_logits, dim=-1)with torch.no_grad():_, next_values = self.model(next_states, adapted_params)targets = rewards + 0.99 * next_values * (1 - dones)advantages = targets - valuespolicy_loss = -(log_probs.gather(1, actions) * advantages.detach()).mean()value_loss = nn.MSELoss()(values, targets)entropy = -(probs * log_probs).sum(-1).mean()task_meta_loss += policy_loss + 0.5 * value_loss - 0.01 * entropymeta_loss += task_meta_loss / len(meta_batch)# 元优化步骤self.meta_optimizer.zero_grad()meta_loss.backward()torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)self.meta_optimizer.step()return meta_loss.item()class AtariPreprocessor:"""专业的Atari游戏预处理"""def __init__(self, frame_size: Tuple[int, int] = (84, 84)):"""初始化预处理器参数:frame_size: 帧大小 (宽度, 高度)"""self.frame_size = frame_sizeself.frame_buffer = deque(maxlen=4)def preprocess(self, frame: np.ndarray) -> np.ndarray:"""预处理单帧参数:frame: 输入帧 (RGB格式)返回:预处理后的帧 (灰度, 缩放, 归一化)"""# 1. 转换为灰度图gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)# 2. 调整大小resized = cv2.resize(gray, self.frame_size, interpolation=cv2.INTER_AREA)# 3. 归一化normalized = resized.astype(np.float32) / 255.0return normalizeddef reset(self):"""重置预处理器状态"""self.frame_buffer.clear()for _ in range(4):self.frame_buffer.append(np.zeros(self.frame_size, dtype=np.float32))def update(self, frame: np.ndarray):"""更新帧缓冲区参数:frame: 新帧"""processed_frame = self.preprocess(frame)self.frame_buffer.append(processed_frame)def get_state(self) -> np.ndarray:"""获取当前状态 (堆叠的4帧)返回:状态数组 (4, 高度, 宽度)"""return np.stack(self.frame_buffer, axis=0)class EnhancedMetaRLFramework:"""增强的元强化学习框架"""def __init__(self, env_name: str = "Breakout-v4", meta_batch_size: int = 4):"""初始化元强化学习框架参数:env_name: 环境名称meta_batch_size: 元批次大小"""self.env_name = env_nameself.meta_batch_size = meta_batch_size# 创建环境并获取维度信息self.env = self._create_environment()self.state_dim = (4, 84, 84)  # 堆叠的4帧,每帧84x84self.action_dim = self.env.action_space.n# 初始化模型和元学习器self.model = ImprovedActorCritic(self.state_dim, self.action_dim)self.meta_learner = MAMLMetaLearner(self.model)# 初始化预处理器self.preprocessor = AtariPreprocessor((84, 84))# 训练统计self.episode_rewards = []self.meta_losses = []# 设备选择self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model.to(self.device)print(f"使用设备: {self.device}")def _create_environment(self) -> gym.Env:"""创建环境并进行包装返回:包装后的环境"""try:# 尝试标准环境名称env = gym.make(self.env_name, render_mode="rgb_array")print(f"成功创建环境: {self.env_name}")except gym.error.Error:# 尝试备选环境名称alternative_envs = ["Breakout-v4","BreakoutNoFrameskip-v4","ALE/Breakout-v5","BreakoutDeterministic-v4","Breakout-ram-v4"]for alt_env in alternative_envs:try:env = gym.make(alt_env, render_mode="rgb_array")print(f"成功创建备选环境: {alt_env}")breakexcept gym.error.Error:continueelse:raise RuntimeError("无法创建任何Atari环境")# 应用标准Atari预处理包装器env = gym.wrappers.ResizeObservation(env, (84, 84))env = gym.wrappers.GrayScaleObservation(env)env = gym.wrappers.FrameStack(env, 4)return envdef collect_experience(self, num_episodes: int = 1, render: bool = False) -> List[Tuple]:"""收集训练经验参数:num_episodes: 收集经验的episode数量render: 是否渲染环境返回:经验列表"""experiences = []total_steps = 0for episode in range(num_episodes):episode_experiences = []state, _ = self.env.reset()self.preprocessor.reset()done = Falseepisode_reward = 0step_count = 0while not done:if render:self.env.render()time.sleep(0.02)# 将状态转换为张量state_tensor = torch.FloatTensor(np.array(state)).unsqueeze(0).to(self.device)# 选择动作with torch.no_grad():action_logits, _ = self.model(state_tensor)action_probs = torch.softmax(action_logits, dim=-1)action = torch.multinomial(action_probs, 1).item()# 执行动作next_state, reward, terminated, truncated, _ = self.env.step(action)done = terminated or truncatedepisode_reward += rewardstep_count += 1# 存储经验next_state_tensor = torch.FloatTensor(np.array(next_state)).unsqueeze(0).to(self.device)reward_tensor = torch.FloatTensor([reward]).to(self.device)done_tensor = torch.FloatTensor([float(done)]).to(self.device)action_tensor = torch.LongTensor([action]).to(self.device)experience = (state_tensor, action_tensor, reward_tensor, next_state_tensor, done_tensor)episode_experiences.append(experience)state = next_stateexperiences.extend(episode_experiences)self.episode_rewards.append(episode_reward)total_steps += step_countprint(f"Episode {episode + 1}, 奖励: {episode_reward}, 步数: {step_count}")print(f"收集了 {len(experiences)} 条经验, 总步数: {total_steps}")return experiencesdef train_meta_epoch(self, num_tasks: int = 8) -> float:"""训练一个元epoch参数:num_tasks: 任务数量返回:平均元损失"""# 收集多个任务的经验tasks_experiences = []for task_idx in range(num_tasks):print(f"\n收集任务 {task_idx + 1}/{num_tasks} 的经验...")task_exp = self.collect_experience(num_episodes=1)tasks_experiences.append(task_exp)# 分组进行元批次更新meta_losses = []for i in range(0, num_tasks, self.meta_batch_size):batch_tasks = tasks_experiences[i:i + self.meta_batch_size]if not batch_tasks:continuemeta_loss = self.meta_learner.meta_update(batch_tasks)meta_losses.append(meta_loss)print(f"元批次 {i//self.meta_batch_size + 1}, 损失: {meta_loss:.4f}")avg_meta_loss = np.mean(meta_losses) if meta_losses else 0.0self.meta_losses.append(avg_meta_loss)return avg_meta_lossdef evaluate(self, num_episodes: int = 5, render: bool = False) -> Tuple[float, float]:"""评估当前策略性能参数:num_episodes: 评估的episode数量render: 是否渲染环境返回:平均奖励, 奖励标准差"""total_rewards = []for episode in range(num_episodes):state, _ = self.env.reset()self.preprocessor.reset()done = Falseepisode_reward = 0while not done:if render:self.env.render()time.sleep(0.02)state_tensor = torch.FloatTensor(np.array(state)).unsqueeze(0).to(self.device)with torch.no_grad():action_logits, _ = self.model(state_tensor)action = torch.argmax(action_logits, dim=-1).item()next_state, reward, terminated, truncated, _ = self.env.step(action)done = terminated or truncatedepisode_reward += rewardstate = next_statetotal_rewards.append(episode_reward)print(f"评估 Episode {episode + 1}, 奖励: {episode_reward}")mean_reward = np.mean(total_rewards)std_reward = np.std(total_rewards)return mean_reward, std_rewarddef save_model(self, filepath: str):"""保存模型参数:filepath: 文件路径"""torch.save({'model_state_dict': self.model.state_dict(),'meta_optimizer_state_dict': self.meta_learner.meta_optimizer.state_dict(),'episode_rewards': self.episode_rewards,'meta_losses': self.meta_losses}, filepath)print(f"模型保存至 {filepath}")def load_model(self, filepath: str):"""加载模型参数:filepath: 文件路径"""if not os.path.exists(filepath):print(f"文件不存在: {filepath}")returncheckpoint = torch.load(filepath, map_location=self.device)self.model.load_state_dict(checkpoint['model_state_dict'])self.meta_learner.meta_optimizer.load_state_dict(checkpoint['meta_optimizer_state_dict'])self.episode_rewards = checkpoint.get('episode_rewards', [])self.meta_losses = checkpoint.get('meta_losses', [])print(f"模型从 {filepath} 加载")def main():"""主训练函数"""# 初始化元RL框架print("初始化元强化学习框架...")framework = EnhancedMetaRLFramework("Breakout-v4")# 训练参数num_meta_epochs = 100eval_interval = 10save_interval = 20print("\n开始元强化学习训练...")for epoch in range(num_meta_epochs):print(f"\n=== 元Epoch {epoch + 1}/{num_meta_epochs} ===")# 元训练meta_loss = framework.train_meta_epoch(num_tasks=8)print(f"元损失: {meta_loss:.4f}")# 定期评估if (epoch + 1) % eval_interval == 0:print("\n执行评估...")mean_reward, std_reward = framework.evaluate(num_episodes=3)print(f"Epoch {epoch + 1} 评估结果:")print(f"平均奖励: {mean_reward:.2f} ± {std_reward:.2f}")# 保存检查点if (epoch + 1) % save_interval == 0:checkpoint_path = f"meta_rl_checkpoint_epoch{epoch+1}.pth"framework.save_model(checkpoint_path)# 最终评估print("\n=== 最终评估 ===")mean_reward, std_reward = framework.evaluate(num_episodes=5, render=True)print(f"最终性能: {mean_reward:.2f} ± {std_reward:.2f}")# 保存最终模型framework.save_model("final_meta_rl_model.pth")return frameworkif __name__ == "__main__":try:print("=" * 50)print("元强化学习框架 - Atari游戏训练")print("=" * 50)# 检查依赖print("\n检查依赖...")try:import ale_pyprint("ale-py 已安装")except ImportError:print("警告: ale-py 未安装,某些Atari环境可能无法正常工作")# 运行主训练trained_framework = main()# 演示迁移学习到新任务print("\n=== 迁移学习演示 ===")print("注意: 实际迁移学习需要在新任务上微调模型")print("这里仅演示在新环境上的初始性能")# 创建新环境try:new_env = gym.make("Pong-v4", render_mode="rgb_array")new_env = gym.wrappers.ResizeObservation(new_env, (84, 84))new_env = gym.wrappers.GrayScaleObservation(new_env)new_env = gym.wrappers.FrameStack(new_env, 4)# 在新环境上评估print("\n在Pong-v4上评估迁移性能...")state, _ = new_env.reset()done = Falsetotal_reward = 0while not done:new_env.render()time.sleep(0.02)state_tensor = torch.FloatTensor(np.array(state)).unsqueeze(0).to(trained_framework.device)with torch.no_grad():action_logits, _ = trained_framework.model(state_tensor)action = torch.argmax(action_logits, dim=-1).item()next_state, reward, terminated, truncated, _ = new_env.step(action)done = terminated or truncatedtotal_reward += rewardstate = next_stateprint(f"Pong-v4初始奖励: {total_reward}")except gym.error.Error as e:print(f"无法创建Pong环境: {e}")print("迁移学习演示跳过")except Exception as e:print(f"\n训练过程中出现错误: {str(e)}")print("\n确保已安装所有必要的依赖:")print("pip install gymnasium[accept-rom-license] ale-py opencv-python torch")print("\n如果遇到Atari环境问题,请尝试:")print("1. 接受ROM许可证: python -m ale.import.roms <ROM文件路径>")print("2. 检查环境名称是否正确")

http://www.dtcms.com/a/523993.html

相关文章:

  • 重组蛋白表达的几种类型介绍
  • STM32之TM1638数码管及键盘驱动
  • Windows 10 安装 Docker Desktop
  • 数据的存储
  • GJOI 10.20/10.22 题解
  • Linux:权限(完结)|权限管理|修改权限chmod chown charp|文件类型|拓展
  • (一)仓库创建与配置 - .git 目录的结构与作用
  • Office 2010 64位 补丁 officesp2010-kb2687455 安装步骤详解(附安装包)
  • 建免费网站建设银行网站能不能注销卡
  • springboot中的怎么用JUnit进行测试的?
  • LeetCode:695. 岛屿的最大面积
  • 传奇手游可以使用云手机挂机搬砖吗
  • 2025 OSCAR丨与创新者同频!Apache RocketMQ 邀您共赴开源之约
  • Dify配置本地部署的音频识别模型
  • C# .NET Core中Chart图表绘制与PDF导出
  • 相机拍照的图片怎么做网站呀国内互联网公司排名
  • 微信怎么建设自己网站在单机安装wordpress
  • 实验-Vlan基础
  • Windows CMD 常用命令:7 大核心模块速查指南(附实战场景)
  • OCR国内外证件识别接口调用指南-身份证文字识别
  • 使用acme.sh创建自己的第一个https证书
  • Galera Cluster部署
  • 【Flink实战】升级HDFS,对Flink SQL(On Yarn模式)的功能兼容性验证
  • LangChain 表达式语言:SQL 数据库查询链
  • 通辽网站网站建设网站卖东西怎么做
  • 免费个人网站建设大全有什么建设网站的书籍
  • 电脑控制DFPlayer Mini MP3播放音乐
  • Day10:Python实现Excel自动汇总
  • 网站建设 美食站点网站设计确认函
  • 新买的笔记本电脑为什么风扇声音一直很大?怎样解决?