当前位置: 首页 > news >正文

门控线性单元GLU (Gated Linear Unit)

文章目录

    • 门控线性单元GLU (Gated Linear Unit)
      • 函数表达式
      • 与 Swish 的对比
      • PyTorch 中的 GLU 实现
      • TensorFlow 中的 GLU 实现

门控线性单元GLU (Gated Linear Unit)

  • 论文

    https://arxiv.org/abs/1612.08083

  • 门控线性单元(GLU)最初在《Language Modeling with Gated Convolutional Networks》提出,设计灵感来自门控控制,通过引入门控操作来控制信息的流动,它巧妙地将线性变换门控机制结合起来,通过可学习的门控信号来控制信息流,可以看做是引入了一种动态的选择极值,以在模型中选择性地传递信息

  • GLU 的有效性来源于其直观的工作流程:

    1. 双线性变换: 首先,输入 x 会并行地进行两次独立的线性变换。一次生成“候选”输出 (xW+b),这是潜在需要传递的信息。
    2. 门控过滤: 同时,另一个线性变换 (xV+c) 的结果会通过 Sigmoid 函数生成一个介于 0 和 1 之间的门控信号。这个门像一个智能开关,决定了“候选”输出中每个维度的信息应该有多少被保留,有多少应该被抑制。
    3. 残差友好: 由于门控输出的平均值大约为 0.5,这使得 GLU 具有一种天然的“残差”特性,有助于缓解深度网络训练中的梯度消失问题

函数表达式

  • GLU函数
    GLU(x)=(xW+b)⊗σ(xV+c)\begin{aligned} \mathrm{GLU(x)}=(xW+b)\otimesσ(xV+c) \end{aligned} GLU(x)=(xW+b)σ(xV+c)

    其中

    • x∈Rdx \in \mathbb{R}^dxRd 为输入向量
    • W、V∈Rd×dW、V \in \mathbb{R}^{d \times d}WVRd×db、c∈Rdb、c \in \mathbb{R}^dbcRd 为可学习的权重矩阵与偏置向量
    • ⊗\otimes 表示逐元素乘积(哈达玛乘积)
    • σ(⋅)\sigma(\cdot)σ() 为 sigmoid 门控,将后面 xV+cxV+cxV+c 值压缩到 (0, 1) 区间,作为门控信号,决定信息通过比例

与 Swish 的对比

  • 与swish对比

    特性SwishGLU
    参数标量 β\betaβ(固定或可学习)全连接权重 WWW、偏置 bbb(可学习)
    门控方式输入自身经过 sigmoid 缩放输入经线性变换后再经 sigmoid 门控
    参数量每通道 0/1 个标量每通道 d+1d+1d+1 个参数
    计算复杂度低(一次 sigmoid)高(一次矩阵乘 + sigmoid)
    表达能力中等

PyTorch 中的 GLU 实现

  • 代码(以 nn.GLU 为例,针对通道维度切分)

    注意:使用官方的GLU函数,输出维度是减半的

    import torch
    import torch.nn as nntorch.manual_seed(1024)batch_size = 8
    seq_len = 64
    d_model = 512x = torch.randn(batch_size, seq_len, d_model)# 官方 GLU 沿指定维度将输入一分为二
    glu = nn.GLU(dim=-1)          # dim 指定切分维度
    out = glu(x)                  # 输出 [batch_size, seq_len//2, d_model]print("Input shape :", x.shape)
    print("Output shape:", out.shape)"""输出"""
    Input shape : torch.Size([8, 64, 512])
    Output shape: torch.Size([8, 64, 256])
    

    若输入通道维度(seq_len)为偶数,可直接使用 nn.GLU(dim=channel_dim),此时将输入均分两份:前一半做值、后一半做门控。

  • 自定义 GLU(任意线性映射 + 门控)

    注意:输出维度可以

    import torch
    import torch.nn as nn
    torch.manual_seed(1024)class GLU(nn.Module):def __init__(self, d_in, d_out):super().__init__()self.w1 = nn.Linear(d_in, d_out, bias=False)self.w2 = nn.Linear(d_in, d_out, bias=False)self.w3 = nn.Linear(d_out, d_in, bias=False)  # 可选:再投影回 d_indef forward(self, x):# x: [batch_size, seq_len, d_model]gate = torch.sigmoid(self.w2(x))   # [batch_size, seq_len, d_in]out  = self.w1(x) * gate           # [batch_size, seq_len, d_out]return self.w3(out)                # [batch_size, seq_len, d_in]# 使用示例
    batch_size = 8
    seq_len = 64
    d_model = 512
    d_ff = 4 * d_modelx = torch.randn(batch_size, seq_len, d_model)
    layer = GLU(d_in=d_model, d_out=d_ff)
    print(layer(x).shape)   # 维度不变"""输出"""
    torch.Size([8, 64, 512])
    

