当前位置: 首页 > news >正文

机器学习训练过程中回调函数常用的一些属性

在`stable-baselines3`中,回调函数(Callbacks)提供了许多有用的属性,这些属性可以帮助你在训练过程中访问和操作模型、环境以及训练状态。以下是一些常用的回调函数属性及其使用方法:

1.`self.model`

• 描述:当前训练的模型对象。

• 用途:可以调用模型的方法,例如保存模型、获取模型参数等。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:# 保存模型self.model.save("model_step_{}".format(self.n_calls))return True

2.`self.n_calls`

• 描述:回调函数被调用的次数。

• 用途:可以用来记录训练的进度,例如每1000步保存一次模型。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.n_calls % 1000 == 0:print(f"Step {self.n_calls}")return True

3.`self.num_timesteps`

• 描述:当前训练的总时间步数。

• 用途:可以用来记录训练的总进度,例如每10000步保存一次模型。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.num_timesteps % 10000 == 0:print(f"Total timesteps: {self.num_timesteps}")return True

4.`self.locals`

• 描述:一个字典,包含了当前训练过程中的局部变量。

• 用途:可以访问和操作训练过程中的各种变量,例如奖励、损失等。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:# 获取当前的奖励current_reward = self.locals.get('rewards', 0)print(f"Current reward: {current_reward}")return True

5.`self.globals`

• 描述:一个字典,包含了当前训练过程中的全局变量。

• 用途:可以访问和操作训练过程中的全局变量。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:# 获取全局变量global_var = self.globals.get('some_global_var', None)print(f"Global variable: {global_var}")return True

6.`self.logger`

• 描述:日志记录器对象,用于记录训练过程中的日志信息。

• 用途:可以记录日志,例如训练进度、奖励等。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:self.logger.record("train/reward", self.locals.get('rewards', 0))return True

7.`self.parent`

• 描述:父回调对象,如果有嵌套回调时使用。

• 用途:可以访问和操作父回调对象的属性和方法。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.parent:print(f"Parent callback: {self.parent}")return True

8.`self.verbose`

• 描述:日志详细程度,通常是一个整数(0,1,2)。

• 用途:可以根据`verbose`的值控制日志的详细程度。

class CustomCallback(BaseCallback):def __init__(self, verbose=0):super(CustomCallback, self).__init__(verbose)self.verbose = verbosedef _on_step(self) -> bool:if self.verbose >= 1:print(f"Step {self.n_calls}")if self.verbose >= 2:print(f"Current reward: {self.locals.get('rewards', 0)}")return True

9.`self.locals['ep_info_buffer']`

• 描述:一个缓冲区,存储了每个 episode 的信息,包括奖励和长度。

• 用途:可以用来计算平均奖励、平均长度等统计信息。

class CustomCallback(BaseCallback):def _on_rollout_end(self) -> None:# 获取当前的平均奖励current_mean_reward = self.locals['ep_info_buffer'].get_mean_reward()print(f"Rollout ended. Mean reward: {current_mean_reward}")return True

10.`self.model.get_parameters()`

• 描述:获取模型的参数。

• 用途:可以用来保存模型的参数,或者在训练过程中动态调整参数。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:# 获取模型的参数params = self.model.get_parameters()print(f"Model parameters: {params}")return True

11.`self.model.save()`

• 描述:保存模型到指定路径。

• 用途:可以在训练过程中定期保存模型。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.n_calls % 1000 == 0:self.model.save(f"model_step_{self.n_calls}")return True

12.`self.model.set_parameters()`

• 描述:设置模型的参数。

• 用途:可以在训练过程中动态加载预训练的参数。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.n_calls % 1000 == 0:# 加载预训练的参数self.model.set_parameters("pretrained_params.pkl")return True

13.`self.model.get_env()`

• 描述:获取模型的环境对象。

• 用途:可以用来访问和操作环境,例如获取环境的状态。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:env = self.model.get_env()print(f"Environment: {env}")return True

14.`self.model.learn()`

• 描述:继续训练模型。

• 用途:可以在训练过程中动态调整训练参数,例如学习率。

class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.n_calls % 1000 == 0:# 动态调整学习率self.model.lr_schedule = lambda x: 0.001self.model.learn(total_timesteps=1000, reset_num_timesteps=False)return True

http://www.dtcms.com/a/577013.html

相关文章:

  • [iOS] GCD - 线程与队列
  • DHTMLX Gantt v9.1 正式发布:聚焦易用性与灵活性,打造更高效的项目管理体验
  • 团队介绍网站模板网站开发学什么语言
  • [AI 应用平台] Dify 在金融、教育、医疗行业的典型应用场景
  • Kiro 安全最佳实践:守护代理式 IDE 的 “防火墙”
  • 【Go】--文件和目录的操作
  • Go 语言变量作用域
  • 23、【Ubuntu】【远程开发】内网穿透:SSH 反向隧道
  • 【Linux】不允许你还不会实现shell的部分功能
  • Jmeter+ant+Jenkins 接口自动化框架-利用ant工具批量跑指定目录下的Jmeter 脚本
  • 网站建设制作 企业站开发哪家好兰州又发现一例
  • LeetCode 刷题【146. LRU 缓存】
  • 网站建设 招标公告c2c的代表性的电商平台
  • RedisCluster客户端路由智能缓存
  • K8s从Docker到Containerd的迁移全流程实践
  • Rust语言高级技巧 - RefCell 是另外一个提供了内部可变性的类型,Cell 类型没办法制造出直接指向内部数据的指针,为什么RefCell可以呢?
  • 【Python后端API开发对比】FastAPI、主流框架Flask、Django REST Framework(DRF)及高性能框架Tornado
  • 计算机外设与CPU通信
  • 玩转Rust高级应用 如何编译器对于省略掉的生命周期,不使用“自动推理”策略呢?
  • Python全栈项目:基于Django的电子商务平台开发
  • 网站建设怎么开票网站设计网页设计公司
  • Python实现GPT自动问答与保存
  • 深度强化学习,用神经网络代替 Q-table
  • seo网站建设技巧电线电缆技术支持中山网站建设
  • supabase外键查询语句
  • 【linux端cursor CLI常用命令】
  • 表的增删改查
  • Git 工作区、暂存区和版本库
  • MIT-矩阵链相乘
  • Go语言实战:入门篇-5:函数、服务接口和Swagger UI