当前位置: 首页 > news >正文

神经网络量化-基础算法介绍

基本公式

首先,遵循如下基本公式进行量化

r=S(q-Z)

q为量化后的数据,例如如果量化为8bit,那么q是一个8bit的整数,S(scale)Z(zero-point)为量化参数,为常量,r是量化前的真实值,为实数。S为正实数,Zq数据类型相同,也是量化的,不同的矩阵和激活计算,我们使用不同的量化参数。那么,容易得到,对于给定的实数r,其量化公式为:

q=round\left ( \frac{r}{S}+Z \right )

矩阵乘法

我们来看下,如何基于上面的公式,完全通过量化整数的计算来实现矩阵乘法。考虑矩阵乘法r_3=r_1r_2,设矩阵的元素为r^{(i,j)}_{\alpha} ; {\alpha}=1,2,3; 1\leqslant i,j \leqslant N,那么量化公式变为

r_\alpha^{(i,j)}=S_\alpha(q_\alpha^{(i,j)}-Z_\alpha)

根据矩阵乘法定义我们有

S_3(q_3^{(i,k)}-Z_3)=\sum_{j=1}^{N}S_1(q_1^{(i,j)}-Z_1)S_2(q_2^{(j,k)}-Z_2)                               (1)

进一步可以得到

q_3^{(i,k)}=Z_3+M\sum_{j=1}^{N}(q_1^{(i,j)}-Z_1)(q_2^{(j,k)}-Z_2); M=\frac{S_1S_2}{S_3}                       (2)

我们可以看到唯一的非整数是M,根据经验M总是在(0,1)范围内,那么可以表示成规范化的形式:

M=2^{-n}M_0

n是非负的整数,M_0在[0.5,1)范围内,此时我们可以增加M_0的位数来将浮点乘法转换为定点乘法(int16或者int32),例如,首先扩大2^{16}倍,运算玩再缩小2^{16}倍,而缩小的运算可以直接通过移位或者截断来非常高效的实现,下面通过一个实例来说明:

#include <iostream>
#include <stdint.h>
#include <math.h>
int main() {
    float Mf = 0.239; // 浮点值M
    uint32_t Q = 123; // M要相乘的整数
    std::cout << "Real result is " << Mf * Q << std::endl;
    uint32_t shiftScale = pow(2,16); // 扩大2^16倍
    uint32_t M0 = shiftScale * Mf; //扩大后的M0
    std::cout << " M0 is " << M0 << std::endl;
    uint32_t result = M0 * Q;
    std::cout << "Quantize result is " << (result >> 16) << std::endl;
    std::cout << "Transform to real result is " << result / pow(2.0,16) << std::endl;
    return 0;
}

执行结果

Real result is 29.397
 M0 is 15663
Quantize result is 29 
Transform to real result is 29.3968

可以看到通过这种方式,我们可以得到小数点位之前整数位计算的正确性,而且低16位其实保存了有效的小数位结果(15~0,依次存:0.5,0.25, 0.125.....),如果我们能够高效的转换成浮点那么可以进一步提高精度,整数部分如果考虑四舍五入(否则值会统一像低位偏),量化结果可以表示为

uint32_t result = M0 * Q + pow(2,15);

零点的高效处理

公式(1)可以进一步简化为:

q_3^{(i,k)}=Z_3+M\left (NZ_1Z_2-Z_1a_2^{(k)}-Z_2a_1^{(i)}+ \sum_{j=1}^{N}q_1^{(i,j)}q_2^{(j,k)} \right )          (3)

其中,

a_2^{(k)}=\sum_{j=1}^{N}q_2^{(j,k)}, a_1^{(i)}=\sum_{j=1}^{N}q_1^{(i,j)},

可以看到,基于变换后的公式,主要计算量在\sum_{j=1}^{N}q_1^{(i,j)}q_2^{(j,k)},零点相关只需要通过两个累加来实现。

层融合

基于公式(3),我们可以进一步将偏置加激活函数层也加入到公式(3)进一步提升效率。\sum_{j=1}^{N}q_1^{(i,j)}q_2^{(j,k)}的输入是uint8输出位int32:

int32 += uint8 * uint8;

这样可以避免多次累加溢出的问题,如果想将偏置加加入到这个累加器,那么偏置向量需要取为int32类型量化数据类型,并且0为量化零点Z_{bias} = 0,最后其量化scale S_{bias}应与累加器一致,即S_{bias}=S_1S_2。这样拿到累加器的结果之后,还有3件事情要做:

缩小比例(scale down)到8-bit输出激活函数需要的scale;

截断到uint8;

执行激活函数生成8-bit输出激活。

参考文献:

Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference

相关文章:

  • Sidekick:你的 macOS 本地 AI 助手,畅享智能对话!
  • Kafka消息队列
  • 分享一个工具可以国内无限制访问GitHub(来源于GitHub开源项目)
  • 【3-14 STC-pair超级详细的解说】
  • linux(ubuntu)中Conda、CUDA安装Xinference报错ERROR: Failed to build (llama-cpp-python)
  • openharmony5.0中HDF驱动框架源码梳理-服务管理接口
  • Deny by project hooks setting ‘default‘: size of the file
  • Android自动化测试工具
  • tcpdump剖析:入门网络流量分析实战指南
  • 《Operating System Concepts》阅读笔记:p286-p308
  • 关于使用Visual Studio编码问题
  • 30天学习Java第四天——设计模式
  • RabbitMQ之旅(2)
  • Python----数据可视化(Pyecharts三:绘图二:涟漪散点图,K线图,漏斗图,雷达图,词云图,地图,柱状图折线图组合,时间线轮廓图)
  • 阿里云魔笔低代码应用开发平台快速搭建教程
  • 【C++】string类的相关成员函数以及string的模拟实现
  • leecode200.岛屿数量
  • Nginx快速上手
  • 【AI与大模型】解锁本地大模型的潜力:Ollama API 调用深度解析与实践指南
  • springboot常用注解
  • 泰山、华海、中路等山东险企综合成本率均超100%,承保业务均亏损
  • 现场丨在胡适施蛰存等手札与文献间,再读百年光华
  • 制造四十余年血腥冲突后,库尔德工人党为何自行解散?
  • 北京13日冰雹过后,已受理各险种报案近3万件
  • 俄官员说将适时宣布与乌克兰谈判代表
  • 西王食品连亏三年:主业齐“崩”,研发人员多为专科生