PyTorch入门学习: 加载数据
1. 加载数据初认识
from torch.utils.data import Dataset
from PIL import Image
import osclass ReadData(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.image_path_dir = os.path.join(root_dir, label_dir) # 拼接路径self.image_list = os.listdir(self.image_path_dir) # 将数据存在列表中def __getitem__(self, idx):image_name = self.image_list[idx]image_item_path = os.path.join(self.image_path_dir, image_name)image = Image.open(image_item_path)label = self.label_dirreturn image, labelroot_dir = "D:\\pytorch_projects2\\dataset\\train"
label_dir = "ants"
test_model = ReadData(root_dir, label_dir)img, label = test_model[1]img.show()
print(label)
运行结果:


2. TensorBoard的使用
(1)简单测试:
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("../logs")
for i in range(100):writer.add_scalar("y=2x", 2*i,i) # 第一个参数:相当于标题; 第二个相当于y轴;第三个相当于x轴writer.close()
终端输入:tensorboard --logdir=logs
运行结果:

(2)写入图片:
import numpy as np
from PIL import Image
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("../logs_image")
img_path = "D:\\pytorch_projects2\\dataset\\train\\ants\\0013035.jpg"
img = Image.open(img_path)
img_array = np.array(img) # 将PIL类型转换为np数组的形式
print(img_array.shape)writer.add_image("test", img_array, 1, dataformats='HWC') # dataformats='HWC'表示图片的格式是高宽通道, 后面使用ToTensor()类型就不需要加这个了
writer.close()
运行结果:

3. Transforms的使用:
(1)
from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriterimg_path = "../dataset/train/bees/16838648_415acd9e3f.jpg"
img = Image.open(img_path)writer = SummaryWriter("../logs_tranforms_img")image_tensor = transforms.ToTensor() # 将类型转为tensor类型
img_output = image_tensor(img)writer.add_image("test", img_output, 1)
writer.close()上述通过transforms.ToTensor()可以将类型转为tensor类型
运行结果:

(2)常见的transforms的使用:
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsimg_path = "D:\\pytorch_projects2\\dataset\\train\\ants\\0013035.jpg"
img = Image.open(img_path)# ToTensor()
tensor = transforms.ToTensor()
img_tensor = tensor(img)writer = SummaryWriter("../useful_transforms_logs")
writer.add_image("original", img_tensor, 0)# normalize
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img_normalize = normalize(img_tensor)
writer.add_image("normalize", img_normalize, 0)# Resize
# print(img.size)
resize = transforms.Resize((512, 512))
img_resize = resize(img) # 得到的仍是 PIL 类型的
# print(img_resize.size)
img_resize_toTensor = tensor(img_resize)
writer.add_image("resize", img_resize_toTensor, 0)# Compose
resize2 = transforms.Resize(512)
toCompose = transforms.Compose([resize2, tensor])
img_Compose = toCompose(img)
writer.add_image("resize", img_Compose, 1)# RandomCrop
randomCrop = transforms.RandomCrop(512)
for i in range(10):img_randomCrop = randomCrop(img)writer.add_image("randomCrop", tensor(img_randomCrop), i)writer.close()
4. torchvision中数据集的使用
import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root="../dataset_transforms", train=True, transform=dataset_transform)
test_set = torchvision.datasets.CIFAR10(root="../dataset_transforms", train=False, transform=dataset_transform)# print(train_set[0])
# img, target = train_set[0]
# print(img)
# print(target)
# print(train_set.classes[target])
# img.show()writer = SummaryWriter("../dataset_transforms_logs")
idx = 0
for img, target in train_set:writer.add_image("CIFAR10", img, idx)idx += 1writer.close()
运行结果:

5. dataloader的使用:
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])test_dataset = torchvision.datasets.CIFAR10(root="../dataset_transforms", train=False, transform=dataset_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=False)writer = SummaryWriter("../dataloader_logs")step = 0
for data in test_loader:imgs, targets = datawriter.add_images("CIFAR10 Images", imgs, step)step += 1writer.close()
参数解释:

运行结果:

