当前位置: 首页 > news >正文

大模型参数量与计算量(FLOPs)估算方法

一、 模型参数量 (Parameters)

参数量指的是模型中所有需要学习和更新的权重 (Weights)偏置 (Biases) 的总数。它决定了模型的大小(例如存储为 .pt 文件的大小)。

1. 核心计算方法

总参数量 = 所有层参数量之和

常见层的参数量计算:

层类型计算公式解释
全连接层 (Linear / Dense)Pfc=(I+1)×OP_{fc} = (I + 1) \times OPfc=(I+1)×OI: 输入特征数,O: 输出特征数。+1 是偏置项 (bias)。
卷积层 (Conv2d)Pconv=(Kw×Kh×Cin+1)×CoutP_{conv} = (K_w \times K_h \times C_{in} + 1) \times C_{out}Pconv=(Kw×Kh×Cin+1)×CoutK_w, K_h: 卷积核宽高。
C_in: 输入通道数。
C_out: 输出通道数。
+1 是每个输出通道的偏置。
层归一化 (LayerNorm)Pln=2×CP_{ln} = 2 \times CPln=2×CC: 归一化的特征维度。对应可学习的 gammabeta 参数。
多头自注意力 (MHSA)Pmhsa=4×(E×E)×HP_{mhsa} = 4 \times (E \times E) \times HPmhsa=4×(E×E)×HE: 嵌入维度 (Embedding Dim)。
H: 头数 (Heads)。
(详见下方详解)
2. 详细示例:Transformer 参数量估算

以一个简化的 Transformer 编码块为例:

  1. 多头自注意力层 (MHSA)

    • 其参数量主要来自四个线性投影层(Q, K, V 和输出投影)。
    • 每个投影层的形状为 (E, E),其中 E 是模型嵌入维度。
    • 参数量为:4×(E×E)4 \times (E \times E)4×(E×E)
    • 如果有偏置,再加 4 * E
  2. 前馈神经网络层 (FFN)

    • 通常是两个线性层:第一个扩大维度(E -> 4E),第二个缩小回原维度(4E -> E)。
    • 参数量为:(E×4E+4E)+(4E×E+E)=8E2+5E(E \times 4E + 4E) + (4E \times E + E) = 8E^2 + 5E(E×4E+4E)+(4E×E+E)=8E2+5E
  3. 层归一化 (LayerNorm)

    • 每个 LayerNorm 层的参数量为 2 * E(gamma 和 beta)。
    • 一个块通常有 2 个 LayerNorm 层,共 4E
  4. 一个 Transformer 块的参数量
    Pblock=Pmhsa+Pffn+Pln=(4E2)+(8E2+5E)+(4E)=12E2+9EP_{block} = P_{mhsa} + P_{ffn} + P_{ln} = (4E^2) + (8E^2 + 5E) + (4E) = 12E^2 + 9EPblock=Pmhsa+Pffn+Pln=(4E2)+(8E2+5E)+(4E)=12E2+9E

  5. 整个模型参数量

    • 加上词嵌入层((Vocab_Size, E))和输出层(通常与词嵌入层共享权重)。
    • 总参数量 ≈ V×E+N×(12E2+9E)V \times E + N \times (12E^2 + 9E)V×E+N×(12E2+9E)
    • N 是 Transformer 块的层数,V 是词表大小。

举例:GPT-2 Small (124M)

  • E=768, N=12, V=50257
  • Ptotal≈50257×768+12×(12×7682+9×768)≈124MP_{total} \approx 50257 \times 768 + 12 \times (12 \times 768^2 + 9 \times 768) \approx 124MPtotal50257×768+12×(12×7682+9×768)124M

二、 计算量 (FLOPs)

FLOPs (Floating Point Operations) 是浮点运算次数,用来衡量模型执行一次前向传播所需的计算复杂度。它直接影响模型的运行速度和能耗。

