当前位置: 首页 > news >正文

神经正切核(NTK):从梯度流到核方法的完整推导

神经正切核(NTK):从梯度流到核方法的完整推导

1. 动机:破解无限宽神经网络的训练“黑箱”

对于参数量庞大的神经网络,其训练过程本质是复杂的非凸优化问题,精确预测网络参数更新与输出演化的动态极为困难。

神经正切核(Neural Tangent Kernel, NTK) 的核心价值在于:在网络宽度趋于无穷大的理想化极限下,将非线性的神经网络训练动态严格线性化,从而用理论成熟的核方法(Kernel Methods) 解析深度学习的训练行为与泛化能力。

2. 数学设定与核心符号

设神经网络为参数化函数 f(x;θ)f(\mathbf{x}; \boldsymbol{\theta})f(x;θ),其中:

  • x∈Rdin\mathbf{x} \in \mathbb{R}^{d_{\text{in}}}xRdin:输入数据;
  • θ∈RP\boldsymbol{\theta} \in \mathbb{R}^PθRP:网络参数(总维度为 PPP);
  • f(x;θ)∈Rdoutf(\mathbf{x}; \boldsymbol{\theta}) \in \mathbb{R}^{d_{\text{out}}}f(x;θ)Rdout:网络输出。

训练相关定义

  • 训练集:D={(xi,yi)}i=1ND = \{(\mathbf{x}_i, \mathbf{y}_i)\}_{i=1}^ND={(xi,yi)}i=1NNNN 为样本数,yi\mathbf{y}_iyi 为标签);
  • 损失函数(以均方误差 MSE 为例):
    L(θ)=12N∑i=1N∥f(xi;θ)−yi∥22\mathcal{L}(\boldsymbol{\theta}) = \frac{1}{2N} \sum_{i=1}^N \| f(\mathbf{x}_i; \boldsymbol{\theta}) - \mathbf{y}_i \|_2^2L(θ)=2N1i=1Nf(xi;θ)yi22
  • 梯度流(连续时间版梯度下降,暂忽略学习率 η\etaη):
    dθ(t)dt=−∇θL(θ(t))\frac{d\boldsymbol{\theta}(t)}{dt} = -\nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}(t))dtdθ(t)=θL(θ(t))
    ttt 为训练时间,∇θ\nabla_{\boldsymbol{\theta}}θ 表示对参数 θ\boldsymbol{\theta}θ 的梯度)。

3. 核心推导:从输出动态中“析出”NTK

我们不直接追踪参数 θ(t)\boldsymbol{\theta}(t)θ(t) 的变化,而是聚焦网络输出 f(xi;θ(t))f(\mathbf{x}_i; \boldsymbol{\theta}(t))f(xi;θ(t)) 的时间演化——这是连接神经网络与核方法的关键桥梁。

步骤 1:输出对时间的导数(链式法则)

网络输出 f(xi;θ(t))f(\mathbf{x}_i; \boldsymbol{\theta}(t))f(xi;θ(t)) 随时间的变化,可通过链式法则分解为“参数梯度”与“参数变化率”的乘积:
df(xi;θ(t))dt=∇θf(xi;θ(t))T⋅dθ(t)dt\frac{d f(\mathbf{x}_i; \boldsymbol{\theta}(t))}{dt} = \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_i; \boldsymbol{\theta}(t))^{\mathsf{T}} \cdot \frac{d\boldsymbol{\theta}(t)}{dt}dtdf(xi;θ(t))=θf(xi;θ(t))Tdtdθ(t)
其中,∇θf(xi;θ(t))∈RP\nabla_{\boldsymbol{\theta}} f(\mathbf{x}_i; \boldsymbol{\theta}(t)) \in \mathbb{R}^Pθf(xi;θ(t))RP 是输出对所有参数的梯度向量(“敏感度向量”)。

步骤 2:代入梯度流方程

将梯度流 dθ(t)dt=−∇θL(θ(t))\frac{d\boldsymbol{\theta}(t)}{dt} = -\nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}(t))dtdθ(t)=θL(θ(t)) 代入上式,得到输出动态的核心方程:
df(xi;θ(t))dt=−∇θf(xi;θ(t))T⋅∇θL(θ(t))\frac{d f(\mathbf{x}_i; \boldsymbol{\theta}(t))}{dt} = - \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_i; \boldsymbol{\theta}(t))^{\mathsf{T}} \cdot \nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}(t))dtdf(xi;θ(t))=θf(xi;θ(t))TθL(θ(t))

步骤 3:展开损失函数的梯度

