免训练指标(Zero-Cost Proxies)
1. 什么是免训练指标(Zero-Cost Proxies,ZC proxies)?
免训练指标是一类 无需完整训练模型即可评估其性能的度量方法,主要用于提高 神经架构搜索(NAS) 的效率。
传统 NAS 需要训练候选架构来评估其性能,但训练消耗巨大,因此免训练指标提供了一种 基于模型本身特性(如梯度、参数分布)快速估计模型质量的方法。
核心思想:
只用一个小批量数据 计算某些统计量(如梯度、参数重要性、激活值分布),从而 近似衡量模型的好坏,而不需要完整训练整个模型。
2. 免训练指标的类别
免训练指标可以大致分为两类:
- 传统结构分析指标(如 SNIP、Synflow、Fisher)
- 基于知识蒸馏的指标(如 DisWOT)
(1)传统结构分析指标
这些方法通常通过计算 梯度、权重、Hessian 矩阵 等信息来评估模型的质量。
① SNIP(Single-shot Network Pruning)
- 计算梯度的重要性,衡量每个参数对损失函数的影响:
ρ s n i p = ∣ ∂ L ∂ W ⊙ W ∣ \rho_{snip} = \left| \frac{\partial \mathcal{L}}{\partial \mathcal{W}} \odot \mathcal{W} \right| ρsnip= ∂W∂L⊙W - 核心思想:如果去掉某个权重后损失变化较大,则该权重很重要。因此,可以用梯度信息估算整个网络的质量。
② Synflow
- 通过梯度流分析,避免层塌陷(layer collapse):
ρ s y n f l o w = ∂ L ∂ W ⊙ W \rho_{synflow} = \frac{\partial \mathcal{L}}{\partial \mathcal{W}} \odot \mathcal{W} ρsynflow=∂W∂L⊙W - 核心思想:确保不同层的梯度能够均匀流动,以保持架构的稳定性。
③ Fisher
- 计算激活梯度的平方和,用于通道剪枝:
ρ f i s h e r = ( ∂ L ∂ A A ) 2 \rho_{fisher} = \left( \frac{\partial \mathcal{L}}{\partial \mathcal{A}} \mathcal{A} \right)^2 ρfisher=(∂A∂LA)2 - 核心思想:通道(Channel)如果对梯度变化敏感,则在训练时影响更大,可以用它来衡量模型质量。
(2)基于知识蒸馏的指标
DisWOT(Distillation Without Training)
-
这是一种 基于知识蒸馏的免训练指标,通过计算 教师-学生模型的特征匹配误差 来评估网络质量:
ρ D i s W O T = D L 2 ( G ( [ A S , A T ] ) ) + D L 2 ( G ( [ F S , F T ] ) ) \rho_{DisWOT} = \mathcal{D}_{L2} (\mathcal{G}([AS,AT])) + \mathcal{D}_{L2} (\mathcal{G}([FS,FT])) ρDisWOT=DL2(G([AS,AT]))+DL2(G([FS,FT])) -
其中:
- ( AS, AT ) 是教师-学生模型的 激活图(Activation Maps)
- ( FS, FT ) 是教师-学生模型的 特征图(Feature Maps)
- ( \mathcal{D}_{L2} ) 计算的是 L2 距离(欧几里得距离),衡量特征匹配误差
-
核心思想:如果一个模型可以很好地模仿教师模型的特征分布(即 L2 误差小),则这个模型的质量更好。
3. 免训练指标如何用于 NAS
在 NAS 中,免训练指标可以用于:
- 快速评估候选架构
- 在搜索空间中 筛选掉性能较差的架构,减少训练计算量。
- 结合搜索算法优化架构
- 可以将 梯度信息(SNIP, Synflow) 或 知识蒸馏误差(DisWOT) 作为搜索目标,指导 NAS 选择更优的架构。
- 设计高效的蒸馏感知 NAS(DAS)
- 结合 DAS(Distillation-aware Architecture Search),让 NAS 选择对知识蒸馏更友好的模型,提高轻量化模型的性能。