当前位置: 首页 > news >正文

深度学习中的模型量化及实现示例

      深度学习模型的量化(quantization)是一种模型优化技术,旨在在不显著影响模型准确率的情况下,减少神经网络的计算负载和内存占用。深度学习中模型量化的主要任务是将神经网络中的高精度浮点数(如FP32)转换为低精度表示(如FP16、INT8),从而减少模型大小、内存占用、加快推理速度、降低功耗等。模型量化的本质是函数映射

      量化的基本思想:将权重和激活的高精度表示(通常是常规的32位浮点数)转换为较低精度的数据类型。

      量化类型

      1.训练后量化(Post-Training Quantization, PTQ):是在使用32位精度训练模型后应用,无需重新训练。在PTQ中,模型的权重(有时包括激活(activations))会从FP32转换为较低精度格式,例如INT8。PTQ有三种常见的类型:

      (1).动态量化:此方法仅将权重量化为INT8,并在推理过程中将激活保留为FP32。在推理过程中,激活会根据其值范围进行动态量化。

      (2).静态量化:此方法中,权重和激活均量化为INT8。这通常需要在部署之前使用样本数据集进行校准(calibration, 是量化过程中计算FP32值范围的步骤),以估计激活的动态范围。

      (3).混合精度:浮点16位量化,与INT8量化不同,它将权重和激活量化为FP16,从而保留了FP32的部分动态范围,同时提供更快的推理速度。

      2.量化感知训练(Quantization-Aware Training, QAT):在训练过程中模拟量化。模型学习使其权重适应较低精度的格式,与PTQ相比,准确度下降幅度更小。在QAT中,模型使用伪量化(fake quantization)进行训练,其中FP32值在前向传播过程中被四舍五入(rounded)以模拟低精度,但梯度和更新(gradients and updates)仍保留在FP32中。

      量化技术

      (1).均匀量化(uniform quantization):浮点值的整个范围被划分为相等的区间(interval),每个区间都用一个量化值表示。从FP32到INT8的转换涉及使用缩放因子和偏移量将值从浮点域映射到整数域。

      (2).非均匀量化(non-uniform quantization):将浮点值范围划分为大小不等的区间,从而为频繁出现或对模型性能至关重要的值提供更多表示(representation)。

      (3).最小-最大量化(min-max quantization):使用权重或激活的最小值和最大值来定义量化范围。然后,根据此范围将值线性缩放到整数域。

      (4).对数量化(logarithmic quantization):值基于对数尺度(logarithmic scale)进行量化,这可以为非常小或非常大的值提供更好的精度。

      量化过程:量化包含权重量化和激活量化两个主要部分。该过程通常遵循以下步骤:

      (1).选择范围:权重和激活值的范围可以通过静态或动态方式确定。在静态量化中,范围由校准数据集(calibration dataset)确定;而在动态量化中,范围在运行时确定。

      (2).缩放和零点:定义范围后,将值缩放到整数范围内。计算缩放因子和零点(即整数偏移量)。零点允许将有符号值(例如-1到1)映射到无符号整数范围(例如INT8的0到255)。

      (3).量化和反量化:在推理过程中,使用缩放因子和零点将浮点值量化为整数。在反量化过程中,使用相同的缩放因子将整数映射回浮点值。

      最常见的两种量化情况是: FP32 -> FP16和FP32 -> INT8

      PyTorch提供三种不同的量化模式:Eager模式量化、FX Graph模式量化、PyTorch 2导出量化。

      注:以上内容主要来自于:

      1. https://www.geeksforgeeks.org

      2. https://docs.pytorch.org/docs/stable/quantization.html

      以下是使用预训练的DenseNet二分类模型进行静态量化的实现代码:

      1. 依赖的模块如下所示:

import argparse
import colorama
import ast
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.ao.quantization as quantization
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
import copy
from PIL import Image
from pathlib import Path

      2. 支持的输入参数如下所示:

def parse_args():parser = argparse.ArgumentParser(description="model quantization")parser.add_argument("--task", required=True, type=str, choices=["quantize", "predict"], help="specify what kind of task")parser.add_argument("--src_model", type=str, help="source model name")parser.add_argument("--dst_model", type=str, help="quantized model name")parser.add_argument("--classes_number", type=int, default=2, help="classes number")parser.add_argument("--mean", type=str, help="the mean of the training set of images")parser.add_argument("--std", type=str, help="the standard deviation of the training set of images")parser.add_argument("--labels_file", type=str, help="one category per line, the format is: index class_name")parser.add_argument("--images_path", type=str, help="predict images path")parser.add_argument("--dataset_path", type=str, help="source dataset path")args = parser.parse_args()return args

      3.量化代码如下:

