对比学习 | 软标签损失计算
软标签损失计算
-
为什么要用软标签来计算损失?
-
传统 one-hot 的问题:
-
只关注唯一正确的类别,对其他类完全忽略;
-
训练时会产生过大的梯度惩罚,导致模型过拟合或训练不稳定;
-
无法表达“相似但不完全相同”的关系,例如:
- 图像中“猫”和“狸花猫”;
- 语义上“银行”和“金融机构”;
- 图文匹配中“图片”和多个描述句子之间的相似度差异。
- 图像中“猫”和“狸花猫”;
-
-
-
如何计算软标签?
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) # 创建一个全0的矩阵 sim_targets.fill_diagonal_(1) # 对角线上设置为1 sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets # 构建预测标签(相似度)和one-hot标签的中和,称为软标签 sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
其中,
sim_targets
是one-hot
标签(真实标签);- `sim_i2t_
是相似度,对相似度进行
softmax`的得到每个样本和其他样本的相似概率(预测标签);
alpha
是预测标签的占比;alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
集合真实标签和预测标签各自的占比,得到软标签,这个软标签会代替原来的真实标签进行交叉熵计算;
-
对比学习中如何利用软标签计算损失?
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() # 手动实现交叉熵loss loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() loss_ita = (loss_i2t+loss_t2i)/2 # 计算对比损失
其中,
log_softmax(sim_i2t, dim=1)
作为预测概率分布,sim_i2t_targets
作为目标概率分布;