深度学习批量矩阵乘法实战解析:torch.bmm
计算当前层隐藏状态与事实关注向量之间的点积,并通过维度调整实现批量处理。

举例说明
理解批量矩阵乘法的逻辑。
假设条件(简化维度方便计算):
batch_size = 2
(2个样本) seq_len = 3
(每个样本有3个token) hidden_dim = 2
(隐藏层维度为2)
1. 输入张量的形状与具体值
current_hidden
(原始形状 [batch_size, seq_len, hidden_dim] = [2, 3, 2]
):
假设其值为(每个元素代表一个token的隐藏状态):
current_hidden