当前位置: 首页 > news >正文

【AutoFormer 源码理解】 conv1d

为什么将嵌入维度移动到中间位置

将输入张量从 [B, L, D] 转置为 [B, D, L] 的主要原因是为了符合 PyTorch 中卷积层的输入格式要求:

PyTorch 卷积层的输入格式

TokenEmbedding 类中,使用了 nn.Conv1d 进行卷积操作。对于 PyTorch 中的一维卷积层:

  • 要求输入格式[批次大小, 输入通道数, 序列长度][B, C, L]
  • 实际输入格式:原始时间序列为 [B, L, D],其中 D 是特征维度

维度重排的目的

  1. 符合卷积要求

    • 将特征维度 D 移到第二个位置,作为卷积的输入通道
    • 将序列长度 L 移到第三个位置,作为卷积操作的滑动维度
  2. 卷积操作方向

    • 这样转置后,卷积核就会沿着时间维度(序列长度)滑动
    • 对每个特征通道同时应用相同的卷积操作
  3. 恢复原始排列

    • 卷积后得到 [B, d_model, L]
    • 再通过 transpose(1, 2) 转回 [B, L, d_model],以符合 Transformer 结构的输入要求

这种设计使模型能够有效捕获时间序列中的局部模式,同时保持与后续 Transformer 层的兼容性。

一维卷积(1D Conv)在TokenEmbedding中的输出计算过程

在TokenEmbedding类中,使用了一维卷积将输入特征从c_in维映射到d_model维。下面详细解释其计算过程:

卷积参数

self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                           kernel_size=3, padding=padding, padding_mode='circular', bias=False)
# 例如: Conv1d(7, 512, kernel_size=(3,), stride=(1,), padding=(1,), bias=False, padding_mode=circular)

输出尺寸计算公式

一维卷积的输出长度计算公式:

L_out = (L_in + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1

计算过程

假设输入x形状为[B, L, D],其中D等于c_in

  1. 维度重排

    x.permute(0, 2, 1)  # 从[B, L, D]变为[B, D, L]
    
  2. 卷积操作

    • 输入:[B, c_in, L]
    • 权重矩阵:[d_model, c_in, kernel_size]
    • 卷积核在序列上滑动,每个位置产生d_model个输出特征
    • 针对每个输出通道j:
      output[b,j,i] = Σ(input[b,c,i+k] * weight[j,c,k])
      
      其中b是批次索引,c遍历所有输入通道,k遍历卷积核位置
  3. 输出形状

    • 使用适当的padding(1或2)保持序列长度L不变
    • 卷积后形状:[B, d_model, L]
  4. 最终转置

    .transpose(1, 2)  # 从[B, d_model, L]变为[B, L, d_model]
    

实际意义

这个卷积操作的实际意义是:

  • 捕获输入时间序列中的局部模式
  • 将原始特征维度映射到模型的嵌入维度
  • 通过循环填充(circular padding)处理时间序列的边界,认为时间序列是周期性的
  • 为每个时间步生成一个维度为d_model的特征表示

通过这种方式,输入序列的每个时间步都被转换成一个更丰富的表示,作为Transformer模型的输入。## 实际意义

这个卷积操作的实际意义是:

  • 捕获输入时间序列中的局部模式
  • 将原始特征维度映射到模型的嵌入维度
  • 通过循环填充(circular padding)处理时间序列的边界,认为时间序列是周期性的
  • 为每个时间步生成一个维度为d_model的特征表示

通过这种方式,输入序列的每个时间步都被转换成一个更丰富的表示,作为Transformer模型的输入。

相关文章:

  • 【蓝桥杯】省赛:缴纳过路费(并查集)
  • 虚拟定位 1.2.0.2 | 虚拟定位,上班打卡,校园跑步模拟
  • AI幻觉时代:避坑指南与技术反思
  • 机器学习扫盲系列(2)- 深入浅出“反向传播”-1
  • 粗粒度和细粒度指的是什么?
  • 回顾Transformer,并深入讲解替代方案Mamba原理(图解)
  • 【6. 系统调用】
  • 异常(11)
  • 解决QT_Debug 调试信息不输出问题
  • Navigation页面导航的使用
  • 无SIM卡时代即将来临?eSIM才是智联未来?
  • ChatBI 的技术演进与实践挑战:衡石科技如何通过 DeepSeek 实现商业落地
  • arthas基础命令
  • Forward Looking Radar Imaging by Truncated Singular Value Decomposition 论文阅读
  • K8S快速部署
  • CSP-J/S冲奖第18天:真题解析
  • Matlab 汽车主动悬架LQR控制器设计与仿真
  • 使用DeepSeek,优化斐波那契数函数,效果相当不错
  • 什么是有限元力学?分而治之,将复杂问题转化为可计算的数学模型
  • 设计模式-适配器模式
  • 火车站员工迟到,致出站门未及时开启乘客被困?铁路部门致歉
  • 文学花边|对话《借命而生》原著作者石一枫:我给剧打90分
  • 左娅︱悼陈昊
  • 做街坊们的“健康管家”,她把专科护理服务送上门
  • “80后”李灿已任重庆市南川区领导,此前获公示拟提名为副区长人选
  • 75万采购防火墙实为299元路由器?重庆三峡学院发布终止公告:出现违法违规行为