当前位置: 首页 > news >正文

机器学习训练过程中的回调函数BaseCallback

在Python编程中,特别是在深度学习和强化学习领域,`BaseCallback`通常是一个基类,用于定义回调函数的接口。回调函数是一种在训练过程中被调用的函数,用于执行一些特定的任务,比如记录日志、保存模型、调整学习率等。

from stable_baselines3.common.callbacks import BaseCallbackclass CyberTrainingCallback(BaseCallback):def __init__(self, verbose=0):super(CyberTrainingCallback, self).__init__(verbose)# 初始化一些变量,例如用于记录训练过程中的信息self.best_mean_reward = -float('inf')self.last_mean_reward = -float('inf')self.check_freq = 1000  # 每1000步检查一次self.save_path = None  # 保存模型的路径def _on_training_start(self) -> None:"""在训练开始时调用。"""# 可以在这里初始化一些变量或者打印一些信息print("Training is starting!")def _on_step(self) -> bool:"""在每一步训练时调用。"""# 每隔一定步数检查一次if self.n_calls % self.check_freq == 0:# 获取当前的平均奖励current_mean_reward = self.locals['rewards'].mean()print(f"Step {self.n_calls}: Mean reward = {current_mean_reward}")# 如果当前的平均奖励比之前的最好奖励高,则保存模型if current_mean_reward > self.best_mean_reward:self.best_mean_reward = current_mean_rewardif self.save_path is not None:self.model.save(self.save_path)print(f"Model saved to {self.save_path}")return True  # 返回True表示训练继续,返回False表示停止训练def _on_training_end(self) -> None:"""在训练结束时调用。"""# 可以在这里执行一些清理工作或者打印一些信息print("Training has ended!")# 使用回调
from stable_baselines3 import PPO# 创建一个模型
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)# 创建回调实例
callback = CyberTrainingCallback(check_freq=1000, save_path='./best_model')# 开始训练
model.learn(total_timesteps=10000, callback=callback)

代码解释:

• 定义回调类:

• `CyberTrainingCallback`继承自`BaseCallback`。

• 在`__init__`方法中,初始化了一些变量,例如`best_mean_reward`用于记录最好的平均奖励,`check_freq`用于设置检查频率,`save_path`用于设置保存模型的路径。


• 回调方法:

• `_on_training_start`:在训练开始时调用,可以在这里初始化一些变量或者打印一些信息。

• `_on_step`:在每一步训练时调用,可以在这里执行一些检查和操作。例如,每隔`check_freq`步检查一次当前的平均奖励,并在奖励提高时保存模型。

• `_on_training_end`:在训练结束时调用,可以在这里执行一些清理工作或者打印一些信息。


• 使用回调:

• 创建一个`PPO`模型。

• 创建`CyberTrainingCallback`的实例,并设置检查频率和保存路径。

• 在调用`model.learn`方法时,将回调实例传递给`callback`参数,这样在训练过程中就会调用回调方法。

二、常用回调函数

以下是一些常用的回调函数及其使用方法:

1.`_on_training_start()`
在训练开始时调用,可以在这里初始化一些变量或者打印一些信息。

class CustomCallback(BaseCallback):def _on_training_start(self) -> None:"""在训练开始时调用。"""print("Training is starting!")# 初始化一些变量self.best_mean_reward = -float('inf')self.save_path = './best_model'

2.`_on_rollout_start()`
在每个 rollout(即每个 episode 或者每个 batch 的采样过程)开始时调用。

class CustomCallback(BaseCallback):def _on_rollout_start(self) -> None:"""在每个 rollout 开始时调用。"""print("Rollout is starting!")

3.`_on_step()`
在每一步训练时调用,可以在这里执行一些每步的操作,例如记录日志、调整学习率等。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:"""在每一步训练时调用。"""# 每隔一定步数检查一次if self.n_calls % 1000 == 0:print(f"Step {self.n_calls}")return True  # 返回True表示训练继续,返回False表示停止训练

4.`_on_rollout_end()`
在每个 rollout 结束时调用,可以在这里执行一些在每个 rollout 结束后需要进行的操作,例如保存模型、记录日志等。

class CustomCallback(BaseCallback):def _on_rollout_end(self) -> None:"""在每个 rollout 结束时调用。"""# 获取当前的平均奖励current_mean_reward = self.locals['ep_info_buffer'].get_mean_reward()print(f"Rollout ended. Mean reward: {current_mean_reward}")# 如果当前的平均奖励比之前的最好奖励高,则保存模型if current_mean_reward > self.best_mean_reward:self.best_mean_reward = current_mean_rewardif self.save_path is not None:self.model.save(self.save_path)print(f"Model saved to {self.save_path}")

