当前位置: 首页 > news >正文

七、CV_模型微调

七、模型微调

1.微调

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型(预训练模型)。——修改预训练模型使他适合你的任务,最重要的是修改输出层
  2. 创建一个新的神经网络模块,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的

  • 当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力
  • 根据当前任务数据集的大小来确定微调的网络层
    • 在数据集较小时,隐藏层的参数可以不进行微调
    • 在数据集较大时,可以将隐藏层划开,里面的参数也可以进行变化

2.热狗识别案例

将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张热狗或者其他事物的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗

import tensorflow as tf
import numpy as np

(1)获取数据集

  • batch_size在读取数据和模型训练的时候均可以进行设置

通过以下方法读取图像文件,该方法以文件夹路径为参数,生成经过图像增强后的结果,并产生batch数据:

flow_from_directory(self,directory, # 目标文件夹路径,对于每一个类对应一个子文件夹,# 该文件夹中任何JPG,PNG,BNP,PPM的图片都可以读取 target_size = (256, 256), # 默认为(256,256),图像将被resize成该尺寸color_mode = 'rgb',classes = None,class_mode = 'categorical',batch_size = 32, # 默认为32shuffle = True, # 是否打乱数据,默认为Trueseed = None,save_to_dir = None)

创建两个tf.keras.preprocessing.image.ImageDataGenerator示例来分别读取训练数据集和测试数据集中的所有图像文件。将训练集图片全部处理为高宽均为224像素的输入。此外,我们对RGB三个颜色通道的数值做标准化。

注意:

class_modelabel 的形状含义
"binary"(32,)二分类(每个样本是 0 或 1)
"categorical"(32, 2)独热编码([1, 0] 或 [0, 1])
"sparse"(32,)多类整数标签(类似 binary)
None无标签仅返回图像,无监督学习时使用

(2)模型构建与训练

实例化预训练数据集(tf.keras.appilcation)------>模型调整(调整输出层,并设置层是否可训练)

  • 我们使用在ImageNet数据集上的预训练模型的ResNet-50作为源模型。这里指定weights = 'imagenet’来自动下载并加载预训练的模型参数。
  • Keras应用程序(keras.applications)是具有预先训练权值的固定框架,该类封装了很多重量级的网络架构

实现时实例化模型架构:

  • 利用tf.keras中的application实现迁移学习
tf.keras.application.ResNet50(include_top = True,  # 是否包含顶层的全连接层(默认为True)weights = 'imagenet', # None代表随机初始化,'imagenet'代表加载在ImageNet上预训练的权重input_tensor = None, # 如果你已经用 tf.keras.Input() 创建了输入层,这里可以传入它;# 一般用于自定义模型结构input_shape = None, # 可选,输入尺寸元组,仅当include_top = False时有效,否则输入形状必须是(224,224,3)(channels_last格式)# 或(3,224,224)(channels_first格式)。它必须为3个输入通道,且高宽必须不小于32pooling = None, # 当 include_top=False 时,是否添加全局池化classes = 1000,**kwargs
)
  • include_top
    • include_top = True, 模型会包含原始 ResNet50 在 ImageNet 上训练的最后三层全连接分类头(avg_poolfc1000 → softmax 输出 1000 类)
    • include_top = False, 就不会包含这些顶层结构,适合迁移学习时接上你自己的分类层。
  • pooling
    • 如果为 None:输出为卷积特征图(feature map),形状类似 (batch, 7, 7, 2048)
    • 'avg':加一层 GlobalAveragePooling2D,输出为 (batch, 2048)
    • 'max':加一层 GlobalMaxPooling2D,输出为 (batch, 2048)
  • classes(输出类别数量)
    • 只有当 **include_top=True** 时有效
    • 用于设置最终全连接层的输出维度。

在该案例中使用resNet50预训练模型架构模型:

# 加载预训练模型
ResNet50 = tf.keras.applications.ResNet50(weights = 'imagenet', input_shape = (224, 224, 3))
# 设置所有层不可训练
for layer in ResNet50.layers:layer.trainable = False# 设置模型
net = tf.keras.models.Squential()
# 预训练模型
net.add(ResNet50)
# 展开
net.add(tf.keras.layers.Flatten())
# 二分类的全连接层
net.add(tf.keras.layers.Dense(2, activation = 'softmax'))

接下来使用之前定义好的ImageGenerator将训练集图片送入ResNet50进行训练

# 模型编译:指定优化器,损失函数,评价指标
net.compile(optimizer = 'adam',loss = 'categorical_crossentropy',metrics = ['accuracy']
)# 模型训练:指定数据,每一个epoch中只运行10个迭代,指定验证数据集
history = net.fit(train_data_gen = True,steps_per_epoch = 10,epochs = 3,validation_data = test_data_gen,  # 验证集validation_step = 10
)
http://www.dtcms.com/a/324536.html

相关文章:

  • SpringBoot学习日记(三)
  • P1152 欢乐的跳
  • 从零开始实现Qwen3(MOE架构)
  • C语言基础05——指针
  • Pinia 状态管理库
  • Redis - 使用 Redis HyperLogLog 进行高效基数统计
  • 无人机集群协同三维路径规划,采用梦境优化算法(DOA)实现,Matlab代码
  • strace的常用案例
  • 基于Qt/QML 5.14和YOLOv8的工业异常检测Demo:冲压点智能识别
  • VSCODE+GDB+QEMU调试内核
  • 为 Prometheus 告警规则增加 UI 管理能力
  • 力扣经典算法篇-47-Pow(x, n)(快速幂思路)
  • 每日算法刷题Day60:8.10:leetcode 队列5道题,用时2h
  • Java Stream流详解:从基础语法到实战应用
  • 安装1panel之后如何通过nginx代理访问
  • Linux系统编程Day11 -- 进程属性和常见进程
  • 智慧社区(十一)——Spring Boot 实现 Excel 导出、上传与数据导入全流程详解
  • Langchain调用MCP服务和工具
  • MySQL的逻辑架构和SQL执行的流程:
  • 正确使用SQL Server中的Hint(10)—Hint简介与Hint分类及语法(1)
  • Spring Boot + SSH 客户端:在浏览器中执行远程命令
  • 深入理解 Java 中的线程池:原理、参数与最佳实践
  • 【密码学】8. 密码协议
  • 金融机构在元宇宙中的业务开展与创新路径
  • 【教学类-29-06】20250809灰色门牌号-黏贴版(6层*5间层2间)题目和答案(剪贴卡片)
  • 使用Python调用OpenAI的function calling源码
  • Pytorch深度学习框架实战教程-番外篇02-Pytorch池化层概念定义、工作原理和作用
  • ROS2 QT 多线程功能包设计
  • PHP项目运行
  • (LeetCode 每日一题) 869. 重新排序得到 2 的幂 (哈希表+枚举)