当前位置: 首页 > news >正文

PCGrad解决多任务冲突

论文解读:"Gradient Surgery for Multi-Task Learning"

1. 论文标题直译
  • Gradient Surgery: 梯度手术
  • for Multi-Task Learning: 应用于多任务学习

合在一起就是:为多任务学习量身定制的梯度手术。这个名字非常形象地概括了它的核心思想。

2. 它要解决的核心问题:多任务学习中的“梯度冲突”

想象一下,你正在训练一个AI模型来开一辆车,它需要同时完成两个任务:

  • 任务A: 识别红绿灯(要求模型关注图像上方的颜色区域)。
  • 任务B: 保持在车道线内(要求模型关注图像下方的白色线条)。

在训练时,模型会根据任务A的错误计算出一个梯度 g_A,根据任务B的错误计算出另一个梯度 g_B。梯度本质上是告诉模型参数“应该朝哪个方向更新才能做得更好”。

问题来了: 如果某次更新中,g_A 说“参数应该向东调整”,而 g_B 恰好说“参数应该向西调整”,那么把它们简单相加(g_A + g_B)的结果可能接近于零,模型几乎学不到任何东西。

更常见的情况是,g_A 想让参数向东走,g_B 想让参数向西北走。它们的合力会是一个“折衷”的方向,这个方向可能对两个任务都不是最优的,甚至可能提升一个任务的性能却损害了另一个。

这种现象就叫做梯度冲突 (Gradient Conflict) 或 负迁移 (Negative Transfer)。这是多任务学习中一个长期存在的痛点,它会导致训练不稳定,模型性能难以提升。

3. PCGrad 的解决方案:“梯度手术”

PCGrad (Projected Gradient Descent) 提出了一种非常聪明的解决方案,就像一个外科医生一样,在更新模型参数之前,先对这些相互冲突的梯度做一次“手术”。

手术流程如下:

第1步:分别计算每个任务的梯度 和传统方法不同,它不把所有损失加起来,而是为每个任务的损失 loss_Aloss_B... 单独计算梯度 g_Ag_B...

第2步:诊断是否存在“冲突” PCGrad 遍历所有梯度对(如 g_A 和 g_B),并通过计算它们的点积 (dot product) 来判断它们是否冲突。

  • 如果 dot(g_A, g_B) > 0: 说明两个梯度的夹角小于90度,它们大方向一致,是“盟友”。无需手术
  • 如果 dot(g_A, g_B) < 0: 说明两个梯度的夹角大于90度,它们的方向是“敌对”的。诊断为冲突,需要手术!

第3步:执行“手术”——投影和矫正 当检测到 g_A 和 g_B 冲突时,PCGrad 会执行以下操作:

  1. 投影 (Project):将梯度 g_A 投影到梯度 g_B 的方向上,得到一个分量 proj_B(g_A)。这个分量可以被理解为 g_A 中与 g_B “正面冲突”的那一部分。
  2. 矫正 (Correct):从原始梯度 g_A 中减去这个冲突分量:g_{A_{new}} = g_A - proj_B(g_A)

手术效果: 经过手术后的新梯度 g_{A_{new}} 与 g_B 变成了正交的(夹角为90度)。这意味着,g_{A_{new}} 的更新方向中,已经完全剔除了与 g_B 直接对抗的部分。它只保留了对自己有益,且不伤害对方的部分。

PCGrad 会对所有发生冲突的梯度对都执行这个“手术”。

第4步:合并与更新 将所有经过“手术”矫正后的新梯度相加,得到最终的、和谐的、没有内斗的梯度,然后用这个梯度去更新模型参数。

4. TensorFlow 实现中的 PCGrad

你在代码中看到的 PCGrad 通常是一个优化器包装器 (Optimizer Wrapper)。它的用法一般是这样的:

  1. 首先,定义一个基础的优化器,比如 Adam。

    base_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
  2. 然后,用 PCGrad 包装它

    from .PCGrad import PCGrad
    optimizer = PCGrad(base_optimizer)
  3. 在训练循环中,用法会稍有不同。 你不再是计算一个总的 loss 然后调用 apply_gradients。而是:

    # 1. 分别计算每个任务的 loss
    loss_A = compute_loss_A(y_true_A, y_pred_A)
    loss_B = compute_loss_B(y_true_B, y_pred_B)
    list_of_losses = [loss_A, loss_B]# 2. PCGrad 优化器会接管梯度的计算和矫正
    # 这一步是 PCGrad 内部实现的,它会:
    #   - 为每个 loss 计算梯度
    #   - 执行梯度手术
    #   - 返回最终的梯度
    # 通常会通过一个自定义的 train_step 来实现
    final_gradients = optimizer.get_gradients(list_of_losses, model.trainable_variables)# 3. 应用经过手术后的梯度
    optimizer.apply_gradients(zip(final_gradients, model.trainable_variables))

