机器学习训练过程中的回调函数BaseCallback
在Python编程中,特别是在深度学习和强化学习领域,`BaseCallback`通常是一个基类,用于定义回调函数的接口。回调函数是一种在训练过程中被调用的函数,用于执行一些特定的任务,比如记录日志、保存模型、调整学习率等。
from stable_baselines3.common.callbacks import BaseCallbackclass CyberTrainingCallback(BaseCallback):def __init__(self, verbose=0):super(CyberTrainingCallback, self).__init__(verbose)# 初始化一些变量,例如用于记录训练过程中的信息self.best_mean_reward = -float('inf')self.last_mean_reward = -float('inf')self.check_freq = 1000 # 每1000步检查一次self.save_path = None # 保存模型的路径def _on_training_start(self) -> None:"""在训练开始时调用。"""# 可以在这里初始化一些变量或者打印一些信息print("Training is starting!")def _on_step(self) -> bool:"""在每一步训练时调用。"""# 每隔一定步数检查一次if self.n_calls % self.check_freq == 0:# 获取当前的平均奖励current_mean_reward = self.locals['rewards'].mean()print(f"Step {self.n_calls}: Mean reward = {current_mean_reward}")# 如果当前的平均奖励比之前的最好奖励高,则保存模型if current_mean_reward > self.best_mean_reward:self.best_mean_reward = current_mean_rewardif self.save_path is not None:self.model.save(self.save_path)print(f"Model saved to {self.save_path}")return True # 返回True表示训练继续,返回False表示停止训练def _on_training_end(self) -> None:"""在训练结束时调用。"""# 可以在这里执行一些清理工作或者打印一些信息print("Training has ended!")# 使用回调
from stable_baselines3 import PPO# 创建一个模型
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)# 创建回调实例
callback = CyberTrainingCallback(check_freq=1000, save_path='./best_model')# 开始训练
model.learn(total_timesteps=10000, callback=callback)代码解释:
• 定义回调类:
• `CyberTrainingCallback`继承自`BaseCallback`。
• 在`__init__`方法中,初始化了一些变量,例如`best_mean_reward`用于记录最好的平均奖励,`check_freq`用于设置检查频率,`save_path`用于设置保存模型的路径。
• 回调方法:• `_on_training_start`:在训练开始时调用,可以在这里初始化一些变量或者打印一些信息。
• `_on_step`:在每一步训练时调用,可以在这里执行一些检查和操作。例如,每隔`check_freq`步检查一次当前的平均奖励,并在奖励提高时保存模型。
• `_on_training_end`:在训练结束时调用,可以在这里执行一些清理工作或者打印一些信息。
• 使用回调:• 创建一个`PPO`模型。
• 创建`CyberTrainingCallback`的实例,并设置检查频率和保存路径。
• 在调用`model.learn`方法时,将回调实例传递给`callback`参数,这样在训练过程中就会调用回调方法。
二、常用回调函数
以下是一些常用的回调函数及其使用方法:
1.`_on_training_start()`
在训练开始时调用,可以在这里初始化一些变量或者打印一些信息。
class CustomCallback(BaseCallback):def _on_training_start(self) -> None:"""在训练开始时调用。"""print("Training is starting!")# 初始化一些变量self.best_mean_reward = -float('inf')self.save_path = './best_model'2.`_on_rollout_start()`
在每个 rollout(即每个 episode 或者每个 batch 的采样过程)开始时调用。
class CustomCallback(BaseCallback):def _on_rollout_start(self) -> None:"""在每个 rollout 开始时调用。"""print("Rollout is starting!")3.`_on_step()`
在每一步训练时调用,可以在这里执行一些每步的操作,例如记录日志、调整学习率等。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:"""在每一步训练时调用。"""# 每隔一定步数检查一次if self.n_calls % 1000 == 0:print(f"Step {self.n_calls}")return True # 返回True表示训练继续,返回False表示停止训练4.`_on_rollout_end()`
在每个 rollout 结束时调用,可以在这里执行一些在每个 rollout 结束后需要进行的操作,例如保存模型、记录日志等。
class CustomCallback(BaseCallback):def _on_rollout_end(self) -> None:"""在每个 rollout 结束时调用。"""# 获取当前的平均奖励current_mean_reward = self.locals['ep_info_buffer'].get_mean_reward()print(f"Rollout ended. Mean reward: {current_mean_reward}")# 如果当前的平均奖励比之前的最好奖励高,则保存模型if current_mean_reward > self.best_mean_reward:self.best_mean_reward = current_mean_rewardif self.save_path is not None:self.model.save(self.save_path)print(f"Model saved to {self.save_path}")5.`_on_training_end()`
在训练结束时调用,可以在这里执行一些清理工作或者打印一些信息。
class CustomCallback(BaseCallback):def _on_training_end(self) -> None:"""在训练结束时调用。"""print("Training has ended!")6.`CheckPointCallback`
用于在训练过程中定期保存模型。
from stable_baselines3.common.callbacks import CheckpointCallback# 创建一个 CheckpointCallback 实例
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./checkpoints')# 使用回调
model.learn(total_timesteps=10000, callback=checkpoint_callback)7.`EvalCallback`
用于在训练过程中定期评估模型的性能,并根据评估结果保存模型。
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env# 创建一个评估环境
eval_env = make_vec_env('CartPole-v1', n_envs=5)# 创建一个 EvalCallback 实例
eval_callback = EvalCallback(eval_env, best_model_save_path='./best_model', log_path='./eval_logs', eval_freq=1000)# 使用回调
model.learn(total_timesteps=10000, callback=eval_callback)8.`StopTrainingOnRewardThreshold`
用于在训练过程中,当模型的平均奖励达到某个阈值时停止训练。
from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold# 创建一个 StopTrainingOnRewardThreshold 实例
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)# 使用回调
model.learn(total_timesteps=10000, callback=stop_callback)9.`EveryNTimesteps`
用于在每N个时间步调用另一个回调函数。
from stable_baselines3.common.callbacks import EveryNTimesteps# 创建一个自定义回调
class CustomCallback(BaseCallback):def _on_step(self) -> bool:print(f"Step {self.n_calls}")return True# 创建一个 EveryNTimesteps 实例
callback = EveryNTimesteps(n_steps=1000, callback=CustomCallback())# 使用回调
model.learn(total_timesteps=10000, callback=callback)10.`CallbackList`
用于将多个回调函数组合在一起,同时使用多个回调。
from stable_baselines3.common.callbacks import CallbackList# 创建多个回调
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./checkpoints')
eval_callback = EvalCallback(eval_env, best_model_save_path='./best_model', log_path='./eval_logs', eval_freq=1000)
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)# 将多个回调组合在一起
callback = CallbackList([checkpoint_callback, eval_callback, stop_callback])# 使用回调
model.learn(total_timesteps=10000, callback=callback)