当前位置: 首页 > news >正文

强化学习PPO-分类任务

一、定义

  1. 注意点
  2. 案例

二、实现

  1. 注意点
    强化学习需要根据实际需求创建符合自己业务场景的环境,从而与智能体进行交互。
    1. 环境需要自己写reset()、step() 函数。 因为分类任务每个回合不需要多步,因此为了避免reset() 重置时数据id 重置,因此每次遍历的时候+1,从而保证能够学习所有的数据。
    2. step() 方法注意,因为每个回合只走一步,因此在step 中,需要终止参数terminated = True。
class ImprovedClassificationEnv(gym.Env):"""改进的分类环境"""metadata = {'render.modes': ['human']}def __init__(self, X, y):super(ImprovedClassificationEnv, self).__init__()...   def reset(self, seed=None, options=None):"""重置环境状态"""super().reset(seed=seed)# 循环使用所有样本self.current_sample_idx = self.current_episode % self.num_samplesself.current_episode += 1# 如果指定了特定样本(用于评估)if options and 'sample_idx' in options:self.current_sample_idx = options['sample_idx']# 创建增强的状态:特征 + 样本索引归一化 + 类别先验sample_features = self.X[self.current_sample_idx].astype(np.float32)# 添加额外信息:样本索引(归一化)和类别分布先验extra_info = np.array([self.current_sample_idx / self.num_samples,  # 归一化索引np.mean(self.y == self.y[self.current_sample_idx])  # 同类比例], dtype=np.float32)state = np.concatenate([sample_features, extra_info])return state, {}def step(self, action):"""执行动作(进行分类)"""true_label = self.y[self.current_sample_idx]# 改进的奖励函数if action == true_label:# 正确分类:基础奖励 + 置信度奖励reward = 2.0else:# 错误分类:基础惩罚 + 根据错误程度调整reward = -1.0# 如果错误程度较大(如将类别0预测为类别2),惩罚更重if abs(action - true_label) > 1:reward -= 0.5# 添加探索奖励(鼓励尝试不同类别)if len(self.visited_samples) < self.num_samples * 0.1:  # 前10%的探索阶段if action not in self.visited_samples:reward += 0.1self.visited_samples.add(action)# 总是终止,因为每个样本只做一次分类决策terminated = Truetruncated = Falseinfo = {'true_label': true_label,'predicted_label': action,'correct': action == true_label,'sample_idx': self.current_sample_idx}# 返回下一个状态(虽然是终止状态,但仍返回当前状态)next_state, _ = self.reset()return next_state, reward, terminated, truncated, infodef render(self, mode='human'):if mode == 'human':print(f"样本 {self.current_sample_idx}: 真实标签={self.y[self.current_sample_idx]}")return None
  1. 案例
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import deque
import torch
import torch.nn as nn# 设置随机种子确保可重现性
np.random.seed(42)
torch.manual_seed(42)# 1. 生成模拟数据
def generate_classification_data(n_samples=10000, n_features=10, n_classes=3):"""生成分类数据集"""print("生成分类数据...")X, y = make_classification(n_samples=n_samples,n_features=n_features,n_informative=8,n_redundant=2,n_classes=n_classes,n_clusters_per_class=1,random_state=42)# 数据标准化scaler = StandardScaler()X = scaler.fit_transform(X)# 分割训练测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)print(f"训练集形状: {X_train.shape}, 测试集形状: {X_test.shape}")print(f"类别分布 - 训练集: {np.bincount(y_train)}, 测试集: {np.bincount(y_test)}")return X_train, X_test, y_train, y_test, scaler# 2. 改进的分类环境
class ImprovedClassificationEnv(gym.Env):"""改进的分类环境"""metadata = {'render.modes': ['human']}def __init__(self, X, y):super(ImprovedClassificationEnv, self).__init__()self.X = Xself.y = yself.num_samples, self.num_features = X.shapeself.num_classes = len(np.unique(y))self.current_episode = 0# 动作空间:选择类别self.action_space = spaces.Discrete(self.num_classes)# 状态空间:当前样本特征 + 额外信息self.observation_space = spaces.Box(low=-5.0,high=5.0,shape=(self.num_features + 2,),  # 增加额外信息dtype=np.float32)self.current_sample_idx = Noneself.visited_samples = set()def reset(self, seed=None, options=None):"""重置环境状态"""super().reset(seed=seed)# 循环使用所有样本self.current_sample_idx = self.current_episode % self.num_samplesself.current_episode += 1# 如果指定了特定样本(用于评估)if options and 'sample_idx' in options:self.current_sample_idx = options['sample_idx']# 创建增强的状态:特征 + 样本索引归一化 + 类别先验sample_features = self.X[self.current_sample_idx].astype(np.float32)# 添加额外信息:样本索引(归一化)和类别分布先验extra_info = np.array([self.current_sample_idx / self.num_samples,  # 归一化索引np.mean(self.y == self.y[self.current_sample_idx])  # 同类比例], dtype=np.float32)state = np.concatenate([sample_features, extra_info])return state, {}def step(self, action):"""执行动作(进行分类)"""true_label = self.y[self.current_sample_idx]# 改进的奖励函数if action == true_label:# 正确分类:基础奖励 + 置信度奖励reward = 2.0else:# 错误分类:基础惩罚 + 根据错误程度调整reward = -1.0# 如果错误程度较大(如将类别0预测为类别2),惩罚更重if abs(action - true_label) > 1:reward -= 0.5# 添加探索奖励(鼓励尝试不同类别)if len(self.visited_samples) < self.num_samples * 0.1:  # 前10%的探索阶段if action not in self.visited_samples:reward += 0.1self.visited_samples.add(action)# 总是终止,因为每个样本只做一次分类决策terminated = Truetruncated = Falseinfo = {'true_label': true_label,'predicted_label': action,'correct': action == true_label,'sample_idx': self.current_sample_idx}# 返回下一个状态(虽然是终止状态,但仍返回当前状态)next_state, _ = self.reset()return next_state, reward, terminated, truncated, infodef render(self, mode='human'):if mode == 'human':print(f"样本 {self.current_sample_idx}: 真实标签={self.y[self.current_sample_idx]}")return None# 4. 改进的评估函数
def enhanced_evaluate_model(model, X_test, y_test, env_class):"""增强的模型评估"""print("\n评估模型性能...")eval_env = env_class(X_test, y_test)# 评估奖励mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=len(X_test))print(f"平均奖励: {mean_reward:.4f} +/- {std_reward:.4f}")# 计算准确率和其他指标all_predictions = []all_true_labels = []all_confidences = []for i in range(len(X_test)):obs, _ = eval_env.reset(options={'sample_idx': i})action, _ = model.predict(obs, deterministic=True)all_predictions.append(action)all_true_labels.append(y_test[i])accuracy = accuracy_score(all_true_labels, all_predictions)print(f"分类准确率: {accuracy:.4f}")# 显示详细分类报告print("\n详细分类报告:")print(classification_report(all_true_labels, all_predictions))# 绘制混淆矩阵# cm = confusion_matrix(all_true_labels, all_predictions)# plt.figure(figsize=(8, 6))# sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')# plt.title('Confusion Matrix')# plt.ylabel('True Label')# plt.xlabel('Predicted Label')# plt.tight_layout()# plt.savefig('confusion_matrix.png')# plt.show()return accuracy, mean_reward# 5. 改进的训练回调
class ImprovedTrainingCallback(BaseCallback):"""改进的训练回调"""def __init__(self, eval_env, check_freq=1000, verbose=0):super(ImprovedTrainingCallback, self).__init__(verbose)self.eval_env = eval_envself.check_freq = check_freqself.best_accuracy = 0self.accuracies = []self.rewards = []def _on_step(self) -> bool:if self.n_calls % self.check_freq == 0:# 评估当前模型current_accuracy, mean_reward = enhanced_evaluate_model(self.model, self.eval_env.X, self.eval_env.y,ImprovedClassificationEnv)self.accuracies.append(current_accuracy)self.rewards.append(mean_reward)print(f"Timestep: {self.n_calls}")print(f"准确率: {current_accuracy:.4f}, 平均奖励: {mean_reward:.4f}")# 保存最佳模型if current_accuracy > self.best_accuracy:self.best_accuracy = current_accuracyprint(f"新的最佳模型! 准确率: {current_accuracy:.4f}")self.model.save("best_classifier_model")return True# 7. 主函数
def main():# 生成数据X_train, X_test, y_train, y_test, scaler = generate_classification_data()# 创建改进的环境env = make_vec_env(lambda: ImprovedClassificationEnv(X_train, y_train), n_envs=4,seed=42)# 创建改进的PPO模型model = PPO("MlpPolicy", env, verbose=1,learning_rate=1e-4,  # 更小的学习率n_steps=1024,        # 更多的步数batch_size=256,      # 更大的批次n_epochs=20,         # 更多的训练轮次gamma=0.99,gae_lambda=0.95,clip_range=0.1,      # 更小的裁剪范围ent_coef=0.02,       # 适当的熵系数vf_coef=0.5,         # 值函数系数max_grad_norm=0.5,   # 梯度裁剪tensorboard_log="./tensorboard_logs/",)# 创建评估环境和回调函数eval_env = ImprovedClassificationEnv(X_test, y_test)callback = ImprovedTrainingCallback(eval_env, check_freq=5000)accuracy, mean_reward = enhanced_evaluate_model(model, X_test, y_test, ImprovedClassificationEnv)# 训练模型print("\n开始训练...")model.learn(total_timesteps=200000,  # 更多的训练步数callback=callback,progress_bar=True,tb_log_name="ppo_classification")# 保存最终模型model.save("ppo_classifier_final")print("训练完成,模型已保存")# 评估模型accuracy, mean_reward = enhanced_evaluate_model(model, X_test, y_test, ImprovedClassificationEnv)# 与传统方法对比from sklearn.linear_model import LogisticRegressionfrom sklearn.ensemble import RandomForestClassifierprint("\n与传统监督学习方法对比:")lr_model = LogisticRegression(random_state=42, max_iter=1000)lr_model.fit(X_train, y_train)lr_pred = lr_model.predict(X_test)lr_accuracy = accuracy_score(y_test, lr_pred)print(f"逻辑回归准确率: {lr_accuracy:.4f}")rf_model = RandomForestClassifier(random_state=42, n_estimators=100)rf_model.fit(X_train, y_train)rf_pred = rf_model.predict(X_test)rf_accuracy = accuracy_score(y_test, rf_pred)print(f"随机森林准确率: {rf_accuracy:.4f}")# 绘制性能对比methods = ['PPO RL', 'Logistic Regression', 'Random Forest']accuracies = [accuracy, lr_accuracy, rf_accuracy]plt.figure(figsize=(10, 6))bars = plt.bar(methods, accuracies, color=['blue', 'green', 'orange'])plt.ylabel('Accuracy')plt.title('Classification Performance Comparison')plt.ylim(0, 1)for bar, acc in zip(bars, accuracies):plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{acc:.4f}', ha='center', va='bottom')plt.tight_layout()plt.savefig('performance_comparison.png')plt.show()# 分析结果print(f"\n结果分析:")print(f"PPO强化学习准确率: {accuracy:.4f}")print(f"逻辑回归准确率: {lr_accuracy:.4f}")print(f"随机森林准确率: {rf_accuracy:.4f}")if __name__ == "__main__":main()

