当前位置: 首页 > news >正文

python学习打卡day45

DAY 45 Tensorboard使用介绍

知识点回顾:

  1. tensorboard的发展历史和原理
  2. tensorboard的常见操作
  3. tensorboard在cifar上的实战:MLP和CNN模型

效果展示如下,很适合拿去组会汇报撑页数:

作业:对resnet18在cifar10上采用微调策略下,用tensorboard监控训练过程。

PS:

  1. tensorboard和torch版本存在一定的不兼容性,如果报错请新建环境尝试。
  2. tensorboard的代码还有有一定的记忆量,实际上深度学习的经典代码都是类似于八股文,看多了就习惯了,难度远远小于考研数学等需要思考的内容
  3. 实际上对目前的ai而言,你只需要先完成最简单的demo,然后让他给你加上tensorboard需要打印的部分即可。---核心是弄懂tensorboard可以打印什么信息,以及如何看可视化后的结果,把ai当成记忆大师用到的时候通过它来调取对应的代码即可。
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets, transforms, models
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt
    from torch.utils.tensorboard import SummaryWriter
    import os
    import torchvision
    # 设置中文字体支持
    plt.rcParams["font.family"] = ["SimHei"]
    plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")# 1. 数据预处理(训练集增强,测试集标准化)
    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])# 2. 加载CIFAR-10数据集
    train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
    )test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
    )# 3. 创建数据加载器
    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)log_dir = "runs/cifar10_resnet18_exp"
    if os.path.exists(log_dir):version = 1while os.path.exists(f"{log_dir}_v{version}"):version += 1log_dir = f"{log_dir}_v{version}"
    writer = SummaryWriter(log_dir)# 4. 定义ResNet18模型
    def create_resnet18(pretrained=True, num_classes=10):model = models.resnet18(pretrained=pretrained)# 修改最后一层全连接层in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)return model.to(device)# 5. 冻结/解冻模型层的函数
    def freeze_model(model, freeze=True):"""冻结或解冻模型的卷积层参数"""# 冻结/解冻除fc层外的所有参数for name, param in model.named_parameters():if 'fc' not in name:param.requires_grad = not freeze# 打印冻结状态frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")else:print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")return model# 6. 训练函数(支持阶段式训练)
    def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs轮冻结卷积层,之后解冻所有层进行训练"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []global_step=0# 初始冻结卷积层if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解冻控制:在指定轮次后解冻所有层if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解冻后调整优化器(可选)optimizer.param_groups[0]['lr'] = 1e-4  # 降低学习率防止过拟合model.train()  # 设置为训练模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录Iteration损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计训练指标running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# ======================== TensorBoard 标量记录 ========================# 记录每个 batch 的损失、准确率batch_acc = 100. * correct_train / total_trainwriter.add_scalar('Train/Batch Loss', iter_loss, global_step)writer.add_scalar('Train/Batch Accuracy', batch_acc, global_step)# 记录学习率(可选)writer.add_scalar('Train/Learning Rate', optimizer.param_groups[0]['lr'], global_step)# 每 200 个 batch 记录一次参数直方图if (batch_idx + 1) % 200 == 0:for name, param in model.named_parameters():writer.add_histogram(f'Weights/{name}', param, global_step)if param.grad is not None:writer.add_histogram(f'Gradients/{name}', param.grad, global_step)global_step += 1# 每100批次打印进度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f}")# 计算 epoch 级指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# ======================== TensorBoard  epoch 标量记录 ========================writer.add_scalar('Train/Epoch Loss', epoch_train_loss, epoch)writer.add_scalar('Train/Epoch Accuracy', epoch_train_acc, epoch)# 测试阶段model.eval()correct_test = 0total_test = 0test_loss = 0.0wrong_images = []  # 存储错误预测样本(用于可视化)wrong_labels = []wrong_preds = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()# 收集错误预测样本(用于可视化)wrong_mask = (predicted != target)if wrong_mask.sum() > 0:wrong_batch_images = data[wrong_mask][:8].cpu()  # 最多存8张wrong_batch_labels = target[wrong_mask][:8].cpu()wrong_batch_preds = predicted[wrong_mask][:8].cpu()wrong_images.extend(wrong_batch_images)wrong_labels.extend(wrong_batch_labels)wrong_preds.extend(wrong_batch_preds)epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# ======================== TensorBoard 测试集记录 ========================writer.add_scalar('Test/Epoch Loss', epoch_test_loss, epoch)writer.add_scalar('Test/Epoch Accuracy', epoch_test_acc, epoch)# 记录历史数据train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 可视化错误预测样本if wrong_images:wrong_img_grid = torchvision.utils.make_grid(wrong_images)writer.add_image('错误预测样本', wrong_img_grid, epoch)wrong_text = [f"真实: {classes[wl]}, 预测: {classes[wp]}" for wl, wp in zip(wrong_labels, wrong_preds)]writer.add_text('错误预测标签', '\n'.join(wrong_text), epoch)# 更新学习率调度器scheduler.step(epoch_test_loss)print(f'Epoch {epoch+1}/{epochs} 完成 | 测试准确率: {epoch_test_acc:.2f}%')writer.close()return epoch_test_acc
    # (可选)CIFAR-10 类别名
    classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 主函数:训练模型
    def main():# 参数设置epochs = 30  # 总训练轮次freeze_epochs = 5  # 冻结卷积层的轮次learning_rate = 1e-3  # 初始学习率weight_decay = 1e-4  # 权重衰减# 创建ResNet18模型(加载预训练权重)model = create_resnet18(pretrained=True, num_classes=10)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定义学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)# 开始训练(前5轮冻结卷积层,之后解冻)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")print("训练后执行: tensorboard --logdir=runs 查看可视化")# # 保存模型# torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')# print("模型已保存至: resnet18_cifar10_finetuned.pth")if __name__ == "__main__":main()

部分结果图像展示 

 

@浙大疏锦行 

相关文章:

  • jmeter之导出接口
  • Java建造者模式(Builder Pattern)详解与实践
  • C#学习第28天:内存缓存和对象池化
  • 给网站添加live2d看板娘
  • 终极数据结构详解:从理论到实践
  • 我用AI降低AI率:一次“用魔法打败魔法”的实验
  • 装载机防撞系统:智能守护,筑牢作业现场人员安全防线
  • unity Addressable的基本使用
  • 今日行情明日机会——20250606
  • 5.3 Spring Boot整合JPA
  • 手机如何防止ip关联?3种低成本方案
  • 淘酒屋(香港)控股公司副总裁郭钢出席2025西班牙葡萄酒行业论坛
  • 从混乱到秩序:探索管理系统如何彻底改变工作流程
  • 区块链技术相关
  • Three.js中AR实现详解并详细介绍基于图像标记模式AR生成的详细步骤
  • 二叉树基础全解:存储方式、遍历原理与查找树对比
  • MySQL用户和授权
  • Verilog编程技巧01——如何编写三段式状态机
  • C++ 设计模式 《小明的奶茶加料风波》
  • gitlab CI/CD本地部署配置
  • html5的网站设计/西安seo排名扣费
  • 泰州泛亚信息做网站怎么样/深圳关键词推广优化
  • 廊坊网站搜索优化/外链吧怎么使用
  • 南宁电商网络推广/优化seo哪家好
  • 零遁nas做网站/网络营销策划书封面
  • 做企业网站专用词/怎么样做推广最有效