神经网络之链式法则
一、什么是链式法则?
链式法则是微积分中用于处理复合函数求导的基本规则。
🧠 基本思想:
当一个变量依赖另一个变量,而另一个变量又依赖另一个变量时,总的变化率是每一步变化率的乘积。
🎯 经典形式(单变量):
设:
z=f(y),y=g(x)⇒z=f(g(x))z = f(y), \quad y = g(x) \Rightarrow z = f(g(x)) z=f(y),y=g(x)⇒z=f(g(x))
那么链式法则是:
dzdx=dzdy⋅dydx\frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx} dxdz=dydz⋅dxdy
即:
“总导数 = 每层局部导数的乘积”
二、链式法则中:何时是乘法?何时是加法?
链式法则的核心操作是乘法,但在计算图(如神经网络)中,有些情况却需要加法。这两种形式其实不是矛盾,而是出现在不同的结构中。
我们来分别说明:
✅ 1. 使用乘法的场景:函数嵌套
📌 结构:
x→y=g(x)→z=f(y)⇒z=f(g(x))x \rightarrow y = g(x) \rightarrow z = f(y) \Rightarrow z = f(g(x)) x→y=g(x)→z=f(y)⇒z=f(g(x))
🧮 梯度计算:
dzdx=dzdy⋅dydx\frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx} dxdz=dydz⋅dxdy
✅ 特点:
- 每个函数的输出作为下一个函数的输入
- 形成一条链
- 每一步都“放大”或“缩小”输入的变化
- 所以梯度乘起来
📎 举例:
z=sin(x2)⇒dzdx=cos(x2)⋅2xz = \sin(x^2) \Rightarrow \frac{dz}{dx} = \cos(x^2) \cdot 2x z=sin(x2)⇒dxdz=cos(x2)⋅2x
✅ 2. 使用加法的场景:分支结构(多路径传播)
📌 结构:
x→f(x)=a,x→g(x)=b,L=a+b⇒L=f(x)+g(x)x \rightarrow f(x) = a,\quad x \rightarrow g(x) = b,\quad L = a + b \Rightarrow L = f(x) + g(x) x→f(x)=a,x→g(x)=b,L=a+b⇒L=f(x)+g(x)
🧮 梯度计算:
dLdx=dadx+dbdx\frac{dL}{dx} = \frac{da}{dx} + \frac{db}{dx} dxdL=dxda+dxdb
✅ 特点:
- 一个变量 xxx 同时用于多个地方(有多个“输出”分支)
- 每条路径都有独立的影响
- 梯度在反向传播时合流,进行加法
📎 举例:
L=x2+sin(x)⇒dLdx=2x+cos(x)L = x^2 + \sin(x) \Rightarrow \frac{dL}{dx} = 2x + \cos(x) L=x2+sin(x)⇒dxdL=2x+cos(x)
三、为什么是乘法?为什么是加法?
🔷 为什么是乘法(链式结构)?
源于导数的定义:
dfdx=limΔx→0ΔfΔx\frac{df}{dx} = \lim_{\Delta x \to 0} \frac{\Delta f}{\Delta x} dxdf=Δx→0limΔxΔf
如果:
x→y→zx \rightarrow y \rightarrow z x→y→z
则:
Δz=dzdy⋅Δy=dzdy⋅dydx⋅Δx\Delta z = \frac{dz}{dy} \cdot \Delta y = \frac{dz}{dy} \cdot \frac{dy}{dx} \cdot \Delta x Δz=dydz⋅Δy=dydz⋅dxdy⋅Δx
所以:
dzdx=dzdy⋅dydx\frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx} dxdz=dydz⋅dxdy
变化被层层放大/缩小 → 所以是乘法
🔶 为什么是加法(分支结构)?
当一个变量影响多个路径时,每条路径都会对最终输出产生一部分影响。
例如:
L=f(x)+g(x)L = f(x) + g(x) L=f(x)+g(x)
那么:
ΔL=Δf+Δg=f′(x)Δx+g′(x)Δx=(f′(x)+g′(x))Δx\Delta L = \Delta f + \Delta g = f'(x)\Delta x + g'(x)\Delta x = (f'(x) + g'(x)) \Delta x ΔL=Δf+Δg=f′(x)Δx+g′(x)Δx=(f′(x)+g′(x))Δx
所以:
dLdx=f′(x)+g′(x)\frac{dL}{dx} = f'(x) + g'(x) dxdL=f′(x)+g′(x)
多个路径独立贡献 → 所以是加法
✅ 总结表格:链式法则中的乘法 vs 加法
项目 | 使用场景 | 结构形式 | 导数公式 | 原因 |
---|---|---|---|---|
🔗 乘法 | 函数嵌套 | z=f(g(x))z = f(g(x))z=f(g(x)) | dzdx=f′(g(x))⋅g′(x)\frac{dz}{dx} = f'(g(x)) \cdot g'(x)dxdz=f′(g(x))⋅g′(x) | 变化层层放大 |
➕ 加法 | 分支结构 | L=f(x)+g(x)L = f(x) + g(x)L=f(x)+g(x) | dLdx=f′(x)+g′(x)\frac{dL}{dx} = f'(x) + g'(x)dxdL=f′(x)+g′(x) | 多路径贡献叠加 |
✅ 一句话总结
链式法则的本质是“乘法”传播,但当一个变量影响多个路径时,每条路径的梯度会在反向传播时加和** —— 所以在分支结构中使用“加法”。**