PyTorch_点积运算
点积运算要求第一个矩阵 shape:(n, m),第二个矩阵 shape: (m, p), 两个矩阵点积运算shape为:(n,p)
- 运算符 @ 用于进行两个矩阵的点乘运算
- torch.mm 用于进行两个矩阵点乘运算,要求输入的矩阵为3维 (mm 代表 mat, mul)
- torch.bmm 用于批量进行矩阵点乘运算,要求输入的矩阵为3维 (b 代表 batch)
- torch.matmul 对进行点乘运算的两矩阵形状没有限定。
a. 对于输入都是二维的张量相当于 mm 运算
b. 对于输入都是三维的张量相当于 bmm 运算
c. 对数输入的shape不同的张量,对应的最后几个维度必须符合矩阵运算规则
代码
import torch
import numpy as np # 使用@运算符
def test01():# 形状为:3行2列 data1 = torch.tensor([[1,2], [3,4], [5,6]])# 形状为:2行2列data2 = torch.tensor([[5,6], [7,8]])data = data1 @ data2print(data) # 使用 mm 函数
def test02():# 要求输入的张量形状都是二维的# 形状为:3行2列 data1 = torch.tensor([[1,2], [3,4], [5,6]])# 形状为:2行2列data2 = torch.tensor([[5,6], [7,8]])data = torch.mm(data1, data2) print(data)print(data.shape)# 使用 bmm 函数
def test03():# 第一个维度:表示批次# 第二个维度:多少行# 第三个维度:多少列data1 = torch.randn(3, 4, 5)data2 = torch.randn(3, 5, 8)data = torch.bmm(data1, data2) print(data.shape)# 使用 matmul 函数
def test04():# 对二维进行计算data1 = torch.randn(4,5)data2 = torch.randn(5,8)print(torch.matmul(data1, data2).shape)# 对三维进行计算data1 = torch.randn(3, 4, 5)data2 = torch.randn(3, 5, 8)print(torch.matmul(data1, data2).shape)data1 = torch.randn(3, 4, 5)data2 = torch.randn(5, 8)print(torch.matmul(data1, data2).shape) if __name__ == "__main__":test04()