当前位置: 首页 > news >正文

【AutoFormer 源码理解】

时间嵌入选择代码解析

这行代码根据embed_type参数动态选择使用哪种时间嵌入方法,是一个条件赋值语句(三元表达式):

self.temporal_embedding = TemporalEmbedding(...) if embed_type != 'timeF' else TimeFeatureEmbedding(...)

解析

  • 条件判断: embed_type != 'timeF'

    • 检查嵌入类型是否不等于’timeF’
  • 如果条件为真(embed_type不是’timeF’):

    • 使用TemporalEmbedding类,这是一种基于查找表的离散时间特征嵌入
    • 适用于将时间信息(月份、日期、星期几等)作为离散类别处理
  • 如果条件为假(embed_type是’timeF’):

    • 使用TimeFeatureEmbedding类,这是一种基于线性映射的连续时间特征嵌入
    • 适用于将时间特征作为连续数值处理

两种嵌入方法的区别

  1. TemporalEmbedding:

    • 使用查找表(Embedding表)存储每个时间值的嵌入向量
    • 将时间视为离散类别(如第1月、第2月…)
    • 输出是各时间特征嵌入的和
  2. TimeFeatureEmbedding:

    • 使用线性层(nn.Linear)直接映射时间特征
    • 将时间视为连续特征
    • 输入是归一化的时间特征,如一年中的相对位置

这种设计允许模型根据不同的数据特性和实验需求灵活选择最适合的时间表示方法。

TimeFeatureEmbedding 类解析

TimeFeatureEmbedding 类是一个专为连续时间特征设计的嵌入层,与离散类别的 TemporalEmbedding 不同。

形状和操作分析

初始化

def __init__(self, d_model, embed_type='timeF', freq='h'):
    # freq_map定义不同数据频率下使用的时间特征维度
    freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
    # 根据频率确定输入维度
    d_inp = freq_map[freq]
    # 创建无偏置的线性层将d_inp维度映射到d_model维度
    self.embed = nn.Linear(d_inp, d_model, bias=False)

前向传播

def forward(self, x):
    # 输入x形状: [B, L, d_inp] - B是批次大小, L是序列长度, d_inp是时间特征数量
    # 线性变换后输出形状: [B, L, d_model]
    return self.embed(x)

关键特点

  1. 连续特征映射:

    • 直接对时间特征进行线性变换,而不是像TemporalEmbedding那样进行查表操作
    • 适用于连续的、已归一化的时间特征
  2. 频率相关输入维度:

    • 根据不同的时间序列频率(freq)确定输入维度
    • 例如,小时级数据使用4个特征,分钟级使用5个特征
  3. 形状转换:

    • 输入: [B, L, d_inp]
    • 线性映射: W·x 其中 W 是形状为 [d_inp, d_model] 的权重矩阵
    • 输出: [B, L, d_model]

这种设计使模型可以直接处理连续的时间特征编码,比如周期性的正弦/余弦表示,而不需要将时间离散化为类别。

http://www.dtcms.com/a/73936.html

相关文章:

  • 从“自习室令牌”到线程同步:探秘锁与条件变量
  • 基于Python的tkinter开发的一个工具,解析图片文件名并将数据自动化导出为Excel文件
  • 深度学习pytorch笔记:TCN
  • 从零开始使用 **Taki + Node.js** 实现动态网页转静态网站的完整代码方案
  • 谈谈 TypeScript 中的联合类型(union types)和交叉类型(intersection types),它们的应用场景是什么?
  • 代码随想录算法训练营第34天 | 62.不同路径 63. 不同路径 II 整数拆分 不同的二叉搜索树 (跳过)
  • linux(centos8)下编译ffmpeg
  • HCIA-PPP
  • 每天五分钟深度学习PyTorch:循环神经网络RNN的计算以及维度信息
  • 大数据 Spark 技术简介
  • TLSR8355F128芯片特色解析
  • Linux中的epoll简单使用案例
  • 视频转音频, 音频转文字
  • 通过socket实现文件上传和下载功能
  • 信息系统运行管理员教程5--信息系统数据资源维护
  • PAT甲级(Advanced Level) Practice 1023 Have Fun with Numbers
  • LeetCode 1005. K 次取反后最大化的数组和 java题解
  • C语言 —— 此去经年梦浪荡魂音 - 深入理解指针(卷二)
  • SpringBoot3+Druid+MybatisPlus多数据源支持,通过@DS注解配置Service/Mapper/Entity使用什么数据源
  • Windows11 新机开荒(二)电脑优化设置
  • C++ 类和对象 友元 内部类 this指针 默认成员函数 初始化列表……
  • Pandas DataFrame:数据分析的利器
  • 14 结构体
  • WebSocket和长轮询
  • 【操作系统】Ch6 文件系统
  • 【最后203篇系列】015 几种消息队列的思考
  • ORA-00600错误的深度剖析:如何避免与解决?
  • 蓝桥杯宝石,考察数学。考察公式推导能力
  • 设计模式(行为型)-命令模式
  • 【MySQL】MySQL数据存储机制之存储引擎