当前位置: 首页 > news >正文

Gemma2DecoderLayer 解析:Pre-FFW 和 Post-FFW LayerNorm 的作用

Gemma2DecoderLayer 解析:Pre-FFW 和 Post-FFW LayerNorm 的作用

1. 引言

在大规模 Transformer 模型(如 LLaMA、Gemma)中,层归一化(Layer Normalization,LN)是确保模型稳定性和训练收敛性的关键技术。Gemma2DecoderLayer 在前馈网络(Feedforward Network, FFW)部分引入了前归一化(Pre-FFW LayerNorm)后归一化(Post-FFW LayerNorm) 两种方式,以进一步优化训练和推理的稳定性。

本文将深入分析:

  • 什么是 Pre-FFW LayerNormPost-FFW LayerNorm
  • 为什么需要这两种归一化
  • 它们如何影响 Transformer 计算
  • Gemma2DecoderLayer 代码中的具体实现

2. Transformer 层结构

在标准 Transformer 解码器(Decoder Layer)中,每一层由自注意力(Self-Attention)和前馈网络(Feedforward, FFW)组成:

+------------------+
| Self-Attention  |
+------------------+
       ↓
+------------------+
| Feedforward (FFW) |
+------------------+

为了提高训练稳定性,通常会在自注意力和前馈网络的输入或输出处添加层归一化(LayerNorm),以确保分布稳定。

Gemma2DecoderLayer 中,FFW 归一化分为两种:

  1. Pre-FFW LayerNorm(前归一化)
  2. Post-FFW LayerNorm(后归一化)

3. 什么是 Pre-FFW LayerNorm 和 Post-FFW LayerNorm?

3.1 Pre-FFW LayerNorm(前归一化)

前归一化的思想是:在进入前馈网络(FFW)之前先进行层归一化,以稳定输入数据的分布。
Norm ( X ) = X − μ σ + ϵ \text{Norm}(X) = \frac{X - \mu}{\sigma + \epsilon} Norm(X)=σ+ϵXμ
FFW ( Norm ( X ) ) \text{FFW}(\text{Norm}(X)) FFW(Norm(X))
在代码中:

if self.pre_feedforward_layernorm is not None:
    hidden_states = self.pre_feedforward_layernorm(hidden_states)

作用:

  • 使 FFW 层的输入具有更稳定的分布,避免梯度爆炸或梯度消失。
  • Transformer 预归一化结构(Pre-Norm Transformer) 中常用,如 GPT-3 和 LLaMA。

3.2 Post-FFW LayerNorm(后归一化)

后归一化的思路是:在 FFW 计算完成后进行层归一化,确保前馈网络的输出分布稳定。
FFW ( X ) → Norm ( FFW ( X ) ) \text{FFW}(X) \rightarrow \text{Norm}(\text{FFW}(X)) FFW(X)Norm(FFW(X))
在代码中:

if self.post_feedforward_layernorm is not None:
    hidden_states = self.post_feedforward_layernorm(hidden_states)

作用:

  • 让前馈网络输出分布更稳定,使后续层的输入更具一致性。
  • 标准 Transformer(Post-Norm Transformer) 结构中常见,如原始的 BERT。

4. 为什么需要 Pre-FFW 和 Post-FFW?

Transformer 训练过程中,如果 LayerNorm 放置不当,可能会导致:

  • 梯度爆炸或梯度消失
  • 收敛速度变慢
  • 长文本任务中不稳定

4.1 为什么要使用 Pre-FFW LayerNorm?

在大规模 Transformer(如 GPT-3、LLaMA)中,Pre-LN(Pre-Norm Transformer) 比标准 Transformer 更稳定,因为:

  • 标准 Transformer(Post-LN) 在深度增加时,容易出现梯度消失问题,导致训练难以收敛。
  • Pre-LN 先归一化输入,使梯度更稳定,能更快收敛。

4.2 为什么还要 Post-FFW LayerNorm?

有些架构仍然保留 Post-FFW LayerNorm,原因:

  • Post-LN 可以让 FFW 的输出分布更稳定,避免梯度抖动。
  • 在推理阶段,Post-LN 可能有更好的表现,特别是在处理长文本时。

5. 代码解析

5.1 Gemma2DecoderLayer 的 LayerNorm 结构

self.pre_feedforward_layernorm = (
    RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    if config.use_pre_ffw_norm
    else None
)
self.post_feedforward_layernorm = (
    RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    if config.use_post_ffw_norm
    else None
)
  • 如果 use_pre_ffw_norm=True,则启用 Pre-FFW LayerNorm
  • 如果 use_post_ffw_norm=True,则启用 Post-FFW LayerNorm
  • 可以灵活配置是否使用 Pre-LN 或 Post-LN。

