检查模型配置参数
1.检查模型配置参数
print(model.config._attn_implementation) # 应输出"flash_attention_2"
- 验证CUDA设备状态
assert next(model.parameters()).is_cuda, "模型必须加载到CUDA设备"
- 查看安装日志
安装时若出现以下提示表示成功:
Successfully installed flash-attn-2.5.8
- 性能基准测试
对比启用前后的推理速度:
# 标准注意力
%%timeit
model.generate(inputs, max_new_tokens=200) # 假设耗时3.2秒
# Flash Attention 2
%%timeit
model.generate(inputs, max_new_tokens=200) # 应缩短至约1.1秒
- 检查注意力层类型
print(type(model.model.layers.self_attn))
# 正确应显示FlashAttention2层:<class 'transformers.models.llama.modeling_llama.LlamaFlashAttention2'>
- 监控显存占用
启用后长序列(4096 tokens)显存消耗应降低约40%