损失 L\mathcal{L}L 由所有样本的输出共同决定,其参数梯度需再次用链式法则拆解:
∇θL(θ(t))=∑j=1N∂L∂f(xj;θ(t))⋅∇θf(xj;θ(t))\nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}(t)) = \sum_{j=1}^N \frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_j; \boldsymbol{\theta}(t))} \cdot \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_j; \boldsymbol{\theta}(t))θL(θ(t))=j=1Nf(xj;θ(t))Lθf(xj;θ(t))
其中,∂L∂f(xj;θ(t))\frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_j; \boldsymbol{\theta}(t))}f(xj;θ(t))L 是损失对第 jjj 个样本输出的梯度(对 MSE 而言,此值为 1N(f(xj;θ(t))−yj)\frac{1}{N}(f(\mathbf{x}_j; \boldsymbol{\theta}(t)) - \mathbf{y}_j)N1(f(xj;θ(t))yj))。

步骤 4:整合并识别核结构

将步骤 3 的结果代入步骤 2 的方程,整理后可得:
df(xi;θ(t))dt=−∑j=1N(∇θf(xi;θ(t))T⋅∇θf(xj;θ(t)))⏟核结构⋅∂L∂f(xj;θ(t))\frac{d f(\mathbf{x}_i; \boldsymbol{\theta}(t))}{dt} = - \sum_{j=1}^N \underbrace{\left( \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_i; \boldsymbol{\theta}(t))^{\mathsf{T}} \cdot \nabla_{\boldsymbol{\theta}} f(\mathbf{x}_j; \boldsymbol{\theta}(t)) \right)}_{\text{核结构}} \cdot \frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_j; \boldsymbol{\theta}(t))}dtdf(xi;θ(t))=j=1N核结构(θf(xi;θ(t))Tθf(xj;θ(t)))f(xj;θ(t))L

上式中,下划线部分即为 NTK 的核心定义——它由两个样本的“参数敏感度向量”的点积构成,仅依赖于输入 xi,xj\mathbf{x}_i, \mathbf{x}_jxi,xj 和参数 θ(t)\boldsymbol{\theta}(t)θ(t)

4. NTK 的正式定义与无限宽度极限

4.1 定义:神经正切核

对于神经网络 f(x;θ)f(\mathbf{x}; \boldsymbol{\theta})f(x;θ),两个输入 x\mathbf{x}xx′\mathbf{x}'x 之间的 神经正切核 定义为:
Θ(x,x′;θ)=∇θf(x;θ)T⋅∇θf(x′;θ)=∑p=1P∂f(x;θ)∂θp⋅∂f(x′;θ)∂θp\Theta(\mathbf{x}, \mathbf{x}'; \boldsymbol{\theta}) = \nabla_{\boldsymbol{\theta}} f(\mathbf{x}; \boldsymbol{\theta})^{\mathsf{T}} \cdot \nabla_{\boldsymbol{\theta}} f(\mathbf{x}'; \boldsymbol{\theta}) = \sum_{p=1}^P \frac{\partial f(\mathbf{x}; \boldsymbol{\theta})}{\partial \theta_p} \cdot \frac{\partial f(\mathbf{x}'; \boldsymbol{\theta})}{\partial \theta_p}Θ(x,x;θ)=θf(x;θ)Tθf(x;θ)=p=1Pθpf(x;θ)θpf(x;θ)

若将所有训练样本的 NTK 组合为矩阵,可得到 NTK 格拉姆矩阵(Gram Matrix) Θ(t)∈RN×N\mathbf{\Theta}(t) \in \mathbb{R}^{N \times N}Θ(t)RN×N,其 (i,j)(i,j)(i,j) 元素为 Θ(xi,xj;θ(t))\Theta(\mathbf{x}_i, \mathbf{x}_j; \boldsymbol{\theta}(t))Θ(xi,xj;θ(t))

4.2 关键定理:无限宽度下的 NTK 特性

根据 Jacot 等人(2018)的开创性研究,对于采用标准初始化(如高斯初始化)的常见网络(全连接、卷积等),当所有隐藏层宽度 m→∞m \to \inftym 时:

  1. 初始化时收敛到确定核:初始 NTK Θ(0)\mathbf{\Theta}(0)Θ(0) 依概率收敛到一个与具体参数无关的确定性核矩阵 KNTK\mathbf{K}_{\text{NTK}}KNTK
  2. 训练中保持恒定:训练过程中参数变化量相对于初始值可忽略(“懒惰训练”,Lazy Training),导致 Θ(t)≈Θ(0)=KNTK\mathbf{\Theta}(t) \approx \mathbf{\Theta}(0) = \mathbf{K}_{\text{NTK}}Θ(t)Θ(0)=KNTK 对所有 t≥0t \geq 0t0 成立。

5. 线性化动态与最终解

在无限宽度极限下,NTK 成为常数矩阵,这使得原本复杂的非线性输出动态退化为常系数线性常微分方程(ODE)

矩阵形式的动态方程

f(t)=[f(x1;θ(t)),…,f(xN;θ(t))]T\mathbf{f}(t) = [f(\mathbf{x}_1; \boldsymbol{\theta}(t)), \dots, f(\mathbf{x}_N; \boldsymbol{\theta}(t))]^{\mathsf{T}}f(t)=[f(x1;θ(t)),,f(xN;θ(t))]T(输出向量),g(t)=[∂L∂f(x1),…,∂L∂f(xN)]T\mathbf{g}(t) = [\frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_1)}, \dots, \frac{\partial \mathcal{L}}{\partial f(\mathbf{x}_N)}]^{\mathsf{T}}g(t)=[f(x1)L,,f(xN)L]T(损失对输出的梯度向量),则输出动态可写为:
df(t)dt=−ηKNTKg(t)\frac{d\mathbf{f}(t)}{dt} = - \eta \mathbf{K}_{\text{NTK}} \mathbf{g}(t)dtdf(t)=ηKNTKg(t)

