当前位置: 首页 > news >正文

pytorch软件封装

封装代码,通过传入文件名,即可输出类别信息

上一章节,我们做了关于动物图像的分类,接下来我们把程序封装,然后进行预测。

单张图片的predict文件

predict.py

'''
    按着路径,导入单张图片做预测
'''
from torchvision.models import resnet18
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import cv2 as cv
import os
import numpy as np

'''
    加载图片与格式转化
'''

# 图片标准化
transform_BZ = transforms.Normalize(
    mean=[0.5062653, 0.46558657, 0.37899864],  # 取决于数据集
    std=[0.22566116, 0.20558165, 0.21950442]
)

img_size = 224
val_tf = transforms.Compose([  ##简单把图片压缩了变成Tensor模式
    transforms.ToPILImage(),  # 将numpy数组转换为PIL图像
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transform_BZ  # 标准化操作
])


def cv_imread(file_path):
    cv_img = cv.imdecode(np.fromfile(file_path, dtype=np.uint8), cv.IMREAD_COLOR)
    return cv_img


def predict(img_path):
    '''
        获取标签名字
    '''
    # # 增加类别标签
    # dir_names = []
    # for root, dirs, files in os.walk("dataset"):
    #     if dirs:
    #         dir_names = dirs
    # 将输出保存到exel中,方便后续分析
    label_names = ['cat', 'chicken', 'cow', 'dog', 'duck',
                   'goldfish', 'lion', 'pig', 'sheep',
                   'snake']

    # 指定设备
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Using {device} device")

    """
        加载模型
    """
    model = resnet18(weights=None)
    num_ftrs = model.fc.in_features  # 获取全连接层的输入
    model.fc = nn.Linear(num_ftrs, 10)  # 全连接层改为不同的输出
    torch_data = torch.load('./logs_resnet18_adam/best.pth',
                            map_location=torch.device(device))
    model.load_state_dict(torch_data)
    model.to(device)

    '''
        读取图片
    '''
    img = cv_imread(img_path)
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    img_tensor = val_tf(img)

    # 增加batch_size维度
    img_tensor = Variable(torch.unsqueeze(img_tensor, dim=0).float(),
                          requires_grad=False).to(device)

    '''
        数据输入与模型输出转换
    '''
    model.eval()
    with torch.no_grad():
        output_tensor = model(img_tensor)

        # 将输出通过softmax变为概率值
        output = torch.softmax(output_tensor, dim=1)

        # 输出可能性最大的那位
        pred_value, pred_index = torch.max(output, 1)

        # 将数据从cuda转回cpu
        if torch.cuda.is_available() == False:
            pred_value = pred_value.detach().cpu().numpy()
            pred_index = pred_index.detach().cpu().numpy()

        result = "预测类别为: " + str(label_names[pred_index[0]]) + " 可能性为: " + str(pred_value[0].item() * 100)[:5] + "%"
        return result


if __name__ == "__main__":
    img_path = r'dataset/cat/10.jpg'
    result = predict(img_path)
    print(result)

这里可以看出,我们用的cat数据集中的图片,预测出来的结果却是是cat,虽然可能性不是很高。

torch_data=torch.load('./logs_resnet18_adam/best.pth',map_location=torch.device(device))

