一、定义
- 注意点
- 案例
二、实现
- 注意点
强化学习需要根据实际需求创建符合自己业务场景的环境,从而与智能体进行交互。
1. 环境需要自己写reset()、step() 函数。 因为分类任务每个回合不需要多步,因此为了避免reset() 重置时数据id 重置,因此每次遍历的时候+1,从而保证能够学习所有的数据。
2. step() 方法注意,因为每个回合只走一步,因此在step 中,需要终止参数terminated = True。
class ImprovedClassificationEnv(gym.Env):"""改进的分类环境"""metadata = {'render.modes': ['human']}def __init__(self, X, y):super(ImprovedClassificationEnv, self).__init__()... def reset(self, seed=None, options=None):"""重置环境状态"""super().reset(seed=seed)self.current_sample_idx = self.current_episode % self.num_samplesself.current_episode += 1if options and 'sample_idx' in options:self.current_sample_idx = options['sample_idx']sample_features = self.X[self.current_sample_idx].astype(np.float32)extra_info = np.array([self.current_sample_idx / self.num_samples, np.mean(self.y == self.y[self.current_sample_idx]) ], dtype=np.float32)state = np.concatenate([sample_features, extra_info])return state, {}def step(self, action):"""执行动作(进行分类)"""true_label = self.y[self.current_sample_idx]if action == true_label:reward = 2.0else:reward = -1.0if abs(action - true_label) > 1:reward -= 0.5if len(self.visited_samples) < self.num_samples * 0.1: if action not in self.visited_samples:reward += 0.1self.visited_samples.add(action)terminated = Truetruncated = Falseinfo = {'true_label': true_label,'predicted_label': action,'correct': action == true_label,'sample_idx': self.current_sample_idx}next_state, _ = self.reset()return next_state, reward, terminated, truncated, infodef render(self, mode='human'):if mode == 'human':print(f"样本 {self.current_sample_idx}: 真实标签={self.y[self.current_sample_idx]}")return None
- 案例
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import deque
import torch
import torch.nn as nn
np.random.seed(42)
torch.manual_seed(42)
def generate_classification_data(n_samples=10000, n_features=10, n_classes=3):"""生成分类数据集"""print("生成分类数据...")X, y = make_classification(n_samples=n_samples,n_features=n_features,n_informative=8,n_redundant=2,n_classes=n_classes,n_clusters_per_class=1,random_state=42)scaler = StandardScaler()X = scaler.fit_transform(X)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)print(f"训练集形状: {X_train.shape}, 测试集形状: {X_test.shape}")print(f"类别分布 - 训练集: {np.bincount(y_train)}, 测试集: {np.bincount(y_test)}")return X_train, X_test, y_train, y_test, scaler
class ImprovedClassificationEnv(gym.Env):"""改进的分类环境"""metadata = {'render.modes': ['human']}def __init__(self, X, y):super(ImprovedClassificationEnv, self).__init__()self.X = Xself.y = yself.num_samples, self.num_features = X.shapeself.num_classes = len(np.unique(y))self.current_episode = 0self.action_space = spaces.Discrete(self.num_classes)self.observation_space = spaces.Box(low=-5.0,high=5.0,shape=(self.num_features + 2,), dtype=np.float32)self.current_sample_idx = Noneself.visited_samples = set()def reset(self, seed=None, options=None):"""重置环境状态"""super().reset(seed=seed)self.current_sample_idx = self.current_episode % self.num_samplesself.current_episode += 1if options and 'sample_idx' in options:self.current_sample_idx = options['sample_idx']sample_features = self.X[self.current_sample_idx].astype(np.float32)extra_info = np.array([self.current_sample_idx / self.num_samples, np.mean(self.y == self.y[self.current_sample_idx]) ], dtype=np.float32)state = np.concatenate([sample_features, extra_info])return state, {}def step(self, action):"""执行动作(进行分类)"""true_label = self.y[self.current_sample_idx]if action == true_label:reward = 2.0else:reward = -1.0if abs(action - true_label) > 1:reward -= 0.5if len(self.visited_samples) < self.num_samples * 0.1: if action not in self.visited_samples:reward += 0.1self.visited_samples.add(action)terminated = Truetruncated = Falseinfo = {'true_label': true_label,'predicted_label': action,'correct': action == true_label,'sample_idx': self.current_sample_idx}next_state, _ = self.reset()return next_state, reward, terminated, truncated, infodef render(self, mode='human'):if mode == 'human':print(f"样本 {self.current_sample_idx}: 真实标签={self.y[self.current_sample_idx]}")return None
def enhanced_evaluate_model(model, X_test, y_test, env_class):"""增强的模型评估"""print("\n评估模型性能...")eval_env = env_class(X_test, y_test)mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=len(X_test))print(f"平均奖励: {mean_reward:.4f} +/- {std_reward:.4f}")all_predictions = []all_true_labels = []all_confidences = []for i in range(len(X_test)):obs, _ = eval_env.reset(options={'sample_idx': i})action, _ = model.predict(obs, deterministic=True)all_predictions.append(action)all_true_labels.append(y_test[i])accuracy = accuracy_score(all_true_labels, all_predictions)print(f"分类准确率: {accuracy:.4f}")print("\n详细分类报告:")print(classification_report(all_true_labels, all_predictions))return accuracy, mean_reward
class ImprovedTrainingCallback(BaseCallback):"""改进的训练回调"""def __init__(self, eval_env, check_freq=1000, verbose=0):super(ImprovedTrainingCallback, self).__init__(verbose)self.eval_env = eval_envself.check_freq = check_freqself.best_accuracy = 0self.accuracies = []self.rewards = []def _on_step(self) -> bool:if self.n_calls % self.check_freq == 0:current_accuracy, mean_reward = enhanced_evaluate_model(self.model, self.eval_env.X, self.eval_env.y,ImprovedClassificationEnv)self.accuracies.append(current_accuracy)self.rewards.append(mean_reward)print(f"Timestep: {self.n_calls}")print(f"准确率: {current_accuracy:.4f}, 平均奖励: {mean_reward:.4f}")if current_accuracy > self.best_accuracy:self.best_accuracy = current_accuracyprint(f"新的最佳模型! 准确率: {current_accuracy:.4f}")self.model.save("best_classifier_model")return True
def main():X_train, X_test, y_train, y_test, scaler = generate_classification_data()env = make_vec_env(lambda: ImprovedClassificationEnv(X_train, y_train), n_envs=4,seed=42)model = PPO("MlpPolicy", env, verbose=1,learning_rate=1e-4, n_steps=1024, batch_size=256, n_epochs=20, gamma=0.99,gae_lambda=0.95,clip_range=0.1, ent_coef=0.02, vf_coef=0.5, max_grad_norm=0.5, tensorboard_log="./tensorboard_logs/",)eval_env = ImprovedClassificationEnv(X_test, y_test)callback = ImprovedTrainingCallback(eval_env, check_freq=5000)accuracy, mean_reward = enhanced_evaluate_model(model, X_test, y_test, ImprovedClassificationEnv)print("\n开始训练...")model.learn(total_timesteps=200000, callback=callback,progress_bar=True,tb_log_name="ppo_classification")model.save("ppo_classifier_final")print("训练完成,模型已保存")accuracy, mean_reward = enhanced_evaluate_model(model, X_test, y_test, ImprovedClassificationEnv)from sklearn.linear_model import LogisticRegressionfrom sklearn.ensemble import RandomForestClassifierprint("\n与传统监督学习方法对比:")lr_model = LogisticRegression(random_state=42, max_iter=1000)lr_model.fit(X_train, y_train)lr_pred = lr_model.predict(X_test)lr_accuracy = accuracy_score(y_test, lr_pred)print(f"逻辑回归准确率: {lr_accuracy:.4f}")rf_model = RandomForestClassifier(random_state=42, n_estimators=100)rf_model.fit(X_train, y_train)rf_pred = rf_model.predict(X_test)rf_accuracy = accuracy_score(y_test, rf_pred)print(f"随机森林准确率: {rf_accuracy:.4f}")methods = ['PPO RL', 'Logistic Regression', 'Random Forest']accuracies = [accuracy, lr_accuracy, rf_accuracy]plt.figure(figsize=(10, 6))bars = plt.bar(methods, accuracies, color=['blue', 'green', 'orange'])plt.ylabel('Accuracy')plt.title('Classification Performance Comparison')plt.ylim(0, 1)for bar, acc in zip(bars, accuracies):plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{acc:.4f}', ha='center', va='bottom')plt.tight_layout()plt.savefig('performance_comparison.png')plt.show()print(f"\n结果分析:")print(f"PPO强化学习准确率: {accuracy:.4f}")print(f"逻辑回归准确率: {lr_accuracy:.4f}")print(f"随机森林准确率: {rf_accuracy:.4f}")if __name__ == "__main__":main()