机器学习训练过程中回调函数常用的一些属性
在`stable-baselines3`中,回调函数(Callbacks)提供了许多有用的属性,这些属性可以帮助你在训练过程中访问和操作模型、环境以及训练状态。以下是一些常用的回调函数属性及其使用方法:
1.`self.model`
• 描述:当前训练的模型对象。
• 用途:可以调用模型的方法,例如保存模型、获取模型参数等。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:# 保存模型self.model.save("model_step_{}".format(self.n_calls))return True2.`self.n_calls`
• 描述:回调函数被调用的次数。
• 用途:可以用来记录训练的进度,例如每1000步保存一次模型。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.n_calls % 1000 == 0:print(f"Step {self.n_calls}")return True3.`self.num_timesteps`
• 描述:当前训练的总时间步数。
• 用途:可以用来记录训练的总进度,例如每10000步保存一次模型。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.num_timesteps % 10000 == 0:print(f"Total timesteps: {self.num_timesteps}")return True4.`self.locals`
• 描述:一个字典,包含了当前训练过程中的局部变量。
• 用途:可以访问和操作训练过程中的各种变量,例如奖励、损失等。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:# 获取当前的奖励current_reward = self.locals.get('rewards', 0)print(f"Current reward: {current_reward}")return True5.`self.globals`
• 描述:一个字典,包含了当前训练过程中的全局变量。
• 用途:可以访问和操作训练过程中的全局变量。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:# 获取全局变量global_var = self.globals.get('some_global_var', None)print(f"Global variable: {global_var}")return True6.`self.logger`
• 描述:日志记录器对象,用于记录训练过程中的日志信息。
• 用途:可以记录日志,例如训练进度、奖励等。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:self.logger.record("train/reward", self.locals.get('rewards', 0))return True7.`self.parent`
• 描述:父回调对象,如果有嵌套回调时使用。
• 用途:可以访问和操作父回调对象的属性和方法。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.parent:print(f"Parent callback: {self.parent}")return True8.`self.verbose`
• 描述:日志详细程度,通常是一个整数(0,1,2)。
• 用途:可以根据`verbose`的值控制日志的详细程度。
class CustomCallback(BaseCallback):def __init__(self, verbose=0):super(CustomCallback, self).__init__(verbose)self.verbose = verbosedef _on_step(self) -> bool:if self.verbose >= 1:print(f"Step {self.n_calls}")if self.verbose >= 2:print(f"Current reward: {self.locals.get('rewards', 0)}")return True9.`self.locals['ep_info_buffer']`
• 描述:一个缓冲区,存储了每个 episode 的信息,包括奖励和长度。
• 用途:可以用来计算平均奖励、平均长度等统计信息。
class CustomCallback(BaseCallback):def _on_rollout_end(self) -> None:# 获取当前的平均奖励current_mean_reward = self.locals['ep_info_buffer'].get_mean_reward()print(f"Rollout ended. Mean reward: {current_mean_reward}")return True10.`self.model.get_parameters()`
• 描述:获取模型的参数。
• 用途:可以用来保存模型的参数,或者在训练过程中动态调整参数。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:# 获取模型的参数params = self.model.get_parameters()print(f"Model parameters: {params}")return True11.`self.model.save()`
• 描述:保存模型到指定路径。
• 用途:可以在训练过程中定期保存模型。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.n_calls % 1000 == 0:self.model.save(f"model_step_{self.n_calls}")return True12.`self.model.set_parameters()`
• 描述:设置模型的参数。
• 用途:可以在训练过程中动态加载预训练的参数。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.n_calls % 1000 == 0:# 加载预训练的参数self.model.set_parameters("pretrained_params.pkl")return True13.`self.model.get_env()`
• 描述:获取模型的环境对象。
• 用途:可以用来访问和操作环境,例如获取环境的状态。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:env = self.model.get_env()print(f"Environment: {env}")return True14.`self.model.learn()`
• 描述:继续训练模型。
• 用途:可以在训练过程中动态调整训练参数,例如学习率。
class CustomCallback(BaseCallback):def _on_step(self) -> bool:if self.n_calls % 1000 == 0:# 动态调整学习率self.model.lr_schedule = lambda x: 0.001self.model.learn(total_timesteps=1000, reset_num_timesteps=False)return True