当前位置: 首页 > news >正文

神经网络量化2-pytorch测试动态量化

本节我们将基于pytorch来实测量化的实现,pytorch基于quantize_per_tensor_dynamic函数可以实现动态量化,我们自己写个量化函数然后和pytorch对比来看其是如何实现的:

import torch
import numpy as np
import math
#自己写的动态量化函数,与pytorch自带quantize_tensor函数做对比

def quantize_tensor(array, num_bits=8):
    highB = array.max()
    lowB = array.min()
    rangeB = highB - lowB
    shiftDist = -(highB + lowB) / 2
    qmax = 2.**num_bits - 1.
    zero_point = shiftDist / rangeB * qmax;
    #if zero_point < 0:
    #    zero_point = zero_point - 1.0
    print(zero_point)
    zero_point = zero_point.floor().int()
    scale = rangeB / qmax
    q_x = array/scale + zero_point
    q_x = q_x.round().int()
    return q_x, zero_point, scale
    
x1 = torch.randn(1, 10, dtype=torch.float32)
x2 = torch.randn(10, 10, dtype=torch.float32)
xq1 = torch.quantize_per_tensor_dynamic(x1, dtype=torch.qint8, reduce_range = False)
xq2 = torch.quantize_per_tensor_dynamic(x2, dtype=torch.qint8, reduce_range = False)
print('************martrix value**************')
print(x1)
#calcScaleAndZero(x2.numpy())
print(x2)
print('************scale**********************')
scale1 = xq1.q_scale()
scale2 = xq2.q_scale()
print(scale1)
print(scale2)
print('************zero point**********************')
zpoint1 = xq1.q_zero_point()
zpoint2 = xq2.q_zero_point()
print(zpoint1)
print(zpoint2)
print('*************calc quant and zero point*********************')

q1, z1, s1 = quantize_tensor(x1)
print(q1, z1, s1)

q2, z2, s2 = quantize_tensor(x2)
print(q2, z2, s2)

xquant1 = xq1.int_repr().int()
xquant2 = xq2.int_repr().int()
print('************quant**********************')
print(xquant1)
print(xquant2)

print('************mult result****************')
multZ = torch.matmul(xquant1 - zpoint1, xquant2 - zpoint2)
print(multZ)
print(scale1)
print(scale2)
print('quant result:')
print(multZ*scale1*scale2)
realResult = torch.matmul(x1,x2)
print('real result')
print(realResult)

结果如下:

************martrix value**************
tensor([[ 1.2477,  0.0531,  0.7887, -1.9008,  0.0422,  0.0558,  2.1269, -0.5745,
         -1.1107, -0.9602]])
tensor([[-0.0498, -1.9346,  1.1775, -0.2848,  1.9393,  0.1473, -0.6528,  1.4783,
         -1.0426, -0.1134],
        [-1.4242, -1.1538, -1.0923,  0.7910, -0.8136,  0.2567,  0.7243,  2.5828,
         -0.5604,  0.1569],
        [ 0.4030,  0.2074,  1.6686, -0.0956,  1.3616, -0.1492,  1.0531, -0.6623,
         -1.1229, -1.9445],
        [-0.3417,  0.4932,  1.1417,  0.0104,  0.2803, -0.1214, -0.2549, -1.2193,
          0.8666, -0.9464],
        [-0.6474, -0.9055, -1.0907, -0.8223, -1.3726,  0.2854, -0.3068, -0.7960,
          0.3766,  0.9145],
        [ 0.4355,  0.3613, -1.0598,  0.8375, -0.6023,  0.6905, -0.4290,  0.7039,
          0.5284,  1.2257],
        [ 1.5872,  0.2304,  0.8338, -1.7823,  2.5621,  0.4503, -0.2524, -0.5032,
          1.1579, -0.4619],
        [-0.3367,  0.9936, -0.9854, -0.9287, -0.2374,  2.7017,  0.3184, -0.1240,
          0.8407, -1.0258],
        [-0.3044,  0.3404, -3.9793, -0.0676,  1.1238, -0.1845, -1.1807, -1.3403,
          0.3283,  1.3031],
        [ 0.5594, -0.8091,  0.5098,  0.7334, -0.1245,  0.5204, -0.0674, -0.6535,
          0.3256, -0.3021]])
************scale**********************
0.015794984967100852
0.026200167338053384
************zero point**********************
-8
24
*************calc quant and zero point*********************
tensor(-7.1556)
tensor([[  71,   -5,   42, -128,   -5,   -4,  127,  -44,  -78,  -69]],
       dtype=torch.int32) tensor(-8, dtype=torch.int32) tensor(0.0158)