对 MSE 损失的解析解

代入 MSE 的梯度 g(t)=1N(f(t)−y)\mathbf{g}(t) = \frac{1}{N}(\mathbf{f}(t) - \mathbf{y})g(t)=N1(f(t)y),方程变为:
df(t)dt=−ηNKNTK(f(t)−y)\frac{d\mathbf{f}(t)}{dt} = - \frac{\eta}{N} \mathbf{K}_{\text{NTK}} (\mathbf{f}(t) - \mathbf{y})dtdf(t)=NηKNTK(f(t)y)

这是标准的一阶线性 ODE,其解析解为:
f(t)=y+exp⁡(−ηtNKNTK)(f(0)−y)\mathbf{f}(t) = \mathbf{y} + \exp\left(-\frac{\eta t}{N} \mathbf{K}_{\text{NTK}}\right) (\mathbf{f}(0) - \mathbf{y})f(t)=y+exp(NηtKNTK)(f(0)y)

关键结论:此解与核回归(Kernel Regression) 的训练动态完全一致——无限宽神经网络的训练等价于一个使用 KNTK\mathbf{K}_{\text{NTK}}KNTK 作为核函数的经典核机器。

http://www.dtcms.com/a/415650.html

相关文章:

  • 想在浏览器里跑 AI?TensorFlow.js 硬件支持完全指南
  • 安徽省城乡住房建设厅网站沧县官厅网站建设
  • 网站开发北京虚拟主机做网站教程
  • WSL 安装方法(简单全面)
  • 京东100道GO面试题及参考答案(上)
  • 网站被挂黑链怎么处理深圳宝安网站建设公司推荐
  • h5网站模板下载wordpress加速访问
  • 语言大模型(LLM)与自然语言处理(NLP)
  • 如何构建网站重庆中技互联
  • QML学习笔记(十五)QML的信号处理器(MouseArea)
  • php 微信 网站建设无限观影次数的app软件
  • 苏州网站建设数据网络WordPress支付宝登录
  • opcode - Claude Code 图形化工具集
  • 淮南招聘网站建设全球域名注册平台
  • VsCode配置Claude Code-Windows
  • 网站建设台词精品课程网站设计说明范文
  • 手写MyBatis第78弹:装饰器模式在MyBatis二级缓存中的应用:从LRU到防击穿的全方案实现
  • 山西网站开发二次开发拍卖网站功能需求文档
  • 中文简洁网站设计图wordpress 导航菜单设置
  • JavaWeb-Ajax、监听器、过滤器及对应案例和jstl补充
  • 如何自己免费建网站做最优秀的自己演讲视频网站
  • 文件包含与下载漏洞
  • centos7.9下安装freeswitch-1.10.5.-release详细教程(极其简单)
  • 慢慢来做网站多少钱互联网保险经纪公司十大排名
  • 【开题答辩全过程】以 Springboot大学英语四、六级学习系统开题为例,包含答辩的问题和答案
  • php网站开发有什么软件男人女人做性关系网站
  • 网站访客qq获取原理南昌网站开发机构
  • 获取淘宝商品视频API接口解析:通过商品链接url获取商品视频item_video
  • k8s node 节点加入 matser 错误 cannot construct envvars
  • 做个自己的网站需要多少钱桂林网络平台开发公司