《CAFE: Learning to Condense Dataset by Aligning Features》
通过对齐特征的数据集浓缩技术 -> CVPR2022
最先进的方法很大程度上依赖于通过匹配真实数据批次和合成数据批次之间的梯度来学习合成数据。尽管有直观的动机和有希望的结果,但这种基于梯度的方法本质上很容易过度拟合产生主导梯度的有偏差的样本集,因此缺乏对数据分布的全局监督。
在本文中,我们提出了一种通过对齐特征(CAFE)来压缩数据集的新颖方案,该方案明确地尝试保留真实特征分布以及所得合成集的判别能力,从而对各种架构具有强大的泛化能力。
使用梯度匹配存在哪些问题?
梯度匹配方法有两个潜在的问题。首先,由于深度神经网络的记忆效应,只有少量的硬样本或噪声在网络参数上产生主导梯度。因此,梯度匹配可能会忽略那些有代表性但简单的样本,而对那些困难的样本或噪声过度拟合。其次,这些产生大梯度的困难示例在不同的架构中可能会有所不同;因此,仅依赖梯度会对未见过的架构产生较差的泛化性能。
为了超越学习偏差并更好地捕获整个数据集分布,在本文中,我们提出了一种通过对齐特征来压缩数据集的新策略,称为 CAFE。
我们通过应用分布级监督来解释合成数据集和真实数据集之间的分布一致性。我们的方法通过匹配涉及所有中间层的特征,将注意力扩展到所有样本,从而提供更全面的分布表征,同时避免对硬样本或噪声样本的过度拟合。
方法
方法分为:逐层特征对齐模块、判别损失和动态双层优化模块
1. 逐层特征对齐模块
之前基于梯度的数据浓缩,产生了具有大梯度的样本,但是这些样本无法捕捉原始数据集的分布,更多是原始数据集中处于边缘的样本。因此对于没有见过的架构可能会有较差的表现。
为了解决这个问题,作者设计了Category-Wise Feature Averaging(CWFA)来解决这个问题。测量每个卷积层在原始数据集和浓缩数据集上的特征差异。
这一部分基本就是DM的拓展,DM就拿输出的前一层当做embding结果,这里直接拿每一层了。
真实数据输出的每一层的每一批输出,取平均值,生成数据输出的每一层的每一批输出,取平均,算一个损失,每一层损失加起来,Over~
这就是第一个损失:
注意,不包含最后一层输出层。
2. 判别损失
虽然层间特征对齐可以捕捉数据分布,但可能忽略判别能力。因此,作者引入了判别损失,确保合成数据能够有效区分不同类别。具体步骤如下:
- 特征中心计算:对合成数据的每类特征进行平均,得到特征中心:
F ˉ S L = [ f ˉ S 1 , L , f ˉ S 2 , L , … , f ˉ S K , L ] \bar{F}_S^L = [\bar{f}_S^{1,L}, \bar{f}_S^{2,L}, \dots, \bar{f}_S^{K,L}] FˉSL=[fˉS1,L,fˉS2,L,…,fˉSK,L]
其中, f ˉ S k , L \bar{f}_S^{k,L} fˉSk,L 是第 k k k类合成数据在最后一层的特征中心。 - 分类:将真实数据的特征 ( F_T^L ) 与合成数据的特征中心 ( \bar{F}_S^L ) 进行内积运算,得到分类结果:
O = F T L ⋅ ( F ˉ S L ) T O = F_T^L \cdot (\bar{F}_S^L)^T O=FTL⋅(FˉSL)T
其中, O ∈ R N ′ × K O \in \mathbb{R}^{N' \times K} O∈RN′×K 是分类得分, N ′ = K × N N' = K \times N N′=K×N 是真实数据的总数。 - 损失计算:计算交叉熵损失:
L d = − 1 N ′ ∑ i = 1 N ′ log p i L_d = -\frac{1}{N'} \sum_{i=1}^{N'} \log p_i Ld=−N′1i=1∑N′logpi
其中, p i p_i pi是真实数据 i i i 的预测概率,通过 softmax 函数计算得到。
3. 动态双层优化模块
CAFE 采用双层优化框架,交替更新合成数据和网络参数。具体步骤如下:
- 外层优化:更新合成数据
S
S
S,通过最小化总损失
L
total
=
L
f
+
β
L
d
L_{\text{total}} = L_f + \beta L_d
Ltotal=Lf+βLd:
S ← arg min S L total S \leftarrow \arg \min_S L_{\text{total}} S←argSminLtotal - 内层优化:更新网络参数
θ
\theta
θ,通过最小化合成数据上的交叉熵损失:
θ ← arg min θ J ( S , θ ) \theta \leftarrow \arg \min_\theta J(S, \theta) θ←argθminJ(S,θ)
其中, J ( S , θ ) J(S, \theta) J(S,θ) 是合成数据上的分类损失。