当前位置: 首页 > news >正文

ResNet 迁移学习---加速深度学习模型训练

一、迁移学习介绍

迁移学习是一种高效的机器学习策略,它利用已在大规模数据集上训练好的模型,针对新任务进行微调。这种方法不仅能大幅加快模型训练速度,还能显著提升模型性能,即便在新任务数据稀缺时,也能有出色表现。

其核心步骤如下:

  1. 选模型与层:挑选在如 ImageNet 等大规模图像数据集上预训练的模型(像 VGG、ResNet 系列),再依据新数据集特点,确定需微调的层。若任务是边缘检测这类低级特征提取,浅层模型层更合适;若是分类这类高级特征任务,则选更深层模型。
  2. 冻预训练参数:固定预训练模型的权重,只训练新增层或微调部分层,防止预训练模型因新数据集数据量少而过拟合。
  3. 训新增层:在预训练模型参数冻结的情况下,训练新增层,让新模型适配新任务,以此提升性能。
  4. 微调预训练层:新增层训练好后,解冻部分已训练的层并将其作为微调对象,进一步提高模型在新数据集上的表现。
  5. 评估与测试:训练完成后,用测试集评估模型。若性能不佳,可调整超参数或更改微调层。

二、项目背景与技术选型

1. 为什么选择迁移学习?

在图像分类任务中,从零开始训练一个深度卷积神经网络需要大量的标注数据和计算资源。而迁移学习(Transfer Learning)通过利用在大规模数据集(如 ImageNet)上预训练好的模型参数,只需少量数据和计算资源就能实现较好的分类效果,特别适合中小型数据集的分类任务。

2. 模型选择:ResNet-18

ResNet(Residual Network)是 2015 年提出的深度残差网络,由微软实验室的何凯明等学者提出,曾斩获当年 ImageNet 竞赛分类任务、目标检测等多项第一名,还在 COCO 数据集中的目标检测、图像分割任务中拔得头筹。

传统卷积神经网络存在诸多问题,它由卷积层和池化层叠加而成,但随着层数增加,会出现梯度消失(每层误差梯度小于 1,反向传播时网络越深梯度越趋近 0)、梯度爆炸(每层误差梯度大于 1,反向传播时网络越深梯度越大)以及退化问题(网络加深后性能不升反降)。而 ResNet 通过残差连接解决了深层网络训练中的梯度消失问题。ResNet-18 作为其中的轻量级模型,拥有 18 层网络结构,在保证分类精度的同时,具有较快的训练和推理速度,非常适合部署在资源有限的环境中。

3. 开发环境

  • 深度学习框架:PyTorch 2.0+
  • 图像处理库:PIL、TorchVision
  • 计算资源:支持 CUDA 的 GPU(推荐)、MPS(Apple Silicon)或 CPU

三、完整代码解析

1. 导入必要库

首先导入项目所需的所有库,包括 PyTorch 核心库、数据加载与预处理库、图像处理库以及预训练模型库。

import torch
from torch.utils.data import DataLoader, Dataset  # 数据加载与数据集定义
from PIL import Image  # 图像读取
from torchvision import transforms  # 图像预处理
import numpy as np  # 数值计算
from torch import nn  # 神经网络模块
from torchvision import models  # 预训练模型库

2. 迁移学习:ResNet-18 模型改造

这一步是迁移学习的核心,我们需要对预训练的 ResNet-18 模型进行微调,使其适应 20 类食物分类任务。

(1)加载预训练模型
# 加载ResNet-18预训练模型(使用ImageNet数据集上的权重)
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
  • weights=models.ResNet18_Weights.DEFAULT:指定使用最新的预训练权重,确保模型性能。
(2)冻结预训练层参数

为了保留预训练模型在大规模数据集上学到的特征提取能力,我们先冻结除最后一层(全连接层)以外的所有参数,只训练自定义的全连接层。

