反量化的详细过程
代码整体功能
void int8_weight_to_float(void *weight, void *scale_list, void *out, int n, int m)
-
功能:把 int8 权重矩阵按照每行的 scale 转换成 FP32 矩阵。
-
输入参数:
weight
:指向 int8 权重数组,大小n * m
。scale_list
:指向每行 scale 的数组(float),大小n
。out
:指向输出 FP32 权重数组,大小n * m
。n
:行数m
:列数
逐行讲解
1️⃣ 指针类型转换
char *w = (char*)weight;
float *scales = (float*)scale_list;
float *o = (float*)out;
weight
、scale_list
、out
都是void*
,即 不确定类型的指针。- 通过
(char*)
和(float*)
转换成实际类型,方便访问元素。 w[i*m + j]
→ 访问 int8 元素(每个 1 字节)scales[i]
→ 访问浮点数 scale(每个 4 字节)o[i*m + j]
→ 写入浮点数结果(每个 4 字节)
💡 小结:类型转换让编译器知道如何正确地按元素大小读取和写入内存。
2️⃣ OpenMP 并行循环
#pragma omp parallel for
for(int i = 0; i < n; i++)
#pragma omp parallel for
:使用 OpenMP 多线程并行处理 行循环。- 每个线程处理不同的行
i
,加速矩阵反量化。 - 内层循环(列循环)在单线程内完成。
3️⃣ 取当前行的 scale
float scale = scales[i]; // 先取当前行的 scale
- 每行有一个 scale 值,用来恢复 int8 权重到 FP32。
- 取出当前行的 scale,减少内层循环访问内存次数(优化性能)。
4️⃣ 列循环,取 int8 权重并反量化
for(int j = 0; j < m; j++)
{char int8_val = w[i * m + j]; // 取出 int8 权重float fp32_val = int8_val * scale; // 反量化o[i * m + j] = fp32_val; // 写入输出
}
-
访问 int8 元素:
w[i*m + j]
- 一维数组索引映射二维矩阵:第
i
行、第j
列 → 索引i*m + j
- 一维数组索引映射二维矩阵:第
-
反量化:
int8_val * scale
- int8 权重(-128 ~ 127)乘以对应行的 scale → 得到浮点权重
-
写入输出数组:
o[i*m + j] = fp32_val
- 输出是 FP32 类型的矩阵
💡 注意:
- 这里假设 按行的 scale(per-row scale),常用于量化神经网络权重。
- int8 和 FP32 存储大小不同(1 vs 4 字节),所以需要按类型访问。
5️⃣ 内存布局与索引
假设矩阵:
n = 2, m = 3
weight:
w00 w01 w02
w10 w11 w12
- 在内存中是一维连续存储:
[w00, w01, w02, w10, w11, w12]
- 访问
(i,j)
→i*m + j
- 输出矩阵
out
同理
6️⃣ 并行与安全性
- 每行独立计算,所以使用
#pragma omp parallel for
是安全的。 - 不会出现多线程写同一地址的冲突。
- 这是 典型的 per-row int8 → FP32 转换。
7️⃣ 总结
-
输入:int8 权重 + 每行 scale
-
输出:FP32 权重矩阵
-
方法:
- 逐行取 scale
- 逐列取 int8 → 乘 scale → 写 FP32
- 支持 OpenMP 并行加速
-
内存处理:
- 指针类型转换
(char*)
和(float*)
- 一维数组索引映射二维矩阵
- 指针类型转换
#include <stdio.h>
#include <stdlib.h>void extract_int8_weight_to_float(void *weight, void *scale_list, void *out, int n, int m)
{char *w = (char*)weight;float *scales = (float*)scale_list;float *o = (float*)out;#pragma omp parallel forfor(int i = 0; i < n; i++){float scale = scales[i]; // 先取当前行的 scalefor(int j = 0; j < m; j++){char int8_val = w[i * m + j]; // 取出 int8 权重float fp32_val = int8_val * scale; // 反量化o[i * m + j] = fp32_val; // 写入输出}}
}
int main()
{int n = 2, m = 3;// 测试 int8 权重char weight[6] = { -128, -64, 0, 64, 127, 32 };// 每行 scalefloat scales[2] = { 0.1f, 0.01f };// 输出 FP32 数组float out[6] = {0};extract_int8_weight_to_float(weight, scales, out, n, m);// 打印结果for(int i = 0; i < n; i++){for(int j = 0; j < m; j++){printf("%6.3f ", out[i * m + j]);}printf("\n");}return 0;
}