【理论推导】互信息与InfoNCE损失:从公式推导理解对比学习的本质
核心结论:InfoNCE损失函数通过对比正负样本,实际上是在最大化变量间互信息的下界,即
I(X;Y)≥log(N)−LInfoNCEI(X;Y) \geq \log(N) - \mathcal{L}_{\text{InfoNCE}} I(X;Y)≥log(N)−LInfoNCE
优化InfoNCE损失等价于隐式地最大化互信息,从而使模型能够捕捉数据中的本质特征和依赖关系。
本文给出了以上公式的逐步理论推导。
相关结论可参考:Oord, Aaron van den, Yazhe Li, and Oriol Vinyals. “Representation learning with contrastive predictive coding.” arxiv preprint arxiv:1807.03748 (2018).
文章目录
- 核心结论
- 一个直观的例子:图像-文本匹配
- 基础知识:互信息与对比学习
- 互信息的定义
- 为什么不能直接优化互信息,而是优化 InfoNCE ?
- InfoNCE损失函数(Information Noise Contrastive Estimation)
- 理论推导1:最优得分函数的形式
- 理论推导2:infoNCE 损失的最优值
- 理论推导3:infoNCE 的上界
- 理论推导4:证明互信息的下界
- 总结
核心结论
InfoNCE损失函数在对比学习中的有效性,本质上源于其对互信息(Mutual Information)的下界估计。当我们用InfoNCE训练模型时,实际上是在最大化两个变量之间的互信息——这个值衡量的是"知道一个变量后,另一个变量的不确定性减少了多少"。具体来说,对于N个样本(1个正样本和N-1个负样本),InfoNCE损失 LInfoNCE\mathcal{L}_{\text{InfoNCE}}LInfoNCE 满足:
I(X;Y)≥log(N)−LInfoNCEI(X;Y) \geq \log(N) - \mathcal{L}_{\text{InfoNCE}} I(X;Y)≥log(N)−LInfoNCE
这意味着,最小化InfoNCE损失等价于最大化互信息的下界。负样本数量N越大,这个下界越紧,模型捕捉变量间依赖关系的能力越强。
一个直观的例子:图像-文本匹配
想象你在训练一个模型,让它将图片和对应的文本描述拉近,不匹配的图文对推远。
- 正样本:一张"金毛犬在草地上奔跑"的图片 + 对应的正确描述
- 负样本:同一张图片 + "一辆红色汽车在街道上行驶"等N-1个无关描述
InfoNCE损失迫使模型学会区分"匹配"与"不匹配"的图文对。从信息论角度看,模型其实是在学习"给定图片后,正确描述的可预测性"——这正是互信息要衡量的东西。当模型能轻易从N个描述中挑出正确那个时,说明图片和文本之间的互信息很高。
基础知识:互信息与对比学习
互信息的定义
互信息 I(X;Y)I(X;Y)I(X;Y) 量化两个随机变量之间的统计依赖性。对于离散变量,定义为:
I(X;Y)=DKL(P(X,Y)∥P(X)P(Y))=∑x,yP(x,y)logP(x,y)P(x)P(y)=EP(x,y)[logP(y∣x)P(y)]I(X;Y) =D_{\text{KL}}\left(P(X,Y) \| P(X)P(Y)\right)= \sum_{x,y} P(x,y) \log \frac{P(x,y)}{P(x)P(y)} = \mathbb{E}_{P(x,y)}\left[\log \frac{P(y|x)}{P(y)}\right] I(X;Y)=DKL(P(X,Y)∥P(X)P(Y))=x,y∑P(x,y)logP(x)P(y)P(x,y)=EP(x,y)[logP(y)P(y∣x)]
符号说明:
- XXX, YYY:两个随机变量
- DKLD_{\text{KL}}DKL:KL散度,衡量两个分布的差异
- P(x,y)P(x,y)P(x,y) 是联合分布,P(x)P(y)P(x)P(y)P(x)P(y) 是边缘分布的乘积。
直观上,互信息度量了知道一个变量后,另一个变量不确定性的减少量。如果XXX和YYY独立,则 P(x,y)=P(x)P(y)P(x,y) = P(x)P(y)P(x,y)=P(x)P(y),此时 I(X;Y)=0I(X;Y)=0I(X;Y)=0;如果它们完全相关,则互信息达到最大值。
例子:在表示学习中,我们希望学习到的特征表示(X)和标签或另一模态数据(Y)之间有高互信息,表示模型捕捉了关键信息。
为什么不能直接优化互信息,而是优化 InfoNCE ?
直接计算互信息的问题:
- 分布未知:真实世界的数据分布 P(x,y)P(x,y)P(x,y) 通常是未知的
- 计算不可行:归一化常数(配分函数)在复杂模型中难以计算
InfoNCE(Information Noise-Contrastive Estimation)通过对比正负样本来绕过这两个问题。它不需要直接估计概率分布,而是通过一个得分函数 f(x,y)f(x,y)f(x,y) 来区分联合分布样本和边缘分布样本。
InfoNCE损失函数(Information Noise Contrastive Estimation)
对比学习的基本设定:给定一个正样本对 (x,y)(x, y)(x,y) 和 N−1N-1N−1 个负样本 yi∼P(Y)y_i \sim P(Y)yi∼P(Y),其中
- 正样本对:(x,y)∼P(x,y)(x, y) \sim P(x,y)(x,y)∼P(x,y)(例如,匹配的图文对)
- 负样本:从 P(y)P(y)P(y) 中独立采样 N−1N-1N−1 个 yiy_iyi(与x无关的样本)
InfoNCE损失定义为:
LInfoNCE=−E[logexp(f(x,y))exp(f(x,y))+∑i=1N−1exp(f(x,yi))]\mathcal{L}_{\text{InfoNCE}} = -\mathbb{E}\left[\log \frac{\exp(f(x,y))}{\exp(f(x,y)) + \sum_{i=1}^{N-1} \exp(f(x,y_i))}\right] LInfoNCE=−E[logexp(f(x,y))+∑i=1N−1exp(f(x,yi))exp(f(x,y))]
符号说明:
- (x,y)(x, y)(x,y):正样本对,来自联合分布P(X,Y)P(X,Y)P(X,Y)
- yiy_iyi:负样本,来自边缘分布P(Y)P(Y)P(Y)
- NNN:总样本数(1个正样本 + N−1N-1N−1个负样本)
- f(x,y)f(x,y)f(x,y):得分函数,衡量x和y的匹配程度(常用神经网络输出的向量点积)
- 分母:正样本得分与所有负样本得分之和,形成softmax概率
直观上,InfoNCE损失鼓励模型给正样本对更高的分数,给负样本对更低的分数。目标是,最大化正样本的相对得分,即让模型将正样本识别为"最匹配"的
理论推导1:最优得分函数的形式
首先,我们证明当得分函数f(x,y)f(x,y)f(x,y)达到最优时,其形式为:
f(x,y)=logP(y∣x)P(y)f(x,y) = \log \frac{P(y|x)}{P(y)} f(x,y)=logP(y)P(y∣x)
推导过程:
InfoNCE本质上是一个二分类问题,模型需要判断样本是来自联合分布P(X,Y)P(X,Y)P(X,Y)(正样本)还是边缘分布P(X)P(Y)P(X)P(Y)P(X)P(Y)(负样本)。
对于二分类问题,最优判别器在样本来自真实分布(正样本)的概率为:
D(x,y)=P(y∣x)P(y∣x)+(N−1)P(y)D(x,y) = \frac{P(y|x)}{P(y|x) + (N-1)P(y)} D(x,y)=P(y∣x)+(N−1)P(y)P(y∣x)
在InfoNCE中,判别器由softmax实现:
D(x,y)=exp(f(x,y))exp(f(y))+∑i=1N−1exp(f(x,yi))D(x,y) = \frac{\exp(f(x,y))}{\exp(f(y)) + \sum_{i=1}^{N-1} \exp(f(x,y_i))} D(x,y)=exp(f(y))+∑i=1N−1exp(f(x,yi))exp(f(x,y))
令以上两个表达式相等
exp(f(x,y))exp(f(x,y))+∑i=1N−1exp(f(x,yi))=P(y∣x)P(y∣x)+(N−1)P(y)\frac{\exp(f(x,y))}{\exp(f(x,y)) + \sum_{i=1}^{N-1} \exp(f(x,y_i))} = \frac{P(y|x)}{P(y|x) + (N-1)P(y)} exp(f(x,y))+∑i=1N−1exp(f(x,yi))exp(f(x,y))=P(y∣x)+(N−1)P(y)P(y∣x)
通过代数推导可得:
exp(f(x,y))⋅(N−1)P(y)=P(y∣x)⋅∑i=1N−1exp(f(x,yi))\exp(f(x,y)) \cdot (N-1)P(y) = P(y|x) \cdot \sum_{i=1}^{N-1} \exp(f(x,y_i)) exp(f(x,y))⋅(N−1)P(y)=P(y∣x)⋅i=1∑N−1exp(f(x,yi))
代数推导过程:
exp(f(x,y))exp(f(x,y))+∑i=1N−1exp(f(x,yi))=P(y∣x)P(y∣x)+(N−1)P(y)\frac{\exp(f(x,y))}{\exp(f(x,y)) + \sum_{i=1}^{N-1} \exp(f(x,y_i))} = \frac{P(y|x)}{P(y|x) + (N-1)P(y)} exp(f(x,y))+∑i=1N−1exp(f(x,yi))exp(f(x,y))=P(y∣x)+(N−1)P(y)P(y∣x)
交叉相乘并整理:exp(f(x,y))⋅[P(y∣x)+(N−1)P(y)]=P(y∣x)⋅[exp(f(x,y))+∑i=1N−1exp(f(x,yi))]\exp(f(x,y)) \cdot \left[ P(y|x) + (N-1)P(y) \right] = P(y|x) \cdot \left[ \exp(f(x,y)) + \sum_{i=1}^{N-1} \exp(f(x,y_i)) \right] exp(f(x,y))⋅[P(y∣x)+(N−1)P(y)]=P(y∣x)⋅[exp(f(x,y))+i=1∑N−1exp(f(x,yi))]
展开后消去相同项:exp(f(x,y))⋅(N−1)P(y)=P(y∣x)⋅∑i=1N−1exp(f(x,yi))\exp(f(x,y)) \cdot (N-1)P(y) = P(y|x) \cdot \sum_{i=1}^{N-1} \exp(f(x,y_i)) exp(f(x,y))⋅(N−1)P(y)=P(y∣x)⋅i=1∑N−1exp(f(x,yi))
假设负样本数量足够多(N→∞N \to \inftyN→∞),求和项可近似为期望:(把 N−1N-1N−1 除到右边)
exp(f(x,y))=P(y∣x)P(y)⋅Ey′∼P(y)[exp(f(x,y′))]\exp(f(x,y)) = \frac{P(y|x)}{P(y)} \cdot \mathbb{E}_{y' \sim P(y)}[\exp(f(x,y'))] exp(f(x,y))=P(y)P(y∣x)⋅Ey′∼P(y)[exp(f(x,y′))]
两边取对数,并令归一化常数 C(x)=Ey′[exp(f(x,y′))]C(x) = \mathbb{E}_{y'}[\exp(f(x,y'))]C(x)=Ey′[exp(f(x,y′))],可得:
f(x,y)=logP(y∣x)P(y)+logC(x)f(x,y) = \log \frac{P(y|x)}{P(y)} + \log C(x) f(x,y)=logP(y)P(y∣x)+logC(x)
在InfoNCE的对称结构中,C(x)C(x)C(x)会被softmax分母中的求和项抵消,不影响损失优化,因此有效最优解为:
f(x,y)=logP(y∣x)P(y)\boxed{f(x,y) = \log \frac{P(y|x)}{P(y)}} f(x,y)=logP(y)P(y∣x)
意义:最优得分函数直接估计了密度比(density ratio),这正是互信息定义中的核心项。
理论推导2:infoNCE 损失的最优值
将最优得分函数 f(x,y)=logP(y∣x)P(y)f(x,y) = \log \frac{P(y|x)}{P(y)}f(x,y)=logP(y)P(y∣x) 代入InfoNCE损失:
LInfoNCE=E[−logexp(logP(y∣x)P(y))exp(logP(y∣x)P(y))+∑i=1N−1exp(logP(yi∣x)P(yi))]\mathcal{L}_{\text{InfoNCE}} = \mathbb{E}\left[-\log \frac{\exp\left(\log \frac{P(y|x)}{P(y)}\right)}{\exp\left(\log \frac{P(y|x)}{P(y)}\right) + \sum_{i=1}^{N-1} \exp\left(\log \frac{P(y_i|x)}{P(y_i)}\right)}\right]\\ LInfoNCE=E−logexp(logP(y)P(y∣x))+∑i=1N−1exp(logP(yi)P(yi∣x))exp(logP(y)P(y∣x))
代数变换:
LInfoNCE=E[−logP(y∣x)P(y)P(y∣x)P(y)+∑i=1N−1P(yi∣x)P(yi)]=E[logP(y∣x)P(y)+∑i=1N−1P(yi∣x)P(yi)P(y∣x)P(y)]\mathcal{L}_{\text{InfoNCE}} = \mathbb{E}\left[-\log \frac{\frac{P(y|x)}{P(y)}}{\frac{P(y|x)}{P(y)} + \sum_{i=1}^{N-1} \frac{P(y_i|x)}{P(y_i)}}\right] = \mathbb{E}\left[\log \frac{\frac{P(y|x)}{P(y)} + \sum_{i=1}^{N-1} \frac{P(y_i|x)}{P(y_i)}}{\frac{P(y|x)}{P(y)}}\right] LInfoNCE=E−logP(y)P(y∣x)+∑i=1N−1P(yi)P(yi∣x)P(y)P(y∣x)=ElogP(y)P(y∣x)P(y)P(y∣x)+∑i=1N−1P(yi)P(yi∣x)
分子分母同乘 P(y)P(y)P(y):
LInfoNCE=E[log(1+∑i=1N−1P(yi∣x)P(yi)⋅P(y)P(y∣x))](4.1)\mathcal{L}_{\text{InfoNCE}} = \mathbb{E}\left[\log \left(1 + \sum_{i=1}^{N-1} \frac{P(y_i|x)}{P(y_i)} \cdot \frac{P(y)}{P(y|x)}\right)\right] \quad (4.1) LInfoNCE=E[log(1+i=1∑N−1P(yi)P(yi∣x)⋅P(y∣x)P(y))](4.1)
注意到,负样本 yiy_iyi 是从边缘分布 P(y)P(y)P(y) 独立采样的。根据大数定律,当 NNN 足够大时:
∑i=1N−1P(yi∣x)P(yi)≈(N−1)⋅Ey′∼P(y)[P(y′∣x)P(y′)]\sum_{i=1}^{N-1} \frac{P(y_i|x)}{P(y_i)} \approx (N-1) \cdot \mathbb{E}_{y' \sim P(y)}\left[\frac{P(y'|x)}{P(y')}\right] i=1∑N−1P(yi)P(yi∣x)≈(N−1)⋅Ey′∼P(y)[P(y′)P(y′∣x)]
计算这个期望:
Ey′∼P(y)[P(y′∣x)P(y′)]=∑y′P(y′)⋅P(y′∣x)P(y′)=∑y′P(y′∣x)=1\mathbb{E}_{y' \sim P(y)}\left[\frac{P(y'|x)}{P(y')}\right] = \sum_{y'} P(y') \cdot \frac{P(y'|x)}{P(y')} = \sum_{y'} P(y'|x) = 1 Ey′∼P(y)[P(y′)P(y′∣x)]=y′∑P(y′)⋅P(y′)P(y′∣x)=y′∑P(y′∣x)=1
因此,式(4.1)简化为:
LInfoNCE=E[log(1+(N−1)⋅P(y)P(y∣x))](4.2)\mathcal{L}_{\text{InfoNCE}} = \mathbb{E}\left[\log \left(1 + (N-1) \cdot \frac{P(y)}{P(y|x)}\right)\right] \quad (4.2) LInfoNCE=E[log(1+(N−1)⋅P(y∣x)P(y))](4.2)
这即是 infoNCE 损失的最优值
理论推导3:infoNCE 的上界
以上已得
LInfoNCE=E[log(1+(N−1)⋅P(y)P(y∣x))]\mathcal{L}_{\text{InfoNCE}} = \mathbb{E}\left[\log \left(1 + (N-1) \cdot \frac{P(y)}{P(y|x)}\right)\right] LInfoNCE=E[log(1+(N−1)⋅P(y∣x)P(y))]
定义辅助变量Z=P(y)P(y∣x)Z = \frac{P(y)}{P(y|x)}Z=P(y∣x)P(y),其期望为:
E[Z]=EP(x,y)[P(y)P(y∣x)]=EP(x)[∑yP(y∣x)⋅P(y)P(y∣x)]=EP(x)[1]=1\mathbb{E}[Z] = \mathbb{E}_{P(x,y)}\left[\frac{P(y)}{P(y|x)}\right] = \mathbb{E}_{P(x)}\left[\sum_{y} P(y|x) \cdot \frac{P(y)}{P(y|x)}\right] = \mathbb{E}_{P(x)}[1] = 1 E[Z]=EP(x,y)[P(y∣x)P(y)]=EP(x)[y∑P(y∣x)⋅P(y∣x)P(y)]=EP(x)[1]=1
应用Jensen不等式(因对数函数为凹函数):
E[log(1+(N−1)Z)]≤logE[1+(N−1)Z]=log(1+(N−1)E[Z])=logN\mathbb{E}\left[\log \left(1 + (N-1)Z\right)\right] \leq \log \mathbb{E}\left[1 + (N-1)Z\right] = \log(1 + (N-1)\mathbb{E}[Z]) = \log N E[log(1+(N−1)Z)]≤logE[1+(N−1)Z]=log(1+(N−1)E[Z])=logN
因此:
LInfoNCE≤logN\mathcal{L}_{\text{InfoNCE}} \leq \log N LInfoNCE≤logN
理论推导4:证明互信息的下界
现在我们需要证明:
I(X;Y)≥log(N)−LInfoNCEI(X;Y) \geq \log(N) - \mathcal{L}_{\text{InfoNCE}} I(X;Y)≥log(N)−LInfoNCE
只需证
I(X;Y)+LInfoNCE−log(N)≥0I(X;Y)+\mathcal{L}_{\text{InfoNCE}}-\log(N)\ge0 I(X;Y)+LInfoNCE−log(N)≥0
其中
I(X;Y)=E[logP(y∣x)P(y)]LInfoNCE=E[log(1+(N−1)⋅P(y)P(y∣x))]I(X;Y) = \mathbb{E}\left[\log \frac{P(y|x)}{P(y)}\right]\\ \mathcal{L}_{\text{InfoNCE}} = \mathbb{E}\left[\log \left(1 + (N-1) \cdot \frac{P(y)}{P(y|x)}\right)\right] I(X;Y)=E[logP(y)P(y∣x)]LInfoNCE=E[log(1+(N−1)⋅P(y∣x)P(y))]
两项求和可得
I(X;Y)+LInfoNCE=E[log(P(y∣x)P(y)+N−1)]I(X;Y)+\mathcal{L}_{\text{InfoNCE}}=\mathbb{E}\left[\log \left(\frac{P(y|x)}{P(y)} + N-1\right)\right] I(X;Y)+LInfoNCE=E[log(P(y)P(y∣x)+N−1)]
进一步
I(X;Y)+LInfoNCE−log(N)=E[log(P(y∣x)NP(y)+N−1N)]I(X;Y)+\mathcal{L}_{\text{InfoNCE}}-\log(N)=\mathbb{E}\left[\log \left(\frac{P(y|x)}{NP(y)} + \frac{N-1}{N}\right)\right] I(X;Y)+LInfoNCE−log(N)=E[log(NP(y)P(y∣x)+NN−1)]
利用凹函数性质:log(ax+by)≥alog(x)+blog(y)\log(ax+by)\ge a\log(x)+b\log(y)log(ax+by)≥alog(x)+blog(y),这里 a+b=1a+b=1a+b=1,a,b>0a,b>0a,b>0,可得
log(P(y∣x)NP(y)+N−1N)≥1NlogP(y∣x)P(y)+N−1Nlog(1)=1NlogP(y∣x)P(y)\log \left(\frac{P(y|x)}{NP(y)} + \frac{N-1}{N}\right)\ge \frac{1}{N}\log\frac{P(y|x)}{P(y)}+\frac{N-1}{N}\log(1)=\frac{1}{N}\log\frac{P(y|x)}{P(y)} log(NP(y)P(y∣x)+NN−1)≥N1logP(y)P(y∣x)+NN−1log(1)=N1logP(y)P(y∣x)
因此,上式可写为
I(X;Y)+LInfoNCE−log(N)≥E[1NlogP(y∣x)P(y)]≥0I(X;Y)+\mathcal{L}_{\text{InfoNCE}}-\log(N)\ge \mathbb{E}\left[\frac{1}{N}\log\frac{P(y|x)}{P(y)}\right]\ge 0 I(X;Y)+LInfoNCE−log(N)≥E[N1logP(y)P(y∣x)]≥0
整理得:
I(X;Y)≥log(N)−LInfoNCEI(X;Y) \geq \log(N) - \mathcal{L}_{\text{InfoNCE}} I(X;Y)≥log(N)−LInfoNCE
至此,互信息的下界得证。
总结
对比学习(如SimCLR、CLIP、MoCo)的成功,根本上是因为它们通过InfoNCE损失隐式地最大化了不同视图或模态间的互信息。模型被迫学习那些在不同变换下保持不变的特征——这些特征恰好承载了数据的核心信息。
从下界公式可以看出:
- log(N)−LInfoNCE\log(N) - \mathcal{L}_{\text{InfoNCE}}log(N)−LInfoNCE 的值越大,说明模型学到的表示中 XXX 和 YYY 的互信息越高。
- 负样本数 NNN 越大,log(N)\log(N)log(N) 项越大,下界越紧。但边际效应递减(log函数特性)。
