神经正切核(NTK):从梯度流到核方法的完整推导
神经正切核(NTK):从梯度流到核方法的完整推导
1. 动机:破解无限宽神经网络的训练“黑箱”
对于参数量庞大的神经网络,其训练过程本质是复杂的非凸优化问题,精确预测网络参数更新与输出演化的动态极为困难。
神经正切核(Neural Tangent Kernel, NTK) 的核心价值在于:在网络宽度趋于无穷大的理想化极限下,将非线性的神经网络训练动态严格线性化,从而用理论成熟的核方法(Kernel Methods) 解析深度学习的训练行为与泛化能力。
2. 数学设定与核心符号
设神经网络为参数化函数 f(x;θ)f(\mathbf{x}; \boldsymbol{\theta})f(x;θ),其中:
- x∈Rdin\mathbf{x} \in \mathbb{R}^{d_{\text{in}}}x∈Rdin:输入数据;
- θ∈RP\boldsymbol{\theta} \in \mathbb{R}^Pθ∈RP:网络参数(总维度为 PPP);
- f(x;θ)∈Rdoutf(\mathbf{x}; \boldsymbol{\theta}) \in \mathbb{R}^{d_{\text{out}}}f(x;θ)∈Rdout:网络输出。
训练相关定义
- 训练集:D={(xi,yi)}i=1ND = \{(\mathbf{x}_i, \mathbf{y}_i)\}_{i=1}^ND={(xi,yi)}i=1N(NNN 为样本数,yi\mathbf{y}_iyi 为标签);
- 损失函数(以均方误差 MSE 为例):
L(θ)=12N∑i=1N∥f(xi;θ)−yi∥22\mathcal{L}(\boldsymbol{\theta}) = \frac{1}{2N} \sum_{i=1}^N \| f(\mathbf{x}_i; \boldsymbol{\theta}) - \mathbf{y}_i \|_2^2L(θ)=2N1i=1∑N∥f(xi;θ)−yi∥22; - 梯度流(连续时间版梯度下降,暂忽略学习率 η\etaη):
dθ(t)dt=−∇θL(θ(t))\frac{d\boldsymbol{\theta}(t)}{dt} = -\nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}(t))dtdθ(t)=−∇θL(θ(t))
(ttt 为训练时间,∇θ\nabla_{\boldsymbol{\theta}}∇θ 表示对参数 θ\boldsymbol{\theta}θ 的梯度)。
3. 核心推导:从输出动态中“析出”NTK
我们不直接追踪参数 θ(t)\boldsymbol{\theta}(t)θ(t) 的变化,而是聚焦网络输出 f(xi;θ(t))f(\mathbf{x}_i; \boldsymbol{\theta}(t))f(xi;θ(t)) 的时间演化——这是连接神经网络与核方法的关键桥梁。
步骤 1:输出对时间的导数(链式法则)
网络输出 f(xi;θ(t))f(\mathbf{x}_i; \boldsymbol{\theta}(t))f(xi;θ(t)) 随时间的变化,可通过链式法则分解为“参数梯度”与“参数变化率”的乘积:
df(xi;θ(t))dt=∇θf(xi;θ(t))T⋅dθ(t)dt\frac{d f(\mathbf{x}_i; \boldsymbol{\theta}(t))}{dt} = \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_i; \boldsymbol{\theta}(t))^{\mathsf{T}} \cdot \frac{d\boldsymbol{\theta}(t)}{dt}dtdf(xi;θ(t))=∇θf(xi;θ(t))T⋅dtdθ(t)
其中,∇θf(xi;θ(t))∈RP\nabla_{\boldsymbol{\theta}} f(\mathbf{x}_i; \boldsymbol{\theta}(t)) \in \mathbb{R}^P∇θf(xi;θ(t))∈RP 是输出对所有参数的梯度向量(“敏感度向量”)。
步骤 2:代入梯度流方程
将梯度流 dθ(t)dt=−∇θL(θ(t))\frac{d\boldsymbol{\theta}(t)}{dt} = -\nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}(t))dtdθ(t)=−∇θL(θ(t)) 代入上式,得到输出动态的核心方程:
df(xi;θ(t))dt=−∇θf(xi;θ(t))T⋅∇θL(θ(t))\frac{d f(\mathbf{x}_i; \boldsymbol{\theta}(t))}{dt} = - \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_i; \boldsymbol{\theta}(t))^{\mathsf{T}} \cdot \nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}(t))dtdf(xi;θ(t))=−∇θf(xi;θ(t))T⋅∇θL(θ(t))
步骤 3:展开损失函数的梯度
损失 L\mathcal{L}L 由所有样本的输出共同决定,其参数梯度需再次用链式法则拆解:
∇θL(θ(t))=∑j=1N∂L∂f(xj;θ(t))⋅∇θf(xj;θ(t))\nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}(t)) = \sum_{j=1}^N \frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_j; \boldsymbol{\theta}(t))} \cdot \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_j; \boldsymbol{\theta}(t))∇θL(θ(t))=j=1∑N∂f(xj;θ(t))∂L⋅∇θf(xj;θ(t))
其中,∂L∂f(xj;θ(t))\frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_j; \boldsymbol{\theta}(t))}∂f(xj;θ(t))∂L 是损失对第 jjj 个样本输出的梯度(对 MSE 而言,此值为 1N(f(xj;θ(t))−yj)\frac{1}{N}(f(\mathbf{x}_j; \boldsymbol{\theta}(t)) - \mathbf{y}_j)N1(f(xj;θ(t))−yj))。
步骤 4:整合并识别核结构
将步骤 3 的结果代入步骤 2 的方程,整理后可得:
df(xi;θ(t))dt=−∑j=1N(∇θf(xi;θ(t))T⋅∇θf(xj;θ(t)))⏟核结构⋅∂L∂f(xj;θ(t))\frac{d f(\mathbf{x}_i; \boldsymbol{\theta}(t))}{dt} = - \sum_{j=1}^N \underbrace{\left( \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_i; \boldsymbol{\theta}(t))^{\mathsf{T}} \cdot \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_j; \boldsymbol{\theta}(t)) \right)}_{\text{核结构}} \cdot \frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_j; \boldsymbol{\theta}(t))}dtdf(xi;θ(t))=−j=1∑N核结构(∇θf(xi;θ(t))T⋅∇θf(xj;θ(t)))⋅∂f(xj;θ(t))∂L
上式中,下划线部分即为 NTK 的核心定义——它由两个样本的“参数敏感度向量”的点积构成,仅依赖于输入 xi,xj\mathbf{x}_i, \mathbf{x}_jxi,xj 和参数 θ(t)\boldsymbol{\theta}(t)θ(t)。
4. NTK 的正式定义与无限宽度极限
4.1 定义:神经正切核
对于神经网络 f(x;θ)f(\mathbf{x}; \boldsymbol{\theta})f(x;θ),两个输入 x\mathbf{x}x 与 x′\mathbf{x}'x′ 之间的 神经正切核 定义为:
Θ(x,x′;θ)=∇θf(x;θ)T⋅∇θf(x′;θ)=∑p=1P∂f(x;θ)∂θp⋅∂f(x′;θ)∂θp\Theta(\mathbf{x}, \mathbf{x}'; \boldsymbol{\theta}) = \nabla_{\boldsymbol{\theta}} f(\mathbf{x}; \boldsymbol{\theta})^{\mathsf{T}} \cdot \nabla_{\boldsymbol{\theta}} f(\mathbf{x}'; \boldsymbol{\theta}) = \sum_{p=1}^P \frac{\partial f(\mathbf{x}; \boldsymbol{\theta})}{\partial \theta_p} \cdot \frac{\partial f(\mathbf{x}'; \boldsymbol{\theta})}{\partial \theta_p}Θ(x,x′;θ)=∇θf(x;θ)T⋅∇θf(x′;θ)=p=1∑P∂θp∂f(x;θ)⋅∂θp∂f(x′;θ)
若将所有训练样本的 NTK 组合为矩阵,可得到 NTK 格拉姆矩阵(Gram Matrix) Θ(t)∈RN×N\mathbf{\Theta}(t) \in \mathbb{R}^{N \times N}Θ(t)∈RN×N,其 (i,j)(i,j)(i,j) 元素为 Θ(xi,xj;θ(t))\Theta(\mathbf{x}_i, \mathbf{x}_j; \boldsymbol{\theta}(t))Θ(xi,xj;θ(t))。
4.2 关键定理:无限宽度下的 NTK 特性
根据 Jacot 等人(2018)的开创性研究,对于采用标准初始化(如高斯初始化)的常见网络(全连接、卷积等),当所有隐藏层宽度 m→∞m \to \inftym→∞ 时:
- 初始化时收敛到确定核:初始 NTK Θ(0)\mathbf{\Theta}(0)Θ(0) 依概率收敛到一个与具体参数无关的确定性核矩阵 KNTK\mathbf{K}_{\text{NTK}}KNTK;
- 训练中保持恒定:训练过程中参数变化量相对于初始值可忽略(“懒惰训练”,Lazy Training),导致 Θ(t)≈Θ(0)=KNTK\mathbf{\Theta}(t) \approx \mathbf{\Theta}(0) = \mathbf{K}_{\text{NTK}}Θ(t)≈Θ(0)=KNTK 对所有 t≥0t \geq 0t≥0 成立。
5. 线性化动态与最终解
在无限宽度极限下,NTK 成为常数矩阵,这使得原本复杂的非线性输出动态退化为常系数线性常微分方程(ODE)。
矩阵形式的动态方程
令 f(t)=[f(x1;θ(t)),…,f(xN;θ(t))]T\mathbf{f}(t) = [f(\mathbf{x}_1; \boldsymbol{\theta}(t)), \dots, f(\mathbf{x}_N; \boldsymbol{\theta}(t))]^{\mathsf{T}}f(t)=[f(x1;θ(t)),…,f(xN;θ(t))]T(输出向量),g(t)=[∂L∂f(x1),…,∂L∂f(xN)]T\mathbf{g}(t) = [\frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_1)}, \dots, \frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_N)}]^{\mathsf{T}}g(t)=[∂f(x1)∂L,…,∂f(xN)∂L]T(损失对输出的梯度向量),则输出动态可写为:
df(t)dt=−ηKNTKg(t)\frac{d\mathbf{f}(t)}{dt} = - \eta \mathbf{K}_{\text{NTK}} \mathbf{g}(t)dtdf(t)=−ηKNTKg(t)
对 MSE 损失的解析解
代入 MSE 的梯度 g(t)=1N(f(t)−y)\mathbf{g}(t) = \frac{1}{N}(\mathbf{f}(t) - \mathbf{y})g(t)=N1(f(t)−y),方程变为:
df(t)dt=−ηNKNTK(f(t)−y)\frac{d\mathbf{f}(t)}{dt} = - \frac{\eta}{N} \mathbf{K}_{\text{NTK}} (\mathbf{f}(t) - \mathbf{y})dtdf(t)=−NηKNTK(f(t)−y)
这是标准的一阶线性 ODE,其解析解为:
f(t)=y+exp(−ηtNKNTK)(f(0)−y)\mathbf{f}(t) = \mathbf{y} + \exp\left(-\frac{\eta t}{N} \mathbf{K}_{\text{NTK}}\right) (\mathbf{f}(0) - \mathbf{y})f(t)=y+exp(−NηtKNTK)(f(0)−y)
关键结论:此解与核回归(Kernel Regression) 的训练动态完全一致——无限宽神经网络的训练等价于一个使用 KNTK\mathbf{K}_{\text{NTK}}KNTK 作为核函数的经典核机器。