当前位置: 首页 > wzjs >正文

网站进度条网站制作推广电话

网站进度条,网站制作推广电话,web界面设计,台州做网站优化目录 代码代码解释函数 1: precompute_pos_cis代码解析 函数 2: apply_rotary_emb1. unite_shape(pos_cis, x):调整 pos_cis 形状2. 复数化 xq 和 xk3. 应用旋转变换4. 返回结果 总结1. precompute_pos_cis(dim, end, theta)2. apply_rotary_emb(xq, xk, pos_cis) 示…

目录

    • 代码
    • 代码解释
    • 函数 1: `precompute_pos_cis`
      • 代码解析
    • 函数 2: `apply_rotary_emb`
      • 1. `unite_shape(pos_cis, x)`:调整 `pos_cis` 形状
      • 2. 复数化 `xq` 和 `xk`
      • 3. 应用旋转变换
      • 4. 返回结果
    • 总结
      • 1. `precompute_pos_cis(dim, end, theta)`
      • 2. `apply_rotary_emb(xq, xk, pos_cis)`
    • 示例
    • 广播机制的作用
      • 广播规则
        • 广播细节
      • 如果形状不兼容
      • 总结

代码

def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))t = torch.arange(end, device=freqs.device)  # type: ignorefreqs = torch.outer(t, freqs).float()  # type: ignorepos_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64return pos_cisdef apply_rotary_emb(xq, xk, pos_cis):def unite_shape(pos_cis, x):ndim = x.ndimassert 0 <= 1 < ndimassert pos_cis.shape == (x.shape[1], x.shape[-1])shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]return pos_cis.view(*shape)xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))pos_cis = unite_shape(pos_cis, xq_)xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)

代码解释

这段代码实现了 旋转位置编码(Rotary Position Embedding, RoPE),用于在 Transformer 自注意力机制中引入相对位置信息,提升模型的捕捉能力。下面详细拆解代码的核心逻辑。


函数 1: precompute_pos_cis

作用:预计算 旋转位置编码的相位因子,即 复数形式的角度信息

def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device)  # type: ignore freqs = torch.outer(t, freqs).float()  # type: ignore pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64 return pos_cis 

代码解析

  1. 计算角频率 (freqs)

    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    • dim 是输入的特征维度(比如 dim=16)。
    • torch.arange(0, dim, 2)[: (dim // 2)] 生成 [0, 2, 4, ..., dim-2] 的索引,然后取前 dim//2 个。
    • 计算 freqs = 1.0 / (theta ** (index / dim)),即 指数衰减的频率(类似于 Sinusoidal PE)。
  2. 构造位置 t

    t = torch.arange(end, device=freqs.device)
    
    • t 是一个长度为 end(默认 32K)的整数序列 [0, 1, 2, ..., end-1],表示序列中的位置索引。
  3. 计算 tfreqs 的外积

    freqs = torch.outer(t, freqs).float()
    
    • torch.outer(t, freqs) 计算 时间 t 与角频率 freqs 的外积,得到一个矩阵:
      在这里插入图片描述

      这个矩阵表示每个位置 t_i 对应的角度。

  4. 转换为极坐标的复数表示

    pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    
    • torch.polar(r, θ) 创建复数:
      在这里插入图片描述

    • pos_cis复数编码的旋转因子,后续会与查询 (xq) 和键 (xk) 相乘,实现旋转变换。


函数 2: apply_rotary_emb

作用:将预计算的旋转嵌入 pos_cis 应用于 xq(查询)和 xk(键),完成 旋转位置编码的实际应用

def apply_rotary_emb(xq, xk, pos_cis): def unite_shape(pos_cis, x): ndim = x.ndim assert 0 <= 1 < ndim assert pos_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return pos_cis.view(*shape) 

1. unite_shape(pos_cis, x):调整 pos_cis 形状

作用:将 pos_cis 形状变换,以匹配 xq 的形状,方便广播计算。

  • xq.shape = (B, L, D)(批量、序列长度、特征维度)
  • pos_cis.shape = (L, D)
  • 目标形状:(1, L, D) 适配 xq
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
  • 只在 序列维 (L) 和最后一维 (D) 维持原尺寸,其余地方设置为 1,保证广播机制正确执行。

2. 复数化 xqxk

xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  • 输入 xq, xk 形状: (B, L, D)
  • reshape(*xq.shape[:-1], -1, 2) 变成 (B, L, D/2, 2),把 D 维度拆成 两个部分,对应于 复数的实部和虚部
  • torch.view_as_complex() 将最后一维 (2,) 转换为 复数格式 (B, L, D/2)

3. 应用旋转变换

pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) 
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) 
  • xq_ * pos_cis 执行复数乘法,相当于对 xq_ 旋转 pos_cis 指定的角度。
  • torch.view_as_real() 将复数转换回实数 (B, L, D/2, 2)
  • .flatten(3) 重新恢复 (B, L, D) 形状。

4. 返回结果

return xq_out.type_as(xq), xk_out.type_as(xk)
  • 确保返回数据类型和 xq, xk 一致。

总结

1. precompute_pos_cis(dim, end, theta)

  • 生成旋转角度信息,以 复数 形式存储位置编码。
  • 公式
    [
    e^{i\theta} = \cos(\theta) + i\sin(\theta)
    ]

