当前位置: 首页 > news >正文

PyTorch Lightning(训练评估框架)

PyTorch Lightning

  • 1、教程
  • 2、TensorBoardLogger和SummaryWriter
    • 1. **SummaryWriter**
    • 2. **TensorBoardLogger**
    • 3. 区别对比表
  • 3、 汇总
    • 1. PyTorch Lightning + TensorBoard + ModelCheckpoint + EarlyStopping
      • 核心代码示例:
    • 2. TensorFlow / Keras + TensorBoard + ModelCheckpoint + EarlyStopping
    • 3. Stable Baselines3 (强化学习)
    • 4. Huggingface Trainer(NLP)
    • 5. 结合Weights & Biases(W\&B)
    • 总结推荐

1、教程

官方:
https://lightning.ai/docs/pytorch/stable/
https://lightning.ai/docs/overview/getting-started

PyTorch Lightning:
在 GPU、TPU 等设备上对 AI 模型进行微调和预训练。专注于科学,而非工程。

其他:
https://blog.51cto.com/u_16175490/13322417
https://github.com/3017218062/Pytorch-Lightning-Learning/tree/master
https://evernorif.github.io/2024/01/19/Pytorch-Lightning%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8/

训练:

  • 超参数,
    https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html

  • 可视化,

  • https://lightning.ai/docs/pytorch/stable/visualize/logging_basic.html
    https://lightning.ai/docs/pytorch/stable/visualize/logging_intermediate.html
    https://lightning.ai/docs/pytorch/stable/levels/intermediate_level_10.html
    Tensorboard集成:https://lightning.ai/docs/pytorch/stable/visualize/logging_intermediate.html

  • 回调,
    https://lightning.ai/docs/pytorch/stable/levels/advanced_level_16.html

评估:
https://lightning.ai/docs/pytorch/stable/common/evaluation_basic.html
https://lightning.ai/docs/torchmetrics/stable/

数据:
https://lightning.ai/docs/pytorch/stable/levels/intermediate_level_9.html

模型:
https://lightning.ai/docs/pytorch/stable/levels/advanced_level_17.html

注意:PyTorch Lightning:不追求训练速度,建议用tensorboard替代它,PyTorch Lightning坑有点多(相对复杂)

2、TensorBoardLogger和SummaryWriter

SummaryWriter和TensorBoardLogger

参考:SummaryWriter
https://pytorch.ac.cn/tutorials/intermediate/tensorboard_tutorial.html
https://pytorch.ac.cn/tutorials/beginner/introyt/tensorboardyt_tutorial.html
https://docs.pytorch.ac.cn/docs/stable/tensorboard.html

SummaryWriterTensorBoardLogger 是两个在 PyTorch(尤其是 PyTorch Lightning)中常用的 TensorBoard 日志工具,但它们的定位和使用场景略有不同。


1. SummaryWriter

来源:torch.utils.tensorboard.SummaryWriter(PyTorch 原生提供)
作用:直接向 TensorBoard 写日志,属于底层 API。

特点

  • PyTorch 官方原生工具,不依赖 Lightning。
  • 需要你手动调用 add_scalar()add_histogram()add_graph() 等方法来记录数据。
  • 灵活性很高,但需要自己管理日志目录、日志周期等。
  • 适合 自己实现训练循环 时使用。

常用写法

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter(log_dir="runs/exp1")for step in range(100):loss = 0.1 * stepwriter.add_scalar("Loss/train", loss, step)writer.add_scalar("Accuracy/train", step * 0.01, step)writer.close()

运行:

tensorboard --logdir=runs

2. TensorBoardLogger

来源:pytorch_lightning.loggers.TensorBoardLogger(Lightning 提供)
作用:Lightning 封装的 TensorBoard 日志工具,是高层 API。

特点

  • 专门给 PyTorch Lightning 用的,能自动把训练中的 lossmetrics 等记录到 TensorBoard。
  • 不用手动写 add_scalar(),Lightning 在 trainer.fit() 时会自动调用。
  • 适合 用 Lightning 管理训练流程 的场景。
  • 可以结合 Lightning 的 log() 方法方便地记录。

常用写法

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLoggerlogger = TensorBoardLogger("lightning_logs", name="my_model")trainer = Trainer(max_epochs=10,logger=logger
)trainer.fit(model)

在 Lightning 的 LightningModule 中:

def training_step(self, batch, batch_idx):loss = self(batch)self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)return loss

