论文略读: CUT YOUR LOSSES IN LARGE-VOCABULARY LANGUAGE MODELS
ICLR 2025 oral
- 随着语言模型(LLMs)的规模不断增长,其词表规模也随之扩大
- 这导致训练过程中内存占用极度向一个层次倾斜:即交叉熵损失计算中的最后一层
- 在计算交叉熵损失时,需要构造一个logit 矩阵,其每个条目对应输入 token 与词表中每个词项之间的得分
- 对于小模型而言,这一操作所占内存甚至比整个 LLM 的其余部分还高出一个数量级
- ——>论文提出了 Cut Cross-Entropy(CCE),一种在不将完整 logits 写入全局内存的情况下计算交叉熵损失的方法
- 仅计算目标 token 的 logit,并通过“按需计算”的方式完成 log-sum-exp 操作
- 实现了一个自定义 kernel,在闪存(flash memory)中完成矩阵乘法和 log-sum-exp 的归约操作,从而使得交叉熵计算的全局内存占用几乎可以忽略不计。
- 以 Gemma 2(20 亿参数)模型为例,CCE 将损失计算的内存占用从 24 GB 降至 1 MB,将整个分类头(classifier head)在训练时的内存消耗从 28 GB 降至 1 GB。