SNN论文阅读——Apprenticeship-Inspired Elegance
Apprenticeship-Inspired Elegance: Synergistic Knowledge Distillation Empowers Spiking Neural Networks for Efficient Single-Eye Emotion Recognition
- 训练时使用了强度帧和时间帧,经过学生网络模仿后,推理时可以只使用强度帧。学生网络十分轻量化,并且具有相当准确率。
总损失函数 (Total Loss Function)
公式 (1):
Ltotal =(1−α)∗LCls+α∗LHCKD+α∗LTCKD
\mathcal{L}_{\text {total }}=(1-\alpha) * \mathcal{L}_{\mathrm{Cls}}+\alpha * \mathcal{L}_{\mathrm{HCKD}}+\alpha * \mathcal{L}_{\mathrm{TCKD}}
Ltotal =(1−α)∗LCls+α∗LHCKD+α∗LTCKD
这个公式定义了学生网络(Student Network)训练的总目标。它是一个加权和,包含三个部分:
-
LCls\mathcal{L}_{\mathrm{Cls}}LCls (分类损失): 权重为 (1−α)(1 - α)(1−α)
- 目的:确保学生网络自己能够学会正确分类,其预测结果逼近真实标签。
- 作用:这是模型学习的基础,保证学生网络不完全依赖教师网络,而是具备基本的判别能力。
-
LHCKD\mathcal{L}_{\mathrm{HCKD}}LHCKD (命中一致性知识蒸馏损失): 权重为 ααα
- 目的:在教师预测正确的时间步上,强制学生网络的预测分布与教师网络的预测分布保持一致。
- 作用:精细化学习。让学生网络重点模仿教师网络“做对了的时候是怎么想的”,学习其最自信、最正确的推理逻辑。
-
LTCKD\mathcal{L}_{\mathrm{TCKD}}LTCKD (时间一致性知识蒸馏损失): 权重为 ααα
- 目的:在所有时间步上(无论教师预测对错),强制学生网络的预测分布与教师网络的预测分布保持一致。
- 作用:全局性学习。让学生网络全面模仿教师网络的整体行为模式,包括其不确定性、犹豫过程等,从而更好地理解任务的时序 dynamics。
-
权重参数 α\alphaα:
- 初始值为 0.50.50.5,之后每训练30个周期(epoch)增加 0.10.10.1。
- 设计理念:这是一个课程学习(Curriculum Learning) 策略。
- 训练初期:α\alphaα 较小,(1−α)(1-α)(1−α) 较大。模型更侧重于打好基础(优化 LCls\mathcal{L}_{\mathrm{Cls}}LCls),先学会基本的分类任务。
- 训练中后期:α\alphaα 逐渐增大。模型的基础分类能力已经稳定,逐渐将学习重点转移到模仿教师(优化 LHCKD\mathcal{L}_{\mathrm{HCKD}}LHCKD 和 LTCKD\mathcal{L}_{\mathrm{TCKD}}LTCKD)上来,蒸馏变得越来越重要。
分类损失 (Classification Loss)
公式 (2):
LCls=1T∑t=1TLCE(Ostu(t),y)
\mathcal{L}_{\mathrm{Cls}}=\frac{1}{T} \sum_{t=1}^{T} \mathcal{L}_{C E}\left(O_{s t u}(t), y\right)
LCls=T1t=1∑TLCE(Ostu(t),y)
这个公式回答了“如何有效地训练SNN?”这个核心问题。
- Ostu(t)O_{stu}(t)Ostu(t): 学生网络分类器在第 ttt 个时间步的突触前输入。
- 关键理解:这指的是脉冲神经元在发放脉冲之前、全连接层的输出(即膜电位 VVV 在与阈值比较之前的数值????)。这是一个连续值。(注意:突出前输入和当前膜电位的值是两个概念。在一般实现中全连接层的输出会直接输入相应的神经元,也就是突触前输入,但膜电位是突触前输入和前一刻膜电位衰减共同计算得到的,不要混淆。不看代码实现尚不清楚是哪种实现)
- LCE\mathcal{L}_{CE}LCE: 标准交叉熵损失函数。
- yyy: 真实标签的 one-hot 向量。
- TTT: 总时间步数。
工作方式:计算每一个时间步的预测结果 (Ostu(t)O_{stu}(t)Ostu(t)) 与真实标签 (yyy) 之间的交叉熵损失,然后对所有时间步的损失取平均。
为什么这样做?—— 采用 TET 解决 SNN 训练难题
论文解释了这个设计背后的深刻原因:
-
SNN的训练难题:
- 不可微性:脉冲发放函数是阶跃函数,不可微,无法直接使用梯度下降法。
- 代理梯度(SG)的局限性:虽然用代理梯度(如矩形函数、sigmoid函数的导数等)可以近似并允许反向传播,但这会使得SNN的损失 landscapes(损失曲面)与ANN的完全不同。SG方法常常无法找到最优解,导致训练精度不理想。
-
解决方案:时序高效训练(TET):
- 本文采用了 [Deng et al., 2022] 提出的 TET(Temporal Efficient Training) 方法。
- 核心思想:优化每个时间步的输出,而非常见的只优化最终时间步或平均输出的策略。
- 好处:
- 提供更多梯度:每个时间步都产生损失和梯度,为网络提供了更丰富、更频繁的优化信号,有助于补偿代理梯度带来的信息损失。
- 引导至更平坦的极小值:多时间步的优化有助于模型收敛到损失曲面中更平坦的区域(flatter minima)。这类解通常具有更好的泛化能力(generalizability),即在新数据上表现更鲁棒。
- 增强时间鲁棒性:使网络对时间序列长度的变化不敏感,具有“更稳健的时间扩展性”。
与SEEN方法的对比
论文特意强调,这与SEEN方法([Zhang et al., 2023])不同。SEEN是直接优化积分后的膜电位(即对多个时间步的膜电位求平均,如公式14:R=σ(1n∑Ot)R=\sigma(\frac{1}{n} \sum O^t)R=σ(n1∑Ot)),而本文的方法是优化每一个时间步的突触前输入 Ostu(t)O_{stu}(t)Ostu(t)。
- SEEN:先融合时间信息,再计算一次损失。梯度信号较弱。
- 本文方法(TET):每个时间步都计算损失。梯度信号更强、更密集,能更有效地指导网络参数的更新。
命中一致性知识蒸馏(Hit Consistency Knowledge Distillation, HCKD)
目标:让学生网络在预测正确的时候,其预测的置信度分布和教师网络在预测正确的时候的分布尽可能一致。
直觉:不仅要让学生学会“答对”,还要让学生学会“像学霸(教师)那样去答题”。即,学霸对正确答案非常确信时,学生也应该非常确信;学霸对正确答案只是略微确信时,学生也应该表现出类似的略微确信。这是一种精细化的模仿。
公式解析
公式 (3): 损失计算
LHCKD=LMSE(Sstu,Stea)
\mathcal{L}_{\mathrm{HCKD}} = \mathcal{L}_{\mathrm{MSE}}\left(S_{s t u}, S_{t e a}\right)
LHCKD=LMSE(Sstu,Stea)
- LMSE\mathcal{L}_{\mathrm{MSE}}LMSE:均方误差(Mean Squared Error) 损失函数。它计算两个信号之间差异的平方和,对大的差异给予更大的惩罚,能有效地拉近两个分布的距离。
- SstuS_{stu}Sstu:学生网络的“正确预测信号”。这是一个代表学生网络在正确时间步的平均预测分布的向量。
- SteaS_{tea}Stea:教师网络的“正确预测信号”。这是一个代表教师网络在正确时间步的平均预测分布的向量。
作用:这个公式的核心是最小化学生和教师在他们都做对题目时的“想法差异”。
公式 (4) & (5): 信号生成
Sstu=1Cstu∑cstu=1CstuOstu(cstu)Stea=1Ctea∑ctea=1CteaOtea(ctea)
\begin{align*}
S_{s t u} & =\frac{1}{C_{s t u}} \sum_{c_{s t u}=1}^{C_{s t u}} O_{s t u}\left(c_{s t u}\right) \\
S_{t e a} & =\frac{1}{C_{t e a}} \sum_{c_{t e a}=1}^{C_{t e a}} O_{t e a}\left(c_{t e a}\right)
\end{align*}
SstuStea=Cstu1cstu=1∑CstuOstu(cstu)=Ctea1ctea=1∑CteaOtea(ctea)
这两个公式定义了如何从一系列预测中提炼出 SstuS_{stu}Sstu 和 SteaS_{tea}Stea。
-
CstuC_{stu}Cstu / CteaC_{tea}Ctea: 在一个训练样本(一段序列)的 T 个时间步中,学生网络/教师网络预测正确的时间步的数量。
- 例如,对于一段有10个时间步的序列,学生可能在其中7个时间步预测正确,那么 Cstu=7C_{stu}=7Cstu=7。
-
Ostu(cstu)O_{stu}(c_{stu})Ostu(cstu) / Otea(ctea)O_{tea}(c_{tea})Otea(ctea): 在第 cstuc_{stu}cstu 个/第 cteac_{tea}ctea 个正确时间步上,学生网络/教师网络分类层的输出(即Softmax前的logits或之后的概率分布)。它体现了网络在该时间步的“原始想法”。
-
求和与平均:将所有正确时间步的输出 O(...)O(...)O(...) 相加并求平均。
- 为什么要求平均? 因为一段序列中可能有多个时间步都预测正确,但每个时间步的“确信度”可能不同。求平均可以得到一个综合的、代表性的“正确预测模式”。
特殊情况处理
当 Cstu/Ctea=0C_{stu} / C_{tea} = 0Cstu/Ctea=0 时:
- 情况:在一段序列中,学生网络或教师网络没有任何一个时间步预测正确。
- 处理:将其对应的信号 SstuS_{stu}Sstu 或 SteaS_{tea}Stea 的值设置为 1/Nc1/N_c1/Nc。NcN_cNc 是情绪类别的总数(例如7类)。
- 为什么这样做?:
- 1/Nc1/N_c1/Nc 代表一种均匀分布,即“完全不确定”或“随机猜测”的状态。
- 这是一种平滑处理,防止分母为零出现数学错误,并为模型提供一个有意义的学习目标(即使完全猜错,目标也是趋向于均匀分布,而不是一个无意义的零向量)。
- 它确保了损失函数的数值稳定性。
时间一致性知识蒸馏(Temporal Consistency Knowledge Distillation, TCKD)
目标:强制学生网络在每一个时间步(无论预测对错),其预测分布都与教师网络在该时间步的预测分布保持一致。
直觉:不仅要让学生学会模仿学霸(教师)答对题时的思路(HCKD),还要全程实时模仿学霸的整个思考过程。学霸在解题过程中可能有犹豫、有试探、有修正,这些动态的、时序上的模式本身也蕴含着丰富的知识。TCKD要求学生“连学霸的思考过程也要复刻”。
公式解析
公式 (6): 损失计算
LTCKD=1T∑t=1TLMSE(Ostu(t),Otea(t))
\mathcal{L}_{\mathrm{TCKD}}=\frac{1}{T} \sum_{t=1}^{T} \mathcal{L}_{\mathrm{MSE}}\left(O_{s t u}(t), O_{t e a}(t)\right)
LTCKD=T1t=1∑TLMSE(Ostu(t),Otea(t))
- TTT: 一个训练样本(一段序列)的总时间步数。
- Ostu(t)O_{stu}(t)Ostu(t): 学生网络在第t个时间步的分类器输出(突触前输入或概率分布)。
- Otea(t)O_{tea}(t)Otea(t): 教师网络在第t个时间步的分类器输出。
- LMSE\mathcal{L}_{\mathrm{MSE}}LMSE: 均方误差损失,用于计算两个网络在同一时间步的输出差异。
- 1T∑t=1T\frac{1}{T} \sum_{t=1}^{T}T1∑t=1T: 对所有T个时间步上计算出的MSE损失取平均值。
作用:这个公式的核心是最小化学生和教师在所有时间步上的“瞬时状态差异”,让学生网络的整个输出序列(轨迹)都与教师网络的序列尽可能相似。
时间一致性损失(TCKD) 的核心作用是对齐学生和教师模型的“内部动态过程”。它与HCKD一起,共同构成了论文所提出的协同知识蒸馏策略:
- HCKD 确保学生能学会教师的成功结论。
- TCKD 确保学生能学会教师得出该结论的完整思考过程。