【深度学习-pytorch】mnist数字识别
文章目录
- 1. 数据集介绍
1. 数据集介绍
MNIST数据集是torchvision内置的数据集之一,主要包括手写体数字的图片及相应标注
- 下载及加载数据集
from torchvision import datasets
dataset = datasets.MNIST(root='./data', train=True, download=True)
len(dataset), dataset[0]
(60000, (<PIL.Image.Image image mode=L size=28x28>, 5))
- 查看单条数据
import matplotlib.pyplot as plt
plt.figure(figsize=(1,2))
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(dataset[0][0], cmap=plt.cm.binary)
plt.xlabel(dataset[0][1])
plt.show()
- 将数据集封装为数据加载器
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoaderdataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True
)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
for X, Y in dataloader:print(X.shape, Y.shape)print(X, Y)break
torch.Size([2, 1, 28, 28]) torch.Size([2])
tensor([[[[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]],[[[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]]]) tensor([2, 1])