混合精度参数说明及数据集相关参数(121)
混合精度参数说明
使用混合精度的目的是加快计算速度,但这可能会导致占用更多GPU内存;不过,这并不一定会导致峰值内存使用量增加(峰值内存过高会引发内存不足(OOM)错误)。你可以通过以下两个参数之一轻松配置混合精度。
- fp16:当
torch.cuda.is_bf16_supported()
(用于检测GPU是否支持BF16精度)的返回结果为False
(即GPU不支持BF16)时,会将FP16(半精度浮点数)作为计算精度。 - bf16:当
torch.cuda.is_bf16_supported()
的返回结果为True
(即GPU支持BF16)时,会将BF16(脑半精度浮点数)作为计算精度。
你也可以将这两个参数均设为False
(此为默认设置),此时模型会以全精度(FP32,单精度浮点数)进行计算。
值得注意的是,使用BF16进行混合精度计算可能会节省内存——因为训练器(trainer)会将多个网络层的数据类型转换为BF16(与往常一样,层归一化(layer norms)层是明显的例外,不会被转换)。当bf16
参数设为True
时,训练器类(trainer class)会调用下方的函数