当前位置: 首页 > news >正文

对比学习 | 软标签损失计算

软标签损失计算

  1. 为什么要用软标签来计算损失?

    • 传统 one-hot 的问题:

      • 只关注唯一正确的类别,对其他类完全忽略;

      • 训练时会产生过大的梯度惩罚,导致模型过拟合或训练不稳定;

      • 无法表达“相似但不完全相同”的关系,例如:

        • 图像中“猫”和“狸花猫”;
          • 语义上“银行”和“金融机构”;
          • 图文匹配中“图片”和多个描述句子之间的相似度差异。
  2. 如何计算软标签?

    sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) # 创建一个全0的矩阵
    sim_targets.fill_diagonal_(1) # 对角线上设置为1          sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets # 构建预测标签(相似度)和one-hot标签的中和,称为软标签
    sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
    

    其中,

    • sim_targetsone-hot标签(真实标签);
    • `sim_i2t_

是相似度,对相似度进行softmax`的得到每个样本和其他样本的相似概率(预测标签);

  • alpha是预测标签的占比;
  • alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets集合真实标签和预测标签各自的占比,得到软标签,这个软标签会代替原来的真实标签进行交叉熵计算;
  1. 对比学习中如何利用软标签计算损失?

    loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() # 手动实现交叉熵loss
    loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() loss_ita = (loss_i2t+loss_t2i)/2 # 计算对比损失
    

    其中,log_softmax(sim_i2t, dim=1)作为预测概率分布,sim_i2t_targets作为目标概率分布;

http://www.dtcms.com/a/292906.html

相关文章:

  • 安科瑞工商业光储充新能源电站ACCU-100M微电网协调控制器
  • MyBatis-Plus 分页实战
  • 目前主流的AI深度学习框架对Windows和Linux的支持哪个更好
  • 单细胞转录组学+空间转录组的整合及思路
  • 一个不起眼的问题,导致插件加载失败
  • python中 tqdm ,itertuples 是什么
  • 学习软件测试的第十九天
  • ​Eyeriss 架构中的访存行为解析(腾讯元宝)
  • Java学习----Redis集群
  • SHAP的升级版:可解释性框架Alibi的相关介绍(一)
  • L1与L2正则化:核心差异全解析
  • RabbitMQ03——面试题
  • DOM/事件高级
  • haprox七层代理
  • 医院如何实现节能降耗?
  • <另一种思维:语言模型如何展现人类的时间认知>读后总结
  • 【上市公司变量测量】Python+FactSet Revere全球供应链数据库,测度供应链断裂与重构变量——丁浩员等(2024)《经济研究》复现
  • Day28| 122.买卖股票的最佳时机 II、55. 跳跃游戏、45.跳跃游戏 II、1005.K次取反后最大化的数组和
  • Spring AI Alibaba + JManus:从架构原理到生产落地的全栈实践——一篇面向 Java 架构师的 20 分钟深度阅读
  • MSTP实验
  • 深入理解 Qt 中的 QImage 与 QPixmap:底层机制、差异、优化策略全解析
  • 集训Demo5
  • 代码检测SonarQube+Git安装和规范
  • 从FDTD仿真到光学神经网络:机器学习在光子器件设计中的前沿应用工坊
  • Matlab学习笔记:界面使用
  • 【数据结构初阶】--栈和队列(二)
  • CanOpen--SDO 数据帧分析
  • vscode不识别vsix结尾的插件怎么解决?
  • sysbench对linux服务器上mysql8.0版本性能压测
  • Thinkphp8使用Jwt生成与验证Token