当前位置: 首页 > news >正文

PyTorch中的Flatten

在 PyTorch 中,Flatten 操作是将多维张量转换为一维向量的重要操作,常用于卷积神经网络(CNN)的全连接层之前。以下是 PyTorch 中实现 Flatten 的各种方法及其应用场景。

一、基本 Flatten 方法

1. 使用 torch.flatten() 函数

import torch

# 创建一个4D张量 (batch_size, channels, height, width)
x = torch.randn(32, 3, 28, 28)  # 32张28x28的RGB图像

# 展平整个张量
flattened = torch.flatten(x)  # 输出形状: [75264] (32*3*28*28)

# 从指定维度开始展平
flattened = torch.flatten(x, start_dim=1)  # 输出形状: [32, 2352] (保持batch维度)

2. 使用 nn.Flatten 层

import torch.nn as nn

flatten = nn.Flatten()  # 默认从第1维开始展平(保持batch维度)
x = torch.randn(32, 3, 28, 28)
output = flatten(x)  # 输出形状: [32, 2352]

 可以指定开始和结束维度:

flatten = nn.Flatten(start_dim=1, end_dim=2)
x = torch.randn(32, 3, 28, 28)
output = flatten(x)  # 输出形状: [32, 84, 28] (合并了第1和2维)

二、不同场景下的 Flatten 应用

1. CNN 中的典型用法

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 16, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 5 * 5, 10)  # 计算展平后的尺寸
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.flatten(x)  # 形状从 [B, 32, 5, 5] 变为 [B, 800]
        x = self.fc(x)
        return x

 2. 手动计算展平后的尺寸

# 计算卷积层输出尺寸的辅助函数
def conv_output_size(input_size, kernel_size, stride=1, padding=0):
    return (input_size - kernel_size + 2 * padding) // stride + 1

# 计算经过多层卷积和池化后的尺寸
h, w = 28, 28  # 输入尺寸
h = conv_output_size(h, 3)  # conv1: 26
w = conv_output_size(w, 3)  # conv1: 26
h = conv_output_size(h, 2, 2)  # pool1: 13
w = conv_output_size(w, 2, 2)  # pool1: 13
h = conv_output_size(h, 3)  # conv2: 11
w = conv_output_size(w, 3)  # conv2: 11
h = conv_output_size(h, 2, 2)  # pool2: 5
w = conv_output_size(w, 2, 2)  # pool2: 5
print(f"展平后的特征数: {32 * h * w}")  # 32 * 5 * 5 = 800

三、高级用法

1. 部分展平

# 只展平图像空间维度,保留通道维度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(start_dim=2)  # 形状: [32, 3, 784]

 2. 自定义 Flatten 层

class ChannelLastFlatten(nn.Module):
    """将通道维度移到最后的展平层"""
    def forward(self, x):
        # 输入形状: [B, C, H, W]
        x = x.permute(0, 2, 3, 1)  # [B, H, W, C]
        return x.reshape(x.size(0), -1)  # [B, H*W*C]

3. 展平特定维度

# 展平批量维度和通道维度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(end_dim=1)  # 形状: [96, 28, 28] (32*3=96)

四、注意事项

  1. 维度计算:确保展平后的尺寸与全连接层的输入尺寸匹配

  2. 批量维度:通常保留第0维(batch维度)不被展平

  3. 内存连续性view()需要连续内存,必要时先调用contiguous()

  4. 替代方法x.view(x.size(0), -1)flatten(start_dim=1)的常见替代写法

五、性能比较

方法优点缺点
torch.flatten()官方推荐,可读性好
nn.Flatten()可作为网络层使用需要实例化对象
x.view()最简洁需要手动计算尺寸
x.reshape()自动处理内存连续性性能略低于view

六、示例代码

import torch
import torch.nn as nn

# 定义一个包含Flatten的完整模型
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.flatten = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.Linear(256 * 4 * 4, 1024),  # 假设输入图像是32x32
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 10)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

# 使用示例
model = ImageClassifier()
input_tensor = torch.randn(16, 3, 32, 32)  # batch=16, 3通道, 32x32图像
output = model(input_tensor)
print(output.shape)  # 输出形状: [16, 10]

http://www.dtcms.com/a/113113.html

相关文章:

  • 【学习笔记】Transformers源码分析
  • LeetCode 2442:统计反转后的不同整数数量
  • 存储基石:深度解读Linux磁盘管理机制与文件系统实战
  • 联合、枚举、类型别名
  • Unity UGUI使用手册
  • 基于spring boot的外卖系统的设计与实现【如何写论文思路与真正写出论文】
  • (八)PMSM驱动控制学习---滑膜观测器
  • Pycharm 启动时候一直扫描索引/更新索引 Update index/Scanning files to index
  • Java学习总结-io流-其他流-全体系
  • Raft算法
  • hydra小记(一):深入理解 Hydra:instantiate() 与 get_class() 的区别
  • 【Linux】日志模块实现详解
  • Android学习总结之应用启动流程(从点击图标到界面显示)
  • Java面试黄金宝典35
  • python 重要易忘 语言基础
  • 使用MATIO库写入MATLAB结构体(struct)数据的示例程序
  • 医疗思维图与数智云融合:从私有云到思维图的AI架构迭代(代码版)
  • devbox加cursor编写项目到上线,不到10分钟
  • Day20 -自动化信息收集工具--ARL灯塔的部署
  • APP的兼容性测试+bug定位方法
  • AI 如何帮助我们提升自己,不被替代
  • Redis数据结构之List
  • 重生之我是去噪高手——diffusion model
  • 第三十章:Python-NetworkX库:创建、操作与研究复杂网络
  • 复古千禧Y2风格霓虹发光酸性镀铬金属短片音乐视频文字标题动画AE/PR模板
  • 15.1linux设备树下的platform驱动编写(知识)_csdn
  • 简单程序语言理论与编译技术·22 实现一个从AST到RISCV的编译器
  • HarmonyOS应用开发者高级-编程题-001
  • keil软件仿真
  • java高并发------线程的六种状态