当前位置: 首页 > news >正文

基于PyTorch的图像分类特征提取与模型训练文档

概述

本代码实现了一个基于PyTorch的图像特征提取与分类模型训练流程。核心功能包括:

  1. 使用预训练ResNet18模型进行图像特征提取

  2. 将提取的特征保存为标准化格式

  3. 基于提取的特征训练分类模型

代码结构详解 

1. 库导入

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
import numpy as np
import os
from ml.model_trainer import ModelTrainer
  • 关键库说明

    • torch:PyTorch核心库

    • torch.nn:神经网络模块

    • torchvision:计算机视觉专用模块

    • numpy:数值计算库

    • os:文件系统操作

    • ModelTrainer:自定义模型训练类(需另行实现)

2. 特征提取器类(FeatureExtractor)

初始化方法 __init__
def __init__(self):self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model = torchvision.models.resnet18(weights='IMAGENET1K_V1')self.model = nn.Sequential(*list(self.model.children())[:-1])self.model = self.model.to(self.device).eval()self.transform = transforms.Compose([...])
  • 功能说明

    • 设备检测:自动选择GPU/CPU

    • 模型加载:使用ImageNet预训练的ResNet18

    • 模型修改:移除最后的全连接层(保留卷积特征提取器)

    • 预处理设置:标准化图像尺寸和颜色空间

特征提取方法 extract_features
def extract_features(self, data_dir):full_dataset = datasets.ImageFolder(...)loader = DataLoader(...)features = []labels = []with torch.no_grad():for inputs, targets in loader:inputs = inputs.to(self.device)outputs = self.model(inputs)features.append(outputs.squeeze().cpu().numpy())labels.append(targets.numpy())features = np.concatenate(...)labels = np.concatenate(...)return features, labels, full_dataset.classes
  • 关键参数

    • data_dir:包含分类子目录的图像数据集路径

    • batch_size=32:平衡内存使用与处理效率

    • num_workers=4:多线程数据加载

  • 处理流程

    1. 创建ImageFolder数据集

    2. 使用DataLoader批量加载

    3. 禁用梯度计算加速推理

    4. 特征维度压缩(squeeze)

    5. 设备间数据传输(GPU->CPU)

    6. 合并所有批次数据

3. 主执行流程

参数配置
DATA_DIR = "/home/.../data"  # 实际数据路径
SAVE_PATH = "./features.npz"  # 特征保存路径
特征提取与保存 
extractor = FeatureExtractor()
if not os.path.exists(SAVE_PATH):features, labels, classes = extractor.extract_features(DATA_DIR)np.savez(SAVE_PATH, features=features, labels=labels, classes=classes)
else:data = np.load(SAVE_PATH)features = data['features']labels = data['labels']
  • 文件结构

    • features: [N_samples, 512] 的特征矩阵

    • labels: [N_samples] 的标签数组

    • classes: 类别名称列表

模型训练与保存
X, y = features, labels
trainer = ModelTrainer()
model = trainer.train_model(X, y)
joblib.dump(model, 'pest_classifier.pkl')

 

  • 假设条件

    • ModelTrainer需实现训练逻辑(如SVM、随机森林等)

    • 默认使用全部数据进行训练(建议实际添加数据分割)

技术细节说明

1. 图像预处理流程

2. 特征维度分析

  • ResNet18最后层输出:512维特征向量

  • 假设1000张图像:

    • 原始图像:1000×3×224×224 (约150MB)

    • 提取特征:1000×512 (约2MB) → 显著降维

3. 性能优化策略

  • GPU加速:自动检测CUDA设备

  • 批量处理:32张/批平衡效率与内存

  • 缓存机制:避免重复特征提取

  • 梯度禁用:减少内存消耗

 

 

 

 

 

相关文章:

  • MapReduce的shuffle过程详解
  • 【C++初阶】--- 模板进阶
  • 将infinigen功能集成到UE5--在ue里面写插件(python和c++)
  • 在Mybatis中写sql的常量应用
  • Redis Sentinel 和 Redis Cluster 各自的原理、优缺点及适用场景是什么?
  • 同一个路由器接口eth0和ppp0什么不同?
  • springboot中有关数据库信息转换的处理
  • Opencv中图像深度(Depth)和通道数(Channels)区别
  • MySQL事务隔离级别的实现原理MVCC
  • 51c自动驾驶~合集37
  • 「国产嵌入式仿真平台:高精度虚实融合如何终结Proteus时代?」——从教学实验到低空经济,揭秘新一代AI赋能的产业级教学工具
  • 夜族觉醒 服务搭建 异地联机 保姆教程 流畅不卡顿
  • 【linux网络】网络基础概念
  • 流量守门员:接口限流艺术
  • 软件设计师-软考知识复习(2)
  • vue3+flex动态的绘制蛇形时间轴
  • Python小程序:上班该做点摸鱼的事情
  • vue3+Nest.js项目 部署阿里云
  • 字节跳动社招面经 —— BSP驱动工程师(4)
  • vue.js中的一些事件修饰符【前端】
  • 国有六大行一季度合计净赚超3444亿,不良贷款余额均上升
  • 书业观察|一本书的颜值革命:从毛边皮面到爆火的刷边书
  • 仲裁法修订草案二审稿拟增加规定规制虚假仲裁
  • 伊朗外长:美伊谈判进展良好,讨论了很多技术细节
  • 证监会发布上市公司信披豁免规定:明确两类豁免范围、规定三种豁免方式
  • 上海发布一组人事任免信息:钱晓、翁轶丛任市数据局副局长