当前位置: 首页 > news >正文

7.训练篇5-毕设

使用23w张数据集-vit-打算30轮-内存崩了-改为batch_size = 8  

 我准备用23w张数据集,太大了,这个用不了,所以

是否保留 .stack() 加载所有图片?情况建议
✅ 小数据集(<2w张,图像小)想加快速度可以用
❌ 大数据集(>5w张图)Colab / 本地内存有限❌ 不建议,容易爆 RAM
✅ 你正在用 Dataloader说明已动态加载不需要这段代码

网上经验

模型图片大小batch_size 安全值(Colab Pro)
ViT-B/16224×2248 非常安全(推荐)
ViT-B/16224×224⚠️ 16 可能会炸(尤其 A100/T4)
ViT-S/16224×22416~32 都行
ViT-Tiny / DeiT-Tiny224×22432~64 可尝试

什么是“骨架图”?

我们说“骨架图”,就是指:

  • 神经网络的“结构组成”

  • 包括:每一层的类型(如 Conv2d, Linear, Transformer 等)

  • 每层的参数维度(比如 Linear(768 → 29)

  • 模型的前向传递路线(从输入 → 输出)

ViT-B/16 模型骨架图包含:

模块名内容简介
conv_proj把图像分成 patch(切成小块),变成 768 维向量
encoder12 层 Transformer,每层包括 self-attention + MLP
heads线性分类层:将最终特征 [768] 映射到你要的类别(比如 29)

举个例子(完整流程):

如果你输入一张图片 img = [1, 3, 224, 224]:(1指batch_size)

  1. conv_proj 把它切成 16x16 的 patch(共 196 个 patch),每个 patch 映射为 768 维向量

  2. Transformer 对 768 的向量做注意力建模(12 层)

  3. 取出第一个“分类 token”的输出,传给 Linear(768 → num_classes)num_classes=29,这里

  4. 输出结果为 [1, num_classes],比如 [1, 29]

ViT 是一种用“文字处理的方式”来看图片的模型。

把图像当成一串“小块块”(Patch),就像文本中的“单词”,然后用 Transformer 来分析这些块的关系。

类比图像与文字:

文本(NLP)图像(ViT)
单词 Word图像小块 Patch
词向量Patch 向量(Embedding)
BERT 模型ViT 模型(结构几乎一样)
输入图像:[B, 3, 224, 224]表示你输入的是 batch_size = B 张 RGB 彩色图像,分辨率为 224x224。
   │
【步骤1】Conv2d 分块 → Patch Embedding(patch 大小为 16x16)
   │ 得到 patch 数量:224/16 * 224/16 = 196个 patch(再加1个分类Token)
   │ 每个 patch 映射为 768维向量
   ↓
总输入:[B, 197, 768] (197 = 196 patch + 1 cls_token)

【步骤2】加上位置编码(告诉模型每个 patch 的位置)
   ↓

【步骤3】12 层 Transformer 编码器(每层都包含以下结构):
       ├── LayerNorm
       ├── Multi-head Self Attention(观察所有 patch 之间的关系)
       ├── MLP(前馈网络:两个 Linear + GELU 激活)
       └── Residual(残差连接)
   ↓

【步骤4】取出第一个位置的输出(cls_token)
   ↓

【步骤5】传入全连接层(Linear(768 → 29)) → 输出分类结果
步骤模块输出 shape(假设 B=8)说明
输入图像img[8, 3, 224, 224]一批图像
Patch Embeddingconv_proj[8, 768, 14, 14]用卷积切成 14x14 个 patch,每个是 768 维向量
→ Flatten + permute.reshape()[8, 196, 768]展平为 patch 序列:14×14 = 196 个 patch
加 CLS tokencls_token + concat[8, 197, 768]加 1 个 [CLS] 向量在开头,共 197 个 token
加位置编码pos_embedding[8, 197, 768]给每个 patch 一个位置信息(加法)

Encoder Block × 12 层:

每层结构都一样,输入输出 shape 都是:

Layer input: [8, 197, 768] Layer output: [8, 197, 768]

说明:每层的输出仍然是 197 个 token(含CLS),每个 token 是 768 维特征。

最终输出阶段:

步骤模块输出 shape
分类 tokenx[:, 0, :][8, 768] → 取第1个CLS token
全连接层Linear(768 → num_classes)[8, 29](假设你要分29类)

使用的ViT-B/16 模型

名字含义
ViTVision Transformer(图像版的 Transformer)
BBase(中等模型大小,有 12 层 encoder)
16Patch 大小为 16×16 像素

使用的步骤,新手小白

阶段要做的事示例代码 / 解释
① 加载预训练模型使用 torchvision 的 vit_b_16✅ 一行代码就能加载
② 修改输出层替换为自己的分类数,比如 29 类model.heads.head = nn.Linear(768, 29)
③ 预处理图像必须是 224×224 大小,标准化transforms.Resize + Normalize
④ 训练模型和 ResNet 一样用 dataloader训练 epoch,记录 loss 和 acc
⑤ 保存 / 加载模型torch.save() + torch.load()保存好 .pth 文件
⑥ 预测一张图像图像 → Tensor → 模型预测用 softmax 和 argmax 得到分类结果
⑦ 可视化 attention(进阶)可选:叠图显示 ViT 看哪里了用 attention map(可视化热图)

 只是做一个手势识别任务(而不是 ImageNet 等通用视觉任务),完全没必要用到全部 23 万张数据,使用的是预训练的 ViT(pretrained=True),你只需要每类几百到上千张图像,就能训练出一个效果不错的模型。

用 ViT-B/16 训练 batch_size=8 的一轮(epoch)
在 A100 上 大约每 step 0.05 - 0.08 秒(视数据加载效率不同)

如果是23w张大概需要14h

原因说明
✅ ViT 已经在 ImageNet 上学过了它早就“学会看图”了,你只需要教它你手势的分类方式
✅ 手势分类是“小数据任务”一般只需要几十个类,图像也比较规范,模型很好学
✅ 23w 张图片训练成本高占用 GPU 时间大、调参慢、不适合原型验证

以29类手势为例

每类图片数总图片数适用阶段训练建议
1002,900快速验证快速调试训练流程,10分钟出结果
50014,500初始训练可达到不错效果
1,00029,000稳定训练精度较好,不容易过拟合
3,000+87,000+高精度训练适合微调完整 ViT,建议 batch_size 大一点
23万张实验冗余除非你做的是论文级 benchmark,否则不建议一开始就全用

 实际证明我前面想的不太对

29类-每类1500张-batch_size=32-训练轮数25

这个准确率

粘贴的,目前识别率可以,我在想是不是因为其他网络训练的数据集没有没那么多的原因


文章转载自:
http://bondman.ciuzn.cn
http://artificer.ciuzn.cn
http://actinin.ciuzn.cn
http://bugbane.ciuzn.cn
http://bombardier.ciuzn.cn
http://bronchiole.ciuzn.cn
http://behavior.ciuzn.cn
http://annotate.ciuzn.cn
http://calibre.ciuzn.cn
http://bristle.ciuzn.cn
http://beldame.ciuzn.cn
http://burnsides.ciuzn.cn
http://caporegime.ciuzn.cn
http://achordate.ciuzn.cn
http://apostle.ciuzn.cn
http://apoplexy.ciuzn.cn
http://biquinary.ciuzn.cn
http://cholangiography.ciuzn.cn
http://asphyxy.ciuzn.cn
http://barefaced.ciuzn.cn
http://arse.ciuzn.cn
http://childbed.ciuzn.cn
http://blobberlipped.ciuzn.cn
http://chrominance.ciuzn.cn
http://authenticator.ciuzn.cn
http://assibilation.ciuzn.cn
http://aphtha.ciuzn.cn
http://bangladeshi.ciuzn.cn
http://anybody.ciuzn.cn
http://alamo.ciuzn.cn
http://www.dtcms.com/a/111387.html

相关文章:

  • 数字内容体验的未来方向是什么?
  • 图形渲染中的定点数和浮点数
  • 智慧放羊如何实现?
  • Python设计模式:克隆模式
  • 音视频入门基础:RTP专题(21)——使用Wireshark分析海康网络摄像机RTSP的RTP流
  • LeetCode 1123.最深叶节点的最近公共祖先 题解
  • Runnable组件动态添加默认调用参数
  • 基于SpringBoot框架发生验证码
  • 【Docker项目实战】使用Docker部署MediaCMS内容管理系统
  • 脑影像分析软件推荐 | BCT(Brain Connectivity Toolbox)
  • c语言修炼秘籍 - - 禁(进)忌(阶)秘(技)术(巧)【第四式】自定义类型详解(结构体、枚举、联合)
  • Windows 11 听的见人声,但是听不见背景音乐或者听不见轻音乐等,可以这样设置
  • 【橘子大模型】Runnable和Chain以及串行和并行
  • STM32 HAL库 CANFD配置工具
  • 小程序API —— 58 自定义组件 - 创建 - 注册 - 使用组件
  • CExercise_04_1运算符_6 (扩展) 找出数组中只出现一次的唯二元素
  • 社会视频汇聚:构筑城市安全防线的智慧之眼
  • VirtualBox 配置双网卡(NAT + 桥接)详细步骤
  • 《微服务概念进阶》精简版
  • 新浪财经股票每天10点自动爬取
  • 免费送源码:Java+SSM+Android Studio 基于Android Studio游戏搜索app的设计与实现 计算机毕业设计原创定制
  • Springboot + Vue + WebSocket + Notification实现消息推送功能
  • 接口自动化学习四:全量字段校验
  • L1-100 四项全能(测试点1)
  • 计算机网络知识点汇总与复习——(三)数据链路层
  • 在VMware下Hadoop分布式集群环境的配置--基于Yarn模式的一个Master节点、两个Slaver(Worker)节点的配置
  • Leetcode 33 -- 二分查找 | 归约思想
  • 【YOLO系列(V5-V12)通用数据集-交通红黄绿灯检测数据集】
  • SpringBoot集成swagger和jwt
  • Flask学习笔记 - 模板渲染