当前位置: 首页 > news >正文

LSTM、GRU 与 Transformer网络模型参数计算

参数计算公式对比

模型类型参数计算公式关键组成部分
LSTM4 × (embed_dim × hidden_size + hidden_size² + hidden_size)4个门控结构
GRU3 × (embed_dim × hidden_size + hidden_size² + hidden_size)3个门控结构
Transformer (Encoder)12 × embed_dim² + 9 × embed_dim × ff_dim + 14 × embed_dim多头注意力 + FFN
Transformer (Decoder)14 × embed_dim² + 9 × embed_dim × ff_dim + 15 × embed_dim多头注意力 + FFN + 掩码注意力

详细参数计算解析

1. LSTM 参数计算

LSTM 单元包含 4 个门控结构(输入门、遗忘门、候选单元、输出门)

Python

LSTM_params = 4 × (input_size × hidden_size +   # Wi, Wf, Wc, Wohidden_size × hidden_size +  # Ui, Uf, Uc, Uohidden_size)                 # bi, bf, bc, bo

简化公式LSTM_params ≈ 4 × hidden_size × (input_size + hidden_size + 1)

2. GRU 参数计算

GRU 单元包含 3 个门控结构(更新门、重置门、候选门)

GRU_params = 3 × (input_size × hidden_size +   # Wz, Wr, Whhidden_size × hidden_size +   # Uz, Ur, Uhhidden_size)                 # bz, br, bh

简化公式GRU_params ≈ 3 × hidden_size × (input_size + hidden_size + 1)

3. Transformer 参数计算

Transformer 由多层堆叠,每层包含:

  • 多头注意力机制(Multi-Head Attention)
  • 前馈神经网络(Feed-Forward Network)
  • 层归一化(LayerNorm)
  • 残差连接(Skip Connections)
单层参数分解:
# 多头注意力层
QKV_proj = 3 × embed_dim × embed_dim  # Wq, Wk, Wv
output_proj = embed_dim × embed_dim   # Wo
attention_params = 4 × embed_dim²# 前馈神经网络
FFN_params = 2 × (embed_dim × ff_dim + ff_dim × embed_dim) + (ff_dim + embed_dim)= 2 × embed_dim × ff_dim + 2 × ff_dim × embed_dim + ff_dim + embed_dim= 4 × embed_dim × ff_dim + ff_dim + embed_dim# 层归一化 (2个)
LayerNorm_params = 2 × 2 × embed_dim  # 每个LN有gamma和beta参数# 总单层参数
Encoder_layer = attention_params + FFN_params + LayerNorm_params= 4×embed_dim² + (4×embed_dim×ff_dim + ff_dim + embed_dim) + 4×embed_dim

完整 Transformer 参数公式

对于 N 层 Transformer:

其中:

  • d = embed_dim (嵌入维度)
  • d_ff = ff_dim (前馈网络隐藏层维度)
  • Embedding = vocab_size × embed_dim (词嵌入参数)

参数对比示例

假设配置:

  • 嵌入维度 (embed_dim) = 512
  • 隐藏层维度 (hidden_size) = 512
  • FFN 维度 (ff_dim) = 2048
  • 词表大小 (vocab_size) = 50000
  • LSTM/GRU 层数 = 1
  • Transformer 层数 = 6

参数计算结果:

模型参数计算总量占比
LSTM4 × (512×512 + 512² + 512) = 4×(262,144 + 262,144 + 512) = 2,100,3522.10M基准
GRU3 × (512×512 + 512² + 512) = 3×(262,144 + 262,144 + 512) = 1,574,4001.57M75%
Transformer Encoder6×(4×512² + 4×512×2048 + 2048 + 5×512) + 50000×512 = 6×(1,048,576 + 4,194,304 + 2048 + 2,560) + 25,600,000 = 6×5,247,488 + 25,600,000 = **57,084,928**57.1M27.2倍
Embedding层50000×512 = 25,600,00025.6M-

参数计算工具函数

def calculate_params(model_type, embed_dim, hidden_size=None, ff_dim=None, num_layers=1, vocab_size=None):params = 0if model_type == "LSTM":# LSTM参数计算params = 4 * (embed_dim * hidden_size + hidden_size**2 + hidden_size)elif model_type == "GRU":# GRU参数计算params = 3 * (embed_dim * hidden_size + hidden_size**2 + hidden_size)elif model_type == "Transformer-Encoder":# Transformer编码器参数计算per_layer = (4 * embed_dim**2) + (4 * embed_dim * ff_dim) + ff_dim + (5 * embed_dim)encoder_params = num_layers * per_layerembedding_params = vocab_size * embed_dimparams = encoder_params + embedding_paramselif model_type == "Transformer-Decoder":# Transformer解码器参数计算per_layer = (8 * embed_dim**2) + (4 * embed_dim * ff_dim) + ff_dim + (6 * embed_dim)decoder_params = num_layers * per_layerembedding_params = vocab_size * embed_dimparams = decoder_params + embedding_paramsreturn params# 示例使用
lstm_params = calculate_params("LSTM", embed_dim=512, hidden_size=512)
transformer_params = calculate_params("Transformer-Encoder", embed_dim=512, ff_dim=2048, num_layers=6, vocab_size=50000)

相关文章:

  • 1931. 用三种不同颜色为网格涂色
  • Spring Boot 集成 Apache Kafka 实战指南
  • Java面试复习:基础、并发、JVM及框架核心考点解析
  • 云零售新中枢:定制化“开源AI智能名片+S2B2C商城小程序”驱动的沉浸式触点进化论
  • Restormer: Efficient Transformer for High-Resolution Image Restoration 论文阅读
  • ✨通义万相2.1深度解析:AI视频生成引擎FLF2V-14B全流程指南(命令行参数+模型架构+数据流)
  • systemd[1]: Failed to start LSB: Bring up/down networking
  • [Java恶补day31] 21. 合并两个有序链表
  • 设计模式 | 单例模式
  • SpringBoot中使用表单数据有效性检验
  • FPGA基础 -- Verilog 命令行参数变量(Command-line Plusargs)
  • HarmonyOS应用开发——层叠布局
  • 2025高考志愿填报核心策略
  • 使用SRv6访问5G网络中的边缘应用
  • 【原创】【3】【辅助工具】基于视觉模型+FFmpeg+MoviePy实现短视频自动化二次编辑+多赛道
  • DAY33神经网络
  • /var/lib/docker/overlay2目录过大怎么办
  • c++IO类
  • Frida Hook Android App 点击事件实战指南:从进程识别到成功注入
  • 程序包androidx.fragment.app不存在 import androidx.fragment.app
  • 购物网站 购物车界面如何做/推广赚钱app
  • 南宁住房和城乡建设委员会网站/seo综合查询平台官网
  • 做视频网站需要什么条件/考研培训班集训营
  • 芜湖学校网站建设电话/最新的疫情最新消息
  • 房地产企业网站建设/微信营销推广软件
  • 郑州做网站找赢博科技/优化网站标题和描述的方法