当前位置: 首页 > news >正文

高精度矢量内积计算方法 (单精度浮点, 超长矢量)

高精度矢量内积计算方法 (单精度浮点, 超长矢量)

对于单精度浮点类型的超长矢量(超过1亿元素)内积计算,累加误差确实是一个重要问题。以下是几种减少误差的方法:

1. Kahan求和算法

这是最常用的补偿求和算法,可以有效减少累加误差:

float kahan_dot_product(const float* a, const float* b, size_t n) {
    float sum = 0.0f;
    float c = 0.0f; // 补偿变量
    
    for (size_t i = 0; i < n; i++) {
        float product = a[i] * b[i];
        float y = product - c;
        float t = sum + y;
        c = (t - sum) - y;
        sum = t;
    }
    
    return sum;
}

2. 成对求和 (Pairwise Summation)

将求和过程组织为二叉树形式,减少大数吃小数的机会:

float pairwise_dot(const float* a, const float* b, size_t start, size_t end) {
    if (end - start <= 128) { // 基础块大小可调整
        float sum = 0.0f;
        for (size_t i = start; i < end; i++) {
            sum += a[i] * b[i];
        }
        return sum;
    }
    
    size_t mid = start + (end - start) / 2;
    float left = pairwise_dot(a, b, start, mid);
    float right = pairwise_dot(a, b, mid, end);
    return left + right;
}

3. 多精度累加器

使用双精度或更高精度累加器:

float high_precision_dot(const float* a, const float* b, size_t n) {
    double sum = 0.0;
    
    for (size_t i = 0; i < n; i++) {
        sum += (double)a[i] * (double)b[i];
    }
    
    return (float)sum;
}

4. 分块并行计算

对于超长矢量,可以结合并行计算和误差控制:

float parallel_dot(const float* a, const float* b, size_t n, int num_threads) {
    std::vector<double> partial_sums(num_threads, 0.0);
    
    #pragma omp parallel for num_threads(num_threads)
    for (int t = 0; t < num_threads; t++) {
        size_t start = t * n / num_threads;
        size_t end = (t + 1) * n / num_threads;
        double local_sum = 0.0;
        
        for (size_t i = start; i < end; i++) {
            local_sum += (double)a[i] * (double)b[i];
        }
        
        partial_sums[t] = local_sum;
    }
    
    double total = 0.0;
    for (double s : partial_sums) {
        total += s;
    }
    
    return (float)total;
}

实际应用建议

  1. 对于GPU计算:使用Kahan求和或成对求和,因为双精度运算在GPU上可能性能较差
  2. 对于CPU计算:优先考虑双精度累加器方法,简单且有效
  3. 混合精度:可以考虑使用融合乘加(FMA)指令结合补偿算法
  4. 内存访问优化:确保数据连续访问,减少缓存未命中

误差分析

对于1亿个单精度数的内积:

  • 普通累加:相对误差可能达到1e-4量级
  • Kahan求和:误差可降低到1e-8量级
  • 双精度累加:误差可降低到1e-16量级

选择哪种方法取决于您的精度要求和计算环境限制。


测试代码

#include <cmath>
#include <iostream>
#include <iomanip>
#include <omp.h>
#include <cstdlib>
#include <immintrin.h>

using namespace std;

//编译: g++ -DAVX2 -DAVX512F -std=c++17 -Ofast -march=native -fopenmp test_dot.cpp

/********************************************************
   <x,y>内积计算用double双精度做累加类型,保证数值稳定
**********************************************************/

