网站服务器空间不足最好的建站平台
在深度学习的实践中,我们常常会遇到这样的挑战:训练过程震荡、模型性能在验证集上“忽高忽低”,或者无论如何调参,性能总是在一个瓶颈期徘徊。今天,我们将深入探讨一个简单、有效且在众多SOTA(State-of-the-Art)模型中广泛使用的技术——指数移动平均(Exponential Moving Average, EMA),它就像一个“稳定器”,能有效解决上述问题,帮助我们获得泛化能力更强的模型。
一、EMA是什么?一个直观的理解
让我们先抛开复杂的定义,从一个熟悉的场景开始:股票价格。股票的每日价格波动很大,为了看清长期趋势,分析师们会使用“移动平均线”。EMA就是其中一种,它计算一个加权平均值,在这个平均值中,越近的数据点权重越高。
在深度学习中,我们应用同样思想,但对象不是价格,而是模型的权重(parameters)。在每个训练步,优化器(如Adam)都会更新一次模型权重,这个过程同样充满了因小批量数据随机性带来的“噪声”。EMA通过对近期的一系列模型权重进行平滑,创建一个“影子模型”,其权重变化更平稳,往往能代表一个泛化性能更好的解。
其核心更新公式非常简洁:
WEMA(t)=β⋅WEMA(t−1)+(1−β)⋅Wmodel(t)W_{EMA}^{(t)} = \beta \cdot W_{EMA}^{(t-1)} + (1 - \beta) \cdot W_{model}^{(t)}WEMA(t)=β⋅WEMA(t−1)+(1−β)⋅Wmodel(t)
其中:
- W_EMA(t)W\_{EMA}^{(t)}W_EMA(t) 是第 ttt 步更新后的EMA模型权重。
- W_EMA(t−1)W\_{EMA}^{(t-1)}W_EMA(t−1) 是上一步的EMA模型权重。
- W_model(t)W\_{model}^{(t)}W_model(t) 是第 ttt 步刚被优化器更新完的原始模型权重。
- beta\\betabeta 是衰减率(decay),一个接近1的常数(如0.999),它决定了历史权重的“记忆”有多长。
二、深入代码:一个动态衰减的EMA实现
下面这段PyTorch代码实现了一个更精巧的EMA策略,它不仅应用了EMA,还实现了一个动态的衰减率,让EMA在训练初期“学习”得更快,在后期则更趋于“稳定”。
import copy
import torch
from torch.nn.modules.batchnorm import _BatchNormclass EMAModel:"""Exponential Moving Average of models weights"""def __init__(self, model, update_after_step=0, inv_gamma=1.0, power=2/3, min_value=0.0, max_value=0.9999):"""@crowsonkb's notes on EMA Warmup:If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you planto train for a million or more steps."""# 注意:实践中通常传入 copy.deepcopy(model) 来避免引用问题self.averaged_model = model self.averaged_model.eval()self.averaged_model.requires_grad_(False)self.update_after_step = update_after_stepself.inv_gamma = inv_gammaself.power = powerself.min_value = min_valueself.max_value = max_valueself.decay = 0.0self.optimization_step = 0def get_decay(self, optimization_step):"""Compute the decay factor for the exponential moving average."""step = max(0, optimization_step - self.update_after_step - 1)value = 1 - (1 + step / self.inv_gamma) ** -self.powerif step <= 0:return 0.0return max(self.min_value, min(value, self.max_value))@torch.no_grad()def step(self, new_model):self.decay = self.get_decay(self.optimization_step)for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):# 跳过BatchNorm层和非训练参数,直接复制if isinstance(module, _BatchNorm) or not param.requires_grad:ema_param.copy_(param.to(dtype=ema_param.dtype).data)else:# 应用EMA更新公式ema_param.mul_(self.decay)ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)self.optimization_step += 1
代码解析:
__init__
: 初始化一个averaged_model
,这是我们的“影子模型”。它被设为eval()
模式且不计算梯度,因为它只被动地接收更新。get_decay
: 这是这段代码的亮点。它没有使用固定的衰减率beta\\betabeta,而是根据训练步数optimization_step
动态计算。在训练初期,step
较小,decay
值也较小,EMA模型会快速跟上原始模型的步伐;随着训练深入,decay
值逐渐增大并趋近于max_value
,EMA模型变得越来越稳定。step(new_model)
: 这是核心更新函数。- 它在每次原始模型
new_model
完成梯度更新后被调用。 - 特殊处理:它会跳过BatchNorm层。因为BN层的
running_mean
和running_var
是数据集的统计量,对它们进行平滑可能会破坏其统计意义。所以选择直接复制。 - 核心更新:
ema_param.mul_(self.decay)
和ema_param.add_(..., alpha=1 - self.decay)
这两行代码完美地实现了我们前面提到的EMA公式。
- 它在每次原始模型
三、为什么EMA能提升模型性能?探究其背后原理
EMA之所以有效,根本原因在于它能帮助优化器找到更平坦的最小值(Flatter Minima)。
在深度学习的损失景观中,存在无数的局部最小值。这些最小值有“尖锐”和“平坦”之分:
- 尖锐最小值:像一个狭窄的深坑。模型在此处对训练数据拟合极好,但稍微挪动一下权重(即遇到测试数据),损失就会急剧上升。这通常是过拟合的标志。
- 平坦最小值:像一个宽阔的盆地。在此区域内,权重的小幅变动不会引起损失的剧烈变化,说明模型更具鲁棒性,泛化能力更强。
EMA通过以下方式帮助我们找到平坦的盆地:
- 平滑优化轨迹:SGD等优化器在训练中因为数据随机性而产生的“抖动”轨迹,其终点可能偶然落在一个尖锐的坑里。EMA通过对轨迹进行平均,其结果更有可能落在轨迹探索区域的中心,也就是那个宽阔盆地的中央。
- 近似模型集成(Ensemble):EMA可以被看作是一种在时间维度上的模型集成。它以极低的成本,融合了模型在训练末期多个时间点的“快照”,综合了它们的“知识”,从而得到一个更鲁棒的解。
四、EMA vs. 优化器:相辅相成的“兄弟”
一个常见的问题是:“Adam等优化器内部不也使用了动量的思想,也是一种EMA吗?它们能取代EMA的作用吗?”
答案是:思想类似,但作用对象和目的完全不同,它们是互补关系。
特性 | 优化器 (如Adam) | 模型参数EMA (EMAModel) |
---|---|---|
作用对象 | 梯度 (Gradients) | 模型参数 (Weights) |
目的 | 计算一个更优的更新方向和步长,加速和稳定训练过程 | 创建一个更平滑、更鲁棒的最终模型版本,提升泛化能力 |
一个比喻:优化器是驾驶员,它通过观察路况(梯度)和保持惯性(动量)来决定如何打方向盘和踩油门。而EMA是车上的GPS轨迹记录仪,它不负责驾驶,只负责记录车辆行驶过的路径,并计算出一条最平滑的平均路线。最终,我们想要的可能不是车辆停下的那个点,而是GPS记录的平均路线的终点。
五、实践指南:何时应该使用EMA?
推荐使用EMA的场景 ✅ | 可以暂不使用EMA的场景 ❌ |
---|---|
1. 冲击SOTA性能,追求极致泛化能力时。 | 1. 项目初期,进行快速原型验证和迭代时。 |
2. 训练过程不稳定,或使用小批量、噪声大的数据集时。 | 2. 计算/显存资源极其受限,无法容纳一个影子模型时。 |
3. 训练大模型(如Transformers)和进行长时间训练时。 | 3. 任务非常简单,模型很容易就收敛到一个好的解时。 |
4. 训练生成模型,如GANs和Diffusion Models,以稳定训练和提升质量。 |
结论
EMA是一种强大而简单有效的技术,它通过平滑模型权重,帮助我们找到泛化能力更强的“平坦最小值”。它并非训练的必需品,而更像是一个高性能增强包。在你的下一个项目中,当基线模型已经稳定,并希望进一步提升其性能时,不妨引入EMA,它很可能会给你带来意想不到的惊喜。