当前位置: 首页 > news >正文

PyTorch入门学习: 加载数据

1. 加载数据初认识

from torch.utils.data import Dataset
from PIL import Image
import osclass ReadData(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.image_path_dir = os.path.join(root_dir, label_dir)  # 拼接路径self.image_list = os.listdir(self.image_path_dir)  # 将数据存在列表中def __getitem__(self, idx):image_name = self.image_list[idx]image_item_path = os.path.join(self.image_path_dir, image_name)image = Image.open(image_item_path)label = self.label_dirreturn image, labelroot_dir = "D:\\pytorch_projects2\\dataset\\train"   
label_dir = "ants"
test_model = ReadData(root_dir, label_dir)img, label = test_model[1]img.show()
print(label)

运行结果:

2. TensorBoard的使用

(1)简单测试:

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("../logs")
for i in range(100):writer.add_scalar("y=2x",  2*i,i)  #  第一个参数:相当于标题; 第二个相当于y轴;第三个相当于x轴writer.close()

终端输入:tensorboard --logdir=logs

运行结果:

(2)写入图片:

import numpy as np
from PIL import Image
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("../logs_image")
img_path = "D:\\pytorch_projects2\\dataset\\train\\ants\\0013035.jpg"
img = Image.open(img_path)
img_array = np.array(img)   # 将PIL类型转换为np数组的形式
print(img_array.shape)writer.add_image("test", img_array, 1, dataformats='HWC')  # dataformats='HWC'表示图片的格式是高宽通道, 后面使用ToTensor()类型就不需要加这个了
writer.close()

运行结果:

3. Transforms的使用:

(1)

from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriterimg_path = "../dataset/train/bees/16838648_415acd9e3f.jpg"
img = Image.open(img_path)writer = SummaryWriter("../logs_tranforms_img")image_tensor = transforms.ToTensor()  # 将类型转为tensor类型
img_output = image_tensor(img)writer.add_image("test", img_output, 1)
writer.close()

上述通过transforms.ToTensor()可以将类型转为tensor类型

运行结果:

(2)常见的transforms的使用:

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsimg_path = "D:\\pytorch_projects2\\dataset\\train\\ants\\0013035.jpg"
img = Image.open(img_path)# ToTensor()
tensor = transforms.ToTensor()
img_tensor = tensor(img)writer = SummaryWriter("../useful_transforms_logs")
writer.add_image("original", img_tensor, 0)# normalize
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img_normalize = normalize(img_tensor)
writer.add_image("normalize", img_normalize, 0)# Resize
# print(img.size)
resize = transforms.Resize((512, 512))
img_resize = resize(img)   # 得到的仍是 PIL 类型的
# print(img_resize.size)
img_resize_toTensor = tensor(img_resize)
writer.add_image("resize", img_resize_toTensor, 0)# Compose
resize2 = transforms.Resize(512)
toCompose = transforms.Compose([resize2, tensor])
img_Compose = toCompose(img)
writer.add_image("resize", img_Compose, 1)# RandomCrop
randomCrop = transforms.RandomCrop(512)
for i in range(10):img_randomCrop = randomCrop(img)writer.add_image("randomCrop", tensor(img_randomCrop), i)writer.close()

4. torchvision中数据集的使用

import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root="../dataset_transforms", train=True, transform=dataset_transform)
test_set = torchvision.datasets.CIFAR10(root="../dataset_transforms", train=False, transform=dataset_transform)# print(train_set[0])
# img, target = train_set[0]
# print(img)
# print(target)
# print(train_set.classes[target])
# img.show()writer = SummaryWriter("../dataset_transforms_logs")
idx = 0
for img, target in train_set:writer.add_image("CIFAR10", img, idx)idx += 1writer.close()

运行结果:

5. dataloader的使用:

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])test_dataset = torchvision.datasets.CIFAR10(root="../dataset_transforms", train=False, transform=dataset_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=False)writer = SummaryWriter("../dataloader_logs")step = 0
for data in test_loader:imgs, targets = datawriter.add_images("CIFAR10 Images", imgs, step)step += 1writer.close()

参数解释:

运行结果:

http://www.dtcms.com/a/617745.html

相关文章:

  • Reactor反应堆
  • 【C++】C++11:智能指针
  • 把网站做成手机版创意设计师
  • 条件前缀|同余优化|栈
  • 做淘客app要网站吗大数据精准营销策略
  • 对于数据结构:链式二叉树的超详细保姆级解析—中
  • 多模态大模型对齐陷阱:对比学习与指令微调的“内耗“问题及破解方案
  • 关键词解释:F1值(F1 Score)
  • 大语言模型入门指南:从科普到实战的技术笔记(2)
  • 【RL-LLM】Self-Rewarding Language Models
  • Redis学习笔记-List列表(2)
  • 区块链与以太坊基础:环境搭建与智能合约部署
  • 二维码怎么在网站上做推广微信商店小程序制作教程
  • 毕业设计可以做哪些网站电子商务网站建设前期规划方案
  • Linux 磁盘挂载管理
  • 智能体知识库核心技术解析与实践指南——从文件处理到智能输出的全链路架构v1.2
  • 【Java 基础】 2 面向对象 - 构造器
  • dw6做网站linux做网站服务器那个软件好
  • 生成式人工智能赋能教师专业发展的机制与障碍:基于教师能动性的质性研究
  • 无锡锡山区建设局网站北京网站定制建设
  • 【Word学习笔记】Word如何转高清PDF
  • 小程序地图导航,怎样实现用户体验更好
  • 下流式接入ai
  • PDF无法打印怎么解决?
  • 南宁市网站建设哪家好企业网站模板html
  • 华为数据中心CE系列交换机级联M-LAG配置示例
  • 【HarmonyOS】性能优化——组件的封装与复用
  • 低代码平台的性能优化:解决页面卡顿、加载缓慢问题
  • 开源工程笔记:gitcode/github与性能优化
  • 微页制作网站模板手机上自己做网站吗