当前位置: 首页 > news >正文

深度学习Save Best、Early Stop

一、Save Best

今天的大模型,在训练过程中可能会终止,但是模型其实是可以接着练的,假设GPU挂了,可以接着训练,在原有的权重上,训练其实就是更新w,如果前面对w进行了存档,那么可以从存档的比较优秀的地方进行训练。

下面代码默认每500步保存权重,第二个参数是选择保存最佳权重

class SaveCheckpointsCallback:
    def __init__(self, save_dir, save_step=500, save_best_only=True):
        """
        Save checkpoints each save_epoch epoch. 
        We save checkpoint by epoch in this implementation.
        Usually, training scripts with pytorch evaluating model and save checkpoint by step.

        Args:
            save_dir (str): dir to save checkpoint
            save_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.
            save_best_only (bool, optional): If True, only save the best model or save each model at every epoch.
        """
        self.save_dir = save_dir # 保存路径
        self.save_step = save_step # 保存步数
        self.save_best_only = save_best_only # 是否只保存最好的模型
        self.best_metrics = -1 # 最好的指标,指标不可能为负数,所以初始化为-1
        
        # mkdir
        if not os.path.exists(self.save_dir): # 如果不存在保存路径,则创建
            os.mkdir(self.save_dir)
        
    def __call__(self, step, state_dict, metric=None):
        if step % self.save_step > 0: #每隔save_step步保存一次
            return
        
        if self.save_best_only:
            assert metric is not None # 必须传入metric
            if metric >= self.best_metrics:
                # save checkpoints
                torch.save(state_dict, os.path.join(self.save_dir, "best.ckpt")) # 保存最好的模型,覆盖之前的模型,不保存step,只保存state_dict,即模型参数,不保存优化器参数
                # update best metrics
                self.best_metrics = metric
        else:
            torch.save(state_dict, os.path.join(self.save_dir, f"{step}.ckpt")) # 保存每个step的模型,不覆盖之前的模型,保存step,保存state_dict,即模型参数,不保存优化器参数

二、Early Stop

如果训练着验证集的准确率开始下降或者损失上升,就需要用到早停:

class EarlyStopCallback:
    def __init__(self, patience=5, min_delta=0.01):
        """

        Args:
            patience (int, optional): Number of epochs with no improvement after which training will be stopped.. Defaults to 5.
            min_delta (float, optional): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute 
                change of less than min_delta, will count as no improvement. Defaults to 0.01.
        """
        self.patience = patience # 多少个step没有提升就停止训练
        self.min_delta = min_delta # 最小的提升幅度
        self.best_metric = -1
        self.counter = 0 # 计数器,记录多少个step没有提升
        
    def __call__(self, metric):
        if metric >= self.best_metric + self.min_delta:#用准确率
            # update best metric
            self.best_metric = metric
            # reset counter 
            self.counter = 0
        else: 
            self.counter += 1 # 计数器加1,下面的patience判断用到
            
    @property #使用@property装饰器,使得 对象.early_stop可以调用,不需要()
    def early_stop(self):
        return self.counter >= self.patience

三、Tensorboard

# TensorBoard 可视化

pip install tensorboard
训练过程中可以使用如下命令启动tensorboard服务。注意使用绝对路径,否则会报错

```shell
 tensorboard  --logdir="D:\PycharmProjects\pythondl\chapter_2_torch\runs" --host 0.0.0.0 --port 8848
```

相关文章:

  • Quadrotor-NMPC-Control 开源项目复现与问题记录
  • 03.06 QT
  • ComfyUI进阶教程核心要点与详解
  • 多模态模型在做选择题时,如何设置Prompt,如何精准定位我们需要的选项
  • 【Kubernetes 指南】基础入门——Kubernetes 基本概念(四)
  • Python在DevOps中的应用:自动化CI/CD管道的实现
  • 【电控笔记z29】扰动估测器DOB估测惯量J-摩擦系数B
  • 私有云基础架构与运维(一)
  • Mybatis中的设计模式
  • SpringBoot+Vue 多模块(子父工程)项目的注册登录及增删改查
  • 软件工程画图题
  • leetcode202 快乐数 哈希结构 集合
  • Ubuntu 安装docker docker-compose
  • 颠覆传统软件测试!Browser Use WebUI+DeepSeek:软件测试行业的革命性突破
  • 深入剖析Android Service:原理、生命周期与实战应用
  • Python中判断静态方法的六种方式
  • 物联网系统搭建
  • 【橘子golang】从golang来谈闭包
  • 【五.LangChain技术与应用】【29.LangChain Agent小案例1:智能代理的实战应用】
  • 6. 机器人实现远程遥控(具身智能机器人套件)
  • 湖南小企业网站建设怎么做/黑龙江最新疫情通报
  • 南京企业网站开发/澳门seo关键词排名
  • 古典家具公司网站模板/友情链接对网站的作用
  • 可以免费创建网站的软件/电商运营主要负责什么
  • 网页设计实验报告分析/aso关键字优化
  • 武汉手机网站建设咨询/自动外链工具