for param in resnet_model.parameters():  # 遍历模型所有参数param.requires_grad = False  # 冻结参数,不计算梯度
(3)修改全连接层

ResNet-18 的默认全连接层输出为 1000 类(对应 ImageNet 的 1000 个类别),我们需要将其修改为 20 类(对应食物分类任务的类别数)。

# 获取原全连接层的输入特征数
in_features = resnet_model.fc.in_features
# 替换全连接层,输出维度为20
resnet_model.fc = nn.Linear(in_features, 20)
(4)指定需要更新的参数

由于我们只训练新的全连接层,需要筛选出requires_grad=True的参数(即新全连接层的参数),用于后续优化器配置。

params_to_update = []  # 存储需要更新的参数
for param in resnet_model.parameters():if param.requires_grad == True:  # 只保留需要梯度更新的参数params_to_update.append(param)

3. 图像预处理:数据增强与标准化

图像预处理是提升模型泛化能力的关键步骤。针对训练集和验证集,我们需要设计不同的预处理策略:

  • 训练集:加入数据增强(旋转、翻转、颜色抖动等),增加数据多样性,防止过拟合。
  • 验证集:仅进行尺寸调整和标准化,确保评估的客观性。
data_transforms = {'train':  # 训练集预处理transforms.Compose([transforms.Resize([300, 300]),  # 调整尺寸为300x300transforms.RandomRotation(45),  # 随机旋转±45度transforms.CenterCrop(224),  # 中心裁剪为224x224(ResNet输入尺寸)transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转(概率0.5)transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转(概率0.5)transforms.ColorJitter(  # 颜色抖动(亮度、对比度、饱和度、色调)brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),transforms.ToTensor(),  # 转换为Tensor(维度:C×H×W,数值归一化到[0,1])# 标准化(使用ImageNet的均值和标准差,与预训练模型一致)transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid':  # 验证集预处理transforms.Compose([transforms.Resize([224, 224]),  # 直接调整为224x224transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

关键说明

  • 标准化使用 ImageNet 的均值和标准差,是因为预训练模型是在该标准化后的图像上训练的,确保输入分布一致。
  • 训练集的随机变换(如旋转、翻转)能有效扩充数据,提升模型对不同角度、光照条件的适应能力。

4. 自定义数据集类:加载食物图像数据

PyTorch 的Dataset类需要自定义实现,用于读取图像路径和标签,并应用预处理。这里假设我们的数据集路径和标签存储在train.txttest.txt中,每行格式为 “图像路径 标签”(如./food/apple/1.jpg 0)。

class food_dataset(Dataset):def __init__(self, file_path, transform=None):"""初始化数据集:param file_path: 存储图像路径和标签的文本文件路径:param transform: 图像预处理函数"""self.file_path = file_pathself.imgs = []  # 存储图像路径self.labels = []  # 存储图像标签self.transform = transform# 读取文本文件,解析图像路径和标签with open(file_path, 'r') as f:# 按行读取,去除空格和换行符,分割路径和标签samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):"""返回数据集总样本数"""return len(self.imgs)def __getitem__(self, idx):"""根据索引获取单个样本(图像+标签)"""# 读取图像(PIL格式)image = Image.open(self.imgs[idx])# 应用预处理if self.transform:image = self.transform(image)# 处理标签(转换为int64类型,适配PyTorch交叉熵损失)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

5. 数据加载器:批量读取数据

DataLoader类用于将Dataset对象转换为批量数据,支持 shuffle(打乱数据)、多线程加载等功能,提升训练效率。

# 初始化训练集和验证集
training_data = food_dataset(file_path='./train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='./test.txt', transform=data_transforms['valid'])# 初始化数据加载器
train_dataloader = DataLoader(training_data, batch_size=64,  # 批量大小(根据GPU内存调整,如32、64)shuffle=True    # 训练集打乱数据,提升泛化能力
)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True    # 验证集可打乱,不影响评估结果
)

6. 设备配置:自动选择计算设备

PyTorch 支持 CPU、CUDA(NVIDIA GPU)和 MPS(Apple Silicon GPU),我们通过代码自动选择最优设备,提升训练速度。

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  # 打印当前使用的设备

7. 训练配置:损失函数、优化器与学习率调度器

(1)模型部署到设备
model = resnet_model.to(device)  # 将模型参数迁移到指定设备
(2)损失函数

使用交叉熵损失(CrossEntropyLoss),适用于多分类任务,且内置了 Softmax 函数,无需手动添加。

loss_fn = nn.CrossEntropyLoss()
(3)优化器

选择 Adam 优化器,对学习率不敏感,收敛速度快,仅优化之前筛选出的params_to_update(即新全连接层参数)。

optimizer = torch.optim.Adam(params_to_update, lr=0.005)  # 初始学习率0.005
(4)学习率调度器

使用ReduceLROnPlateau调度器,当验证集准确率不再提升时,自动降低学习率,帮助模型在后期稳定收敛。

# 基于验证集准确率(max)调整,连续3个epoch无提升则降低学习率(乘以0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.5)

8. 训练函数:模型训练逻辑

训练函数负责模型的前向传播、损失计算、反向传播和参数更新,同时打印训练过程中的损失信息。

def train(dataloader, model, loss_fn, optimizer):model.train()  # 设置模型为训练模式(启用Dropout、BatchNorm更新等)batch_size_num = 1  # 记录当前批次号for X, y in dataloader:  # 遍历所有批次# 将数据迁移到指定设备X, y = X.to(device), y.to(device)# 前向传播:计算模型预测值pred = model(X)# 计算损失loss = loss_fn(pred, y)# 反向传播:清空梯度→计算梯度→更新参数optimizer.zero_grad()  # 清空上一轮梯度loss.backward()  # 反向传播计算梯度optimizer.step()  # 优化器更新参数# 每64个批次打印一次损失loss = loss.item()  # 提取损失值(脱离计算图)if batch_size_num % 64 == 0:print(f"loss: {loss:>7f} [number: {batch_size_num}]")batch_size_num += 1

9. 验证函数:模型评估与最优模型保存

验证函数在每个 epoch 结束后评估模型在验证集上的准确率和平均损失,并保存准确率最高的模型(避免过拟合,保留最优模型)。

best_acc = 0  # 记录最优验证集准确率def test(dataloader, model, loss_fn):global best_acc  # 引用全局变量,更新最优准确率size = len(dataloader.dataset)  # 验证集总样本数num_batches = len(dataloader)  # 验证集总批次数model.eval()  # 设置模型为评估模式(禁用Dropout、固定BatchNorm等)test_loss, correct = 0, 0  # 累计验证损失和正确预测数# 禁用梯度计算(评估阶段无需反向传播,节省内存和时间)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累计损失(每个批次的损失相加)test_loss += loss_fn(pred, y).item()# 累计正确预测数(预测类别与真实类别一致的样本数)correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batches  # 平均损失 = 总损失 / 批次数correct /= size  # 准确率 = 正确数 / 总样本数print(f"Test result:\n Accuracy: {(100 * correct):.2f}%, Avg loss: {test_loss:.4f}")# 保存最优模型(准确率高于当前最优时更新)if correct > best_acc:best_acc = correcttorch.save(model.state_dict(), 'best.pth')  # 保存模型参数到best.pthreturn correct  # 返回当前验证集准确率,用于学习率调度

10. 主训练循环:多轮训练与调度

主循环控制训练的轮数(epochs),每轮先调用train函数训练模型,再调用test函数评估模型,最后根据验证准确率调整学习率。

epochs = 20  # 训练轮数(可根据实际情况调整,如30、50)
acc_s = []  # 存储每轮验证准确率(可选,用于后续可视化)
loss_s = []  # 存储每轮验证损失(可选,用于后续可视化)for t in range(epochs):print(f"Epoch {t + 1}\n...............")# 训练模型train(train_dataloader, model, loss_fn, optimizer)# 评估模型,获取验证准确率val_acc = test(test_dataloader, model, loss_fn)# 根据验证准确率调整学习率scheduler.step(val_acc)print("Done!")  # 训练结束


文章转载自:

http://jaiqM4pa.pqmjs.cn
http://9XiSRjIP.pqmjs.cn
http://A6Xn0JqI.pqmjs.cn
http://GfpXEHMA.pqmjs.cn
http://AZX0MMuz.pqmjs.cn
http://1YOthbv2.pqmjs.cn
http://HFCjdEp4.pqmjs.cn
http://CjjPGb9j.pqmjs.cn
http://4p22DqJw.pqmjs.cn
http://UIgYqPhp.pqmjs.cn
http://7OvIVJIV.pqmjs.cn
http://akespeuo.pqmjs.cn
http://QZtfU5KL.pqmjs.cn
http://oWhXWV4d.pqmjs.cn
http://9ofDiSNe.pqmjs.cn
http://995Ynag5.pqmjs.cn
http://kv4FPJPV.pqmjs.cn
http://MnJXnLji.pqmjs.cn
http://eQ3t6zH0.pqmjs.cn
http://SZAS2byZ.pqmjs.cn
http://djhWhOpX.pqmjs.cn
http://Qe82AQgD.pqmjs.cn
http://eGH256hE.pqmjs.cn
http://PSYG9MHM.pqmjs.cn
http://E1iz0NtU.pqmjs.cn
http://BVWjDIwR.pqmjs.cn
http://Hue5LmWB.pqmjs.cn
http://NMBUs695.pqmjs.cn
http://feQscUfZ.pqmjs.cn
http://MtN0cE5L.pqmjs.cn
http://www.dtcms.com/a/368396.html

相关文章:

  • Django REST framework:SimpleRouter 使用指南
  • Vue3 频率范围输入失焦自动校验实现
  • 删除元素(不是删除而是覆盖)快慢指针 慢指针是覆盖位置,快指针找元素
  • 代码随想录算法训练营第三天| 链表理论基础 203.移除链表元素 707.设计链表 206.反转链表
  • 结合机器学习的Backtrader跨市场交易策略研究
  • 前端开发vscode插件 - live server
  • 码农的“必修课”:深度解析Rust的所有权系统(与C++内存模型对比)
  • 【Python基础】 17 Rust 与 Python 运算符对比学习笔记
  • 云手机可以息屏挂手游吗?
  • 会话管理巅峰对决:Spring Web中Cookie-Session、JWT、Spring Session + Redis深度秘籍
  • 腾讯云大模型训练平台
  • iPhone17全系优缺点分析,加持远程控制让你的手机更好用!
  • 数据泄露危机逼近:五款电脑加密软件为企业筑起安全防线
  • 阿里云vs腾讯云按量付费服务器
  • DocuAI深度测评:自动文档生成工具如何高效产出规范API文档与数据库表结构文档?
  • React JSX 语法讲解
  • 工厂办公环境如何实现一台服务器多人共享办公
  • 从 0 到 1 学 sed 与 awk:Linux 文本处理的两把 “瑞士军刀”
  • VNC连接服务器实现远程桌面-针对官方给的链接已经失效问题
  • 【Web】理解CSS媒体查询
  • 编写前端发布脚本
  • 无密码登录与设备信任:ABP + WebAuthn/FIDO2
  • 消息队列-ubutu22.04环境下安装
  • Vue3源码reactivity响应式篇之EffectScope
  • 从Java全栈到前端框架:一位程序员的实战之路
  • 【Java实战㉖】深入Java单元测试:JUnit 5实战指南
  • 【AI论文】Robix:一种面向机器人交互、推理与规划的统一模型
  • C++(Qt)软件调试---bug排查记录(36)
  • yolov8部署在一台无显卡的电脑上,实时性强方案
  • Alibaba Cloud Linux 3 安装Docker