当前位置: 首页 > news >正文

【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作

【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作

文章目录

  • 【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作
    • 1 引言
    • 2 Top-k的“微分困境”
    • 3 核心思想:当Top-k遇上最优传输(OT)
      • 3.1 将Top-k问题参数化为OT问题
      • 3.2 通过熵正则化实现平滑
    • 4 SOFT Top-k 算子详解
    • 5 应用与效果
    • 6 总结
    • 参考文献

1 引言

在机器学习的世界里,top-k 操作无处不在。无论是在推荐系统中筛选出用户最可能喜欢的 k 个商品,在 k-NN 算法中寻找 k 个最近的邻居 ,还是在自然语言处理的束搜索(Beam Search)中保留 k 个最有可能的序列,top-k 都扮演着至关重要的角色。

然而,这个基础且强大的操作有一个“阿喀琉斯之踵”:它通常是不可微的。这意味着无法将它像普通层一样直接嵌入到深度神经网络中,然后使用梯度下降法进行端到端的训练。这极大地限制了许多新颖模型的设计。

为了绕开这个障碍,研究者们通常采用“两阶段训练”等妥协方案:先用一个代理损失(如交叉熵)训练特征提取网络,然后再将提取的特征用于 top-k 相关的任务。这种做法导致了训练目标和最终任务之间的不一致,往往会损害模型的性能表现。

那么,有没有办法让 top-k 变得可导呢?来自 Google 和佐治亚理工学院的研究者们在论文 Differentiable Top-k Operator with Optimal Transport 中提出了一种名为 SOFT (Scalable Optimal transport-based differenTiable) top-k 的算子,基于“最优传输”(Optimal Transport)的思想解决了这个问题。

2 Top-k的“微分困境”

为什么常规的 top-k 操作是不可导的?

  1. 算法实现的障碍top-k 的标准算法,如冒泡排序或快速选择(QuickSelect),都涉及到大量的索引交换和比较操作。这些基于逻辑和顺序的操作,其梯度要么无法定义,要么处处为零,无法为梯度下降提供有效信息。

  2. 数学本质的非连续性:从数学角度看,top-k 可以被视为一个映射,它将一组输入分数 x 映射到一个由 0 和 1 组成的指示向量 A(1代表该元素属于top-k,0则不属于)。这个映射是分段常数函数,因此是非连续的。

让我们以一个最简单的 top-1 例子来说明(即找出两个数 x1,x2x_1, x_2x1,x2 中较大的一个)。指示向量 A1A_1A1 的值(表示 x1x_1x1 是否为最大值)关于 x1x_1x1 的函数图像如下:

图1

图1 左侧为标准top-k算子,右侧为SOFT top-k算子。可以看出标准算子的输出是突变的,而SOFT算子是平滑的

x1<x2x_1 < x_2x1<x2 时,A1=0A_1=0A1=0。当 x1x_1x1 刚刚超过 x2x_2x2 时,A1A_1A1 会从 0 瞬间跳变到 1 。在 x1=x2x_1 = x_2x1=x2 这个点,函数是不可导的;而在其他所有地方,它的导数都是 0 。这样的梯度对于模型训练是毫无用处的。

3 核心思想:当Top-k遇上最优传输(OT)

既然直接微分此路不通,换一个思路:top-k 问题重构为一个最优传输(Optimal Transport, OT)问题

3.1 将Top-k问题参数化为OT问题

最优传输旨在以最低的“运输成本”将一个概率分布的“质量”转移到另一个概率分布上。将 top-k 重新定义为这样一个运输问题:

  • 源分布 μ\muμ:我们有 n 个输入分数 {xi}i=1n{\{x_i\}}_{i=1}^n{xi}i=1n,我们将它们看作 n 个质量均为 1/n1/n1/n 的源点。
  • 目标分布 ν\nuν:我们设立两个目标点 0,1{0, 1}0,1。我们希望将 k 个单位的质量运到点 0(代表“top-k集合”),剩下的 n-k 个单位的质量运到点 1(代表“非top-k集合”)。
  • 运输成本 CCC:从源点 xix_ixi 运输到目标点 yjy_jyj (其中yj∈0,1y_j \in {0, 1}yj0,1) 的成本定义为它们之间的欧氏距离的平方,即 Cij=(xi−yj)2C_{ij} = (x_i - y_j)^2Cij=(xiyj)2

