PyTorch L2范数详解与应用
torch.norm 是什么
torch.norm(dot_product, p=2, dim=-1)
是 PyTorch 中用于计算张量 L2 范数的函数,
1. 各参数解析
dot_product
:输入张量,在代码中形状为[batch_size, seq_len]
(每个元素是 token 隐藏状态与关注向量的点积)。p=2
:指定计算L2 范数(欧几里得范数),公式为:对于向量[x₁, x₂, ..., xₙ]
,L2 范数 =√(x₁² + x₂² + ... + xₙ²)
。dim=-1
:指定计算范数的维度。-1
表示“最后一个维度”,在[batch_size, seq_len]
中即seq_len
维度(序列长度维度)。
2. 计算逻辑(结合代码上下文)
假设 dot_product
的形状为 [2, 3]
(batch_size=2
,seq_len=3