当前位置: 首页 > news >正文

RMSNorm模块

目录

    • 代码
    • 代码解释
      • 1. 初始化方法 `__init__`
      • 2. 前向传播方法 `forward`
      • 3. 总结
      • 4. 使用场景
    • 可视化

代码

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return self.weight * (
            x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        ).type_as(x)

代码解释

这段代码定义了一个自定义的PyTorch模块 RMSNorm,用于实现Root Mean Square Normalization (RMSNorm)。RMSNorm是一种归一化技术,类似于Layer Normalization,但它只对输入进行缩放,而不进行平移(即没有偏置项)。下面是代码的详细解释:

1. 初始化方法 __init__

def __init__(self, dim: int, eps: float):
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.ones(dim))
  • dim: int: 输入特征的维度。
  • eps: float: 一个小常数,用于数值稳定性,避免除以零的情况。
  • self.weight: 一个可学习的参数,形状为 (dim,),初始化为全1的张量。这个参数用于对归一化后的输入进行缩放。

2. 前向传播方法 forward

def forward(self, x):
    return self.weight * (
        x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    ).type_as(x)
  • x: 输入张量,形状通常为 (batch_size, ..., dim)
  • x.pow(2): 对输入 x 的每个元素求平方。
  • x.pow(2).mean(-1, keepdim=True): 沿着最后一个维度(即特征维度 dim)计算平方的均值,并保持维度不变。结果形状为 (batch_size, ..., 1)
  • torch.rsqrt(...): 计算均方根的倒数(即1除以平方根),用于归一化。
  • x.float() * torch.rsqrt(...): 将输入 x 转换为浮点数后,乘以均方根的倒数,得到归一化后的结果。
  • .type_as(x): 将结果转换回与输入 x 相同的数据类型。
  • self.weight * (...): 最后,将归一化后的结果乘以可学习的权重 self.weight,进行缩放。

3. 总结

  • RMSNorm 通过对输入进行归一化,使得每个特征的均方根值为1,然后通过可学习的权重进行缩放。
  • 与LayerNorm不同,RMSNorm没有偏置项,只进行缩放操作。
  • eps 用于防止除以零的情况,增加数值稳定性。

4. 使用场景

RMSNorm通常用于深度学习模型中,特别是在Transformer架构中,作为LayerNorm的替代方案。它可以加速训练并提高模型的稳定性。

可视化

dim = 64
eps = 1e-5
m = RMSNorm(dim, eps)
x = torch.randn(32, 10, dim)  # 示例输入 (batch_size, seq_len, dim)


f = "rms_norm.onnx"  # 导出的 ONNX 文件名
torch.onnx.export(m, x, f)  # 模型  # 示例输入

https://netron.app/ 上打开 rms_norm.onnx

在这里插入图片描述

相关文章:

  • SQL-labs13-16闯关记录
  • LeetCode-Hot100-008无重复最长子串
  • 111. 二叉树的最小深度
  • ESP32之Flash操作
  • 数字人分身/123数字人/数字人直播
  • [51 单片机] --串口编程
  • 【华为OD机考】华为OD笔试真题解析(17)--打印文件
  • 2025-03-04 学习记录--C/C++-PTA 习题5-4 使用函数求素数和
  • 手动调整3DTiles倾斜模型的高度、位置、亮度
  • MWC 2025 | 紫光展锐联合移远通信推出全面支持R16特性的5G模组RG620UA-EU
  • HTML label 标签使用
  • 基于微信小程序的心理健康恢复系统+LW示例参考
  • 用DeepSeeker写小说构思 《故事大纲、主线、剧情风格》
  • 无人机遥控器无线传输技术解析!
  • 如何在随机振动分析中包括缓冲器
  • 【MySQL】与MongoDB的区别,字符集,三范式,存储引擎InnoDB、MyISAM
  • 【C++设计模式】第三篇:抽象工厂模式(Abstract Factory)
  • MySQL JOIN 与子查询深度对比:原理、性能陷阱与优化策略
  • 【C++学习篇】智能指针
  • 七、Redis 内存管理详解:模型、优化策略(LRU/LFU、对象共享)
  • 室内设计8年熬不起了/百度seo优化技术
  • 深圳有做网站公司/福州百度seo代理
  • 深圳装饰企业前50强/北京云无限优化
  • 电子商务网站软件建设的核心是/如何联系百度人工客服
  • wordpress主题d8/兰州快速seo整站优化招商
  • txt怎么做网站/3d建模培训学校哪家好