template<typename F,int P=0>
F dot(int n, const F *x, const F *y)
{
    if constexpr (P==0)
    {
        double s_time=omp_get_wtime();
		
		//累加用单精度
        F s=0;
        for(int i=0; i<n; i++)
        {
            s+=x[i]*y[i];
        }

        double e_time=omp_get_wtime();
        cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
        return s;
    }
    else if constexpr (P==1)
    {
        double s_time=omp_get_wtime();
		
		//累加用双精度,乘法用单精度
        double s=0;
        for(int i=0; i<n; i++)
        {
            s+=x[i]*y[i];
        }

        double e_time=omp_get_wtime();
        cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
        return s;
    }
    else if constexpr (P==2)
    {
        double s_time=omp_get_wtime();
		
		//累加用双精度,乘法用双精度
        double s=0;
        for(int i=0; i<n; i++)
        {
            double a=x[i];
            double b=y[i];

            s+=a*b;
        }
        double e_time=omp_get_wtime();
        cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
        return s;
    }
#ifdef AVX2 
    else if constexpr(P==3)
    {
		static_assert(is_same_v<F,float>);
		
        /********************************************************************
           OpenMP多线程,AVX2计算<x,x>,<x,y>
			  累加和乘法都用双精度	
        *********************************************************************/
        double s_time=omp_get_wtime();

		__m256d xx_sum,xy_sum;
        xx_sum=xy_sum=_mm256_setzero_pd();

        #pragma omp parallel
        {
            __m256d sum_xx=_mm256_setzero_pd();
            __m256d sum_xy=_mm256_setzero_pd();

            #pragma omp for
            for(int i=0; i<n; i+=8)
            {
                __m256 x8=_mm256_loadu_ps(x+i);
                __m256 y8=_mm256_loadu_ps(y+i);

                __m128 lo,hi;
				__m256d t1,t2,t3,t4;
				
                lo=_mm256_extractf128_ps(x8,0);
                hi=_mm256_extractf128_ps(x8,1);

                t1=_mm256_cvtps_pd(lo);
                t2=_mm256_cvtps_pd(hi);

                lo=_mm256_extractf128_ps(y8,0);
                hi=_mm256_extractf128_ps(y8,1);

                t3=_mm256_cvtps_pd(lo);
                t4=_mm256_cvtps_pd(hi);
#if 0				
                sum_xx=_mm256_add_pd(sum_xx,_mm256_mul_pd(t1,t1));
                sum_xx=_mm256_add_pd(sum_xx,_mm256_mul_pd(t2,t2));
				
                sum_xy=_mm256_add_pd(sum_xy,_mm256_mul_pd(t1,t3));
                sum_xy=_mm256_add_pd(sum_xy,_mm256_mul_pd(t2,t4));
#else
				/**********************************
			                    FMA
				***********************************/
                sum_xx=_mm256_fmadd_pd(t1,t1,sum_xx);
                sum_xx=_mm256_fmadd_pd(t2,t2,sum_xx);
			
                sum_xy=_mm256_fmadd_pd(t1,t3,sum_xy);
                sum_xy=_mm256_fmadd_pd(t2,t4,sum_xy);
#endif
            }
            #pragma omp critical
            {
                xx_sum=_mm256_add_pd(xx_sum,sum_xx);
                xy_sum=_mm256_add_pd(xy_sum,sum_xy);
            }
        }
        double tmp[4];

        _mm256_storeu_pd(tmp,xy_sum);
        double xy=tmp[0]+tmp[1]+tmp[2]+tmp[3];

        _mm256_storeu_pd(tmp,xx_sum);
        double xx=tmp[0]+tmp[1]+tmp[2]+tmp[3];

        for(int i=n&~7; i<n; i++)
        {
            double a=x[i];
            double b=y[i];
            xx+=a*a;
            xy+=a*b;
        }

        double e_time=omp_get_wtime();
        cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
        cout<<"xx="<<xx<<endl;

        return xy;
    }//P==3
#endif 	
#ifdef AVX512F 	
	else if constexpr(P==4)
	{
		static_assert(is_same_v<F,float>);
		
        /********************************************************************
           OpenMP多线程,AVX512F计算<x,x>,<x,y>
			  累加和乘法都用双精度	
        *********************************************************************/
        double s_time=omp_get_wtime();
        __m512d xx_sum=_mm512_setzero_pd();
        __m512d xy_sum=_mm512_setzero_pd();
        #pragma omp parallel
        {
            __m512d sum_xx=_mm512_setzero_pd();
            __m512d sum_xy=_mm512_setzero_pd();

            #pragma omp for
            for(int i=0; i<n; i+=16)
            {
                __m512 x16=_mm512_loadu_ps(x+i);
                __m512 y16=_mm512_loadu_ps(y+i);

                __m256 lo,hi;
				__m512d t1,t2,t3,t4;
				
                lo=_mm512_extractf32x8_ps(x16,0);
                hi=_mm512_extractf32x8_ps(x16,1);

                t1=_mm512_cvtps_pd(lo);
                t2=_mm512_cvtps_pd(hi);

                lo=_mm512_extractf32x8_ps(y16,0);
                hi=_mm512_extractf32x8_ps(y16,1);

                t3=_mm512_cvtps_pd(lo);
                t4=_mm512_cvtps_pd(hi);
#if 0				
                sum_xx=_mm512_add_pd(sum_xx,_mm512_mul_pd(t1,t1));
                sum_xx=_mm512_add_pd(sum_xx,_mm512_mul_pd(t2,t2));
				
                sum_xy=_mm512_add_pd(sum_xy,_mm512_mul_pd(t1,t3));
                sum_xy=_mm512_add_pd(sum_xy,_mm512_mul_pd(t2,t4));
#else
				/***********************************
			                    FMA
				************************************/
                sum_xx=_mm512_fmadd_pd(t1,t1,sum_xx);
                sum_xx=_mm512_fmadd_pd(t2,t2,sum_xx);
				
                sum_xy=_mm512_fmadd_pd(t1,t3,sum_xy);
                sum_xy=_mm512_fmadd_pd(t2,t4,sum_xy);
#endif
            }
            #pragma omp critical
            {
                xx_sum=_mm512_add_pd(xx_sum,sum_xx);
                xy_sum=_mm512_add_pd(xy_sum,sum_xy);
            }
        }
        double tmp[8];

        _mm512_storeu_pd(tmp,xy_sum);
        double xy=tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];

        _mm512_storeu_pd(tmp,xx_sum);
        double xx=tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];

        for(int i=n&~15; i<n; i++)
        {
            double a=x[i];
            double b=y[i];
            xx+=a*a;
            xy+=a*b;
        }

        double e_time=omp_get_wtime();
        cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;
        cout<<"xx="<<xx<<endl;

        return xy;
	}