注意:FLOPs 是 计算量,单位是 FLOP。而设备的计算能力是 算力,单位是 FLOPS (Floating Point Operations Per Second),即“每秒浮点运算次数”。理论耗时 = FLOPs / FLOPS

1. 核心计算方法

总FLOPs ≈ 前向传播中所有矩阵运算的FLOPs之和

常见层的 FLOPs 计算:

层类型计算公式解释
全连接层 (Linear)Ffc=B×(2×I−1)×O≈2×B×I×OF_{fc} = B \times (2 \times I - 1) \times O \\ \approx 2 \times B \times I \times OFfc=B×(2×I1)×O2×B×I×OB: 批次大小 (Batch Size)。
```2 * I * O````是核心计算量(乘加运算算作2次操作)。偏置计算量 B * O 通常可忽略。
卷积层 (Conv2d)Fconv≈2×B×Hout×Wout×Cin×Kw×Kh×CoutF_{conv} \approx 2 \times B \times H_{out} \times W_{out} \times C_{in} \times K_w \times K_h \times C_{out}Fconv2×B×Hout×Wout×Cin×Kw×Kh×CoutH_out, W_out: 输出特征图的高宽。
核心是每个输出像素点上的 K*K*C_in 次乘加运算,再乘以输出通道数 C_out
多头自注意力 (MHSA)Fmhsa≈4×B×S2×E+2×B×S×E2F_{mhsa} \approx 4 \times B \times S^2 \times E \\ + 2 \times B \times S \times E^2Fmhsa4×B×S2×E+2×B×S×E2S: 序列长度 (Sequence Length)。
(详见下方详解)
2. 详细示例:Transformer FLOPs 估算

同样以一个 Transformer 编码块为例,假设输入张量形状为 (B, S, E)

  1. 多头自注意力层 (MHSA)

    • Q, K, V 投影:三个线性层,每个的 FLOPs 为 2 * B * S * E * E。总共 6 * B * S * E^2
    • Q·K^T 矩阵乘法(B, H, S, E/H)(B, H, E/H, S) 相乘。FLOPs 为 2 * B * H * S * (E/H) * S = 2 * B * S^2 * E
    • Attention·V 矩阵乘法(B, H, S, S)(B, H, S, E/H) 相乘。FLOPs 为 2 * B * H * S * S * (E/H) = 2 * B * S^2 * E
    • 输出投影:一个线性层,FLOPs 为 2 * B * S * E * E
    • 总计Fmhsa≈8×B×S×E2+4×B×S2×EF_{mhsa} \approx 8 \times B \times S \times E^2 + 4 \times B \times S^2 \times EFmhsa8×B×S×E2+4×B×S2×E
  2. 前馈神经网络层 (FFN)

    • 两个线性层:E -> 4E4E -> E
    • FLOPs 为:2×B×S×E×4E+2×B×S×4E×E=16×B×S×E22 \times B \times S \times E \times 4E + 2 \times B \times S \times 4E \times E = 16 \times B \times S \times E^22×B×S×E×4E+2×B×S×4E×E=16×B×S×E2
  3. 层归一化 (LayerNorm)

    • 计算量远小于线性层,通常可以忽略不计。
  4. 一个 Transformer 块的 FLOPs
    Fblock≈Fmhsa+Fffn≈24×B×S×E2+4×B×S2×EF_{block} \approx F_{mhsa} + F_{ffn} \approx 24 \times B \times S \times E^2 + 4 \times B \times S^2 \times EFblockFmhsa+Fffn24×B×S×E2+4×B×S2×E

  5. 整个模型 FLOPs

    • 加上词嵌入层(可忽略)和输出层(2 * B * S * E * V,很大!)。
    • 总 FLOPs ≈ N×(24×B×S×E2+4×B×S2×E)+2×B×S×E×VN \times (24 \times B \times S \times E^2 + 4 \times B \times S^2 \times E) + 2 \times B \times S \times E \times VN×(24×B×S×E2+4×B×S2×E)+2×B×S×E×V

观察

  • FLOPs 与批次大小 B、序列长度 S 直接相关。
  • S 很大时(长文本),S2S^2S2 项会占主导,这就是为什么 Transformer 计算量随序列长度增长非常快。
  • E 很大时(大模型),E2E^2E2 项会占主导。

三、 使用工具自动计算 (推荐)

手动计算复杂且易错,强烈推荐使用现有工具库。

1. 使用 thop (Torch-OpCounter)
import torch
import torch.nn as nn
from thop import profile, clever_format# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1000, 500)self.fc2 = nn.Linear(500, 100)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = SimpleModel()
dummy_input = torch.randn(1, 1000) # (B, I)# 使用thop计算FLOPs和参数量
flops, params = profile(model, inputs=(dummy_input, ))
flops, params = clever_format([flops, params], "%.3f") # 格式化输出print(f"FLOPs: {flops}")
print(f"Params: {params}")
2. 使用 ptflops
pip install ptflops
from ptflops import get_model_complexity_infomacs, params = get_model_complexity_info(model, (1000, ), as_strings=True, print_per_layer_stat=True)
print(f"MACs: {macs}")
print(f"Params: {params}")

注意ptflops 报告的 MACs (Multiply-ACCumulate operations) 是乘加运算次数。1 MAC = 2 FLOPs。所以 FLOPs = 2 * MACs


总结与关键点

指标参数量 (Params)计算量 (FLOPs)
含义模型的大小,静态模型的计算复杂度,动态
决定因素模型结构(宽度、深度)模型结构 + 输入数据 (B, S)
影响内存占用、磁盘空间运行速度、能耗
单位M (百万), B (十亿)GFLOPs, TFLOPs
工具thop, ptflops, torchinfothop, ptflops, fvcore

重要提示

  1. FLOPs ≠ 速度:FLOPs 是理论计算量。实际速度还受内存带宽 (Memory Bandwidth)并行度硬件架构等因素严重影响。一个低FLOPs但高内存访问的模型可能比高FLOPs但计算密集的模型更慢。
  2. 估算工具的结果是近似的,不同工具的计算规则可能略有不同,但对于对比模型设计已经足够。
  3. 对于 训练 来说,FLOPs 大约是 前向传播的 3 倍(因为还有反向传播和梯度更新)。
http://www.dtcms.com/a/360307.html

相关文章:

  • [WUSTCTF2020]B@se1
  • 后向投影合成孔径辐射源定位方法(一)
  • Linux-数据库
  • MVC模式学习
  • 物种多样性与物种丰富度
  • 制造业生产线连贯性动作识别系统开发
  • 使用 Claude Code 与 Remotion 制作自定义动画视频的完整教程
  • 代码分析之符号执行技术
  • 多人协作开发指南二
  • 简化对齐训练:用明文对比数据SFT替代复杂DPO
  • 8针脚的1.8寸IIC接口的TFT彩屏的八个引脚都需要使用吗?
  • 【编号186】中国劳动统计年鉴(1991-2023)
  • LeetCode 2570.合并两个二维数组
  • 超越关键词:RAG系统如何破解用户查询的“模糊密码”
  • BLE广播与扫描
  • 嵌入式C学习笔记之预编译
  • Redis面试重点-2
  • Coze源码分析-工作空间-项目开发-前端源码
  • 在Windows系统Docker中使用wsl2、容器、windows文件路径三种不同挂载方式的区别和性能差异
  • ceph对象存储-存储池-用户认证
  • @Value注解的底层原理(一)
  • Day18 (前端:JavaScript基础阶段)
  • 数据结构 04(线性:双向链表)
  • Ansible 临时命令与常用模块实操指南
  • Cartographer中的gflag与lua文件
  • 国庆福建霞浦游
  • 阿里云创建自己的博客,部署wordpress
  • Java学习笔记-IO流(更新中...)
  • 嵌入式C学习笔记之链表
  • kkfileview自建cdn引入