高精度矢量内积计算方法 (单精度浮点, 超长矢量)
高精度矢量内积计算方法 (单精度浮点, 超长矢量)
对于单精度浮点类型的超长矢量(超过1亿元素)内积计算,累加误差确实是一个重要问题。以下是几种减少误差的方法:
1. Kahan求和算法
这是最常用的补偿求和算法,可以有效减少累加误差:
float kahan_dot_product(const float* a, const float* b, size_t n) {
float sum = 0.0f;
float c = 0.0f; // 补偿变量
for (size_t i = 0; i < n; i++) {
float product = a[i] * b[i];
float y = product - c;
float t = sum + y;
c = (t - sum) - y;
sum = t;
}
return sum;
}
2. 成对求和 (Pairwise Summation)
将求和过程组织为二叉树形式,减少大数吃小数的机会:
float pairwise_dot(const float* a, const float* b, size_t start, size_t end) {
if (end - start <= 128) { // 基础块大小可调整
float sum = 0.0f;
for (size_t i = start; i < end; i++) {
sum += a[i] * b[i];
}
return sum;
}
size_t mid = start + (end - start) / 2;
float left = pairwise_dot(a, b, start, mid);
float right = pairwise_dot(a, b, mid, end);
return left + right;
}
3. 多精度累加器
使用双精度或更高精度累加器:
float high_precision_dot(const float* a, const float* b, size_t n) {
double sum = 0.0;
for (size_t i = 0; i < n; i++) {
sum += (double)a[i] * (double)b[i];
}
return (float)sum;
}
4. 分块并行计算
对于超长矢量,可以结合并行计算和误差控制:
float parallel_dot(const float* a, const float* b, size_t n, int num_threads) {
std::vector<double> partial_sums(num_threads, 0.0);
#pragma omp parallel for num_threads(num_threads)
for (int t = 0; t < num_threads; t++) {
size_t start = t * n / num_threads;
size_t end = (t + 1) * n / num_threads;
double local_sum = 0.0;
for (size_t i = start; i < end; i++) {
local_sum += (double)a[i] * (double)b[i];
}
partial_sums[t] = local_sum;
}
double total = 0.0;
for (double s : partial_sums) {
total += s;
}
return (float)total;
}
实际应用建议
- 对于GPU计算:使用Kahan求和或成对求和,因为双精度运算在GPU上可能性能较差
- 对于CPU计算:优先考虑双精度累加器方法,简单且有效
- 混合精度:可以考虑使用融合乘加(FMA)指令结合补偿算法
- 内存访问优化:确保数据连续访问,减少缓存未命中
误差分析
对于1亿个单精度数的内积:
- 普通累加:相对误差可能达到1e-4量级
- Kahan求和:误差可降低到1e-8量级
- 双精度累加:误差可降低到1e-16量级
选择哪种方法取决于您的精度要求和计算环境限制。
测试代码
#include <cmath>
#include <iostream>
#include <iomanip>
#include <omp.h>
#include <cstdlib>
#include <immintrin.h>
using namespace std;
//编译: g++ -DAVX2 -DAVX512F -std=c++17 -Ofast -march=native -fopenmp test_dot.cpp
/********************************************************
<x,y>内积计算用double双精度做累加类型,保证数值稳定
**********************************************************/
template<typename F,int P=0>
F dot(int n, const F *x, const F *y)
{
if constexpr (P==0)
{
double s_time=omp_get_wtime();
//累加用单精度
F s=0;
for(int i=0; i<n; i++)
{
s+=x[i]*y[i];
}
double e_time=omp_get_wtime();
cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
return s;
}
else if constexpr (P==1)
{
double s_time=omp_get_wtime();
//累加用双精度,乘法用单精度
double s=0;
for(int i=0; i<n; i++)
{
s+=x[i]*y[i];
}
double e_time=omp_get_wtime();
cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
return s;
}
else if constexpr (P==2)
{
double s_time=omp_get_wtime();
//累加用双精度,乘法用双精度
double s=0;
for(int i=0; i<n; i++)
{
double a=x[i];
double b=y[i];
s+=a*b;
}
double e_time=omp_get_wtime();
cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
return s;
}
#ifdef AVX2
else if constexpr(P==3)
{
static_assert(is_same_v<F,float>);
/********************************************************************
OpenMP多线程,AVX2计算<x,x>,<x,y>
累加和乘法都用双精度
*********************************************************************/
double s_time=omp_get_wtime();
__m256d xx_sum,xy_sum;
xx_sum=xy_sum=_mm256_setzero_pd();
#pragma omp parallel
{
__m256d sum_xx=_mm256_setzero_pd();
__m256d sum_xy=_mm256_setzero_pd();
#pragma omp for
for(int i=0; i<n; i+=8)
{
__m256 x8=_mm256_loadu_ps(x+i);
__m256 y8=_mm256_loadu_ps(y+i);
__m128 lo,hi;
__m256d t1,t2,t3,t4;
lo=_mm256_extractf128_ps(x8,0);
hi=_mm256_extractf128_ps(x8,1);
t1=_mm256_cvtps_pd(lo);
t2=_mm256_cvtps_pd(hi);
lo=_mm256_extractf128_ps(y8,0);
hi=_mm256_extractf128_ps(y8,1);
t3=_mm256_cvtps_pd(lo);
t4=_mm256_cvtps_pd(hi);
#if 0
sum_xx=_mm256_add_pd(sum_xx,_mm256_mul_pd(t1,t1));
sum_xx=_mm256_add_pd(sum_xx,_mm256_mul_pd(t2,t2));
sum_xy=_mm256_add_pd(sum_xy,_mm256_mul_pd(t1,t3));
sum_xy=_mm256_add_pd(sum_xy,_mm256_mul_pd(t2,t4));
#else
/**********************************
FMA
***********************************/
sum_xx=_mm256_fmadd_pd(t1,t1,sum_xx);
sum_xx=_mm256_fmadd_pd(t2,t2,sum_xx);
sum_xy=_mm256_fmadd_pd(t1,t3,sum_xy);
sum_xy=_mm256_fmadd_pd(t2,t4,sum_xy);
#endif
}
#pragma omp critical
{
xx_sum=_mm256_add_pd(xx_sum,sum_xx);
xy_sum=_mm256_add_pd(xy_sum,sum_xy);
}
}
double tmp[4];
_mm256_storeu_pd(tmp,xy_sum);
double xy=tmp[0]+tmp[1]+tmp[2]+tmp[3];
_mm256_storeu_pd(tmp,xx_sum);
double xx=tmp[0]+tmp[1]+tmp[2]+tmp[3];
for(int i=n&~7; i<n; i++)
{
double a=x[i];
double b=y[i];
xx+=a*a;
xy+=a*b;
}
double e_time=omp_get_wtime();
cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
cout<<"xx="<<xx<<endl;
return xy;
}//P==3
#endif
#ifdef AVX512F
else if constexpr(P==4)
{
static_assert(is_same_v<F,float>);
/********************************************************************
OpenMP多线程,AVX512F计算<x,x>,<x,y>
累加和乘法都用双精度
*********************************************************************/
double s_time=omp_get_wtime();
__m512d xx_sum=_mm512_setzero_pd();
__m512d xy_sum=_mm512_setzero_pd();
#pragma omp parallel
{
__m512d sum_xx=_mm512_setzero_pd();
__m512d sum_xy=_mm512_setzero_pd();
#pragma omp for
for(int i=0; i<n; i+=16)
{
__m512 x16=_mm512_loadu_ps(x+i);
__m512 y16=_mm512_loadu_ps(y+i);
__m256 lo,hi;
__m512d t1,t2,t3,t4;
lo=_mm512_extractf32x8_ps(x16,0);
hi=_mm512_extractf32x8_ps(x16,1);
t1=_mm512_cvtps_pd(lo);
t2=_mm512_cvtps_pd(hi);
lo=_mm512_extractf32x8_ps(y16,0);
hi=_mm512_extractf32x8_ps(y16,1);
t3=_mm512_cvtps_pd(lo);
t4=_mm512_cvtps_pd(hi);
#if 0
sum_xx=_mm512_add_pd(sum_xx,_mm512_mul_pd(t1,t1));
sum_xx=_mm512_add_pd(sum_xx,_mm512_mul_pd(t2,t2));
sum_xy=_mm512_add_pd(sum_xy,_mm512_mul_pd(t1,t3));
sum_xy=_mm512_add_pd(sum_xy,_mm512_mul_pd(t2,t4));
#else
/***********************************
FMA
************************************/
sum_xx=_mm512_fmadd_pd(t1,t1,sum_xx);
sum_xx=_mm512_fmadd_pd(t2,t2,sum_xx);
sum_xy=_mm512_fmadd_pd(t1,t3,sum_xy);
sum_xy=_mm512_fmadd_pd(t2,t4,sum_xy);
#endif
}
#pragma omp critical
{
xx_sum=_mm512_add_pd(xx_sum,sum_xx);
xy_sum=_mm512_add_pd(xy_sum,sum_xy);
}
}
double tmp[8];
_mm512_storeu_pd(tmp,xy_sum);
double xy=tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];
_mm512_storeu_pd(tmp,xx_sum);
double xx=tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];
for(int i=n&~15; i<n; i++)
{
double a=x[i];
double b=y[i];
xx+=a*a;
xy+=a*b;
}
double e_time=omp_get_wtime();
cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
cout<<"xx="<<xx<<endl;
return xy;
}
#endif
return(0);
}
const int N=50000000;
void test()
{
using FLOAT=float;
FLOAT *x=new FLOAT[N];
FLOAT *y=new FLOAT[N];
for(int i=0; i<N; i++)
{
FLOAT t=0.001*sqrtf(FLOAT(i));
FLOAT s=sqrt(sqrt(FLOAT(i)));
x[i]=(rand()<RAND_MAX/2)?t:-t;
y[i]=(rand()<RAND_MAX/2)?s:-s;
}
cout<<setprecision(15)<<endl;
cout<<(double)dot<FLOAT,0>(N,x,y)<<endl;
cout<<(double)dot<FLOAT,1>(N,x,y)<<endl;
cout<<(double)dot<FLOAT,2>(N,x,y)<<endl;
#ifdef AVX2
cout<<(double)dot<FLOAT,3>(N,x,y)<<endl;
#endif
#ifdef AVX512F
cout<<(double)dot<FLOAT,4>(N,x,y)<<endl;
#endif
}
int main(int argc, char **argv)
{
test();
return(0);
}