当前位置: 首页 > news >正文

使用 PyTorch 和 SwanLab 实时可视化模型训练

以下是一个使用 PyTorch 和 SwanLab 实现训练可视化监控的完整示例,以 MNIST 手写数字识别为例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import swanlab# 初始化 SwanLab 实验 (自动生成仪表盘)
swanlab.init(experiment_name="MNIST_CNN",description="Simple CNN on MNIST with SwanLab monitoring",config={"batch_size": 64,"epochs": 10,"learning_rate": 0.01,"model": "CNN"}
)# 1. 数据准备
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)train_loader = DataLoader(train_dataset, batch_size=swanlab.config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)# 2. 定义 CNN 模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout = nn.Dropout(0.25)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout(x)x = self.fc2(x)return nn.functional.log_softmax(x, dim=1)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=swanlab.config.learning_rate)# 3. 训练循环
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = nn.functional.nll_loss(output, target)loss.backward()optimizer.step()# 实时记录每个batch的损失if batch_idx % 100 == 0:swanlab.log({"train_loss": loss.item()}, step=epoch * len(train_loader) + batch_idx)# 打印日志到控制台print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")# 4. 测试函数
def test(epoch):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)# 记录epoch级别的指标swanlab.log({"test_loss": test_loss,"accuracy": accuracy,"epoch": epoch})print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n")# 5. 执行训练
for epoch in range(1, swanlab.config.epochs + 1):train(epoch)test(epoch)print("训练完成!请在 https://swanlab.cn 查看可视化结果")

关键说明:

  1. SwanLab 初始化

    swanlab.init() # 创建实验并设置跟踪参数
    
  2. 实时日志记录

    swanlab.log({"train_loss": loss.item()}) # 记录每个batch的损失
    
  3. 指标可视化

    swanlab.log({"accuracy": accuracy, "test_loss": test_loss}) # 记录测试指标
    

使用步骤:

  1. 安装依赖:
pip install torch torchvision swanlab
  1. 运行脚本:
python mnist_example.py
  1. 查看结果:
    • 终端会自动打印监控链接(如:SwanLab Experiment: https://swanlab.cn/[username]/MNIST_CNN/runs/[run_id]
    • 或在 SwanLab 官网 登录查看

仪表盘功能:

  1. 实时监控

    • 训练损失曲线(每100个batch更新)
    • 测试精度/损失曲线(每个epoch更新)
  2. 实验管理

    • 记录所有超参数(batch_size, lr等)
    • 保存实验配置和系统环境
    • 对比多次运行结果
  3. 自动分析

    • 训练过程动态可视化
    • 指标变化趋势分析
    • 性能指标汇总统计

通过这个示例,你可以实时:

  • 监控训练损失下降趋势
  • 观察模型在验证集的性能变化
  • 分析不同超参数对结果的影响
  • 比较多次实验的结果差异

SwanLab 会自动保存所有实验数据,即使训练中断也能恢复可视化结果。

相关文章:

  • 京津冀城市群13城市空间权重0-1矩阵
  • 亚矩阵云手机针对AdMob广告平台怎么进行多账号的广告风控
  • imgui绘制图像(c++)
  • 《单光子成像》第二章 预习2025.6.12
  • 如何在SOLIDWORKS工程图中添加材料明细表?
  • linux共享内存解析
  • ArkUI-X构建Android平台AAR及使用
  • 复现论文报错解决
  • 基于mapreduce的气候分析系统
  • QCoreApplication QApplication
  • vue2项目开发中遇到的小问题
  • vue3集成高德地图绘制轨迹地图
  • 分割任意组织:用于医学图像分割的单样本参考引导免训练自动点提示方法|文献速递-深度学习医疗AI最新文献
  • vanna多表关联的实验
  • 英一真题阅读单词笔记 10年
  • Meta发布V-JEPA 2世界模型及物理推理新基准,推动AI在物理世界中的认知与规划能力
  • RED DA认证-EN18031网络安全常见问题以及解答
  • supervisorctr命令简介
  • 思科交换机-路由器-配置命令-详细总结
  • Git 清理指南:如何从版本库中移除误提交的文件(保留本地文件)
  • 做英文网站需要多少/网络营销推广工作内容
  • 遂宁市住房与城乡建设厅网站/网络营销有什么
  • 做网站网页的工作怎么样/在线代理浏览网址
  • 一个公司做两个网站可以吗/广告投放网站平台
  • 公司的网站建设 交给谁做更好些/千锋教育学费
  • 美食网站怎样做蛋挞/百度一下网址是多少