PyTorch中torch.eq()、torch.argmax()函数的详解和代码示例
下面对 PyTorch 中常用的两个函数 torch.eq()
和 torch.argmax()
的详解,包括 功能、参数说明、返回值、注意事项 以及 代码示例。
一、torch.eq(input, other)
功能:
按元素逐位判断 input == other
,返回 bool
类型的张量(值为 True
或 False
)。
参数:
参数 | 类型 | 含义 |
---|---|---|
input | Tensor | 被比较的第一个张量 |
other | Tensor 或数值 | 被比较的第二个张量或标量 |
要求
input
和other
的维度可广播(broadcastable)。
返回值:
一个 布尔型张量,形状与广播后的 input
相同。
示例 1:两个张量比较
import torcha = torch.tensor([1, 2, 3])
b = torch.tensor([1, 0, 3])
result = torch.eq(a, b)
print(result)
# 输出: tensor([True, False, True])
示例 2:张量与标量比较
a = torch.tensor([[3, 5], [5, 5]])
result = torch.eq(a, 5)
print(result)
# 输出: tensor([[False, True], [True, True]])
二、torch.argmax(input, dim=None, keepdim=False)
功能:
返回 input
张量在指定维度上最大值的索引(索引位置)。
参数:
参数 | 类型 | 说明 |
---|---|---|
input | Tensor | 输入张量 |
dim | int 或 None | 指定在哪个维度上返回最大值索引。默认 None ,表示返回 扁平化后的最大值索引。 |
keepdim | bool | 是否保持维度。默认为 False 。若为 True ,输出将保留原始维度大小为1。 |
返回值:
张量索引(int64),表示每个位置上最大值的索引。
示例 1:不指定 dim
(返回扁平化最大索引)
a = torch.tensor([[1, 9], [3, 7]])
idx = torch.argmax(a)
print(idx)
# 输出: tensor(1)(值 9 在展平后第1个位置)
示例 2:指定 dim=1
(按行取最大值索引)
a = torch.tensor([[1, 9], [3, 7]])
idx = torch.argmax(a, dim=1)
print(idx)
# 输出: tensor([1, 1]) (每行最大值索引)
示例 3:保留维度
a = torch.tensor([[1, 9], [3, 7]])
idx = torch.argmax(a, dim=1, keepdim=True)
print(idx)
# 输出: tensor([[1], [1]])
注意事项对比
函数名 | 返回类型 | 说明 |
---|---|---|
torch.eq() | BoolTensor | 返回逐元素相等的布尔张量 |
torch.argmax() | IntTensor | 返回指定维度或展平后最大值的位置索引 |
实战小技巧:
判断两个结果张量是否完全相同:
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
if torch.eq(a, b).all():print("完全相同")
多分类预测取最大概率类:
logits = torch.tensor([[0.1, 0.3, 0.6],[0.2, 0.5, 0.3]])
pred = torch.argmax(logits, dim=1)
print(pred) # 输出: tensor([2, 1])