2. apply_rotary_emb(xq, xk, pos_cis)

  • 将旋转编码应用到查询 xq 和键 xk
  • 具体步骤:
    1. 调整 pos_cis 形状 适配 xq
    2. 转换 xq, xk 为复数(实部 + 虚部)。
    3. 执行复数乘法(旋转操作)。
    4. 转换回实数,恢复 (B, L, D) 形状。

最终,这个过程增强了注意力机制,使其具有相对位置编码能力,提升了 Transformer 在长序列建模中的表现。


示例

import torch# 设置超参数
dim = 16  # 特征维度
seq_len = 10  # 序列长度
batch_size = 2  # 批量大小
head_dim = dim // 2  # 假设 head_dim = dim / 2# 预计算旋转嵌入
pos_cis = precompute_pos_cis(dim, end=seq_len)# 生成模拟的 xq 和 xk
xq = torch.randn(batch_size, seq_len, dim)  # (B, L, D)
xk = torch.randn(batch_size, seq_len, dim)  # (B, L, D)# 打印原始维度
print(f"原始 xq 维度: {xq.shape}")
print(f"原始 xk 维度: {xk.shape}")
print(f"预计算 pos_cis 维度: {pos_cis.shape}")
原始 xq 维度: torch.Size([2, 10, 16])
原始 xk 维度: torch.Size([2, 10, 16])
预计算 pos_cis 维度: torch.Size([10, 8])
# 应用旋转嵌入
xq_out, xk_out = apply_rotary_emb(xq, xk, pos_cis)# 打印转换后的维度
print(f"转换后 xq 维度: {xq_out.shape}")
print(f"转换后 xk 维度: {xk_out.shape}")
转换后 xq 维度: torch.Size([2, 10, 8, 2])
转换后 xk 维度: torch.Size([2, 10, 8, 2])

广播机制的作用

xq_ = torch.randn(2, 10,16) 
pos_cis= torch.randn(1, 10,16) 
print((xq_ * pos_cis).shape)
torch.Size([2, 10, 16])

广播规则

当进行 xq_ * pos_cis 这样的逐元素乘法时,PyTorch 会根据广播规则(broadcasting rules)来扩展 pos_cis 使其匹配 xq_ 的形状。

广播细节
  1. xq_ 形状是 (2, 10, 16)
  2. pos_cis 形状是 (1, 10, 16)

根据广播规则:

  • 1 可以扩展成 2(沿 batch 维度),变成 (2, 10, 16)
  • 1016 维度匹配,不需要改变

所以 pos_cis 自动扩展(2, 10, 16),然后进行逐元素相乘,最终结果形状仍然是 (2, 10, 16)

如果形状不兼容

如果 pos_cis(10, 16),它的 shape 实际上是 (10, 16) → (1, 10, 16),仍然可以广播。
但是如果 pos_cis(10, 1), 它会变成 (1, 10, 1),那么广播到 (2, 10, 16)1 会扩展到 16,仍然能进行逐元素运算。
如果 pos_cis(3, 10, 16),和 xq_2 不匹配,就会报错。

总结

pos_cis 形状 (1, 10, 16) 通过广播扩展到 (2, 10, 16),因此 xq_ * pos_cis 的结果仍然是 (2, 10, 16)

http://www.dtcms.com/wzjs/288225.html

相关文章:

  • 网站被篡改处理如何做网站推广的策略
  • 新塘17网站一起做网店官网软文的概念是什么
  • 专业3合1网站建设公司如何开发网站
  • 深圳做营销网站公司网站快速排名优化哪家好
  • 如何实现网站的纯静态化互联网销售包括哪些
  • 网站建设与运营推广的回报材料今日头条新闻
  • 海外网站测速广点通广告投放平台登录
  • 聚商网络营销公司服务内容seo免费诊断联系方式
  • 做网站优化有什么方法windows优化大师好不好
  • 视觉差网站制作百度投诉中心电话
  • 视频直播系统源码十堰seo排名公司
  • 全网营销型网站建设模板seo顾问阿亮
  • 做下载网站好不好做汕头seo全网营销
  • 深圳民治做网站腾讯企业qq
  • 江阴响应式网站开发公关公司一般收费标准
  • 平台网站如何做推广方案设计郴州网络推广外包公司
  • 邢台网站优化定制如何策划一个营销方案
  • 成都网站建设多少费用网络营销的六大特征
  • 华艺网站开发微信朋友圈广告投放
  • 网站统计工具有哪些谷歌外贸平台推广需要多少钱
  • 做网站的图哪来百度文库官网首页
  • 做汽配的都上什么网站营销技巧和营销方法
  • 熊岳网站怎么做自己做网络推广怎么做
  • 帮别人做网站市场价石家庄新闻网头条新闻
  • 建设网站需要什么设施建网站seo
  • 做内贸哪个网站找客户软文推广300字
  • 网站开发合作协议合同范本福建百度开户
  • 重庆观音桥房价东莞关键词优化推广
  • 电子商务网站建设与实践5118关键词工具
  • 智能建站系统哪个好百度seo排名优化如何