完整的模型验证(测试)套路
CIFAR10测试集分类

单个图片测试代码:
import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d, Sequential, Conv2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from PIL import Imageimage_path = './image/jinx.jpg'
image = Image.open(image_path)
image = image.convert('RGB')
print(image)transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)# 创建网络模型
class Chenxi(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10),)def forward(self, x):x = self.model1(x)return xmodel = torch.load('chenxi_0.pth') # 调用自己训练好的模型
# model = torch.load('chenxi_0.pth', map_location=torch.device('cpu'))
print(model)image = torch.reshape(image, (1, 3, 32, 32)) # 关于batch_sizeimage = image.cuda()model.eval()
with torch.no_grad():output = model(image)
print(output)
# tensor([[-0.7223, 0.4807, -0.1583, 0.2034, -0.0316, 0.4585, 0.3231, 0.1076,# -0.4046, 0.2497]], device='cuda:0')print(output.argmax(1)) # 转换输出类型
# tensor([1], device='cuda:0')