简单行,列索引操作
import torch
import numpy as np
def test01():data = torch.randint(0, 10, [4, 5])print(data)print(data[0])print(data[:, 0]) print(data[1, 2])print(data[:3, 2])print(data[:3, :2])
def test02():data = torch.randint(0, 10, [4, 5])print(data)print(data[[0, 2, 3], [0, 1, 2]])print(data[[[0], [2], [3]], [0, 1, 2]])if __name__ == "__main__":test02()
布尔索引
import torch
import numpy as np
def test01():torch.manual_seed(0)data = torch.randint(0, 10, [4, 5])print(data)print(data > 3)print(data[data > 3])print(data[data[:, 1] > 6])print(data[:, data[1] > 3])
def test02():torch.manual_seed(0)data = torch.randint(0, 10, [3, 4, 5])print(data)print(data[0, :, :])print(data[:, 0, :])print(data[:, :, 0])if __name__ == "__main__":test02()