直观上,为了最小化总运输成本,模型会自发地将那些离 0 最近的 k 个 xix_ixi(也就是 k 个最小的数)运输到目标点 0,将其余离 1 最近的 n-k 个 xix_ixi(也就是 n-k 个最大的数)运输到目标点 1。

这个 OT 问题的解,即“最优运输方案” Γ∗\Gamma^*Γ, 是一个 n×2n \times 2n×2 的矩阵。Γi,1∗\Gamma^*_{i,1}Γi,1 代表从 xix_ixi 运到 0 的质量。可以证明,这个运输方案恰好可以用来表示 top-k 的结果。

3.2 通过熵正则化实现平滑

虽然 OT 的框架很优雅,但标准 OT 问题的解对于输入的变化仍然不具备可微性。

关键的第二步来了:为 OT 问题引入熵正则化(Entropy Regularization)。具体来说,在最小化运输成本的同时,我们还希望最大化运输方案 Γ\GammaΓ 的熵 H(Γ)H(\Gamma)H(Γ)。目标函数变为:

Γ∗,ϵ=arg⁡min⁡Γ⟨C,Γ⟩+ϵH(Γ)\Gamma^{*,\epsilon} = \arg\min_{\Gamma} \langle C, \Gamma \rangle + \epsilon H(\Gamma)Γ,ϵ=argΓminC,Γ+ϵH(Γ)
其中 ϵ\epsilonϵ 是一个大于0的正则化系数。

图2

图2 不同平滑参数 epsilon 下的SOFT算子输出

熵正则化就像一个“平滑器”。它使得运输方案 Γ∗,ϵ\Gamma^{*,\epsilon}Γ,ϵ 不再是“非此即彼”的硬分配,而是变成了一个“软”的、模糊的分配。每个输入 xix_ixi 都会将一部分质量分配给 0,一部分分配给 1。最终,top-k 的指示向量 AϵA^\epsilonAϵ 中的值不再是刚性的 0 或 1,而是位于 (0, 1) 之间的平滑数值。这个熵正则化最优传输(EOT)问题的解 AϵA^\epsilonAϵ 对于输入分数 XXX 是可微的

4 SOFT Top-k 算子详解

基于上述思想,SOFT top-k 算子诞生了。它的工作流程分为前向传播和反向传播。

注:关于前向传播(Forward Pass)与反向传播(Backward Pass)这两个概念的介绍,可以参见我的这一篇文章:【深度学习】一文彻底搞懂前向传播(Forward Pass)与反向传播(Backward Pass)。

  • 前向传播:给定输入分数 XXX,通过高效的 Sinkhorn 算法 来求解熵正则化 OT 问题,从而计算出平滑后的 top-k 指示向量 AϵA^\epsilonAϵ

  • 反向传播:计算 AϵA^\epsilonAϵ 相对于输入 XXX 的雅可比矩阵(梯度)。如果直接对 Sinkhorn 算法的迭代过程应用自动微分,会占用巨大的内存。因此,作者们利用了 EOT 问题的 KKT (Karush-Kuhn-Tucker) 最优性条件,通过隐式微分技术,推导出了计算雅可比矩阵的解析表达式。这种方法在计算上十分高效,其时间和空间复杂度仅为 O(n)\mathcal{O}(n)O(n)

  • 超参数 ϵ\epsilonϵ 的权衡ϵ\epsilonϵ 控制着近似的程度。

    • ϵ\epsilonϵ:近似偏差小,结果更接近真实的 top-k,但函数更“陡峭”,平滑效果弱。
    • ϵ\epsilonϵ:平滑效果好,梯度计算更稳定,但近似偏差大,可能影响模型性能。
    • 有趣的是,偏差也和第 k 个元素与第 k+1 个元素之间的差距(gap)有关。如果差距很大,即使 ϵ\epsilonϵ 较大,偏差也可以很小。

此外,该框架还能被轻松扩展为 Sorted SOFT Top-k 算子,用于需要对 top-k 结果进行排序的场景,例如 Beam Search 。

5 应用与效果