3. 区别对比表

特性SummaryWriterTensorBoardLogger
所属PyTorch 原生PyTorch Lightning
灵活性高(完全手动记录)中(自动记录 + Lightning API)
是否自动记录训练指标❌ 需要手动调用✅ 自动记录
是否依赖 Lightning
使用场景自己写训练循环用 Lightning 管理训练
API 例子writer.add_scalar()self.log() / logger

总结

  • 如果你用原生 PyTorch 训练 → 用 SummaryWriter,自己控制日志写法。
  • 如果你用 PyTorch Lightning → 用 TensorBoardLogger,自动记录更方便。
  • 其实 TensorBoardLogger 底层就是用 SummaryWriter 实现的,只是 Lightning 帮你做了封装和调用。

如果你愿意,我可以帮你画一个 SummaryWriter vs TensorBoardLogger 的调用流程图,这样能更直观看清它们的关系。你需要吗?

3、 汇总

详细参考下面笔记:
https://blog.csdn.net/weixin_44986037/article/details/150425539?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/151153156?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/151225587?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/150427299?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/148544185?spm=1001.2014.3001.5501

原文链接:https://blog.csdn.net/weixin_44986037/article/details/150425539
建议用针对性框架:
Stable Baselines3 : 强化学习
Huggingface Trainer :NLP及Transformer模型
PyTorch Lightning:不追求训练速度,建议用tensorboard替代它,PyTorch Lightning坑有点多

提问:现成的ai模型训练评估框架,包括tersorboard查看训练数据、筛选模型、评估效果

提问:

简单概述主流现成的AI模型训练+评估框架,包含:训练过程中的数据监控(比如通过 TensorBoard 可视化损失、指标)模型筛选(如自动保存最优模型、早停)评估效果(自动计算指标,方便比较模型)

你想要的是一个现成的AI模型训练+评估框架,包含:

  • 训练过程中的数据监控(比如通过 TensorBoard 可视化损失、指标)
  • 模型筛选(如自动保存最优模型、早停)
  • 评估效果(自动计算指标,方便比较模型)

下面给你推荐几个主流且功能完整的框架/工具组合,覆盖你需求的这几个点:


1. PyTorch Lightning + TensorBoard + ModelCheckpoint + EarlyStopping

  • 训练监控:Lightning内置对TensorBoard的支持,训练时自动记录loss、metric到TensorBoard。
  • 模型筛选:内置 ModelCheckpoint,可自动保存指标最优模型(如val_loss最小、准确率最高)。
  • 早停:内置 EarlyStopping,当验证指标不再提升时自动停止训练。
  • 评估效果:支持训练、验证、测试阶段指标计算和记录。

核心代码示例:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStoppingcheckpoint_callback = ModelCheckpoint(monitor='val_loss',    # 监控验证损失mode='min',           # 取最小值对应模型save_top_k=1,verbose=True,
)early_stop_callback = EarlyStopping(monitor='val_loss',patience=5,verbose=True,mode='min'
)trainer = pl.Trainer(max_epochs=50,callbacks=[checkpoint_callback, early_stop_callback],logger=pl.loggers.TensorBoardLogger('tb_logs/')
)trainer.fit(model, train_dataloader, val_dataloader)

启动后,在终端运行 tensorboard --logdir=tb_logs/ 即可实时查看训练过程。


2. TensorFlow / Keras + TensorBoard + ModelCheckpoint + EarlyStopping

  • TensorFlow内置支持TensorBoard,训练时会自动写入日志。
  • ModelCheckpoint可保存验证指标最优模型。
  • EarlyStopping自动停止训练避免过拟合。

示例:

callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs'),tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True),tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
]model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=50, callbacks=callbacks)

3. Stable Baselines3 (强化学习)

  • 自带 EvalCallback 用于训练过程中自动评估模型表现。
  • CheckpointCallback 用于保存模型。
  • 训练日志可通过 TensorBoard 监控。

示例:

from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallbackeval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',log_path='./logs/', eval_freq=10000, deterministic=True)
checkpoint_callback = CheckpointCallback(save_freq=5000, save_path='./logs/')model.learn(total_timesteps=100000, callback=[eval_callback, checkpoint_callback])

运行 tensorboard --logdir=./logs 可查看指标变化。


4. Huggingface Trainer(NLP)

  • 内置日志记录(支持TensorBoard、WandB等)
  • 自动保存最优模型和早停(EarlyStoppingCallback)
  • 评估指标自动计算和记录

