当前位置: 首页 > news >正文

征程 6 J6E/M linear 双int16量化支持替代方案

1.背景简介

当发现使用 plugin 精度 debug 工具定位到是某个 linear 敏感时,示例如下:

op_name                                sensitive_type    op_type                                                                          L1  quant_dtype    flops
-------------------------------------  ---------------   -----------------------------  ----------------  -------------------------  -------  -------------  --------------
model.layernorm.rsqrt                  activation        <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'>              6.52537  qint16         0(0%)
model.linear2                          weight            <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'>                       5.02445  qint8          3072000(0.00%)
model.layernorm.var_mean.pre_mean      activation        <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>  3.1683   qint16         0(0%)

可以发现,model.linear2 weight 排在了前面,且是 int8 量化。

接下来看下 baseline_statistic.txt 与 analysis_statistic.txt,其中有 model.linear2 的 input、weight、output 的数值分布范围,示例如下:

| Op Name                            | Mod Name       | Attr     | Min            | Max            | Mean           | Var        | Shape                       |
|---------------------------------------------------------------------------------------------------------------------------------------------------------------
| torch.nn.modules.linear.Linear     | model.linear2  | input    | 0.0000000      | 15.4210167     | 4.0793311      | 0.2532279  | torch.Size([2, 100, 256])   |
| torch.nn.modules.linear.Linear     | model.linear2  | weight   | -41.6590347    | 31.2311363     | -0.0053362     | 0.4427260  | torch.Size([60, 256])       |
| torch.nn.modules.linear.Linear     | model.linear2  | bias     | -0.4426649     | 0.3714900      | 0.0053294      | 0.0112585  | torch.Size([60])            |
| torch.nn.modules.linear.Linear     | model.linear2  | output   | -32.0065079    | 5.7881856      | 0.4558742      | 3.8736136  | torch.Size([2, 100, 60])    |

解决方案:使用 int16 来量化这个敏感 linear 的 weight。

如果必须要求 linear input weight output 都是 int16 量化,怎么办呢?

2.知识基础

在 征程 6E/M 上,地平线 BPU 对 linear 支持的情况如下:

本文发布时是这样的

Description

可以看到:input 和 weight 不能同时为 int16。

3.Linear input weight both int16

对于 linear input 和 weight 均需要 int16 量化的情况,可使用 broadcast mul sum 来替代验证,无需重训 float。

异同简介:broadcast_mul_sum_replace_linear 在 float 层面可以等价替换 linear,但在量化方式上存在区别:Linear weight 是 per channel 量化,weight 作为 mul 输入时,是 per tensor 量化。一般情况下:weight int8 perchannel 变成 per tensor int16,精度是正向优化。

替换方案:在 float 训练完成后替换,然后进行 calib+qat。

class SmallModel(nn.Module):def __init__(self, linear2_weight, linear2_bias):super(SmallModel, self).__init__()# 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]self.linear1 = nn.Linear(256, 256)self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化self.relu = nn.ReLU()# 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]# self.linear2 = nn.Linear(256, 60)self.linear2_weight = linear2_weightself.linear2_bias = linear2_bias# 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]self.linear3 = nn.Linear(60, 60)self.quant = QuantStub()self.dequant = DeQuantStub()self.quant_linear2_weight = QuantStub()self.quant_linear2_bias = QuantStub()def forward(self, x):x = self.quant(x)linear2_weight = self.quant_linear2_weight(self.linear2_weight)linear2_bias = self.quant_linear2_bias(self.linear2_bias)# 第一个 Linearx = self.linear1(x)  # [2, 100, 256]x = self.layernorm(x)  # [2, 100, 256]x = self.relu(x)  # [2, 100, 256]# 第二个 Linear# x = self.linear2(x)  # [2, 100, 60]# ===================================# 使用 broadcast mul + sum 替换linear# ===================================# 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256)  # [2, 100, 60, 256]# 按最后一个维度求和:sum 操作模拟线性层的加权求和sum_output = broadcast_mul.sum(dim=-1)  # [2, 100, 60]# 加上偏置x = sum_output + linear2_bias  # [2, 100, 60]# 第三个 Linearx = self.linear3(x)x = self.dequant(x)return x

broadcast mul sum 替换方案,均支持 int16。

