当前位置: 首页 > news >正文

打卡Day45

使用PyTorch在CIFAR10数据集上微调ResNet18,并用TensorBoard监控训练过程

1. 环境准备

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import numpy as np
import os

2. 数据预处理与加载

# 数据增强和归一化(使用ImageNet统计量)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)

3. 模型准备(ResNet18微调)

# 加载预训练模型并修改
model = torchvision.models.resnet18(pretrained=True)# 修改第一层适配32x32输入(原始为224x224)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()  # 移除初始maxpool# 修改最后的全连接层(CIFAR10有10类)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)# 移动到GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

4. 训练配置

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 创建TensorBoard writer
writer = SummaryWriter('runs/resnet18_cifar10_finetune')

5. 训练循环(集成TensorBoard日志)

def train(epoch):model.train()train_loss = 0correct = 0total = 0for batch_idx, (inputs, targets) in enumerate(train_loader):inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录batch级数据if batch_idx % 100 == 0:writer.add_scalar('Training/Loss (batch)', loss.item(), epoch * len(train_loader) + batch_idx)writer.add_scalar('Training/Accuracy (batch)', 100. * correct / total, epoch * len(train_loader) + batch_idx)# 记录epoch级数据avg_loss = train_loss / len(train_loader)acc = 100. * correct / totalwriter.add_scalar('Training/Loss (epoch)', avg_loss, epoch)writer.add_scalar('Training/Accuracy (epoch)', acc, epoch)print(f'Epoch: {epoch} | Train Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_lossdef test(epoch):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(test_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录验证结果avg_loss = test_loss / len(test_loader)acc = 100. * correct / totalwriter.add_scalar('Validation/Loss', avg_loss, epoch)writer.add_scalar('Validation/Accuracy', acc, epoch)# 记录学习率writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)print(f'Test Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_loss# 主训练循环
for epoch in range(100):train_acc, train_loss = train(epoch)test_acc, test_loss = test(epoch)scheduler.step()# 保存最佳模型if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), 'best_model.pth')writer.close()

相关文章:

  • 渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
  • 2025年渗透测试面试题总结-ali 春招内推电话1面(题目+回答)
  • RKNN3588上部署 RTDETRV2
  • 全球IP归属地查询接口如何用C#进行调用?
  • 使用SSH tunnel访问内网的MySQL
  • 【JS进阶】ES5 实现继承的几种方式
  • python项目如何创建docker环境
  • OpenCV 图像像素的逻辑操作
  • React Hooks 指南:何时使用 useEffect ?
  • OPenCV CUDA模块目标检测----- HOG 特征提取和目标检测类cv::cuda::HOG
  • 阿里云域名怎么绑定
  • 概述侧边导航的作用与价值
  • 结合Jenkins、Docker和Kubernetes等主流工具,部署Spring Boot自动化实战指南
  • 06.最长连续序列
  • (头歌作业)-6.5 幻方(project)
  • Web后端基础(Maven基础)
  • 8.axios Http网络请求库(1)
  • 源码编译 Cas Server 4/5/6/7
  • 无人机军用与民用技术对比分析
  • App使用webview套壳引入h5(三)——解决打包为app后在安卓机可物理返回但是在苹果手机无法测滑返回的问题
  • 怎么做猫的静态网站/网站优化推广外包
  • 网站权重值在较长时间内是一定的页面优化/福州百度网站排名优化
  • 网站不备案做优化/推广app大全
  • 昆山做网站的公司/免费发帖的平台有哪些
  • 大连哪家公司做网站比较好/新开传奇网站
  • 做招聘网站多少钱/最新足球赛事