使用 PyTorch 加载一个保存的模型权重文件(best.pth),并将其映射到指定的设备(device

img_tensor=Variable(torch.unsqueeze(img_tensor,dim=0).float(),requires_grad=False).to(device)

将一个图像张量(img_tensor)进行处理,使其成为适合输入到神经网络模型中的格式,并将其移动到指定的设备(CPU 或 GPU)上

1. torch.unsqueeze(img_tensor, dim=0)
  • 作用:在张量的第 0 维(即最外层)添加一个维度。

  • 背景:神经网络模型通常期望输入数据是一个四维张量,形状为 [batch_size, channels, height, width]。如果 img_tensor 是一个三维张量(例如 [channels, height, width]),则需要在第 0 维添加一个维度,使其形状变为 [1, channels, height, width],其中 1 表示批量大小(batch_size)为 1。

2. .float()
  • 作用:将张量的数据类型转换为 float32

  • 背景:许多神经网络模型在训练和推理时使用 float32 数据类型。如果 img_tensor 的数据类型不是 float32,则需要显式转换。

3. Variable(..., requires_grad=False)
  • 作用:将张量封装为 Variable 对象,并设置 requires_grad 属性。

  • 背景

    • Variable 是 PyTorch 中的一个旧类,用于封装张量并支持自动求导。在较新的 PyTorch 版本中,Variable 已经与 Tensor 合并,因此这一步在现代代码中通常是多余的。

    • requires_grad=False 表示这个张量不需要计算梯度。这在推理阶段非常常见,因为输入数据不需要参与梯度计算。

转成ONNX,兼容各种设备

ONNX是什么?

ONNX(Open Neural Network Exchange)是一种开放的神经网络交换格式,它为深度学习模型提供了一种标准化的表示方式,使得模型可以在不同的深度学习框架之间进行转换和共享。

ONNX的作用是什么?

  • 模型转换:开发者可以将训练好的模型从一个框架(如PyTorch)转换为ONNX格式,然后在另一个框架(如TensorFlow)中加载和使用。这使得开发者可以在不同的框架之间灵活切换,利用不同框架的优势。

  • 模型部署:ONNX模型可以被导出到多种推理引擎,如ONNX Runtime。ONNX Runtime是一个高性能的推理引擎,支持多种硬件平台(如CPU、GPU、FPGA等),可以用于将模型部署到生产环境中。

  • 模型优化:通过ONNX,开发者可以对模型进行优化和量化等操作。例如,可以将模型从浮点数量化为整数,以提高模型的推理速度和降低存储需求。

import torch
from torch import nn
from torchvision.models import resnet18
# pip install onnx
# pip install onnxruntime

if __name__ == '__main__':

    # 指定设备
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")

    # 指定模型
    model = resnet18(pretrained=False)

    num_ftrs = model.fc.in_features    # 获取全连接层的输入
    model.fc = nn.Linear(num_ftrs, 10)  # 全连接层改为不同的输出

    # 模型加载权重
    torch_data = torch.load('logs_resnet18_pretrain/best.pth',
                            map_location=torch.device(device))

    model.load_state_dict(torch_data)
    model.to(device)

    # 创建一个示例输入
    dummy_input = torch.randn(1,3,224,224, device=device)
    # 指定输出文件路径
    onnx_file_path = "logs_resnet18_pretrain/model.onnx"

    # 导出onnx
    torch.onnx.export(model, dummy_input, onnx_file_path,
                      verbose=True,  # 屏幕中打印日志信息
                      input_names=['input'],
                      output_names=['output'])

    print("Model Exported Success")

Netron模型可视化

NETRON查看网络结构

如何下载可以看这篇文章网络可视化工具netron安装流程-CSDN博客

下载过后打开文件

ONNX单张图片预测

# -*- coding: utf-8 -*-
'''
    按着路径,导入单张图片做预测
'''
import onnxruntime as ort  # pip install onnxruntime onnx
import numpy as np
import torchvision.transforms as transforms
import cv2 as cv
import os


def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1, keepdims=True)


def cv_imread(file_path):
    cv_img = cv.imdecode(np.fromfile(file_path, dtype=np.uint8), cv.IMREAD_COLOR)
    return cv_img


def predict(img_path):
    '''
        获取标签名字
    '''
    # dir_names = []
    # for root, dirs, files in os.walk("dataset"):
    #     if dirs:
    #         dir_names = dirs
    # label_names = dir_names

    label_names = ['cat', 'chicken', 'cow', 'dog', 'duck',
                   'goldfish', 'lion', 'pig', 'sheep',
                   'snake']

    '''
        加载图片与格式转化
    '''

    # 图片标准化
    transform_BZ = transforms.Normalize(
        mean=[0.5062653, 0.46558657, 0.37899864],  # 取决于数据集
        std=[0.22566116, 0.20558165, 0.21950442]
    )

    img_size = 224
    val_tf = transforms.Compose([  # 简单把图片压缩了变成Tensor模式
        transforms.ToPILImage(),  # 将numpy数组转换为PIL图像
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transform_BZ  # 标准化操作
    ])

    # 读取图片
    img = cv_imread(img_path)
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    img_tensor = val_tf(img)

    # 将图片转换为ONNX运行时所需的格式
    img_numpy = img_tensor.numpy()
    img_numpy = np.expand_dims(img_numpy, axis=0)  # 增加batch_size维度

    # 加载ONNX模型
    onnx_model_path = r'logs_resnet18_pretrain/model.onnx'  # 替换为ONNX模型的路径
    ort_session = ort.InferenceSession(onnx_model_path)

    # 运行ONNX模型
    outputs = ort_session.run(None, {'input': img_numpy})
    output = outputs[0]

    # 应用softmax
    probabilities = softmax(output)

    # 获得预测结果
    pred_index = np.argmax(probabilities, axis=1)
    pred_value = probabilities[0][pred_index[0]]

    result = "预测类别为: " + str(label_names[pred_index[0]]) + " 可能性为: " + str(pred_value * 100)[:5] + "%"
    return result