5.`_on_training_end()`
在训练结束时调用,可以在这里执行一些清理工作或者打印一些信息。

class CustomCallback(BaseCallback):def _on_training_end(self) -> None:"""在训练结束时调用。"""print("Training has ended!")

6.`CheckPointCallback`
用于在训练过程中定期保存模型。

from stable_baselines3.common.callbacks import CheckpointCallback# 创建一个 CheckpointCallback 实例
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./checkpoints')# 使用回调
model.learn(total_timesteps=10000, callback=checkpoint_callback)

7.`EvalCallback`
用于在训练过程中定期评估模型的性能,并根据评估结果保存模型。

from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env# 创建一个评估环境
eval_env = make_vec_env('CartPole-v1', n_envs=5)# 创建一个 EvalCallback 实例
eval_callback = EvalCallback(eval_env, best_model_save_path='./best_model', log_path='./eval_logs', eval_freq=1000)# 使用回调
model.learn(total_timesteps=10000, callback=eval_callback)

8.`StopTrainingOnRewardThreshold`
用于在训练过程中,当模型的平均奖励达到某个阈值时停止训练。

from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold# 创建一个 StopTrainingOnRewardThreshold 实例
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)# 使用回调
model.learn(total_timesteps=10000, callback=stop_callback)

9.`EveryNTimesteps`
用于在每N个时间步调用另一个回调函数。

from stable_baselines3.common.callbacks import EveryNTimesteps# 创建一个自定义回调
class CustomCallback(BaseCallback):def _on_step(self) -> bool:print(f"Step {self.n_calls}")return True# 创建一个 EveryNTimesteps 实例
callback = EveryNTimesteps(n_steps=1000, callback=CustomCallback())# 使用回调
model.learn(total_timesteps=10000, callback=callback)

10.`CallbackList`
用于将多个回调函数组合在一起,同时使用多个回调。

from stable_baselines3.common.callbacks import CallbackList# 创建多个回调
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./checkpoints')
eval_callback = EvalCallback(eval_env, best_model_save_path='./best_model', log_path='./eval_logs', eval_freq=1000)
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)# 将多个回调组合在一起
callback = CallbackList([checkpoint_callback, eval_callback, stop_callback])# 使用回调
model.learn(total_timesteps=10000, callback=callback)

http://www.dtcms.com/a/578028.html

相关文章:

  • Cordys CRM正式开源,AI驱动客户关系管理加速演进
  • 河北省 建设执业注册中心网站长沙 汽车 网站建设
  • 手机如何定位:从时间差到地图上的“小蓝点”
  • Rust : Send、Sync与现实世界的映射
  • PHP推荐权重算法以及分页
  • 做软件赚钱的网站有哪些淘宝客seo推广教程
  • 企业网站制作建设建设通app官方下载
  • 【FAQ】HarmonyOS SDK 闭源开放能力 — Form Kit
  • 学习:JavaScript(8)
  • Docker的host网络模式
  • HORIBA 新型便携式废气测量系统技术解析
  • 建设自有网站需要什么杭州网站建设设计公司哪家好
  • 常州网站建设方案维护小皮搭建本地网站
  • 静态路由-等价路由、浮动路由配置
  • 37-38 for循环
  • 【SSM 框架 | day27 spring MVC】
  • H618-配置静态IP
  • 全面解析网站建设及报价高端网约车有哪些平台
  • 商城 静态网站模板wordpress 作品集插件
  • 论文分享 |用线性复杂度实现Transformer级性能的递归网络新范式
  • 12_FastMCP 2.x 中文文档之FastMCP高级功能:图标详解
  • 打工人日报#20251106
  • 在Windows上通过WSL体验openEuler:打造高效的AI开发环境
  • ERP和WMS系统有什么区别吗?ERP系统能代替WMS仓储管理系统吗?
  • 我在造一个编程语言,叫 Free
  • 石家庄做网站那家好做推广的公司一般都叫什么
  • 论文分享 | AlexNet:点燃深度学习革命的“一把火”
  • 拉普拉斯算子及散度
  • 前端FAQ: 如何使⽤Web Workers来提⾼⻚⾯性能?
  • 怎么建设淘客自己的网站_品牌形象网站建设