文章转载自:

http://u2MjPUlH.wkjzt.cn
http://1cNpT6zN.wkjzt.cn
http://YplqKUtm.wkjzt.cn
http://UvIAgM4x.wkjzt.cn
http://IOOYEVdh.wkjzt.cn
http://uYXPzZfL.wkjzt.cn
http://5YyhAKk9.wkjzt.cn
http://hgM1g01s.wkjzt.cn
http://OgWdE3Yk.wkjzt.cn
http://1en1jfGE.wkjzt.cn
http://cUqhPcFn.wkjzt.cn
http://5dZMnRWt.wkjzt.cn
http://G5uy3j3Y.wkjzt.cn
http://hqCE4FXL.wkjzt.cn
http://jjB82uPV.wkjzt.cn
http://IDaYY8jA.wkjzt.cn
http://bCX27L9x.wkjzt.cn
http://ppycrKSl.wkjzt.cn
http://JeMfCYf6.wkjzt.cn
http://5UPxIZkQ.wkjzt.cn
http://x2tyhIfR.wkjzt.cn
http://TkYhDGdt.wkjzt.cn
http://s14l72Rw.wkjzt.cn
http://njmMRWCr.wkjzt.cn
http://rAWUqhVa.wkjzt.cn
http://1ln657FV.wkjzt.cn
http://NBkAIudF.wkjzt.cn
http://Qduwlypt.wkjzt.cn
http://NI1clsRG.wkjzt.cn
http://79sf5m8Z.wkjzt.cn
http://www.dtcms.com/a/385660.html

