当前位置: 首页 > news >正文

打卡Day45

使用PyTorch在CIFAR10数据集上微调ResNet18,并用TensorBoard监控训练过程

1. 环境准备

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import numpy as np
import os

2. 数据预处理与加载

# 数据增强和归一化(使用ImageNet统计量)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)

3. 模型准备(ResNet18微调)

# 加载预训练模型并修改
model = torchvision.models.resnet18(pretrained=True)# 修改第一层适配32x32输入(原始为224x224)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()  # 移除初始maxpool# 修改最后的全连接层(CIFAR10有10类)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)# 移动到GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

4. 训练配置

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 创建TensorBoard writer
writer = SummaryWriter('runs/resnet18_cifar10_finetune')

5. 训练循环(集成TensorBoard日志)

def train(epoch):model.train()train_loss = 0correct = 0total = 0for batch_idx, (inputs, targets) in enumerate(train_loader):inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录batch级数据if batch_idx % 100 == 0:writer.add_scalar('Training/Loss (batch)', loss.item(), epoch * len(train_loader) + batch_idx)writer.add_scalar('Training/Accuracy (batch)', 100. * correct / total, epoch * len(train_loader) + batch_idx)# 记录epoch级数据avg_loss = train_loss / len(train_loader)acc = 100. * correct / totalwriter.add_scalar('Training/Loss (epoch)', avg_loss, epoch)writer.add_scalar('Training/Accuracy (epoch)', acc, epoch)print(f'Epoch: {epoch} | Train Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_lossdef test(epoch):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(test_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录验证结果avg_loss = test_loss / len(test_loader)acc = 100. * correct / totalwriter.add_scalar('Validation/Loss', avg_loss, epoch)writer.add_scalar('Validation/Accuracy', acc, epoch)# 记录学习率writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)print(f'Test Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_loss# 主训练循环
for epoch in range(100):train_acc, train_loss = train(epoch)test_acc, test_loss = test(epoch)scheduler.step()# 保存最佳模型if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), 'best_model.pth')writer.close()
http://www.dtcms.com/a/232673.html

相关文章:

  • 渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
  • 2025年渗透测试面试题总结-ali 春招内推电话1面(题目+回答)
  • RKNN3588上部署 RTDETRV2
  • 全球IP归属地查询接口如何用C#进行调用?
  • 使用SSH tunnel访问内网的MySQL
  • 【JS进阶】ES5 实现继承的几种方式
  • python项目如何创建docker环境
  • OpenCV 图像像素的逻辑操作
  • React Hooks 指南:何时使用 useEffect ?
  • OPenCV CUDA模块目标检测----- HOG 特征提取和目标检测类cv::cuda::HOG
  • 阿里云域名怎么绑定
  • 概述侧边导航的作用与价值
  • 结合Jenkins、Docker和Kubernetes等主流工具,部署Spring Boot自动化实战指南
  • 06.最长连续序列
  • (头歌作业)-6.5 幻方(project)
  • Web后端基础(Maven基础)
  • 8.axios Http网络请求库(1)
  • 源码编译 Cas Server 4/5/6/7
  • 无人机军用与民用技术对比分析
  • App使用webview套壳引入h5(三)——解决打包为app后在安卓机可物理返回但是在苹果手机无法测滑返回的问题
  • Uniapp 二维码生成与解析完整教程
  • win32相关(远程线程和远程线程注入)
  • Spring框架学习day7--SpringWeb学习(概念与搭建配置)
  • 深度解析地质灾害风险普查:RS与GIS技术在泥石流、滑坡灾害中的应用,ArcGIS数据管理、空间数据转换、专题地图制作、DEM分析及实战案例分析
  • 实用对比图软件推荐:快速呈现信息差异
  • opencv-4.8.1到 sln
  • Excel数据分析:基础
  • Tensorrt python api 10.11.0笔记
  • 红花UGT鉴定与特征分析-文献精读142
  • 本地部署大模型实战:使用AIStarter一键安装Ollama+OpenWeb教程(含最新版本更新指南)