def _str2tuple(value):if not isinstance(value, tuple):value = ast.literal_eval(value) # str to tuplereturn valuedef _load_dataset(dataset_path, mean, std, batch_size):mean = _str2tuple(mean)std = _str2tuple(std)calibration_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])calibration_dataset = ImageFolder(root=dataset_path+"/train", transform=calibration_transform)print(f"calibration dataset length: {len(calibration_dataset)}; classes: {calibration_dataset.class_to_idx}; number of categories: {len(calibration_dataset.class_to_idx)}")calibration_loader = DataLoader(calibration_dataset, batch_size, shuffle=True, num_workers=0)return calibration_loaderdef _calibrate(model, calibration_loader, num_batches=20):model.eval()with torch.no_grad():for i, (x, _) in enumerate(calibration_loader):x = x.to(torch.device("cpu"))_ = model(x)if i + 1 >= num_batches:breakdef quantize(src_model, device, classes_number, dataset_path, mean, std, dst_model):# load modelmodel = models.densenet121(weights=None)model.classifier = nn.Linear(model.classifier.in_features, classes_number)model.load_state_dict(torch.load(src_model, weights_only=False, map_location="cpu"))model.to(device)model.eval()# prepare quantization: fxqconfig_mapping = quantization.get_default_qconfig_mapping('x86')model_prepared = prepare_fx(copy.deepcopy(model), qconfig_mapping, example_inputs=torch.randn(1, 3, 224, 224))model_prepared.eval()# load datasetcalibration_loader = _load_dataset(dataset_path, mean, std, 4)# calibration_calibrate(model_prepared, calibration_loader)# quantize: INT8quantized_model = convert_fx(model_prepared)quantized_model.eval()# save modelscripted_model = torch.jit.script(quantized_model)scripted_model.save(dst_model)

      (1).静态量化必须要校准,目的是收集激活的每一层的动态范围,确定量化参数

      (2).保存和加载量化后的模型时使用:torch.jit.script/torch.jit.load;别使用torch.save/torch.load,会导致序列号问题

      执行结果如下图所示:

      4.预测代码如下:

def _get_images_list(images_path):image_names = []p = Path(images_path)for subpath in p.rglob("*"):if subpath.is_file():image_names.append(subpath)return image_namesdef _parse_labels_file(labels_file):classes = {}with open(labels_file, "r") as file:for line in file:idx_value = []for v in line.split(" "):idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the lineassert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"classes[int(idx_value[0])] = idx_value[1]return classesdef predict(model_name, device, labels_file, images_path, mean, std):model = torch.jit.load(model_name, map_location="cpu")model.to(device)model.eval()mean = _str2tuple(mean)std = _str2tuple(std)image_names = _get_images_list(images_path)assert len(image_names) != 0, "no images found"classes = _parse_labels_file(labels_file)assert len(classes) != 0, "the number of categories can't be 0"model.eval()with torch.no_grad():for image_name in image_names:input_image = Image.open(image_name)preprocess = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])input_tensor = preprocess(input_image) # (c,h,w)input_batch = input_tensor.unsqueeze(0) # (1,c,h,w)input_batch = input_batch.to(device)output = model(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0)max_value, max_index = torch.max(probabilities, dim=0)print(f"{image_name.name}\t{classes[max_index.item()]}\t{max_value.item():.4f}")

      注:预测时不需要把测试图像转换成INT8,模型内部的QuantStub和DeQuantStub会自动完成 FP32与INT8之间的转换

      执行结果如下图所示:

      5.入口函数如下所示:

if __name__ == "__main__":colorama.init(autoreset=True)args = parse_args()device = torch.device("cpu")if args.task == "quantize":quantize(args.src_model, device, args.classes_number, args.dataset_path, args.mean, args.std, args.dst_model)elif args.task == "predict":predict(args.dst_model, device, args.labels_file, args.images_path, args.mean, args.std)print(colorama.Fore.GREEN + "====== execution completed ======")

      GitHub:https://github.com/fengbingchun/NN_Test

http://www.dtcms.com/a/348987.html

相关文章:

  • 【RAGFlow代码详解-4】数据存储层
  • MySQL学习记录-基础知识及SQL语句
  • 【零代码】OpenCV C# 快速开发框架演示
  • 在 Docker 容器中查看 Python 版本
  • C语言第十二章自定义类型:结构体
  • LangChain RAG系统开发基础学习之文档切分
  • Python核心技术开发指南(016)——表达式
  • 多线程——认识Thread类和创建线程
  • 【记录】Docker|Docker镜像拉取超时的问题、推荐的解决办法及安全校验
  • FPGA时序分析(四)
  • asio的线程安全
  • 使用Cobra 完成CLI开发 (一)
  • 3.1 存储系统概述 (答案见原书 P149)
  • C++ string自定义类的实现
  • 【论文阅读 | arXiv 2025 | WaveMamba:面向RGB-红外目标检测的小波驱动Mamba融合方法】
  • 上科大解锁城市建模新视角!AerialGo:从航拍视角到地面漫步的3D城市重建
  • 深度剖析Spring AI源码(三):ChatClient详解,优雅的流式API设计
  • R60ABD1 串口通信实现
  • 在 Ubuntu 24.04 或 22.04 LTS 服务器上安装、配置和使用 Fail2ban
  • 【Qwen Image】蒸馏版与非蒸馏版 评测小结
  • 第3篇:配置管理的艺术 - 让框架更灵活
  • 多线程下单例如何保证
  • [身份验证脚手架] 前端认证与个人资料界面
  • 2025.8.18-2025.8.24第34周:有内耗有挣扎
  • Spring Cloud 快速通关之Sentinel
  • 遥感机器学习入门实战教程|Sklearn案例⑩:降维与分解(decomposition 模块)
  • [e3nn] 等变神经网络 | 线性层o3.Linear | 非线性nn.Gate
  • 动态规划--编译距离
  • AI代码生成器全面评测:六个月、500小时测试揭示最强开发助手
  • Redis 高可用篇