当前位置: 首页 > news >正文

Python Day44

Task:
1.预训练的概念
2.常见的分类预训练模型
3.图像预训练模型的发展史
4.预训练的策略
5.预训练代码实战:resnet18


1. 预训练的概念

预训练(Pre-training)是指在大规模数据集上,先训练模型以学习通用的特征表示,然后将其用于特定任务的微调。这种方法可以显著提高模型在目标任务上的性能,减少训练时间和所需数据量。

核心思想:

  • 在大规模、通用的数据(如ImageNet)上训练模型,学习丰富的特征表示。
  • 将预训练模型应用于任务特定的细调(fine-tuning),使模型适应目标任务。

优势:

  • 提升模型性能
  • 缩短训练时间
  • 需要较少的标注数据
  • 提供良好的特征初始化

2. 常见的分类预训练模型

常见的分类预训练模型主要包括:

模型名称提出年份特色与应用
AlexNet2012标志深度学习重返计算机视觉的起点
VGG(VGG16/19)2014简洁结构,深层网络,广泛用于特征提取
ResNet(Residual Network)2015引入残差连接,解决深层网络退化问题
Inception(GoogLeNet)2014多尺度特征提取,复杂模块设计
DenseNet2017密集连接,加深网络而不增加参数
MobileNet2017轻量级模型,适合移动端应用
EfficientNet2019根据模型宽度、深度和分辨率优化设计

这些模型在ImageNet等大规模数据集上预训练,成为计算机视觉各种任务的基础。


3. 图像预训练模型的发展史

  1. AlexNet (2012)
    首次使用深度卷积神经网络大规模应用于ImageNet,显著提升分类效果。

  2. VGG系列 (2014)
    简单堆叠卷积和池化层,深度逐步增加,提高表现。

  3. GoogLeNet/Inception (2014)
    引入Inception模块,进行多尺度特征提取,有效提升效率。

  4. ResNet (2015)
    通过残差连接解决深层网络的退化问题,使网络深度大幅提升(如ResNet-50,ResNet-101等)。

  5. DenseNet (2017)
    特色是密集连接,增强特征传播,改善梯度流。

  6. MobileNet, EfficientNet (2017-2019)
    追求轻量级和高效率,适应移动端和资源有限场景。

总的趋势:

  • 从浅层逐步向深层网络发展
  • 引入残差、密集连接等结构解决深层网络训练难题
  • 注重模型效率与性能平衡

4. 预训练的策略

常用的预训练策略包括:

1. 直接使用预训练模型进行微调(Fine-tuning)

  • 加载预训练权重
  • 替换最后的分类层以适应新任务(如类别数不同)
  • 选择性冻结部分层(如只训练最后几层)或全部训练

2. 特征提取(Feature Extraction)

  • 使用预训练模型的固定特征提取器,从中提取特征
  • 在这些特征基础上训练简单的分类器(如SVM或线性层)

3. 逐层逐步微调(Layer-wise Fine-tuning)

  • 先冻结底层特征层,只训练高层
  • 再逐步解冻低层,进行全层微调

4. 迁移学习(Transfer Learning)

  • 利用预训练模型迁移到相似领域任务中
  • 通过微调适应不同数据分布和任务需求

5. 预训练代码实战:ResNet18

以下是基于PyTorch框架的ResNet18预训练模型加载和微调的示例代码:

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader# 1. 加载预训练ResNet18模型
model = models.resnet18(pretrained=True)# 2. 替换分类层以适应新任务(比如有10个类别)
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)# 3. 冻结前面层,只训练最后的全连接层(可选)
for param in model.parameters():param.requires_grad = False  # 冻结所有参数# 只训练最后一层参数
for param in model.fc.parameters():param.requires_grad = True# 4. 定义数据变换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])# 5. 加载数据集
train_dataset = ImageFolder('path_to_train_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_dataset = ImageFolder('path_to_val_data', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)# 6. 设置优化器(只优化可训练参数)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()# 7. 训练环节
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)for epoch in range(10):model.train()total_loss = 0for images, labels in train_loader:images = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")# 8. 评估
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in val_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()total += labels.size(0)
print(f'Validation Accuracy: {100 * correct / total:.2f}%')

总结

  • 预训练是一种利用大规模数据学习通用特征,从而在目标任务中快速获得优秀表现的技术。
  • 常用的分类预训练模型包括ResNet、VGG、Inception等,发展经历了从浅层到深层、从视觉到效率的不断演变。
  • 预训练策略多样,适应不同场景,微调与特征提取是常用手段。
  • 实战中,可以利用PyTorch提供的模型接口快速加载预训练模型,并进行微调以满足具体需求。

相关文章:

  • 南京市建委网站下载中心建设工程招标襄阳seo
  • 赣州人才网招聘找工作长沙网站seo收费
  • 烟台 网站建设律师推广网站排名
  • 网站开发作用怎样做竞价推广
  • 淘宝网站上的图片是怎么做的优化设计四年级上册数学答案
  • wordpress建站图片效果百度公司招聘2022年最新招聘
  • 智慧园区数字孪生全链交付方案:降本增效30%,多案例实践驱动全周期交付
  • SQL进阶之旅 Day 16:特定数据库引擎高级特性
  • 华为OD最新机试真题-小明减肥-OD统一考试(B卷)
  • 华为OD最新机试真题-数组组成的最小数字-OD统一考试(B卷)
  • Python训练营---Day44
  • 今日科技热点速览
  • Android协程学习
  • 消息的幂等性
  • RAID磁盘阵列
  • Kafka存储机制核心优势剖析
  • 作为过来人,浅谈一下高考、考研、读博
  • 26考研 | 王道 | 计算机组成原理 | 四、指令系统
  • 如何搭建自动化测试框架?
  • 【leetcode】347. 前k个高频元素
  • 通过BUG(prvIdleTask、pxTasksWaitingTerminatio不断跳转问题)了解空闲函数(prvIdleTask)和TCB
  • 机器学习实验八--基于pca的人脸识别
  • LeetCode-70. 爬楼梯
  • 中国西部逐日1 km全天候地表温度数据集(TRIMS LST-TP;2000-2024)
  • GIC流协议接口
  • c++ Base58编码解码