注意事项:如果 mul 的输出 绝大多数 数值都在 0 附近 -> MSE 校准受异常值影响较大 -> 输出 scale 非常大 -> 0 附近的大量小数值被舍入成 0 -> sum 和发生巨大偏差。

影响范围:mul 后面跟着 sigmoid 或 add+sigmoid 时影响很大。

解决方案:mul 输出设置 fixed scale 为 7/32767,因为 sigmoid 并不需要太大的输入,而 mul 的输出分布需要小 scale。

4.全流程示例

从表中可以看到,在 linear 需要 int16 量化的场景,input/output int16 对应的 latency 最短,其次是 weight output int16 input int8,最差的是三者都需要 int16,针对这三种情况,下面分别提供完整的例子供参考。

信息描述

Description

注意:非完全等价,仅作为参考

4.1 示例代码

import torch
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_M)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn
from horizon_plugin_pytorch.quantization import hbdk4 as hb4
from hbdk4.compiler import convert, save, hbm_perf, visualize, compileimport torch
import torch.nn as nn# 定义网络结构
class SmallModel(nn.Module):def __init__(self, linear2_weight, linear2_bias):super(SmallModel, self).__init__()# 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]self.linear1 = nn.Linear(256, 256)self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化self.relu = nn.ReLU()# 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]# self.linear2 = nn.Linear(256, 60)self.linear2_weight = linear2_weightself.linear2_bias = linear2_bias# 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]self.linear3 = nn.Linear(60, 60)self.quant = QuantStub()self.dequant = DeQuantStub()self.quant_linear2_weight = QuantStub()self.quant_linear2_bias = QuantStub()def forward(self, x):x = self.quant(x)linear2_weight = self.quant_linear2_weight(self.linear2_weight)linear2_bias = self.quant_linear2_bias(self.linear2_bias)# 第一个 Linearx = self.linear1(x)  # [2, 100, 256]x = self.layernorm(x)  # [2, 100, 256]x = self.relu(x)  # [2, 100, 256]# 第二个 Linear# x = self.linear2(x)  # [2, 100, 60]# ===================================# 使用 broadcast mul + sum 替换linear# ===================================# 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256)  # [2, 100, 60, 256]# 按最后一个维度求和:sum 操作模拟线性层的加权求和sum_output = broadcast_mul.sum(dim=-1)  # [2, 100, 60]# 加上偏置x = sum_output + linear2_bias  # [2, 100, 60]# 第三个 Linearx = self.linear3(x)x = self.dequant(x)return xfloat_ckpt_path = "model_path/float-checkpoint.ckpt" 
float_state_dict = torch.load(float_ckpt_path)
# 遍历 OrderedDict,查找包含 "linear2" 的键
for key, value in float_state_dict.items():# if "linear2" in key:#     print(f"Key: {key}, Value: {value.shape}")if key == "linear2.weight":linear2_weight = valueif key == "linear2.bias":linear2_bias = value# example_input = torch.randn(2, 100, 256)
file_path = "random_data.pt"
example_input = torch.load(file_path)
model = SmallModel(linear2_weight, linear2_bias)
missing_keys, unexpected_keys = model.load_state_dict(float_state_dict, strict=False)
print("missing_keys & unexpected_keys:", missing_keys, '\n', unexpected_keys)# 前向传播
output = model(example_input)
print("float输出数据:", output)
torch.save(output, "model_path/6_model_float_output.pt")
print("输入形状:", example_input.shape)
print("输出形状:", output.shape)# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_M)calib_model = prepare(model.eval(), example_input,qconfig_setter=(calibration_8bit_weight_16bit_act_qconfig_setter,),)calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(example_input)calib_model.eval()        
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(example_input)
print("calib输出数据:", calib_out)
qat_bc = export(calib_model, example_input)
hb_quantized_model = convert(qat_bc, March.NASH_M)

4.2 比较替代方案的输出一致性

  • linear2 weight input output int16
