神经网络之计算图分支节点
🌿 一、什么是计算图中的分支节点?
分支节点是指在计算图中,某个中间变量的值被多个后续节点所“依赖”或“使用”的情况。也就是说,一个节点的输出是多个节点的输入,就像一条河流分叉,数据“流”向多个方向。
✅ 举个简单例子:
假设我们有函数:
[
L = f(x, y) = (x + y)^2 + (x + y)^3
]
我们可以定义:
[
z = x + y
L = z^2 + z^3
]
这里的 z
被两个操作使用了:( z^2 ) 和 ( z^3 )。这个 z
就是一个分支节点,因为它的值被用于多个地方。
🔁 二、为什么分支节点在反向传播中重要?
在反向传播中,我们要计算每个变量对最终损失 (L) 的导数(梯度)。当一个变量像 z
这样被多个路径使用时,我们需要:
将所有路径上传回来的梯度累加起来!
也就是说:
[
\frac{dL}{dz} = \frac{dL_1}{dz} + \frac{dL_2}{dz}
]
这是反向传播中的一个关键规则,叫做:
“梯度汇聚(Gradient Accumulation)” 或 “梯度相加”
✍️ 三、上面的例子完整反向传播推导
我们再看刚才的例子:
[
L = z^2 + z^3
z = x + y
]
第一步:前向传播
- ( z = x + y )
- ( L = z^2 + z^3 )
第二步:反向传播(求梯度)
- ( \frac{dL}{dz} = \frac{d}{dz}(z^2) + \frac{d}{dz}(z^3) = 2z + 3z^2 )
注意,这里就是 分支节点的关键体现:两个路径的导数相加!
-
( \frac{dL}{dx} = \frac{dL}{dz} \cdot \frac{dz}{dx} = (2z + 3z^2) \cdot 1 )
-
同理,( \frac{dL}{dy} = (2z + 3z^2) \cdot 1 )
🔄 四、图示(计算图结构)
x y\ /\ /z = x + y/ \z^2 z^3| |+----+-----+|L = z^2 + z^3
- z 是分支节点
- 反向传播时,来自 ( z^2 ) 和 ( z^3 ) 两个分支的梯度都会合并回 z
🧠 五、编程实现中的处理(以 PyTorch 为例)
在自动微分框架中,如 PyTorch、TensorFlow:
- 每个变量在构建计算图时,记录自己被哪些操作使用;
- 反向传播时,每条路径上传回来的梯度都会自动累加(accumulate);
- 框架自动实现“分支节点”的梯度合并,不需要用户手动相加。
🧩 六、可能的误解和注意点
误解 | 正确做法 |
---|---|
每条路径上的梯度是互相独立的 | 错!必须累加所有路径上传回来的梯度 |
分支节点和参数节点一样对待 | 错!分支节点的梯度可能来自多个地方 |
不需要考虑重复使用的中间变量 | 错!一个变量只要被多次使用,就必须合并梯度 |
🧮 七、总结:分支节点的本质
特性 | 描述 |
---|---|
多路径使用 | 一个变量被多个操作所依赖 |
梯度累加 | 所有路径上传回的梯度需要加在一起 |
计算图核心结构 | 分支是计算图中非线性结构的核心 |
自动处理 | 在现代框架中已自动实现梯度累加 |