当前位置: 首页 > news >正文

Diffusion模型中时间t嵌入的方法

Diffusion模型中时间t嵌入的方法

class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

我们用 dim=128x=[10, 12, 16, 100] 来具体计算 PositionalEmbedding 的输出。


1. 设定参数

  • dim=128,意味着嵌入向量的维度是 128。
  • half_dim = dim // 2 = 64,所以我们需要计算 64 个频率因子的正弦和余弦值。
  • x = [10, 12, 16, 100] 是输入值。

2. 计算频率因子

emb = math.log(10000) / half_dim  # 计算缩放因子
emb = torch.exp(torch.arange(half_dim) * -emb)  # 生成 64 维的指数频率因子
  • math.log(10000) ≈ 9.2103
  • emb = torch.exp(torch.arange(64) * (-9.2103 / 64))
  • torch.arange(64) 生成 [0, 1, 2, ..., 63],然后乘以 -emb,再计算指数 exp,得到 64 个递减的频率因子。

3. 计算外积

emb = torch.outer(x * self.scale, emb)
  • 计算 x * self.scale,如果 scale=1.0,那么 x 仍然是 [10, 12, 16, 100]
  • emb 是一个 4 × 64 的矩阵,每一行表示 x[i] 乘以 emb 里的每个频率因子。

假设 emb(频率因子)前 5 个数是:

[1.0000, 0.9120, 0.8318, 0.7586, 0.6918, ...]

那么 x=10 这一行计算结果是:

[10.0000, 9.1200, 8.3180, 7.5860, 6.9180, ...]

4. 计算正弦和余弦

emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
  • 先对 embsin,然后取 cos,最后拼接,得到 4 × 128 的矩阵。

如果 sin(10.0000) ≈ -0.5440cos(10.0000) ≈ -0.8391,那么 x=10 这一行最终变成:

[-0.5440, 0.4120, 0.9890, 0.9870, -0.9912, ... | -0.8391, 0.9111, -0.1479, 0.1603, -0.1321, ...]

其中,前 64 维是 sin 计算结果,后 64 维是 cos 计算结果。


5. 最终输出

如果 x = [10, 12, 16, 100],输出 emb4 × 128 的矩阵:

tensor([
    [-0.5440,  0.4120,  0.9890,  0.9870, -0.9912, ..., -0.8391,  0.9111, -0.1479,  0.1603, -0.1321, ...],
    [-0.5366,  0.4576,  0.9941,  0.9891, -0.9954, ..., -0.8437,  0.9005, -0.1085,  0.1521, -0.1242, ...],
    [-0.5215,  0.5321,  0.9971,  0.9922, -0.9986, ..., -0.8524,  0.8804, -0.0563,  0.1423, -0.1113, ...],
    [-0.5064,  0.8658, -0.9813,  0.9989,  0.9924, ..., -0.8849,  0.7912,  0.1951,  0.0234, -0.9811, ...],
])
  • 每一行对应输入 x 的一个数的 128 维位置编码。
  • 其中前 64 维是 sin(x * 频率),后 64 维是 cos(x * 频率)
  • x=100 时,周期性更明显,因为 sincos 是周期函数,大的 x 会导致编码的模式周期性更强。

6. 总结

  • 这个位置编码会为 x 生成一个 128 维的向量,每个维度都由 sincos 计算得到。
  • x 变大时,周期性更明显。
  • 适用于 Transformer 或其他模型,以在输入数据中添加位置信息,使模型能够区分不同位置的输入数据。

相关文章:

  • Teaching Small Language Models Reasoning throughCounterfactual Distillation
  • Hive-07之企业级调优
  • Docker部署MySQL
  • Python:简单的爬虫程序,从web页面爬取图片与标题并保存MySQL
  • 大模型核心要素完全解析:从数字神经元到智能对话的奥秘
  • go语言数据类型
  • (50)[HGAME 2023 week2]before_main
  • AutoGen学习笔记系列(六)Tutorial - Termination
  • unittest框架 核心知识的系统复习及与pytest的对比
  • uniapp x 学习之 uts 语言快速入门
  • 【Embedding】何为Embedding?
  • 筑牢网络安全防线:守护您的数据安全
  • 单体架构、集群、分布式、微服务的区别!
  • Redis设计与实现-数据结构
  • Selenium遇到Exception自动截图
  • 【大模型学习】第八章 深入理解机器学习技术细节
  • 【前端】【vue-i18n】安装和使用全解
  • Redis Stream
  • Ubuntu20.04 在离线机器上安装 NVIDIA Container Toolkit
  • [项目]基于FreeRTOS的STM32四轴飞行器: 三.电源控制
  • 国税总局上海市税务局回应刘晓庆被举报涉嫌偷漏税:正依法依规办理
  • 缅甸内观冥想的历史漂流:从“人民鸦片”到东方灵修
  • 为何选择上海?两家外企提到营商环境、人才资源……
  • 沙青青评《通勤梦魇》︱“人机组合”的通勤之路
  • 专访|茸主:杀回UFC,只为给自己一个交代
  • 日本广岛大学一处拆迁工地发现疑似未爆弹