TensorFlow 中的 GLU 实现

  • 代码(tf.keras 自定义层)

    import tensorflow as tfclass GLU(tf.keras.layers.Layer):"""典型 Transformer-FFN 中的 GLU 层:GLU(x) = (x W_gate) ⊙ σ(x W_up)  再投影回 d_in,维度不变"""def __init__(self, d_in, d_out, **kwargs):super().__init__(**kwargs)self.d_in = d_inself.d_out = d_out# 两路线性映射self.w_gate = tf.keras.layers.Dense(d_out, use_bias=False)self.w_up   = tf.keras.layers.Dense(d_out, use_bias=False)self.w_down = tf.keras.layers.Dense(d_in, use_bias=False)def call(self, x):gate = tf.nn.sigmoid(self.w_gate(x))   # [batch_size, seq_len, d_out]up   = self.w_up(x)                    # [batch_size, seq_len, d_out]return self.w_down(gate * up)          # [batch_size, seq_len, d_in]# 使用示例
    batch_size = 8
    seq_len = 64
    d_model = 512
    d_ff = 4 * d_modelx = tf.random.normal([batch_size, seq_len, d_model])
    glu = GLU(d_in=d_model, d_out=d_ff)
    print(glu(x).shape)   # (4, 64, 512)  维度不变"""输出"""
    (8, 64, 512)
    

http://www.dtcms.com/a/286051.html

相关文章:

  • ApplicationContext 事件发布与监听机制详解
  • 反射机制的登录系统
  • PHP 8.0 升级到 PHP 8.1
  • 创建型模式
  • 基于 HT 的 3D 可视化智慧矿山开发实现
  • 从一开始的网络攻防(四):XSS
  • hadoop(服务器伪分布式搭建)
  • FastAdmin后台登录地址变更原理与手动修改方法-后台入口机制原理解析-优雅草卓伊凡
  • Hadoop安全机制深度剖析:Kerberos认证与HDFS ACL细粒度权限控制
  • 《Web安全之深度学习实战》读书笔记总结
  • AI赋能轮胎安全:基于YOLO11的智能裂纹检测系统
  • 基于springboot+vue+mysql的智慧社区设计与实现(源码+论文+开题报告)
  • Docker Swarm 集群使用记录
  • Matlab打开慢、加载慢的解决办法
  • 免费的一些工具收集
  • 【Oracle】centos7离线静默安装oracle11g(p13390677_112040)
  • Hive 向量化执行引擎 Vectorized Execution 常见 NPE 报错分析及解决
  • 全球天气预报5天(经纬度版)免费API接口教程
  • Python绘制数据(二)
  • JAVA面试宝典 -《微服务治理:从链路追踪到熔断》
  • 某邮生活旋转验证码识别
  • 算法竞赛备赛——【图论】求最短路径——小结
  • 前端之CSS
  • MyBatis之关联查询
  • WEB安全架构
  • Tomcat及Nginx部署使用
  • DevExpress WinForms v25.1 亮点:AI驱动的语义搜索、模板库更新
  • RPC 与 Feign 的区别笔记
  • SQLite 数据库字段类型-详细说明,数据类型详细说明。
  • 服务器mysql数据的简单备份脚本