第4节 大模型推理内存与计算优化
前言
大模型推理过程中,内存占用和计算效率是两大核心瓶颈。即使通过并行策略实现了多设备协同,仍需针对性优化内存使用和计算流程,才能充分发挥硬件性能。本节从基础概念入手,逐步介绍内存管理、量化、算子优化等关键技术,帮助理解如何在有限资源下提升推理效率。
一、KV缓存:推理中的“内存黑洞”
1. 什么是KV缓存?
在Transformer模型生成token的过程中(尤其是 autoregressive 生成,即逐词生成),每一步都需要用到前序所有token的Key(K) 和Value(V) 张量(来自注意力层计算)。如果每次生成新token都重新计算所有历史K和V,会导致计算量呈指数级增长(例如生成1000个token需要计算1+2+…+1000≈50万次注意力)。
为避免重复计算,推理系统会将每一步的K和V缓存下来,后续步骤直接复用。这部分缓存就称为KV缓存。
- 举例:生成第5个token时,只需计算第5个token的Query(Q),然后与缓存的前4个token的K、V计算注意力,无需重新计算前4个token的K、V。
2. KV缓存的内存问题
KV缓存的内存占用随序列长度(生成的token数)线性增长,计算公式为:
KV缓存总大小 = 序列长度 × 隐藏层维度 × 头数 × 2(K和V) × 数据类型字节数
以70B模型(隐藏层维度8192,头数96,FP16精度)为例:
- 1K token:1024 × 8192 × 96 × 2 × 2字节 ≈ 31GB
- 32K token:32768 × 8192 × 96 × 2 × 2字节 ≈ 1000GB(1TB)
长序列推理时,KV缓存的内存占用会远超模型参数本身(70B模型FP16参数仅140GB),成为真正的“内存黑洞”。
二、KV缓存的分布式管理
1. 块化存储:像管理硬盘一样管理缓存
如果将KV缓存视为一整块连续内存,长序列会导致内存碎片(例如128K token的缓存可能被分成多段零散空间)。解决办法是块化存储:
-
将KV缓存按固定大小(如16/64 token)分成“块”,每个块独立管理;
-
用元数据记录块的位置(哪个设备)、所属序列、有效token数等信息。
-
举例:64 token/块时,128K token的序列会被分成2000个块,每个块独立存储,减少碎片。
2. 跨设备共享:分布式缓存池
单机内存无法容纳长序列KV缓存时,需要将块分散到多设备(GPU/CPU),形成分布式缓存池:
-
全局块索引表:记录每个块所在的设备,类似“地址簿”;
-
按需访问:计算时通过索引表找到块的位置,跨设备读取(如GPU0需要某块时,从GPU1或CPU内存中读取)。
-
优势:突破单设备内存限制,支持128K+ token长序列。
3. 动态分配与驱逐策略
当缓存池满了,新序列需要空间时,需“驱逐”旧块:
- 贪心策略:优先保留连续块(减少碎片);
- 注意力权重策略:驱逐低注意力分数的块(对当前生成影响小);
- 超时策略:驱逐长时间未使用的块(如30秒无访问)。
4. 代码示例:简单的块化KV缓存管理
class KVCacheBlock:def __init__(self, block_size=64, dtype=torch.float16):self.block_size = block_size # 每块64 tokenself.data = torch.zeros((block_size, 8192, 96), # (token数, 隐藏层维度, 头数)dtype=dtype,device="cuda")self.used = 0 # 已使用的token数(如30表示用了30个位置)class KVCacheManager:def __init__(self, max_blocks=1000):self.blocks = [KVCacheBlock() for _ in range(max_blocks)] # 缓存池self.free_blocks = list(range(max_blocks)) # 空闲块索引self.seq_blocks = {} # 记录序列占用的块:{seq_id: [block_idx1, block_idx2]}def allocate(self, seq_id, need_tokens):"""为序列分配块(需容纳need_tokens个token)"""need_blocks = (need_tokens + 63) // 64 # 向上取整(如65 token需2块)if len(self.free_blocks) < need_blocks:self.evict(need_blocks - len(self.free_blocks)) # 驱逐旧块alloc_idx = self.free_blocks[:need_blocks]self.free_blocks = self.free_blocks[need_blocks:]self.seq_blocks[seq_id] = alloc_idxreturn alloc_idxdef evict(self, num):"""驱逐num个块(简单起见,驱逐最早分配的)"""for _