深度学习中的模型量化及实现示例
深度学习模型的量化(quantization)是一种模型优化技术,旨在在不显著影响模型准确率的情况下,减少神经网络的计算负载和内存占用。深度学习中模型量化的主要任务是将神经网络中的高精度浮点数(如FP32)转换为低精度表示(如FP16、INT8),从而减少模型大小、内存占用、加快推理速度、降低功耗等。模型量化的本质是函数映射。
量化的基本思想:将权重和激活的高精度表示(通常是常规的32位浮点数)转换为较低精度的数据类型。
量化类型:
1.训练后量化(Post-Training Quantization, PTQ):是在使用32位精度训练模型后应用,无需重新训练。在PTQ中,模型的权重(有时包括激活(activations))会从FP32转换为较低精度格式,例如INT8。PTQ有三种常见的类型:
(1).动态量化:此方法仅将权重量化为INT8,并在推理过程中将激活保留为FP32。在推理过程中,激活会根据其值范围进行动态量化。
(2).静态量化:此方法中,权重和激活均量化为INT8。这通常需要在部署之前使用样本数据集进行校准(calibration, 是量化过程中计算FP32值范围的步骤),以估计激活的动态范围。
(3).混合精度:浮点16位量化,与INT8量化不同,它将权重和激活量化为FP16,从而保留了FP32的部分动态范围,同时提供更快的推理速度。
2.量化感知训练(Quantization-Aware Training, QAT):在训练过程中模拟量化。模型学习使其权重适应较低精度的格式,与PTQ相比,准确度下降幅度更小。在QAT中,模型使用伪量化(fake quantization)进行训练,其中FP32值在前向传播过程中被四舍五入(rounded)以模拟低精度,但梯度和更新(gradients and updates)仍保留在FP32中。
量化技术
(1).均匀量化(uniform quantization):浮点值的整个范围被划分为相等的区间(interval),每个区间都用一个量化值表示。从FP32到INT8的转换涉及使用缩放因子和偏移量将值从浮点域映射到整数域。
(2).非均匀量化(non-uniform quantization):将浮点值范围划分为大小不等的区间,从而为频繁出现或对模型性能至关重要的值提供更多表示(representation)。
(3).最小-最大量化(min-max quantization):使用权重或激活的最小值和最大值来定义量化范围。然后,根据此范围将值线性缩放到整数域。
(4).对数量化(logarithmic quantization):值基于对数尺度(logarithmic scale)进行量化,这可以为非常小或非常大的值提供更好的精度。
量化过程:量化包含权重量化和激活量化两个主要部分。该过程通常遵循以下步骤:
(1).选择范围:权重和激活值的范围可以通过静态或动态方式确定。在静态量化中,范围由校准数据集(calibration dataset)确定;而在动态量化中,范围在运行时确定。
(2).缩放和零点:定义范围后,将值缩放到整数范围内。计算缩放因子和零点(即整数偏移量)。零点允许将有符号值(例如-1到1)映射到无符号整数范围(例如INT8的0到255)。
(3).量化和反量化:在推理过程中,使用缩放因子和零点将浮点值量化为整数。在反量化过程中,使用相同的缩放因子将整数映射回浮点值。
最常见的两种量化情况是: FP32 -> FP16和FP32 -> INT8
PyTorch提供三种不同的量化模式:Eager模式量化、FX Graph模式量化、PyTorch 2导出量化。
注:以上内容主要来自于:
1. https://www.geeksforgeeks.org
2. https://docs.pytorch.org/docs/stable/quantization.html
以下是使用预训练的DenseNet二分类模型进行静态量化的实现代码:
1. 依赖的模块如下所示:
import argparse
import colorama
import ast
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.ao.quantization as quantization
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
import copy
from PIL import Image
from pathlib import Path
2. 支持的输入参数如下所示:
def parse_args():parser = argparse.ArgumentParser(description="model quantization")parser.add_argument("--task", required=True, type=str, choices=["quantize", "predict"], help="specify what kind of task")parser.add_argument("--src_model", type=str, help="source model name")parser.add_argument("--dst_model", type=str, help="quantized model name")parser.add_argument("--classes_number", type=int, default=2, help="classes number")parser.add_argument("--mean", type=str, help="the mean of the training set of images")parser.add_argument("--std", type=str, help="the standard deviation of the training set of images")parser.add_argument("--labels_file", type=str, help="one category per line, the format is: index class_name")parser.add_argument("--images_path", type=str, help="predict images path")parser.add_argument("--dataset_path", type=str, help="source dataset path")args = parser.parse_args()return args
3.量化代码如下:
def _str2tuple(value):if not isinstance(value, tuple):value = ast.literal_eval(value) # str to tuplereturn valuedef _load_dataset(dataset_path, mean, std, batch_size):mean = _str2tuple(mean)std = _str2tuple(std)calibration_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])calibration_dataset = ImageFolder(root=dataset_path+"/train", transform=calibration_transform)print(f"calibration dataset length: {len(calibration_dataset)}; classes: {calibration_dataset.class_to_idx}; number of categories: {len(calibration_dataset.class_to_idx)}")calibration_loader = DataLoader(calibration_dataset, batch_size, shuffle=True, num_workers=0)return calibration_loaderdef _calibrate(model, calibration_loader, num_batches=20):model.eval()with torch.no_grad():for i, (x, _) in enumerate(calibration_loader):x = x.to(torch.device("cpu"))_ = model(x)if i + 1 >= num_batches:breakdef quantize(src_model, device, classes_number, dataset_path, mean, std, dst_model):# load modelmodel = models.densenet121(weights=None)model.classifier = nn.Linear(model.classifier.in_features, classes_number)model.load_state_dict(torch.load(src_model, weights_only=False, map_location="cpu"))model.to(device)model.eval()# prepare quantization: fxqconfig_mapping = quantization.get_default_qconfig_mapping('x86')model_prepared = prepare_fx(copy.deepcopy(model), qconfig_mapping, example_inputs=torch.randn(1, 3, 224, 224))model_prepared.eval()# load datasetcalibration_loader = _load_dataset(dataset_path, mean, std, 4)# calibration_calibrate(model_prepared, calibration_loader)# quantize: INT8quantized_model = convert_fx(model_prepared)quantized_model.eval()# save modelscripted_model = torch.jit.script(quantized_model)scripted_model.save(dst_model)
(1).静态量化必须要校准,目的是收集激活的每一层的动态范围,确定量化参数
(2).保存和加载量化后的模型时使用:torch.jit.script/torch.jit.load;别使用torch.save/torch.load,会导致序列号问题
执行结果如下图所示:
4.预测代码如下:
def _get_images_list(images_path):image_names = []p = Path(images_path)for subpath in p.rglob("*"):if subpath.is_file():image_names.append(subpath)return image_namesdef _parse_labels_file(labels_file):classes = {}with open(labels_file, "r") as file:for line in file:idx_value = []for v in line.split(" "):idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the lineassert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"classes[int(idx_value[0])] = idx_value[1]return classesdef predict(model_name, device, labels_file, images_path, mean, std):model = torch.jit.load(model_name, map_location="cpu")model.to(device)model.eval()mean = _str2tuple(mean)std = _str2tuple(std)image_names = _get_images_list(images_path)assert len(image_names) != 0, "no images found"classes = _parse_labels_file(labels_file)assert len(classes) != 0, "the number of categories can't be 0"model.eval()with torch.no_grad():for image_name in image_names:input_image = Image.open(image_name)preprocess = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])input_tensor = preprocess(input_image) # (c,h,w)input_batch = input_tensor.unsqueeze(0) # (1,c,h,w)input_batch = input_batch.to(device)output = model(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0)max_value, max_index = torch.max(probabilities, dim=0)print(f"{image_name.name}\t{classes[max_index.item()]}\t{max_value.item():.4f}")
注:预测时不需要把测试图像转换成INT8,模型内部的QuantStub和DeQuantStub会自动完成 FP32与INT8之间的转换
执行结果如下图所示:
5.入口函数如下所示:
if __name__ == "__main__":colorama.init(autoreset=True)args = parse_args()device = torch.device("cpu")if args.task == "quantize":quantize(args.src_model, device, args.classes_number, args.dataset_path, args.mean, args.std, args.dst_model)elif args.task == "predict":predict(args.dst_model, device, args.labels_file, args.images_path, args.mean, args.std)print(colorama.Fore.GREEN + "====== execution completed ======")
GitHub:https://github.com/fengbingchun/NN_Test