PyTorch 实现 CIFAR - 10 图像分类
一、数据加载
借助 PyTorch 的 torchvision
工具加载 CIFAR - 10 数据集。预先将数据下载解压到当前目录的 data
目录,所以设置 download=False
。通过 transforms.Compose
对数据做预处理,把图像转成张量并标准化。分别加载训练集和测试集,用 DataLoader
按批次加载数据,还定义了包含 10 类物体的类别标签。另外,能随机获取部分训练数据,通过自定义函数显示图像并打印标签,直观查看数据情况。
二、构建网络
依据卷积神经网络架构示意图构建网络。创建 CNNNet
类继承 nn.Module
,在初始化方法里定义卷积层、池化层和全连接层等组件。forward
方法定义数据在网络中的前向传播路径,先经过卷积、激活、池化操作,再通过 view
调整张量形状,最后经过全连接层得到输出。同时,还能查看网络的参数总数以及网络结构,也可获取模型的前几层。
三、训练模型
训练时设置训练轮数为 10。在每一轮训练中,遍历训练数据加载器。先获取训练数据并将其移到合适的设备(CPU 或 GPU),然后将优化器的梯度清零。接着进行正向传播得到网络输出,计算损失值,再反向传播计算梯度,最后通过优化器更新网络参数。每经过一定数量的小批次,就显示当前的损失值,训练结束后提示 “Finished Training”。
四、模型评估
从测试数据加载器中获取数据,先显示测试图像并打印真实标签。然后将图像和标签移到相应设备,把图像输入训练好的网络得到输出。通过 torch.max
函数获取预测的类别,最后打印出预测的类别,以此评估模型在测试集上的表现。