当前位置: 首页 > news >正文

python打卡day39

图像数据与显存

知识点回顾

  1. 图像数据的格式:灰度和彩色数据
  2. 模型的定义
  3. 显存占用的4种地方
    1. 模型参数+梯度参数
    2. 优化器参数
    3. 数据批量所占显存
    4. 神经元输出中间状态
  4. batchisize和训练的关系

作业:今日代码较少,理解内容即可

在 PyTorch 中,图像数据的形状通常遵循 (通道数, 高度, 宽度) 的格式(即 Channel First 格式),这与常见的 (高度, 宽度, 通道数)(Channel Last,如 NumPy 数组)不同。---注意顺序关系,

注意点:

  1. 如果用matplotlib库来画图,需要转换下顺序 image = np.transpose(image.numpy(), (1, 2, 0)
  2. 模型输入通常需要批次维度(Batch Size),形状变为 (批次大小, 通道数, 高度, 宽度)。例如,批量输入 10 张 MNIST 图像时,形状为 (10, 1, 28, 28)

对于图像数据集比如MNIST构建神经网络来训练的话,比起之前的结构化数据多了一个展平操作:

# 定义两层MLP神经网络
class MLP(nn.Module):def __init__(self, input_size=784, hidden_size=128, num_classes=10):super().__init__()self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(input_size, hidden_size)  # 第一层:784个输入,128个神经元self.relu = nn.ReLU()  # 激活函数self.layer2 = nn.Linear(hidden_size, num_classes)  # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x)  # 展平图像x = self.layer1(x)   # 第一层线性变换x = self.relu(x)     # 应用ReLU激活函数x = self.layer2(x)   # 第二层线性变换,输出logitsreturn x# 初始化模型
model = MLP()

MLP的输入层要求输入是一维向量,但 MNIST 图像是二维结构(28×28 像素),形状为 [1, 28, 28](通道 × 高 × 宽)。nn.Flatten() 展平操作将二维图像 “拉成” 一维向量(784=28×28 个元素),使其符合全连接层的输入格式

在面对数据集过大的情况下,由于无法一次性将数据全部加入到显存中,所以采取了分批次加载这种方式。所以实际应用中,输入图像还存在batch_size这一维度,但在PyTorch中,模型定义和输入尺寸的指定不依赖于batch_size,无论设置多大的batch_size,模型结构和输入尺寸的写法都是不变的,batch_size是在数据加载阶段定义的(之前提过这是DataLoader的参数)

那么显存设置多少合适呢?如果设置的太小,那么每个batch_size的训练不足以发挥显卡的能力,浪费计算资源;如果设置的太大,会出现OOM(out of memory)显存一般被以下内容占用:

  1. 模型参数与梯度:模型的权重和对应的梯度会占用显存,尤其是深度神经网络(如 Transformer、ResNet 等),一个 1 亿参数的模型(如 BERT-base),单精度(float32)参数占用约 400MB(1e8×4Byte),加上梯度则翻倍至 800MB(每个权重参数都有其对应的梯度)
  2. 部分优化器(如 Adam)会为每个参数存储动量(Momentum)和平方梯度(Square Gradient),进一步增加显存占用(通常为参数大小的 2-3 倍)
  3. 其他开销
  • 单张图像尺寸:1×28×28(通道×高×宽),归一化转换为张量后为float32类型,显存占用:1×28×28×4 Byte = 3,136 Byte ≈ 3 KB
  • 批量数据占用:batch_size × 单张图像占用,例如batch_size=64时,数据占用为64×3 KB ≈ 192 KB

对于batch_size的设置,大规模数据时,通常从16开始测试,然后逐渐增加,确保代码运行正常且不报错,直到出现内存不足(OOM)报错或训练效果下降,此时选择略小于该值的 batch_size。训练时候搭配 nvidia-smi 监控显存占用,合适的 batch_size = 硬件显存允许的最大值 × 0.8(预留安全空间),并通过训练效果验证调整

@浙大疏锦行

相关文章:

  • MySQL入门笔记
  • mac电脑安装nvm
  • 一个超简易的RMAN备份并保留到异地的方案,仅适用于小规模环境
  • k8s上运行的mysql、mariadb数据库的备份记录
  • IT选型指南:电信行业需要怎样的服务器?
  • uniapp分包配置,uniapp设置subPackages
  • SpringIOC中Bean生命周期
  • TeleAI发布TeleChat2.5及T1正式版,双双开源上线魔乐社区!
  • 如何应对客户对项目进度的过度干预
  • 低代码——表单生成器Form Generator详解(二)——从JSON配置项到动态渲染表单渲染
  • 若依框架 账户管理 用户分配界面解读
  • SpringBoot使用MQTT协议简述
  • 十、【核心功能篇】项目与模块管理:前端页面开发与后端 API 联调实战
  • 华为OD机试真题——阿里巴巴找黄金宝箱(II)(2025A卷:100分)Java/python/JavaScript/C/C++/GO最佳实现
  • leetcode450.删除二叉搜索树中的节点:迭代法巧用中间节点应对多场景删除
  • Oracle的NVL函数
  • MCP协议开发规范
  • 第八章 Wireshark工具的安装与使用
  • 数据治理是什么意思?数据治理平台有哪些?
  • JDBC 核心执行流程详解
  • 网站推广思路/广州网页制作
  • 手机网站建设推广软文/百度平台电话多少
  • 交换机可以做网站跳转吗/百度关键词竞价价格
  • 用html表格做的网站/微信公众号平台官网
  • 做网站维护是什么岗位/全面网络推广营销策划
  • wordpress 加载速度优化/什么是seo营销