if __name__ == "__main__":
    img_path = r'dataset/cat/10.jpg'
    result = predict(img_path)
    print(result)

这个没什么好讲的,就是可以直接封装成了一个onnx,可以不用安装pytorch库

PyQt5做预测模型

接下来先请大家准备一些库,看一看下面这篇文章PyCharm配置外部工具PyQtDesigner、PyUIC、Pyrcc_pycharm外部工具-CSDN博客

我把所有的文件封装了一下,大家要记得改一改路径

main_one_thread,py

# -*- coding: utf-8 -*-
from mainwindow import Ui_MainWindow
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog
import sys
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtWidgets import *
from predict封装 import predict


class UiMain(QMainWindow, Ui_MainWindow):
    def __init__(self, parent=None):
        super(UiMain, self).__init__(parent)
        self.setupUi(self)
        self.fileBtn.clicked.connect(self.loadImage)


    # 打开文件功能
    def loadImage(self):
        self.fname, _ = QFileDialog.getOpenFileName(self, '请选择图片','.','图像文件(*.jpg *.jpeg *.png)')
        if self.fname:
            print(self.fname)
            self.Infolabel.setText("文件打开成功\n"+self.fname)
            jpg = QtGui.QPixmap(self.fname).scaled(self.Imglabel.width(),
                                                   self.Imglabel.height())

            self.Imglabel.setPixmap(jpg)

            result = predict(self.fname)
            self.Infolabel.setText(result)

        else:
            # print("打开文件失败")
            self.Infolabel.setText("打开文件失败")


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ui = UiMain()
    ui.show()
    sys.exit(app.exec_())

运行结果

但是这个文件如果打包,别人不一定能用

main_one_thread_onnx.py

from mainwindow import Ui_MainWindow
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog
import sys
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtWidgets import *
from predict_onnx import predict


class UiMain(QMainWindow, Ui_MainWindow):
    def __init__(self, parent=None):
        super(UiMain, self).__init__(parent)
        self.setupUi(self)
        self.fileBtn.clicked.connect(self.loadImage)


    # 打开文件功能
    def loadImage(self):
        self.fname, _ = QFileDialog.getOpenFileName(self, '请选择图片','.','图像文件(*.jpg *.jpeg *.png)')
        if self.fname:
            print(self.fname)
            self.Infolabel.setText("文件打开成功\n"+self.fname)
            jpg = QtGui.QPixmap(self.fname).scaled(self.Imglabel.width(),
                                                   self.Imglabel.height())

            self.Imglabel.setPixmap(jpg)

            result = predict(self.fname)
            self.Infolabel.setText(result)

        else:
            # print("打开文件失败")
            self.Infolabel.setText("打开文件失败")


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ui = UiMain()
    ui.show()
    sys.exit(app.exec_())
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.dtcms.com/a/124057.html

相关文章:

  • Spring基本概念
  • 模拟-与-现实协同训练:基于视觉机器人操控的简单方法
  • Netty之ChannelOutboundBuffer详解与实战
  • 虚拟dom工作原理以及渲染过程
  • Ruoyi-vue plus 5.2.2 flowble 结束节点异常错误
  • 基于CNN-BiLSTM-GRU的深度Q网络(Deep Q-Network,DQN)求解移动机器人路径规划,MATLAB代码
  • 30天学Java第八天——设计模式
  • mmrotate训练自己的数据(记录)
  • 使用多进程和 Socket 接收解析数据并推送到 Kafka 的高性能架构
  • 使用js创建img加载阿里云oss图片跨域的问题
  • opencv常用边缘检测算子示例
  • Java 并发-newFixedThreadPool
  • Java——接口扩展
  • 记录一下移动端uView动态表单校验
  • 安装npm install element-plus --save报错
  • OpenCV 图形API(24)图像滤波-----双边滤波函数bilateralFilter()
  • 随机森林与决策树
  • 什么是虚拟线程?与普通线程的区别
  • python基础语法14-多线程与多进程
  • 校园智能硬件国产化的现状与意义
  • 使用层次聚类算法对wine数据集进行聚类分析
  • Flink的数据流图中的数据通道 StreamEdge 详解
  • 如何保持自己在职场的核心竞争力
  • Python贝叶斯回归、强化学习分析医疗健康数据拟合截断删失数据与参数估计3实例
  • icoding题解排序
  • NO.87十六届蓝桥杯备战|动态规划-完全背包|疯狂的采药|Buying Hay|纪念品(C++)
  • x265 编码器中运动搜索 ME 方法对比实验
  • C++基础精讲-03
  • 苍穹外卖总结
  • 【Web API系列】WebSocketStream API 深度实践:构建高吞吐量实时应用的流式通信方案