当前位置: 首页 > news >正文

快速上手Pytorch Lighting框架 | 深度学习入门

快速上手Pytorch Lighting框架 | 深度学习入门

  • 前言
    • 参考官方文档
  • 介绍
  • 快速上手
    • 基本流程
    • 常用接口
      • LightningModule
        • \_\_init\_\_ & setup()
        • \*\_step()
        • configure_callbacks()
        • configure_optimizers()
        • load_from_checkpoint
      • Trainer
        • 常用参数
    • 可选接口
      • Loggers
        • TensorBoard Logger
      • Callbacks
        • EarlyStopping
        • ModelCheckpoint
        • ProgressBar

前言

本文将介绍一个深度学习的训练框架——Pytorch Lighting框架。首先会介绍Pytorch Lighting框架的特点,然后会聚焦于你使用该框架时一定会使用的那些接口,包括我个人学习该框架时的经验传授。

参考官方文档

  • Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 2.5.1.post0 documentation
  • Lightning in 15 minutes — PyTorch Lightning 2.5.1.post0 documentation
  • How to Organize PyTorch Into Lightning — PyTorch Lightning 2.5.1.post0 documentation

介绍

Pytorch Lightning是一个基于Pytorch的深度学习与机器学习的框架,它进一步封装Pytorch的接口,简化了深度学习训练代码的搭建过程,帮助用户能够关注于模型本身,而不需要再反复书写重复的训练代码。

Pytorch Lighting框架本质是对Pytorch的进一步封装,所以如果熟悉Pytorch框架,那么很容易上手Pytorch Lighting。结合官方文档以及个人使用体验,相比Pytorch,我认为Pytorch Lightning具有以下特点:

  • 代码复用性:Pytorch Lightning提供训练流程的所有接口,可以通过继承的方式,准备训练不同阶段的组件,从而在相似任务之间使用同一份代码。
  • 代码可读性:原本的Pytorch代码被进一步封装到框架中,让代码的聚合程度更高,训练流程更清晰,提高了代码的可读性。
  • 灵活性:通过框架类方法,可以根据需求定制特定环节的计算逻辑,精细控制训练的每个细节。
  • 可移植性:Pytorch Lightning的框架添加了自动检测训练设备的功能,同一份代码可以不仅在本地的CPU上训练,也可以通过远程服务器使用多GPU训练。
  • 自动化:框架集成了一些训练会用到的工具,比如日志输出、检查点记录等等。

更多详细内容,可以查阅官方文档介绍!

快速上手

基本流程

使用Lighting框架训练一个深度学习模型,遵循以下的流程:

  1. 安装Pytorch Lighting
  2. 定义Pytorch Lighting模块
  3. 定义数据集(生成样本迭代器)
  4. 配置训练器,训练模型
  5. 使用模型:包括测试模型或使用模型预测…
  6. 可视化训练过程

常用接口

上一小节,简单介绍了使用Pytorch Lighting框架的流程。其本质和普通的机器学习训练流程是一致的,如果只是简单的使用PL框架,几乎可以不输入多余的参数,就能直接开始训练,PL会帮助你完成大量的任务。同时框架提供了训练流程中每一步的对应接口,让用户可以根据需求,修改不同的细节。本小节中将具体介绍这些重要的接口,主要对应上述流程的第2步、第4步及第5步。

对于第3步,PL训练时需要迭代器类型的输入,可以手动生成样本迭代器,也可以使用Pytorch中的Dataloader等,此处将不再展开。

LightningModule

LightningModule是框架的核心部件,该类中提供关于训练的所有核心方法,涵盖6个方面:

  • 模型初始化:init & setup()
  • 训练循环:training_step()
  • 验证循环:validation_step()
  • 测试循环:test_step()
  • 预测循环:predict_step()
  • 优化器及学习率调整
__init__ & setup()

与类的基本使用方法相同,在LightingModule类的构造函数中,需要对类做必要的初始化,比如导入核心模型结构、优化器方法、损失函数类型等等。

