当前位置: 首页 > news >正文

8.18 打卡 DAY 45 Tensorboard使用介绍

DAY 45: TensorBoard——让你的训练过程可视化

告别print,拥抱交互式监控

在之前的学习中,我们为了解训练过程,用了很多“手动”的辅助工具,比如打印loss、使用matplotlib绘制静态的准确率曲线等。这些方法虽然有效,但不够便捷,而且无法实时反馈。

今天,我们将学习一个强大的官方可视化工具——TensorBoard。它可以将枯燥的训练日志转化为直观、可交互的网页图表,让我们能够实时监控训练过程中的各项指标。这对于理解模型行为、快速定位问题、动态调整策略,甚至是在组会上进行成果汇报,都非常有帮助。

1. TensorBoard是什么?(发展历史与原理)
1.1 简介与发展

简单来说,TensorBoard是TensorFlow生态中官方的可视化工具(现在也已完美兼容PyTorch),就像是给模型训练过程装上了一个“监控仪表盘”。

  • 诞生 (2015年):随TensorFlow一同发布,旨在可视化复杂的训练过程。
  • 发展 (2016-2019年):功能不断丰富,增加了图像/音频可视化、直方图、多任务对比等核心功能。
  • 兼容PyTorch (2019年后):通过torch.utils.tensorboard模块,PyTorch可以无缝使用TensorBoard,使其成为一个通用的深度学习可视化工具。

今天,我们将聚焦于它最经典的几个功能:

  1. 保存和查看模型结构图 (Graph)
  2. 实时绘制训练/验证集的损失和准确率曲线 (Scalars)
  3. 可视化每一层权重的分布变化 (Histograms)
  4. 展示原始图像或模型的预测结果 (Images)
1.2 核心原理