5.2 前馈网络部分

residual = hidden_states
if self.pre_feedforward_layernorm is not None:
    hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.post_feedforward_layernorm is not None:
    hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
  • 先检查是否有 Pre-FFW LN,如果有,归一化 hidden_states
  • 进入 MLP 前馈网络
  • 检查是否有 Post-FFW LN,如果有,再归一化 hidden_states
  • 残差连接(Residual Connection) 保持信息流稳定。

6. Pre-LN 和 Post-LN 的对比

归一化方式计算公式优点缺点使用场景
Pre-FFW LayerNormNorm(X) -> FFW(X)稳定梯度,收敛快影响表达能力GPT-3, LLaMA
Post-FFW LayerNormFFW(X) -> Norm(X)保持分布一致性可能导致深度梯度消失BERT, T5

7. 总结

  • Transformer 计算中 LayerNorm 影响模型稳定性和训练收敛速度。
  • Pre-FFW LayerNorm 先归一化输入,适用于深层网络,避免梯度消失(Pre-Norm Transformer)。
  • Post-FFW LayerNorm 归一化输出,保持输出分布稳定,适用于推理任务
  • Gemma2DecoderLayer 结合 Pre-FFW 和 Post-FFW,提供更灵活的归一化方式,可以根据不同任务需求调整归一化策略。

🚀 理解 LayerNorm 的作用,对于优化 Transformer 训练至关重要!

附录

源代码:

class Gemma2DecoderLayer(nn.Module):
    def __init__(
        self,
        config: gemma_config.GemmaConfig,
        attn_type: gemma_config.AttentionType,
    ):
        super().__init__()
        self.self_attn = GemmaAttention(
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            attn_logit_softcapping=config.attn_logit_softcapping,
            query_pre_attn_scalar=config.query_pre_attn_scalar,
            head_dim=config.head_dim,
            quant=config.quant,
            attn_type=attn_type,
            sliding_window_size=config.sliding_window_size,
        )
        self.mlp = GemmaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant=config.quant,
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
        self.pre_feedforward_layernorm = (
            RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            if config.use_pre_ffw_norm
            else None
        )
        self.post_feedforward_layernorm = (
            RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            if config.use_post_ffw_norm
            else None
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_cis: torch.Tensor,
        kv_write_indices: torch.Tensor,
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
        mask: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            freqs_cis=freqs_cis,
            kv_write_indices=kv_write_indices,
            kv_cache=kv_cache,
            mask=mask,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        # MLP
        residual = hidden_states
        if self.pre_feedforward_layernorm is not None:
            hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        if self.post_feedforward_layernorm is not None:
            hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

后记

2025年2月24日16点26分于上海,在GPT 4o大模型辅助下完成。

相关文章:

  • 【论文笔记-ECCV 2024】AnyControl:使用文本到图像生成的多功能控件创建您的艺术作品
  • VSCode+PlatformIO报错 找不到头文件
  • Zabbix告警分析新纪元:本地DeepSeek大模型实现智能化告警分析
  • 深度学习-133-LangGraph之应用实例(二)使用面向过程和面向对象的两种编程方式构建带记忆的聊天机器人
  • C#问题解决方案 --- 生成软件hash,生成文件hash
  • git merge -s ours ...的使用方法
  • 数据安全_笔记系列10:数据分类分级与保护策略详解
  • threejs:射线拾取封装
  • 计算机毕业设计 ——jspssm518Springboot 的影视影院订票选座管理系统
  • unity使用PICO Neo3开发,XR环境配置
  • 异常(2)
  • Java高频面试之SE-23
  • 27.[前端开发-JavaScript基础]Day04-函数基本使用-递归-变量作用域-函数式编程
  • 结构型模式 - 代理模式 (Proxy Pattern)
  • 利用python进行数据分析(重点、易忘点)---第八章数据规整:聚合、合并和重塑
  • Linux查看和处理文件内容
  • 【网络编程】网络套接字和使用案例
  • 数学与计算生物学:生物系统的数学建模
  • vs code默认主题修改配置
  • 大白话JavaScript如何深拷贝一个对象或数组?JSON.parse (JSON.stringify ()) 这种方法有什么局限性?
  • 网站建设百度推广开户/seo关键词排名优化软件
  • 政府行业网站建设方案/网站产品怎么优化
  • 做公众号首图的设计网站/seo去哪学
  • 旅游 网站建设目标/广告服务平台
  • 网站建设具体需求/免费b站推广网站
  • pjax wordpress/软件网站关键词优化