为了验证 SOFT top-k 的威力,作者在三个典型的应用场景中进行了实验。

  1. k-NN 图像分类
    通过将 SOFT top-k 算子整合进 k-NN 分类器,作者们实现了一个可以端到端训练的神经网络 k-NN 模型。在 MNIST 和 CIFAR-10 数据集上的实验结果表明,该方法显著优于传统的两阶段训练方法和其他可微 top-k 的基线模型。

    算法MNISTCIFAR10
    kNN+pretrained CNN98.4%91.1%
    CE+CNN99.0%91.3%
    kNN+Softmax k times99.3%92.2%
    kNN+SOFT Top-k99.4%92.6%
表1:k-NN分类准确率对比
  1. 机器翻译中的束搜索 (Beam Search)
    在序列生成任务中,训练和推理之间的不一致(也称作“暴露偏差”)是一个长期存在的问题。作者将 Sorted SOFT Top-k 算子应用于 Beam Search 过程,使得搜索过程本身可以被整合到训练循环中。在 WMT’14 英法翻译任务上,该方法取得了约 0.9 BLEU 值的提升。

  2. 机器翻译中的 Top-k Attention
    传统的 Soft Attention 机制会考虑所有源端词语,可能导致注意力分散和冗余。作者使用 SOFT top-k 来选择性地关注最重要的 k 个源端词语,从而实现了一种稀疏的注意力机制。

图3

图3 Top-k Attention可视化(示意图),注意力变得稀疏且对齐清晰

在 WMT’16 英德翻译任务上,这种稀疏注意力机制也带来了约 0.8 BLEU 值的提升。

6 总结

top-k 操作的不可微性一直是深度学习领域的一个痛点。论文 Differentiable Top-k Operator with Optimal Transport 通过将 top-k 与熵正则化的最优传输理论相结合,提出了一种名为 SOFT top-k 的优雅解决方案:将离散的 top-k 选择问题转化为一个连续平滑的最优传输问题,实现了完全可微,并且具有高效、可扩展的前向和后向传播算法,并解锁了将 top-k 依赖的操作(如 k-NN、Beam Search)进行端到端训练的能力,在多个任务上取得了显著的效果提升。


参考文献

  1. Xie, Y., Dai, H., Chen, M., Dai, B., Zhao, T., Zha, H., Wei, W., & Pfister, T. (2020). Differentiable top-k operator with optimal transport. Advances in Neural Information Processing Systems, 33
http://www.dtcms.com/a/302898.html

相关文章:

  • (二)Eshop(RabbitMQ手动)
  • 如何 5 分钟给英语视频加上中文字幕?
  • 2025.7.28总结
  • 学术论文写作心得笔记:如何避免“论文像实验报告”
  • 关于sql面试积累
  • [Linux]线程池
  • 【深度学习新浪潮】基于文字生成3D城市景观的算法有哪些?
  • 前端实现PDF在线预览的8种技术方案对比与实战
  • 软件设计师-知识点记录
  • WAIC 2025深度解析:当“养虎”警示遇上机器人拳击赛
  • 构建你的专属区块链:深入了解 Polkadot SDK
  • Java序列化与反序列化
  • 从零开始学习Dify-基于MCP的智能旅行规划助手下(九)
  • 02_FOC学习之-闭环位置控制
  • #Datawhale 组队学习#强化学习Task5
  • C# 基于halcon的视觉工作流-章24-矩形查找
  • SpringBoot数学实例:高等数学实战
  • 学习嵌入式的第三十四天-数据结构-(2025.7.28)数据库
  • Linux选择题2
  • Leaflet简介、初步了解
  • 分布式IO详解:2025年分布式无线远程IO采集控制方案选型指南
  • Java学习-----JVM的垃圾回收算法
  • 分布式IO选型指南:2025年分布式无线远程IO品牌及采集控制方案详解
  • OpenGL为什么要用4X4矩阵
  • 构建 P2P 网络与分布式下载系统:从底层原理到安装和功能实现
  • 分布式高可用架构核心:复制、冗余与生死陷阱——从主从灾难到无主冲突的避坑指南
  • 文件夹隐藏精灵 for Win的文件隐私管理痛点
  • 中国汽车能源消耗量(2010-2024年)
  • 点击事件的防抖和节流
  • 【硬件-笔试面试题】硬件/电子工程师,笔试面试题-42,(知识点:D触发器,D锁存器,工作原理,区别)