相关文章:

  • 决策树模型全解析:从分类到回归(基于鸢尾花数据集)
  • shell脚本部署lamp
  • c语言6:static 关键字控制变量/函数的 “生命周期” 与 “可见性”
  • MySQL 数据库对象与视图:从概念到实战,掌握虚拟表的核心价值
  • 【VPX361】基于3U VPX总线架构的XCZU47DR射频收发子模块
  • 消火栓设备工程量计算 -【图形识别】秒计量
  • 基于LangGraph的深度研究智能体技术解析
  • 【哈希表】1512. 好数对的数目|2506. 统计相似字符串对的数目
  • Java--多线程基础知识(2)
  • 活泼解析pthread_join函数:多线程世界的等待仪式
  • 机器视觉的智能手表后盖激光打标应用
  • 第七章 来日方长(2025.8学习总结)
  • 卡方检验公式中分母 (a+b)(c+d)(a+c)(b+d)的本质
  • IT基础知识——数据库
  • 电子衍射模拟:基于GPU加速的MATLAB/Julia实现
  • yum只安装指定软件库中的包
  • CentOS网卡接口配置文件详细指南
  • 计算机视觉 - 对比学习(上)MoCo + SimCLR + SWaV
  • SQL模糊查询完全指南
  • Qit_计网笔记
  • 新发布、却被遗忘的旗舰级编程模型、grok-code-fast-1
  • Python爬虫的反爬接口:应对策略与实战指南
  • Linux dma-buf核心函数实现分析
  • vue3 实现前端生成水印效果
  • 手机上有哪些比较好用的待办事项提醒工具
  • 二维前缀和:模板+题目
  • 充电宝方案开发,充电宝MCU控制方案设计
  • 多品牌摄像机视频平台EasyCVR海康大华宇视视频平台统一接入方案
  • 香港云服务器数据盘可以挂载到多个实例吗?
  • 【C语言】用程序求1!+2!+3!+4!+...n!的和,来看看?