TensorBoard的原理非常直观,可以概括为两步:

  1. 写入日志文件:在PyTorch代码中,我们创建一个SummaryWriter对象。在训练的各个阶段(如每个batch或每个epoch),我们调用writer的方法(如add_scalar, add_image),将需要监控的数据(如损失值、图片、权重分布等)连同一个“时间戳”(global_step)一同写入到一个本地的日志文件(.tfevents文件)中。

  2. 读取并网页展示:在终端启动TensorBoard服务,它会监控指定的日志目录。一旦有新的数据写入,TensorBoard就会读取这些日志文件,并启动一个本地网页服务(通常在http://localhost:6006),将数据动态地渲染成各种图表。

这种“存下来+画出来”的自动化流程,让我们摆脱了手动print和绘图的繁琐,可以随时通过刷新网页来查看最新的训练动态。


2. TensorBoard的常见操作

下面我们来看一下在PyTorch中如何使用TensorBoard的核心API。

2.1 初始化SummaryWriter与日志目录管理

这是使用TensorBoard的第一步,也是最关键的入口。

from torch.utils.tensorboard import SummaryWriter
import oslog_dir = 'runs/cifar10_experiment'# 自动避免日志目录重复,方便对比实验
if os.path.exists(log_dir):i = 1while os.path.exists(f"{log_dir}_{i}"):i += 1log_dir = f"{log_dir}_{i}"# 关键入口,所有的数据都通过这个writer对象写入
writer = SummaryWriter(log_dir) ```
这小段代码确保了每次运行都会创建一个新的日志文件夹(如`..._1`, `..._2`),这样不同实验的结果就不会混在一起,便于后续对比分析。##### **2.2 记录标量 (Scalars)**标量是指单个数值,这是最常用的功能,用于追踪损失、准确率、学习率等指标的变化。```python
# global_step 是x轴坐标,可以是batch数或epoch数
writer.add_scalar('Tag/Metric_Name', value, global_step)# 示例
writer.add_scalar('Train/Loss', epoch_train_loss, epoch)
writer.add_scalar('Test/Accuracy', epoch_test_acc, epoch)
  • Tag: 用于在TensorBoard界面中对图表进行分组,例如用TrainTest来区分训练集和测试集的指标。
2.3 可视化模型结构 (Graph)

这可以让我们清晰地看到网络的层次结构。

# 需要一个样本输入来“追踪”计算图
sample_images, _ = next(iter(train_loader))
writer.add_graph(model, sample_images.to(device))

这将在TensorBoard的GRAPHS选项卡中生成一个可交互的模型结构图。

2.4 可视化图像 (Image)

非常适合检查输入数据或分析模型的错误预测。

import torchvision# 将多张图片拼接成一个网格
img_grid = torchvision.utils.make_grid(images[:8]) # 取前8张# 'Tag'是这张图在TensorBoard中的标题
writer.add_image('Sample Training Images', img_grid)

这将在IMAGES选项卡中显示图片。

2.5 记录直方图 (Histogram)

用于观察模型参数(权重)和梯度的分布情况,是诊断梯度消失/爆炸等问题的利器。

# 通常我们会定期记录(比如每几百个batch)
for name, param in model.named_parameters():writer.add_histogram(f'weights/{name}', param, global_step)if param.grad is not None:writer.add_histogram(f'grads/{name}', param.grad, global_step)
2.6 启动TensorBoard

训练代码运行后,在你的项目文件夹的终端中执行以下命令:

tensorboard --logdir=runs
  • --logdir 参数指向你存放所有实验日志的根目录(在我们的例子中是runs)。

然后打开浏览器访问终端提示的URL(如 http://localhost:6006),即可看到可视化界面。


3. TensorBoard实战 (CIFAR-10)

在课程提供的代码中,我们分别在MLP和CNN模型上集成了TensorBoard。

效果展示 (非常适合拿去组会汇报撑页数!)

  1. SCALARS - 损失与准确率曲线
    可以直观地看到CNN(蓝色线)的损失下降更快,准确率提升更明显,而MLP(红色线)则较早地陷入了瓶颈。

  2. GRAPHS - 模型结构图
    清晰地展示了CNN模型中卷积、池化、全连接层的连接方式。

    3A%2F%2Fi.imgur.com%2FBf1oY3W.png&pos_id=img-lIPUbSiz-1755529766394)

  3. IMAGES - 错误预测样本
    通过观察模型预测错误的图片,可以分析模型的弱点。比如下图中,模型将一张真实的“卡车(truck)”错误地预测为了“汽车(car)”,这说明模型在区分这两种相似类别时还存在困难。

  4. HISTOGRAMS - 权重分布
    可以观察到随着训练进行,权重从初始的随机分布逐渐变得更加规整。


4. 作业与参考答案

作业:对DAY 44的resnet18在CIFAR-10上采用微调策略的训练过程,用TensorBoard进行监控。

参考答案:

这个作业的核心是将DAY 44的阶段式微调代码与今天学习的TensorBoard代码进行结合。

实现思路:

  1. 合并代码:以DAY 44的resnet18微调代码为主体。
  2. 添加TensorBoard初始化:在主函数main()的开头,添加SummaryWriter的初始化代码。
  3. 修改训练函数:在train_with_freeze_schedule函数中,添加writer.add_...系列方法,记录我们关心的指标。
  4. 传入writer对象:在主函数调用训练函数时,将初始化好的writer对象传进去。

关键代码修改点:

1. 主函数 main() 中:

def main():# ... (原有参数设置) ...# 1. 初始化TensorBoard Writerlog_dir = "runs/cifar10_resnet18_finetune"if os.path.exists(log_dir):version = 1while os.path.exists(f"{log_dir}_v{version}"):version += 1log_dir = f"{log_dir}_v{version}"writer = SummaryWriter(log_dir)print(f"TensorBoard 日志目录: {log_dir}")model = create_resnet18(pretrained=True, num_classes=10)# ... (原有优化器、损失函数等定义) ...# 2. 调用训练函数时,传入writerfinal_accuracy = train_with_freeze_schedule(model=model,# ... (其他参数) ...epochs=epochs,freeze_epochs=freeze_epochs,writer=writer  # <--- 新增)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

2. 训练函数 train_with_freeze_schedule 中:

# 1. 修改函数签名,接收writer
def train_with_freeze_schedule(model, ..., epochs, freeze_epochs=5, writer=None):# ...# 2. (可选) 在训练开始前记录模型图和样本图像dataiter = iter(train_loader)images, _ = next(dataiter)writer.add_graph(model, images.to(device))writer.add_image('Sample Images', torchvision.utils.make_grid(images[:8]))for epoch in range(epochs):# ... (原有训练逻辑) ...# 3. 在每个epoch结束时,记录标量数据# 记录训练指标writer.add_scalar('Train/Epoch_Loss', epoch_train_loss, epoch)writer.add_scalar('Train/Epoch_Accuracy', epoch_train_acc, epoch)# 记录测试指标writer.add_scalar('Test/Epoch_Loss', epoch_test_loss, epoch)writer.add_scalar('Test/Epoch_Accuracy', epoch_test_acc, epoch)# 记录当前学习率writer.add_scalar('Train/Learning_Rate', optimizer.param_groups[0]['lr'], epoch)print(f"Epoch {epoch+1} 完成 | ...")# 4. 训练结束后关闭writerwriter.close()# 此时,原有的matplotlib绘图代码可以注释掉或删除了# plot_iter_losses(...)# plot_epoch_metrics(...)return epoch_test_acc

预期在TensorBoard中看到的结果:
启动tensorboard --logdir=runs后,你应该能在一个名为cifar10_resnet18_finetune的实验中看到:

  • SCALARS曲线:准确率曲线会在第5个epoch之后(即解冻后)出现一个显著的跃升,损失曲线则会急剧下降,这清晰地反映了微调策略中“解冻”步骤带来的巨大效果提升。学习率曲线也会在某个时刻(当满足ReduceLROnPlateau的条件时)出现下降。
  • 其他如图形、图像和直方图等,也会被正常记录。

@浙大疏锦行

http://www.dtcms.com/a/338395.html

相关文章:

  • Mysql——前模糊索引失效原因及解决方式
  • 深度强化学习之前:强化学习如何记录策略与价值?
  • Java面试题储备14: 使用aop实现全局日志打印
  • Nodejs学习
  • 【SkyWalking】单节点安装
  • Linux命令大全-rmdir命令
  • Java中的 “128陷阱“
  • vue从入门到精通:轻松搭建第一个vue项目
  • go语言条件语if …else语句
  • rem 响应式布局( rem 详解)
  • 鼠标右键没有“通过VSCode打开文件夹”
  • FreeRTOS【3-1】创建第一个多任务程序复习笔记
  • STM32驱动SG90舵机全解析:从PWM原理到多舵机协同控制
  • Sring框架-IOC篇
  • ​​Java核心知识体系与集合扩容机制深度解析​
  • JavaSE高级-02
  • JDBC的使用
  • 【Python】Python Socket 网络编程详解:从基础到实践​
  • Street Crafter 阅读笔记
  • IDEA创建项目
  • MYSQL中读提交的理解
  • MySQL新手教学
  • lesson41:MySQL数据库进阶实战:视图、函数与存储引擎全解析
  • springBoot启动报错问题汇总
  • OVS:ovn是如何支持组播的?
  • LwIP 核心流程总结
  • wishbone总线
  • thinkphp8:一、环境准备
  • c++26新功能—可观测检查点
  • torch.nn.Conv1d详解