当前位置: 首页 > news >正文

【人工智能】Python中的迁移学习:使用预训练模型进行分类任务

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门!

解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界

目录

  1. 迁移学习概述
  2. 环境准备与数据预处理
  3. 使用Keras实现迁移学习
  4. 使用PyTorch实现迁移学习
  5. 模型评估与结果分析
  6. 迁移学习技巧与最佳实践
  7. 应用场景与总结

1. 迁移学习概述

迁移学习(Transfer Learning)是机器学习中的一种技术,通过将在一个任务上训练好的模型参数迁移到另一个相关任务中,从而加速模型训练过程并提升模型性能。在计算机视觉领域,常用的预训练模型(如VGG16、ResNet、Inception等)已经在ImageNet数据集上经过充分训练,可以直接用于特征提取或微调(Fine-tuning)。

迁移学习的优势:

  • 节省训练时间:预训练模型已学习通用特征
  • 降低数据需求:适合小样本场景
  • 提升模型性能:利用已有知识提升新任务表现

典型应用场景:

  • 医学影像分类
  • 卫星图像识别
  • 工业缺陷检测
  • 自然场景物体识别

2. 环境准备与数据预处理

2.1 环境配置

# 安装必要库(Keras版本)
!pip install tensorflow keras numpy pandas matplotlib scikit-learn

2.2 数据准备

假设我们使用Kaggle的猫狗分类数据集(包含25000张训练图像)

import os
import numpy as np
from keras.preprocessing.image import ImageDataGenerator

# 数据集路径配置
train_dir = '/path/to/train'
validation_dir = '/path/to/validation'
test_dir = '/path/to/test'

# 图像预处理参数
img_width, img_height = 224, 224  # 匹配预训练模型输入尺寸
batch_size = 32
num_classes = 2  # 猫和狗分类

# 数据增强配置
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')

validation_datagen = ImageDataGenerator(rescale=1./255)

# 创建数据生成器
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')

3. 使用Keras实现迁移学习

3.1 加载预训练模型

from keras.applications import VGG16
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D, Dropout

# 加载VGG16模型(不包括顶层)
base_model = VGG16(
    weights='imagenet', 
    include_top=False, 
    input_shape=(img_width, img_height, 3))

# 冻结卷积基
for layer in base_model.layers:
    layer.trainable = False

# 添加自定义顶层
x = base_model.output
x = GlobalAveragePooling2D()(x)  # 全局平均池化
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)  # 防止过拟合
predictions = Dense(num_classes, activation='softmax')(x)

# 构建完整模型
model = Model(inputs=base_model.input, outputs=predictions)

相关文章:

  • 【前端】CSS 备忘清单(超级详细!)
  • 内核进程调度队列(linux的真实调度算法) ─── linux第13课
  • 【经验分享】Ubuntu vmware虚拟机存储空间越来越小问题(已解决)
  • Spring IoC配置(xml+组件类的生命周期方法)
  • 精准汇报:以明确答复助力高效工作
  • 网络原理----TCP/IP(3)
  • 解决:org.springframework.web.multipart.support.MissingServletRequestPartException
  • 小练习之配置本地yum源和ssh服务
  • Uniapp使用大疆SDK打包离线原生插件
  • Cherno C++ P61 C++当中的命名空间
  • K8S学习之基础五:k8s中node节点亲和性
  • Nginx1.19.2不适配OPENSSL3.0问题
  • DeepSeek 助力 Vue3 开发:打造丝滑的时间选择器(Time Picker)
  • 17.9 LangSmith Tracing 深度实战:构建透明可观测的大模型应用
  • 蓝桥杯刷题(Cows in a Skyscraper G,炮兵阵营)
  • ffmpeg源码编译支持cuda
  • STM32-GPIO详解
  • 主时钟与虚拟时钟约束
  • 【UCB CS 61B SP24】Lecture 19 20: Hashing Hashing II 学习笔记
  • YOLOv11融合YOLOv12中的R-ELAN结构
  • 国台办:民进党当局所谓“对等尊严”,就是企图改变两岸同属一中
  • 孙磊已任中国常驻联合国副代表、特命全权大使
  • “五一”假期预计全社会跨区域人员流动量超14亿人次
  • 住房和城乡建设部办公厅主任李晓龙已任部总工程师
  • 国泰海通合并后首份业绩报告出炉:一季度净利润增逾391%
  • 解放日报头版聚焦“人民城市”:共建共享展新卷