setup(_trainer_, _pl_module_, _stage_)

setup()本质是一个回调函数,功能也是对类进行初始化设置,一般用于不同的训练阶段(predict,test,…)。调用该接口可以在不同的阶段采用不同的初始化策略。

*_step()

在不同的阶段的循环步中,可以部署期望的任务结果。除了基本的前馈计算、反向传播等操作,可以添加日志输出、指标收集等等。比如在train阶段,只获取loss指标;在test阶段,同时获取loss指标、acc指标等。

configure_callbacks()

通过重写该方法,可以定制训练所需的回调函数。当模型被调用的时候,比如执行test()的时候,框架会自动调用这些回调函数。
如果与Trainer中的回调函数表有冲突时,框架会优先使用此处的回调函数配置。

configure_optimizers()

该方法下,可以配置训练过程中使用的优化器类型以及具体的学习率。在常规模型的训练中,只会配置一个优化器,那么返回值就是单个优化器。如果是GANs或其他需要多个优化器的模型,支持返回多个迭代器,但是需要手动进行模型优化,即需要配置optimizer_step()方法。

load_from_checkpoint
load_from_checkpoint(_checkpoint_path_, _map_location=None_, _hparams_file=None_, _**kwargs_)

一般在测试阶段会需要调用该函数,用一个已经训练好的模型来初始化LightingModule类。checkpoint_path是训练好的模型的.ckpt文件存储位置,PL框架也支持传入URL,或一个类。

TIPS: 如果构造函数传入超参数,记得在构造函数中调用调用self.save_hyperparameters()。这样框架才会自动保存这些超参数到.ckpt文件中。否则如果训练、测试阶段分开进行时,需要重新导入模型,则需要准备.yaml文件,或超参数列表,才能正确的初始化模型。

Trainer

如果完成了LightningModule的配置,直接实例化一个训练器Trainer,便可以直接开始训练,默认生成的Train可以自动的帮助你完成所有训练任务:

model = MyLightningModule()trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)

模型完成训练后,单独调用test()、validate()方法,对模型进行测试、验证。如果有特殊的训练、测试、验证需求,可以在实例化Trainer的时候进行配置。

常用参数
  • accelerator & devices::
    该参数是PL框架的特点之一,只需要实例化不同的Trainer就可以实现在不同硬件设备下的训练。也可以不指定参数,框架会自动匹配对应设备完成训练。
accelerator = ["cpu"] ["gpu"] ["tpu"] ["hpu"] ["auto"]
devices = [number of devices] ["auto"]
  • callbacks:: 传入单个回调类或回调列表。当传入的是列表时,框架会自动根据顺序逐个调用回调类。如果在PL框架中重写了configure_callbacks()方法,则以框架中的回调类优先。
  • max_epochs:: 最大的训练周期。
  • enable_progress_bar:: 是否显示进度条,默认将会为True。
  • logger:: 传入一个Loggers的实例,默认会使用TensorBoard Logger。设置为False则会禁用日志功能。
  • log_every_n_steps:: 日志记录的步长
  • strategy:: 训练策略,如ddp, fsdp等。
  • limit_train_batches:: 限制训练时的batch数量,一般在调试时使用。传入一个数字,当数字小于1时按比例计算【0.25,则使用Dataloader总数的25%的batch】;当数字大于1时按个数计算【5,则使用5个batch】

可选接口

Loggers

在PL中,继承自基类Logger有多种log格式可选,比如MLflow Logger,CSV logger,TensorBoard Logger等等。可以根据自己的需要,使用不同的日志记录形式。此处着重介绍TensorBoard Logger。

TensorBoard Logger

调用该类,日志将会以tensorboard格式进行记录,训练结束后可以可视化看到训练过程。

TensorBoardLogger(_save_dir_, _name='lightning_logs'_, _version=None_, _log_graph=False_, _default_hp_metric=True_, _prefix=''_, _sub_dir=None_, _**kwargs_)

