当前位置: 首页 > news >正文

pytorch 模型测试

在使用 PyTorch 进行模型测试时,一般包含加载测试数据、加载训练好的模型、进行推理以及评估模型性能等步骤。以下为你详细介绍每个步骤及对应的代码示例。

1. 导入必要的库

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

2. 加载测试数据

假设我们使用的是 CIFAR - 10 数据集作为示例,你需要定义数据预处理的转换操作,然后加载测试数据集。

# 定义数据预处理的转换操作
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# 类别标签
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3. 定义模型结构

如果你已经有训练好的模型,这一步可以跳过。但为了完整性,这里给出一个简单的卷积神经网络(CNN)示例。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

4. 加载训练好的模型

假设你已经将训练好的模型保存为 cifar_net.pth 文件,现在可以加载它。

# 加载模型
net.load_state_dict(torch.load('cifar_net.pth'))

5. 进行推理和评估

在测试阶段,我们需要将模型设置为评估模式,然后遍历测试数据集,对每个样本进行推理,并计算模型的准确率。

# 将模型设置为评估模式
net.eval()

correct = 

相关文章:

  • 刷题记录10
  • 下载谷歌浏览器(Chrome)
  • HttpServletRequest 和 HttpServletResponse 不同JDK版本的引入
  • 23种设计模式之单例模式(Singleton Pattern)【设计模式】
  • 【三.大模型实战应用篇】【4.智能学员辅导系统:docx转PDF的自动化流程】
  • 基于springboot的丢失儿童的基因比对系统(源码+lw+部署文档+讲解),源码可白嫖!
  • SFP28(25 Gigabit Small Form-factor Pluggable)详解
  • STM32-FOC-SDK包含以下关键知识点
  • 算法基础 -- 字符串哈希的基本概念和数学原理分析
  • Linux常用指令学习笔记
  • 以1.7K深圳小区房价为例,浙大GIS实验室使用注意力机制挖掘地理情景特征,提升空间非平稳回归精度
  • 蓝桥与力扣刷题(蓝桥 k倍区间)
  • JavaScript 系列之:事件
  • 使用Docker搭建Oracle Database 23ai Free并扩展MAX_STRING_SIZE的完整指南
  • C++基础算法:模拟
  • Redis 哨兵模式
  • 本地部署大数据集群前置准备
  • Java中常见的设计模式
  • Qt信号与槽机制
  • 调用的子组件中使用v-model绑定数据以及使用@调用方法
  • 新时代,新方志:2025上海地方志论坛暨理论研讨会举办
  • 美官方将使用华为芯片视作违反美出口管制行为,外交部回应
  • 绿景中国地产:洛杉矶酒店出售事项未能及时披露纯属疏忽,已采取补救措施
  • 古巴外长谴责美国再次将古列为“反恐行动不合作国家”
  • 金正恩观摩朝鲜人民军各兵种战术综合训练
  • 经济日报整版聚焦:上海构建法治化营商环境,交出高分答卷