tensor(24.3810)
tensor([[  22,  -50,   69,   13,   98,   30,   -1,   80,  -16,   20],
        [ -30,  -20,  -18,   54,   -7,   34,   52,  123,    3,   30],
        [  39,   32,   88,   20,   76,   18,   64,   -1,  -19,  -50],
        [  11,   43,   68,   24,   35,   19,   14,  -23,   57,  -12],
        [  -1,  -11,  -18,   -7,  -28,   35,   12,   -6,   38,   59],
        [  41,   38,  -16,   56,    1,   50,    8,   51,   44,   71],
        [  85,   33,   56,  -44,  122,   41,   14,    5,   68,    6],
        [  11,   62,  -14,  -11,   15,  127,   36,   19,   56,  -15],
        [  12,   37, -128,   21,   67,   17,  -21,  -27,   37,   74],
        [  45,   -7,   43,   52,   19,   44,   21,   -1,   36,   12]],
       dtype=torch.int32) tensor(24, dtype=torch.int32) tensor(0.0262)
************quant**********************
tensor([[  71,   -5,   42, -128,   -5,   -4,  127,  -44,  -78,  -69]],
       dtype=torch.int32)
tensor([[  22,  -50,   69,   13,   98,   30,   -1,   80,  -16,   20],
        [ -30,  -20,  -18,   54,   -7,   34,   52,  123,    3,   30],
        [  39,   32,   88,   20,   76,   18,   64,   -1,  -19,  -50],
        [  11,   43,   68,   24,   35,   19,   14,  -23,   57,  -12],
        [  -1,  -11,  -18,   -7,  -28,   35,   12,   -6,   38,   59],
        [  41,   38,  -16,   56,    1,   50,    8,   51,   44,   71],
        [  85,   33,   56,  -44,  122,   41,   14,    5,   68,    6],
        [  11,   62,  -14,  -11,   15,  127,   36,   19,   56,  -15],
        [  12,   37, -128,   21,   67,   17,  -21,  -27,   37,   74],
        [  45,   -7,   43,   52,   19,   44,   21,   -1,   36,   12]],
       dtype=torch.int32)
************mult result****************
tensor([[ 10245,  -7079,  16232, -10362,  17634,  -1202,   2760,  11839,  -6065,
          -3179]], dtype=torch.int32)
0.015794984967100852
0.026200167338053384
quant result:
tensor([[ 4.2397, -2.9295,  6.7173, -4.2881,  7.2975, -0.4974,  1.1422,  4.8993,
         -2.5099, -1.3156]])
real result
tensor([[ 4.1968, -2.9492,  6.7218, -4.2829,  7.2829, -0.5282,  1.1583,  4.8998,
         -2.5157, -1.3109]])

可以通过我们给的自定义函数quantize_tensor看出动态量化的原理,最后我们比较了量化矩阵相乘后的结果,可以看到我们可以保证小数点后一位的精度,当然精度取决于我们的scale的大小,越小精度越高,当然随之而来的要求数据的取值范围要小,否则会出现溢出的情况。

相关文章:

  • FPGA-流水灯
  • vulhub/joker 靶机----练习攻略
  • 基于Java(Springboot+Gradle+Mybatis+templeaf 框架)+Mysql构建的(Web)校园二手平台系统
  • on-policy对比off-policy
  • 微服务的网关配置
  • 厨卫行业供应链产销协同前中后大平台现状需求分析报告+P120(120页PPT)(文末有下载方式)
  • Java面试黄金宝典2
  • LeetCode BFS解决FloodFill算法
  • 无需刷机、root,畅享原生安卓的丝滑体验。
  • 智能提示语链分析平台技术解析
  • 动态库、静态库、导入库
  • 人事档案管理系统基于Spring BootSSM
  • 268.数组美丽值求和
  • 【C++】函数next_permutation
  • 生成式AI红队测试:如何有效评估大语言模型
  • 基于FPGA频率、幅度、相位可调的任意函数发生器(DDS)实现
  • zabbix统计闲置资产
  • HTML课后实践
  • 代码随想录 Day 45 | 【第九章 动态规划part 08】121. 买卖股票的最佳时机、122.买卖股票的最佳时机II、123.买卖股票的最佳时机III
  • SPI驱动(九) -- SPI_Master驱动程序
  • 对谈|“大礼议”:嘉靖皇帝的礼法困境与权力博弈
  • 陕西榆林:全力推进榆林学院升格榆林大学
  • 车载抬头显示爆发在即?业内:凭借市场和产业链优势,国内供应商实现反超
  • 菲律宾中期选举初步结果出炉,杜特尔特家族多人赢得地方选举
  • KPL“王朝”诞生背后:AG和联赛一起迈向成熟
  • 山西临汾哪吒主题景区回应雕塑被指抄袭:造型由第三方公司设计