重要的参数是save_dir,name,version。因为这将决定日志的保存位置:save_dir/name/version。在不同的训练阶段可以实例化不同的logger,就可以将不同的阶段的日志放置在不同路径,方便分析研究。

构建好Logger的实例后,作为参数传入到Trainer中即可,以下是官方文档中的例子:

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLoggerlogger = TensorBoardLogger("tb_logs", name="my_model")
trainer = Trainer(logger=logger)

Callbacks

EarlyStopping

通过该类配置训练早停的策略。

EarlyStopping(_monitor_, _min_delta=0.0_, _patience=3_, _verbose=False_, _mode='min'_, _strict=True_, _check_finite=True_, _stopping_threshold=None_, _divergence_threshold=None_, _check_on_train_epoch_end=None_, _log_rank_zero_only=False_)
  • monitor:: 监视指标。
  • patience:: 传入一个整数n。默认情况下,每个epoch后都会检查指标的数值,当指标n次检查都一样时会触发早停。
  • mode:: 可选max或min模式:max模式下,指标不再增长时会触发早停;min模式下,指标不再下降时会触发早停。
ModelCheckpoint

通过该类配置模型保存的保存策略。

ModelCheckpoint(_dirpath=None_, _filename=None_, _monitor=None_, _verbose=False_, _save_last=None_, _save_top_k=1_, _save_weights_only=False_, _mode='min'_, _auto_insert_metric_name=True_, _every_n_train_steps=None_, _train_time_interval=None_, _every_n_epochs=None_, _save_on_train_epoch_end=None_, _enable_version_counter=True_)
  • dirpath & filename:: 模型文件将存储为dirpath/filename。
  • monitor:: 评价指标,需要搭配save_top_k选项一起使用。
  • save_top_k:: 传入一个整数n,指定保存模型的数量。
    1. n为0,不会保存模型。
    2. n为-1,会保存所有检查点时的模型。
    3. n大于2,模型会保存指标最好的n个模型。
ProgressBar

通过继承该类,重写成员方法,以按需求定制进度条的形式。

  • get_metrics :: 可以从基类获得所有指标,然后返回想要显示的指标的字典
  • print:: 定制进度条的输出样式。原文提到without breaking the progress bar.,应该是要注意输出的方式,比如不能重新刷新屏幕缓冲区?

相关文章:

  • 经济体制1
  • 网络基础入门第6-7集(抓包技术)
  • 含铜废水循环利用体系
  • 【RAG】indexing 中的 Hierarchical Indexing(分层索引)
  • 手写 Vue 源码 === 依赖清理机制详解
  • Arm核的Ubuntu系统上安装Qt
  • 系统网络运维基础:Linux与Windows实践指南(带电子书资料)
  • Qt 通过控件按钮实现hello world + 命名规范(7)
  • 课外活动:简单了解原生测试框架Unittest前置后置的逻辑
  • Vue.js Watch 侦听器:深入理解与应用
  • 低代码云MES、轻量级部署、让智造更简单
  • 【AI入门】CherryStudio入门5:创建知识库,对接Obsidian 笔记
  • 特殊版本,官宣永久免费
  • 摄像头模组AF、OIS模组
  • 308.旅行终点站
  • 援外培训项目冈比亚数字政府能力建设研修班莅临麒麟信安参观考察
  • Linux基础(最常用基本命令)
  • 移动端返回指定页面
  • Linux命令行参数注入详解
  • MacBook M2芯片 Sequoia15.4.1 安装免费版VMware Fusion 13.6.3版本
  • 新修订的《婚姻登记条例》明起施行,领证不用户口本了
  • 山东14家城商行中,仅剩枣庄银行年营业收入不足10亿
  • 上海:企业招用高校毕业生可享受1500元/人一次性扩岗补助
  • 江苏省人社厅党组书记、厅长王斌接受审查调查
  • 国家税务总局泰安市税务局:山东泰山啤酒公司欠税超536万元
  • 泽连斯基称与特朗普通话讨论停火事宜