总结

方面解释
它是什么?PCGrad 是一种优化策略,而非损失函数或模型架构。
解决什么问题?解决多任务学习中的梯度冲突 (Gradient Conflict) 问题。
核心思想?梯度手术 (Gradient Surgery):在更新模型前,先检测并消除梯度之间的冲突部分。
如何实现?通过向量投影,将冲突的梯度分量从原始梯度中移除,使它们变得正交
最终效果?1. 训练过程更稳定。 2. 避免了任务间的“内耗”,有助于所有任务性能的同步提升。

因此,当你看到代码中使用了 PCGrad,就可以立刻明白:这个项目正在处理一个多任务学习的场景,并且使用了一种相当先进的技术来确保不同任务能够“和平共处”,协同进步。


文章转载自:

http://MSW5hc0i.qrcxh.cn
http://MLK8mHpP.qrcxh.cn
http://HwcJqLjc.qrcxh.cn
http://vA4B7MKF.qrcxh.cn
http://b9OdFYRB.qrcxh.cn
http://HlpHcxb4.qrcxh.cn
http://22qJz6fW.qrcxh.cn
http://A83ptTZo.qrcxh.cn
http://gEZf06lK.qrcxh.cn
http://pajwpgfi.qrcxh.cn
http://nYsybLPE.qrcxh.cn
http://TGDbIdWA.qrcxh.cn
http://zX58eNSe.qrcxh.cn
http://tRM9tC2K.qrcxh.cn
http://oQkg4T0F.qrcxh.cn
http://MbEB3Ekw.qrcxh.cn
http://uHRqfHra.qrcxh.cn
http://ovXrVsYj.qrcxh.cn
http://Nf37IX8F.qrcxh.cn
http://pVIMUoYB.qrcxh.cn
http://8G6QWIdQ.qrcxh.cn
http://wweA54ui.qrcxh.cn
http://jMywdt3g.qrcxh.cn
http://Bt4UDEcD.qrcxh.cn
http://TvI5w4wD.qrcxh.cn
http://9GZnWoIf.qrcxh.cn
http://73r6i6VG.qrcxh.cn
http://hueynQKp.qrcxh.cn
http://VRGlB41C.qrcxh.cn
http://Lxhz4xMr.qrcxh.cn
http://www.dtcms.com/a/386643.html

相关文章:

  • 第十一章:游戏玩法和屏幕特效-Gameplay and ScreenEffects《Unity Shaders and Effets Cookbook》
  • Choerodon UI V1.6.7发布!为 H-ZERO 开发注入新动能
  • 科教共融,具创未来!节卡助力第十届浦东新区机器人创新应用及技能竞赛圆满举行
  • 食品包装 AI 视觉检测技术:原理、优势与数据应用解析
  • 【深度学习计算机视觉】05:多尺度目标检测之FPN架构详解与PyTorch实战
  • 从工业革命到人工智能:深度学习的演进与核心概念解析
  • [Emacs list使用及配置]
  • DQN在稀疏奖励中的局限性
  • 为何需要RAII——从“手动挡”到“自动挡”的进化
  • 第五课、Cocos Creator 中使用 TypeScript 基础介绍
  • 09MYSQL视图:安全高效的虚拟表
  • R 语言本身并不直接支持 Python 中 f“{series_matrix}.txt“ 这样的字符串字面量格式化(f-string)语法 glue函数
  • 【AI论文】AgentGym-RL:通过多轮强化学习训练大语言模型(LLM)智能体以实现长期决策制定
  • Win11本地jdk1.8和jdk17双版本切换运行方法
  • vue3 使用print.js打印el-table全部数据
  • Vue 3 + TypeScript + 高德地图 | 实战:多车轨迹回放(点位驱动版)
  • [vue]创建表格并实现筛选和增删改查功能
  • JVM-运行时内存
  • 后缀树跟字典树的区别
  • LanceDB向量数据库
  • RabbitMQ 异步化抗洪实战
  • 《Java集合框架核心解析》
  • 二维码生成器
  • OSI七层模型
  • 【原创·极简新视角剖析】【组局域网】设备在同一局域网的2个条件
  • 第8课:高级检索技术:HyDE与RAG-Fusion原理与DeepSeek实战
  • Windows 命令行:路径的概念,绝对路径
  • 异常检测在网络安全中的应用
  • 【ubuntu】ubuntu 22.04 虚拟机中扩容操作
  • 【数值分析】05-绪论-章节课后1-7习题及答案