float输出数据: tensor([[[-0.3016,  0.1338, -0.5251,  ..., -0.0551, -0.2093, -0.0308],[-0.1969, -0.0131, -0.3287,  ...,  0.3234, -0.0869, -0.0637],[-0.3056,  0.1478, -0.2673,  ...,  0.2355, -0.3487,  0.0134],...,[-0.3990, -0.0389, -0.1686,  ..., -0.0046, -0.4131,  0.0482],[-0.1059,  0.2431, -0.1886,  ...,  0.0787, -0.3454,  0.0231],[-0.2134, -0.1071, -0.0575,  ...,  0.3434, -0.1661,  0.2248]]],grad_fn=<ViewBackward0>)calib输出数据: tensor([[[-0.3038,  0.1370, -0.5269,  ..., -0.0571, -0.2111, -0.0296],[-0.1975, -0.0111, -0.3280,  ...,  0.3215, -0.0884, -0.0637],[-0.3052,  0.1488, -0.2677,  ...,  0.2348, -0.3479,  0.0132],...,[-0.3988, -0.0393, -0.1662,  ..., -0.0055, -0.4117,  0.0484],[-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],[-0.2142, -0.1061, -0.0587,  ...,  0.3422, -0.1657,  0.2255]]],grad_fn=<ViewBackward0>)
  • broadcast mul sum int16
float输出数据: tensor([[[-0.3016,  0.1338, -0.5251,  ..., -0.0551, -0.2093, -0.0308],[-0.1969, -0.0131, -0.3287,  ...,  0.3234, -0.0869, -0.0637],[-0.3056,  0.1478, -0.2673,  ...,  0.2355, -0.3487,  0.0134],...,[-0.3990, -0.0389, -0.1686,  ..., -0.0046, -0.4131,  0.0482],[-0.1059,  0.2431, -0.1886,  ...,  0.0787, -0.3454,  0.0231],[-0.2134, -0.1071, -0.0575,  ...,  0.3434, -0.1661,  0.2248]]],grad_fn=<ViewBackward0>)
calib输出数据: tensor([[[-0.3038,  0.1370, -0.5269,  ..., -0.0571, -0.2111, -0.0296],[-0.1975, -0.0111, -0.3280,  ...,  0.3215, -0.0884, -0.0637],[-0.3051,  0.1487, -0.2678,  ...,  0.2349, -0.3478,  0.0132],...,[-0.3988, -0.0392, -0.1662,  ..., -0.0055, -0.4117,  0.0484],[-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],[-0.2142, -0.1061, -0.0586,  ...,  0.3423, -0.1657,  0.2255]]],grad_fn=<ViewBackward0>)...,[-0.3988, -0.0392, -0.1662,  ..., -0.0055, -0.4117,  0.0484],[-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],[-0.2142, -0.1061, -0.0586,  ...,  0.3423, -0.1657,  0.2255]]],grad_fn=<ViewBackward0>)

相关文章:

  • 野火鲁班猫(arrch64架构debian)从零实现用MobileFaceNet算法进行实时人脸识别(四)安装RKNN Toolkit2
  • Java—— IO流 第三期
  • 基于 AMDXCVU47P HBM2 FPGA 的 2 路 100G 光纤 PCIe 高性能计算加速卡
  • redis Pub/Sub 简介 -16 (PUBLISH、SUBSCRIBE、PSUBSCRIBE)
  • Linux 强制访问控制深度解析:机制、比较与战略部署
  • 【VLNs篇】05:TGS-在无地图室外环境中使用视觉语言模型进行轨迹生成和选择
  • 基于FPGA控制电容阵列与最小反射算法的差分探头优化设计
  • dlib库的人脸检测案例实现
  • Gitee PPM:智能化项目管理如何重塑软件工厂的未来格局
  • 计算机网络 第三章:运输层(二)
  • 5G 网络寻呼的信令及 IE 信息分析
  • C#对集合进行分组IGroupingout TKey, out TElement>
  • day19-20-四剑客-find-grep-sed-awk
  • C# 大文件分割
  • TensorFlow简介与使用指南
  • 学习笔记:黑马程序员JavaWeb开发教程(2025.4.11)
  • 计算机网络 第三章:运输层(三)
  • 解决自签名证书HTTPS告警:强制使用SHA-256算法生成证书
  • 微软CTO:AI已经“能力过剩”,行业需要努力缩小模型能力与实际产品交付之间的差距
  • AUTOSAR AP 入门0:AUTOSAR_EXP_PlatformDesign.pdf
  • 找网络公司建网站每年收维护费/必应搜索推广
  • 政府门户网站内容建设/业务网站制作
  • 电子商务网站建设的目的是开展网络营销/网站seo搜索
  • 网站换域名能换不/软文范例大全800字
  • com网站域名可以是汉字吗/今日新闻快报
  • 长沙搜索排名优化公司/seo营销优化软件