#endif 

	return(0);
	
}

const int N=50000000;

void test()
{
    using FLOAT=float;

    FLOAT *x=new FLOAT[N];
    FLOAT *y=new FLOAT[N];

    for(int i=0; i<N; i++)
    {
        FLOAT t=0.001*sqrtf(FLOAT(i));
        FLOAT s=sqrt(sqrt(FLOAT(i)));
        x[i]=(rand()<RAND_MAX/2)?t:-t;
        y[i]=(rand()<RAND_MAX/2)?s:-s;
    }
    cout<<setprecision(15)<<endl;

    cout<<(double)dot<FLOAT,0>(N,x,y)<<endl;
    cout<<(double)dot<FLOAT,1>(N,x,y)<<endl;
    cout<<(double)dot<FLOAT,2>(N,x,y)<<endl;

#ifdef AVX2	
    cout<<(double)dot<FLOAT,3>(N,x,y)<<endl;
#endif 

#ifdef AVX512F	
	cout<<(double)dot<FLOAT,4>(N,x,y)<<endl;
#endif 

}

int main(int argc, char **argv)
{
    test();
    return(0);
}

 

相关文章:

  • hashtable遍历的方法有哪些
  • uname
  • SpringBoot洗衣店订单管理系统设计与实现
  • LeetCode 3047 求交集区域内的最大正方形面积
  • VScode连接CentOS 7.6虚拟机
  • B站左神算法课学习笔记(P8):贪心
  • Python函数一(五)
  • 算法 | 基于蜘蛛蜂优化算法求解带时间窗的车辆路径问题研究(附matlab代码)
  • ZKmall开源商城:基于Spring Boot 3的高效后端架构设计与实践
  • 三维点云数据的哈希快速查找方法
  • linux驱动学习(十五)之ioctl
  • 软件工程面试题(三十)
  • 【Android】界面布局-相对布局RelativeLayout-例子
  • 网络基础二
  • linux专题3-----禁止SSH的密码登录
  • 论文阅读笔记——RDT-1B: A DIFFUSION FOUNDATION MODEL FOR BIMANUAL MANIPULATION
  • R 语言科研绘图第 36 期 --- 饼状图-基础
  • 大厂不再招测试?软件测试左移开发合理吗?
  • C 语言排序算法:从基础到进阶的全面解析一、引言
  • Deep Reinforcement Learning for Robotics翻译解读