使用 PyTorch 和 SwanLab 实时可视化模型训练
以下是一个使用 PyTorch 和 SwanLab 实现训练可视化监控的完整示例,以 MNIST 手写数字识别为例:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import swanlab# 初始化 SwanLab 实验 (自动生成仪表盘)
swanlab.init(experiment_name="MNIST_CNN",description="Simple CNN on MNIST with SwanLab monitoring",config={"batch_size": 64,"epochs": 10,"learning_rate": 0.01,"model": "CNN"}
)# 1. 数据准备
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)train_loader = DataLoader(train_dataset, batch_size=swanlab.config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)# 2. 定义 CNN 模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout = nn.Dropout(0.25)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout(x)x = self.fc2(x)return nn.functional.log_softmax(x, dim=1)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=swanlab.config.learning_rate)# 3. 训练循环
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = nn.functional.nll_loss(output, target)loss.backward()optimizer.step()# 实时记录每个batch的损失if batch_idx % 100 == 0:swanlab.log({"train_loss": loss.item()}, step=epoch * len(train_loader) + batch_idx)# 打印日志到控制台print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")# 4. 测试函数
def test(epoch):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)# 记录epoch级别的指标swanlab.log({"test_loss": test_loss,"accuracy": accuracy,"epoch": epoch})print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n")# 5. 执行训练
for epoch in range(1, swanlab.config.epochs + 1):train(epoch)test(epoch)print("训练完成!请在 https://swanlab.cn 查看可视化结果")
关键说明:
-
SwanLab 初始化:
swanlab.init() # 创建实验并设置跟踪参数
-
实时日志记录:
swanlab.log({"train_loss": loss.item()}) # 记录每个batch的损失
-
指标可视化:
swanlab.log({"accuracy": accuracy, "test_loss": test_loss}) # 记录测试指标
使用步骤:
- 安装依赖:
pip install torch torchvision swanlab
- 运行脚本:
python mnist_example.py
- 查看结果:
- 终端会自动打印监控链接(如:
SwanLab Experiment: https://swanlab.cn/[username]/MNIST_CNN/runs/[run_id]
) - 或在 SwanLab 官网 登录查看
- 终端会自动打印监控链接(如:
仪表盘功能:
-
实时监控:
- 训练损失曲线(每100个batch更新)
- 测试精度/损失曲线(每个epoch更新)
-
实验管理:
- 记录所有超参数(batch_size, lr等)
- 保存实验配置和系统环境
- 对比多次运行结果
-
自动分析:
- 训练过程动态可视化
- 指标变化趋势分析
- 性能指标汇总统计
通过这个示例,你可以实时:
- 监控训练损失下降趋势
- 观察模型在验证集的性能变化
- 分析不同超参数对结果的影响
- 比较多次实验的结果差异
SwanLab 会自动保存所有实验数据,即使训练中断也能恢复可视化结果。