当前位置: 首页 > news >正文

基于CNN的FashionMNIST数据集识别3——模型验证

源码

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet



def test_data_process():
    test_data = FashionMNIST(root='./data',
                              train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)
    return test_dataloader


def test_model_process(model, test_dataloader):
    # 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
    device = "cuda" if torch.cuda.is_available() else 'cpu'

    # 讲模型放入到训练设备中
    model = model.to(device)

    # 初始化参数
    test_corrects = 0.0
    test_num = 0

    # 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output= model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)
            # 如果预测正确,则准确度test_corrects加1
            test_corrects += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

    # 计算测试准确率
    test_acc = test_corrects.double().item() / test_num
    print("测试的准确率为:", test_acc)




if __name__=="__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('best_model.pth'))
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)

源码讲解

当模型训练完毕后,我们得到的是一组最优的参数配置。

最后要做的就是验证这组参数的表现。

数据准备

def test_data_process():
    test_data = FashionMNIST(root='./data',
                              train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)
    return test_dataloader

测试数据的准备和训练数据的准备有明显不同:

  1. 测试数据集是将MINIST里所有的数据当做验证集。并且训练模式设置为false。
  2. dataloader里面的batch大小设置为1,也就是说对每个样本进行验证,不再存在“分批”的概念。

循环验证

    # 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output= model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)
            # 如果预测正确,则准确度test_corrects加1
            test_corrects += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

和之前训练模型时的验证逻辑基本相同。只进行前向传播,将预测正确的样本个数进行累加。

代码运行

if __name__=="__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('best_model.pth'))
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)

在代码运行时,需要将参数载入到模型里,再进行验证。

相关文章:

  • D. C05.L08.贪心算法入门(一).课堂练习4.危险的实验(NHOI2015初中)
  • 清华大学102页PPT 《deepseek从入门到精通》
  • 使用Python脚本转换YOLOv5配置文件到https://github.com/ultralytics/ultralytics:一个详细的指南
  • 《道德经的现代智慧:解码生活与商业的底层逻辑1》
  • escape SQL中用法
  • 9-1. MySQL 性能分析工具的使用——last_query_cost,慢查询日志
  • 修改/etc/hosts并生效
  • 蓝禾,oppo,游卡,汤臣倍健,康冠科技,作业帮,高途教育25届春招内推
  • jmeter 接入deepseek 或者chatgpt
  • qt.qpa.fonts: Unable to open default EUDC font: “EUDC.TTE“
  • MATLAB中isletter函数用法
  • 爬虫与反爬-Ja3指纹风控(Just a moment...)处理方案及参数说明
  • 软件架构设计:软件工程
  • 【学习资料】嵌入式人工智能Embedded AI
  • SCSS——CSS的扩展和进化
  • java 单例模式(Lazy Initialization)实现遍历文件夹下所有excel文件且返回其运行时间
  • 【Java从入门到起飞】数组
  • Pycharm下载|附安装包+详细安装教程
  • 网卡驱动架构以及源码分析
  • 炫影智能轻云盒(智慧小盒)移动版SY910_RK3528芯片_2+8G_安卓9.0_免拆固件包
  • 自己怎么做视频网站/长春seo外包
  • 加工厂网站建设/信息流广告的特点
  • 最好的网站模板下载网站/seoul是什么品牌
  • wordpress 菜单跳转/1688关键词怎么优化
  • 杭州网站建设/微商推广哪家好
  • 互联网网站建设公司/百度seo点击排名优化