济宁做网站公司seo课
本节我们将基于pytorch来实测量化的实现,pytorch基于quantize_per_tensor_dynamic函数可以实现动态量化,我们自己写个量化函数然后和pytorch对比来看其是如何实现的:
import torch
import numpy as np
import math
#自己写的动态量化函数,与pytorch自带quantize_tensor函数做对比def quantize_tensor(array, num_bits=8):highB = array.max()lowB = array.min()rangeB = highB - lowBshiftDist = -(highB + lowB) / 2qmax = 2.**num_bits - 1.zero_point = shiftDist / rangeB * qmax;#if zero_point < 0:# zero_point = zero_point - 1.0print(zero_point)zero_point = zero_point.floor().int()scale = rangeB / qmaxq_x = array/scale + zero_pointq_x = q_x.round().int()return q_x, zero_point, scalex1 = torch.randn(1, 10, dtype=torch.float32)
x2 = torch.randn(10, 10, dtype=torch.float32)
xq1 = torch.quantize_per_tensor_dynamic(x1, dtype=torch.qint8, reduce_range = False)
xq2 = torch.quantize_per_tensor_dynamic(x2, dtype=torch.qint8, reduce_range = False)
print('************martrix value**************')
print(x1)
#calcScaleAndZero(x2.numpy())
print(x2)
print('************scale**********************')
scale1 = xq1.q_scale()
scale2 = xq2.q_scale()
print(scale1)
print(scale2)
print('************zero point**********************')
zpoint1 = xq1.q_zero_point()
zpoint2 = xq2.q_zero_point()
print(zpoint1)
print(zpoint2)
print('*************calc quant and zero point*********************')q1, z1, s1 = quantize_tensor(x1)
print(q1, z1, s1)q2, z2, s2 = quantize_tensor(x2)
print(q2, z2, s2)xquant1 = xq1.int_repr().int()
xquant2 = xq2.int_repr().int()
print('************quant**********************')
print(xquant1)
print(xquant2)print('************mult result****************')
multZ = torch.matmul(xquant1 - zpoint1, xquant2 - zpoint2)
print(multZ)
print(scale1)
print(scale2)
print('quant result:')
print(multZ*scale1*scale2)
realResult = torch.matmul(x1,x2)
print('real result')
print(realResult)
结果如下:
************martrix value**************
tensor([[ 1.2477, 0.0531, 0.7887, -1.9008, 0.0422, 0.0558, 2.1269, -0.5745,-1.1107, -0.9602]])
tensor([[-0.0498, -1.9346, 1.1775, -0.2848, 1.9393, 0.1473, -0.6528, 1.4783,-1.0426, -0.1134],[-1.4242, -1.1538, -1.0923, 0.7910, -0.8136, 0.2567, 0.7243, 2.5828,-0.5604, 0.1569],[ 0.4030, 0.2074, 1.6686, -0.0956, 1.3616, -0.1492, 1.0531, -0.6623,-1.1229, -1.9445],[-0.3417, 0.4932, 1.1417, 0.0104, 0.2803, -0.1214, -0.2549, -1.2193,0.8666, -0.9464],[-0.6474, -0.9055, -1.0907, -0.8223, -1.3726, 0.2854, -0.3068, -0.7960,0.3766, 0.9145],[ 0.4355, 0.3613, -1.0598, 0.8375, -0.6023, 0.6905, -0.4290, 0.7039,0.5284, 1.2257],[ 1.5872, 0.2304, 0.8338, -1.7823, 2.5621, 0.4503, -0.2524, -0.5032,1.1579, -0.4619],[-0.3367, 0.9936, -0.9854, -0.9287, -0.2374, 2.7017, 0.3184, -0.1240,0.8407, -1.0258],[-0.3044, 0.3404, -3.9793, -0.0676, 1.1238, -0.1845, -1.1807, -1.3403,0.3283, 1.3031],[ 0.5594, -0.8091, 0.5098, 0.7334, -0.1245, 0.5204, -0.0674, -0.6535,0.3256, -0.3021]])
************scale**********************
0.015794984967100852
0.026200167338053384
************zero point**********************
-8
24
*************calc quant and zero point*********************
tensor(-7.1556)
tensor([[ 71, -5, 42, -128, -5, -4, 127, -44, -78, -69]],dtype=torch.int32) tensor(-8, dtype=torch.int32) tensor(0.0158)
tensor(24.3810)
tensor([[ 22, -50, 69, 13, 98, 30, -1, 80, -16, 20],[ -30, -20, -18, 54, -7, 34, 52, 123, 3, 30],[ 39, 32, 88, 20, 76, 18, 64, -1, -19, -50],[ 11, 43, 68, 24, 35, 19, 14, -23, 57, -12],[ -1, -11, -18, -7, -28, 35, 12, -6, 38, 59],[ 41, 38, -16, 56, 1, 50, 8, 51, 44, 71],[ 85, 33, 56, -44, 122, 41, 14, 5, 68, 6],[ 11, 62, -14, -11, 15, 127, 36, 19, 56, -15],[ 12, 37, -128, 21, 67, 17, -21, -27, 37, 74],[ 45, -7, 43, 52, 19, 44, 21, -1, 36, 12]],dtype=torch.int32) tensor(24, dtype=torch.int32) tensor(0.0262)
************quant**********************
tensor([[ 71, -5, 42, -128, -5, -4, 127, -44, -78, -69]],dtype=torch.int32)
tensor([[ 22, -50, 69, 13, 98, 30, -1, 80, -16, 20],[ -30, -20, -18, 54, -7, 34, 52, 123, 3, 30],[ 39, 32, 88, 20, 76, 18, 64, -1, -19, -50],[ 11, 43, 68, 24, 35, 19, 14, -23, 57, -12],[ -1, -11, -18, -7, -28, 35, 12, -6, 38, 59],[ 41, 38, -16, 56, 1, 50, 8, 51, 44, 71],[ 85, 33, 56, -44, 122, 41, 14, 5, 68, 6],[ 11, 62, -14, -11, 15, 127, 36, 19, 56, -15],[ 12, 37, -128, 21, 67, 17, -21, -27, 37, 74],[ 45, -7, 43, 52, 19, 44, 21, -1, 36, 12]],dtype=torch.int32)
************mult result****************
tensor([[ 10245, -7079, 16232, -10362, 17634, -1202, 2760, 11839, -6065,-3179]], dtype=torch.int32)
0.015794984967100852
0.026200167338053384
quant result:
tensor([[ 4.2397, -2.9295, 6.7173, -4.2881, 7.2975, -0.4974, 1.1422, 4.8993,-2.5099, -1.3156]])
real result
tensor([[ 4.1968, -2.9492, 6.7218, -4.2829, 7.2829, -0.5282, 1.1583, 4.8998,-2.5157, -1.3109]])
可以通过我们给的自定义函数quantize_tensor看出动态量化的原理,最后我们比较了量化矩阵相乘后的结果,可以看到我们可以保证小数点后一位的精度,当然精度取决于我们的scale的大小,越小精度越高,当然随之而来的要求数据的取值范围要小,否则会出现溢出的情况。