当前位置: 首页 > news >正文

position embedding

文章目录

  • 1. 四种position embedding
  • 2. pytorch 源码[后续整理]

【因比较忙,后续整理】

1. 四种position embedding

Position Embedding
1. Transformer
1.1 1d absolute
1.2 sin/cos constant
1.3
2. Vision Transformer
2.1 1d absolute
2.2 trainable
3. Swin Transformer
3.1 2d relative bias
3.2 trainable
4. Masked AutoEncoder
4.1 2d absolute
4.2 sin/cos constant

2. pytorch 源码[后续整理]

import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)


# ---------------------------------------------------------------------------------
# transformer constant sin/cos embedding position
def create_1d_absolute_sincos_embeddings(n_pos_vec, dim):
    # n_pos_vec : torch.arange(n_pos,dtype=torch.float)
    assert dim % 2 == 0, "wrong dimension"
    position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)

    omega = torch.arange(dim // 2, dtype=torch.float)
    omega /= dim / 2.0
    omega = 1.0 / (10000 ** omega)

    out = n_pos_vec[:, None] @ omega[None, :]

    emb_sin = torch.sin(out)
    emb_cos = torch.cos(out)

    position_embedding[:, 0::2] = emb_sin
    position_embedding[:, 1::2] = emb_cos

    return position_embedding


# ---------------------------------------------------------------------------------

# ---------------------------------------------------------------------------------
# 2. 1d absolute trainable embedding
def create_1d_absolute_trainable_embeddings(n_pos_vec, dim):
    # n_pos_vec : torch.arange(n_pos,dtype=torch.float)
    position_embedding = nn.Embedding(n_pos_vec.numel(), dim)
    nn.init.constant_(position_embedding.weight, 0.0)
    return position_embedding


# 3. 2d relative bias trainable embedding
def create_2d_relative_bias_trainable_embeddings(n_heads, height, width, dim):
    # width=5,-->torch.arange(5)=[0,1,2,3,4]--> bias=[-4,-3,-2,-1,0,1,2,3,4]=2*width-1
    # width=5,-->torch.arange(5)=[0,1,2,3,4]--> bias=[-4,-3,-2,-1,0,1,2,3,4]=2*width-1
    ps_height = (2 * height - 1) * (2 * width - 1)
    ps_width = n_heads
    position_embedding = nn.Embedding(ps_height, ps_width)
    nn.init.constant_(position_embedding.weight, 0.0)

    def get_relative_position_index(height, width):
        coords = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width)))  # [2,height,width]
        coords_flatten = torch.flatten(coords, 1)  # [2,height*width]
        relative_coords_bias = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2,height*width,height*width]
        relative_coords_bias[0, :, :] += height - 1
        relative_coords_bias[1, :, :] += width - 1

        # A:2d,B:1d,B[i*cos+j] = A[i,j]
        relative_coords_bias[0, :, :] *= relative_coords_bias[1, :, :].max() + 1
        return relative_coords_bias.sum(0)  # [height*width,height*width]

    relative_position_bias = get_relative_position_index(height, width)
    bias_embedding = position_embedding(torch.flatten(relative_position_bias)).reshape(height * width, height * width,
                                                                                       n_heads)
    bias_embedding = bias_embedding.permute(2, 0, 1).unsqueeze(0)
    return bias_embedding


if __name__ == "__main__":
    run_code = 0
    n_pos = 4
    dim = 4
    n_pos_vec = torch.arange(n_pos, dtype=torch.float)
    pe = create_1d_absolute_sincos_embeddings(n_pos_vec, dim)
    print(f"pe=\n{pe}")
    my_n_heads = 3
    my_height = 4
    my_width = 5
    my_dim = 6

    result = create_2d_relative_bias_trainable_embeddings(n_heads=my_n_heads, height=my_height, width=my_width,
                                                          dim=my_dim)
    print(f"result=\n{result}")

相关文章:

  • 【测试报告】论坛系统
  • 语言解码双生花:人类经验与AI算法的镜像之旅
  • 树状数组模板
  • 【redis】哨兵:搭建主从/哨兵节点详解和细节
  • 【WebGIS教程2】Web服务与地理空间服务解析
  • Java:JDK8 新特性:Lambda表达式
  • Vulnhub-Thales通关攻略
  • 第30周Java分布式入门 ThreadLocal
  • 无法打开... .exe进行写入 解决方法
  • vue中defineModel简化defineProps和defineEmits的用法
  • KofamKOALA:KEGG本地化注释
  • 无线安灯按钮盒汽车零部件工厂的故障告警与人员调度专家
  • 【干货,实战经验】nginx缓存问题
  • 程序员英语口语练习笔记
  • python dict转换成json格式
  • 深入解析Flink Kafka Connector的分布式流数据采集架构与底层实现
  • 2025最新版Ubuntu Server版本Ubuntu 24.04.2 LTS下载与安装-详细教程,细致到每一步都有说明
  • SAP 获取RFC的WSDL文件
  • react项目中当组件渲染的时候如何执行接口
  • 侯捷 C++ 课程学习笔记:现代 C++ 中的移动语义与完美转发深度解析
  • 昆山公司网站建设/重庆seo推广运营
  • 网站开发时什么时间适合创建视图/企业网站建设报价表
  • 企业网站建设排名官网/建站平台哪个比较权威
  • 太原网站建设公司招聘/网站设计公司排行
  • 搭建正规网站/公司网站建设服务机构
  • 深圳营销培训班/丁的老头seo博客