np.sum(e_x, axis=-1, keepdims=True)
是 NumPy 中用于求和的函数调用,在你的多头注意力代码中(尤其是 softmax 函数里)有特殊意义,我们可以拆解来看:
e_x
:输入数组(在 softmax 中是经过指数运算的数组,e_x = np.exp(...)
)。axis=-1
:指定求和的轴。-1
表示「最后一个轴」,这是一种灵活的写法,无论数组是 2 维、3 维还是更高维,都能准确选中最后一个维度。- 例如:若
e_x
是形状为 (batch_size, seq_len)
的 2 维数组,axis=-1
等价于 axis=1
(对每个样本的序列长度维度求和); - 若
e_x
是形状为 (batch_size, num_heads, seq_len)
的 3 维数组,axis=-1
等价于 axis=2
(对每个头的序列长度维度求和)。
keepdims=True
:保持求和后的维度不变(不压缩维度)。求和后,被求和的轴长度会变为 1,其他轴形状保持不变。
在 softmax 函数中,这个操作的目的是计算每个样本(或每个注意力头)的指数和,用于后续归一化:
def softmax(x):e_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) return e_x / np.sum(e_x, axis=-1, keepdims=True)
- 假设
e_x
形状为 (batch_size, seq_len)
(例如一批序列的注意力分数),np.sum(..., axis=-1, keepdims=True)
会得到形状为 (batch_size, 1)
的数组,每个元素是对应样本的 seq_len
个指数值之和。 - 由于
keepdims=True
,结果可以和原始 e_x
((batch_size, seq_len)
)通过广播机制进行除法,最终每个位置的输出都是「该位置指数值 / 所有位置指数和」,符合 softmax 归一化的定义。
假设有一个 2 维数组 e_x
(模拟 2 个样本,每个样本 3 个元素):
e_x = np.array([[1, 2, 3], [4, 5, 6]])
np.sum(e_x, axis=-1, keepdims=True)
的结果为:[[6], # 1+2+3[15]] # 4+5+6
形状为 (2, 1)
,与原始 e_x
((2, 3)
)广播后除法,即可得到每个元素的 softmax 概率。
简单说,这个操作的核心是按「最后一个维度」求和并保持形状,确保 softmax 能在正确的维度上进行归一化,同时避免因维度不匹配导致的运算错误。