示例:

from transformers import Trainer, TrainingArguments, EarlyStoppingCallbacktraining_args = TrainingArguments(output_dir='./results',evaluation_strategy='epoch',save_strategy='epoch',logging_dir='./logs',logging_steps=10,load_best_model_at_end=True,
)trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)trainer.train()

然后使用TensorBoard查看 ./logs 文件夹。


5. 结合Weights & Biases(W&B)

(即WandB)

  • 支持各种框架(PyTorch、TensorFlow、Lightning、SB3、HF等)
  • 自动记录训练日志、模型、超参
  • 支持在线实时可视化和对比多个实验
  • 支持模型版本管理和筛选

官网:https://wandb.ai/site


总结推荐

框架/工具训练监控(TensorBoard)自动保存最优模型早停评估指标简单易用适用范围
PyTorch LightningPyTorch通用
TensorFlow / KerasTensorFlow/Keras
Stable Baselines3部分强化学习
Huggingface TrainerNLP及Transformer模型
Weights & Biases ((WandB))各类深度学习框架

如果你告诉我你用的是哪种框架,我可以帮你写具体的训练+评估+TensorBoard配置示例,或者帮你搭建完整的训练流水线。你看?


文章转载自:

http://86K9aZVD.tscsd.cn
http://WdWBWK8X.tscsd.cn
http://mR7aN0uk.tscsd.cn
http://q3za9uzx.tscsd.cn
http://Ol2mMlif.tscsd.cn
http://tzzL15V1.tscsd.cn
http://Kl9msVJF.tscsd.cn
http://oN9XBw2T.tscsd.cn
http://YnDOtfWX.tscsd.cn
http://KOvVD9PY.tscsd.cn
http://WCyNhueJ.tscsd.cn
http://AU24XmmZ.tscsd.cn
http://xoKKUxSa.tscsd.cn
http://kPs1OByu.tscsd.cn
http://ZHeRK5wM.tscsd.cn
http://NGc36UVv.tscsd.cn
http://kSBLdLOo.tscsd.cn
http://SaoI8HQF.tscsd.cn
http://P3KzI55C.tscsd.cn
http://yemhvDSd.tscsd.cn
http://PPg72xHW.tscsd.cn
http://xo5qOszM.tscsd.cn
http://UzLPC74A.tscsd.cn
http://WN7EjxWg.tscsd.cn
http://T102nUJz.tscsd.cn
http://6kzSMOyh.tscsd.cn
http://Ztz5S3iB.tscsd.cn
http://43OaFiGF.tscsd.cn
http://5zj61eKC.tscsd.cn
http://xGysgiX0.tscsd.cn
http://www.dtcms.com/a/371804.html

相关文章:

  • Python进程,线程
  • java设计模式二、工厂
  • Claude Code核心功能操作指南
  • Python Mysql
  • Ansible 角色使用指南
  • 【c++】从三个类的设计看软件架构的哲学思考
  • 695章:使用Scrapy框架构建分布式爬虫
  • X448 算法签名验签流程深度解析及代码示例
  • 基于Apache Flink Stateful Functions的事件驱动微服务架构设计与实践指南
  • 算法题(201):传球游戏
  • 【JavaEE】(23) 综合练习--博客系统
  • 第五十四天(SQL注入数据类型参数格式JSONXML编码加密符号闭合复盘报告)
  • Kotlin 协程之 突破 Flow 限制:Channel 与 Flow 的结合之道
  • RabbitMQ 确认机制
  • DrissionPage 优化天猫店铺商品爬虫:现代化网页抓取技术详解
  • 腾讯云服务器 监控系统 如何查看服务器的并发数量?
  • Qt---对话框QDialog
  • 5G NR-NTN协议学习系列:NR-NTN介绍(1)
  • 9.7需求
  • 43. 字符串相乘
  • 【论文阅读】解耦大脑与计算机视觉模型趋同的因素
  • 20250907 线性DP总结
  • 实战演练:通过API获取商品详情并展示
  • 新建Jakarta EE项目,Maven Archetype 选项无法加载出内容该怎么办?
  • 单层石墨烯及其工业化制备技术
  • 监控系统|实验
  • Jmeter快速安装配置全指南
  • 深入理解 IP 地址:概念、分类与日常应用
  • 高速公路监控录像车辆类型检测识别数据集:8类,6k+图像,yolo标注
  • 现代C++(C++17/20)特性详解