当前位置: 首页 > news >正文

深度学习——基于 ResNet18 的图像分类训练

PyTorch 基于 ResNet18 的图像分类训练与验证全流程解析


一、项目概述

本文实现了一个基于 PyTorch 框架的图像分类模型,使用 ResNet18 作为预训练骨干网络(Backbone),并在其基础上进行迁移学习(Transfer Learning)。整个流程涵盖了:

  • 数据预处理与增强

  • 自定义 Dataset 与 DataLoader

  • 模型微调与参数冻结

  • 训练与验证循环

  • 学习率调度策略(ReduceLROnPlateau)

该项目的核心目标是利用已有的强大视觉特征提取网络(ResNet18)对新的小规模数据集进行分类任务,从而提升训练效率与模型性能。


二、模型部分解析:ResNet18 微调机制

import torch
import torchvision.models as models
from torch import nn, optimresnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

这里加载了 torchvision.models 中的预训练 ResNet18 模型,其权重参数来自 ImageNet 大规模数据集的训练结果。

接着,冻结网络的所有参数,防止在训练过程中被更新:

for param in resnet_model.parameters():print(param)param.requires_grad = False

原理说明:

  • 迁移学习的关键思路在于“保留特征提取层”。

  • 早期卷积层学习的是通用特征(如边缘、纹理),可直接用于新任务。

  • 仅需微调后几层或分类头层(fc层),可显著减少训练量。

然后替换最后一层全连接层(fc层):

in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)

🔹 原始 ResNet18 的输出为 1000 类(ImageNet)。
🔹 这里改为 20 类,适配自定义数据集。

最后仅选择需要更新的参数:

params_to_update = []
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)

这意味着优化器只会更新新加入的全连接层参数。


三、数据预处理与增强(Data Augmentation)

数据增强可提升模型泛化能力,代码中定义了两种处理策略:

from torchvision import transforms

1. 训练集增强 data_transforms['train']

包含大量随机性增强操作:

transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

📘 数据增强效果:

  • 旋转翻转灰度转换可提升模型在多视角条件下的鲁棒性。

  • 归一化操作确保输入分布与预训练模型保持一致。

2. 验证集预处理 data_transforms['valid']

transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

验证集通常不使用随机增强,以保持结果的可重复性和客观性。


四、自定义 Dataset 与 DataLoader

1. Dataset 类定义

from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as npclass food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.transform = transformself.imgs = []self.labels = []with open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)self.labels.append(label)

该类通过文本文件(如 train2.txt)加载图片路径和标签。

每一行格式为:

image_path label

2. 数据访问接口

def __len__(self):return len(self.imgs)def __getitem__(self, index):image = Image.open(self.imgs[index])if self.transform:image = self.transform(image)label = torch.from_numpy(np.array(self.labels[index], dtype=np.int64))return image, label

⚙️ Dataset 必备方法:

  • __len__():返回数据集大小。

  • __getitem__():返回一条样本及其标签。

3. 数据加载器 DataLoader

training_data = food_dataset('./train2.txt', transform=data_transforms['train'])
test_data = food_dataset('./test2.txt', transform=data_transforms['valid'])train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

✅ DataLoader 的作用:

  • 自动打包 batch

  • 支持多线程加载(num_workers

  • 支持数据打乱(shuffle)


五、训练环境与优化器设置

1. 自动选择设备

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

同时兼容:

  • NVIDIA GPU (CUDA)

  • Apple M1/M2 GPU (MPS)

  • CPU

2. 定义损失函数与优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)

使用 交叉熵损失函数 处理多分类任务,优化器为 Adam

3. 学习率调度器

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=3,verbose=True
)

🔹 当验证集 Loss 连续 3 轮未改善时,学习率减半。
🔹 可有效防止过拟合与梯度振荡。


六、训练与验证循环

1. 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()if batch_size_num % 100 == 0:print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")batch_size_num += 1

🔁 每 100 个 batch 打印一次 loss。

2. 测试函数

def Test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result:\n Accuracy:{100*correct}%, Avg loss:{test_loss}")return test_loss

🧮 评估指标:

  • Accuracy(准确率)

  • Avg loss(平均验证损失)


七、主训练循环

epochs = 50
for t in range(epochs):print(f"---------------\nepoch {t+1}")train(train_dataloader, model, loss_fn, optimizer)val_loss = Test(test_dataloader, model, loss_fn)scheduler.step(val_loss)
print("Done!")

共进行 50 轮训练,每轮包括:

  1. 模型训练

  2. 验证集测试

  3. 根据验证集 loss 调整学习率

💡 随着 epoch 增加,loss 应逐渐下降,accuracy 提升。


八、总结

模块作用特点
ResNet18特征提取主干使用 ImageNet 预训练权重
Dataset读取图片与标签支持 transform 自动增强
DataLoader批量化输入shuffle 提升训练效果
train()前向传播与反向传播更新梯度
Test()模型评估计算平均损失与准确率
ReduceLROnPlateau学习率调整自动降低学习率防止过拟合

http://www.dtcms.com/a/496537.html

相关文章:

  • 西安公司建一个网站需要多少钱广告设计公司合同
  • Linux:11.线程概念与控制
  • 恋家网邯郸房产网站排名优化服务公司
  • 婚纱网站建设需求分析国外模板wordpress
  • 南阳理工网站建设专项培训网站建设方案
  • 便携气象站具备完整的气象观测能力
  • 杭州倍世康 做网站网站怎样制作 优帮云
  • 一级A视网站 一级做爰片网站建设类有哪些岗位
  • 永兴县网站建设服务商什么是网站设计与运营
  • Google Landmarks Dataset v2 (GLDv2):面向实例级识别与检索的500万图像,200k+类别大规模地标识别基准
  • 个人域名做企业网站企业seo的措施有哪些
  • 网站开发验收流程图网站建设合同详细
  • 上海做网站比较好的公司有哪些wordpress两栏响应式主题
  • 【Altium Designer实战操作】对网络端口名称采用全中文命名的可行性及其相关隐患研究
  • 可视化NS-3安装踩坑记录
  • 怎么看别人的网站有没有做301电商网站建设哪家公司好
  • 河北省建设注册中心网站html5网络公司网站模板
  • 4.类和对象(上)
  • 高端手机“探花”之争,AI会成为“胜负手”吗?
  • 福建省住房和城乡建设厅官方网站做网站挂广告赚多少
  • 门户网站直接登录系统wordpress吐槽源码
  • 网站中的flash龙岩做网站怎么做
  • Cucumber + Playwright framework based on javascript
  • 关于电子商务网站建设的论文广告传媒公司介绍
  • 武威建设厅网站建设企业网站进去无法显示
  • 一台网站服务器多少钱网站建设业务员培训
  • 网站建设如何赚钱wordpress主题设置选择
  • SVN 非页面操作 锁定单个cell
  • 布恩网站删除西峡微网站开发
  • 网站后台图片模板wordpress 手机发文