当前位置: 首页 > news >正文

Strassen矩阵乘法算法

Strassen算法是一种用于矩阵乘法的分治算法,由Volker Strassen在1969年提出。它通过减少递归乘法次数来降低传统矩阵乘法的时间复杂度

1. 算法概述

传统矩阵乘法 vs Strassen算法

特性传统方法Strassen算法
时间复杂度O(n³)O(n^log₂7) ≈ O(n^2.807)
乘法次数8次递归7次递归
适用场景小矩阵大矩阵

2. 算法原理

分治策略

  1. 将每个n×n矩阵划分为4个(n/2)×(n/2)子矩阵

  2. 递归计算7个矩阵乘积

  3. 通过加法和减法组合这些乘积得到结果矩阵

关键公式

对于矩阵:

A = [A11 A12]    B = [B11 B12][A21 A22]        [B21 B22]

传统计算:

C11 = A11×B11 + A12×B21
C12 = A11×B12 + A12×B22  
C21 = A21×B11 + A22×B21
C22 = A21×B12 + A22×B22

Strassen优化(7次乘法):

M1 = (A11 + A22) × (B11 + B22)
M2 = (A21 + A22) × B11
M3 = A11 × (B12 - B22)
M4 = A22 × (B21 - B11)
M5 = (A11 + A12) × B22
M6 = (A21 - A11) × (B11 + B12)
M7 = (A12 - A22) × (B21 + B22)C11 = M1 + M4 - M5 + M7
C12 = M3 + M5
C21 = M2 + M4
C22 = M1 - M2 + M3 + M6

3. C语言完整实现

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>// 矩阵结构体
typedef struct {int** data;int rows;int cols;
} Matrix;// 创建矩阵
Matrix* create_matrix(int rows, int cols) {Matrix* mat = (Matrix*)malloc(sizeof(Matrix));mat->rows = rows;mat->cols = cols;mat->data = (int**)malloc(rows * sizeof(int*));for (int i = 0; i < rows; i++) {mat->data[i] = (int*)malloc(cols * sizeof(int));// 初始化为0memset(mat->data[i], 0, cols * sizeof(int));}return mat;
}// 释放矩阵内存
void free_matrix(Matrix* mat) {if (mat == NULL) return;for (int i = 0; i < mat->rows; i++) {free(mat->data[i]);}free(mat->data);free(mat);
}// 打印矩阵
void print_matrix(Matrix* mat) {if (mat == NULL) return;for (int i = 0; i < mat->rows; i++) {for (int j = 0; j < mat->cols; j++) {printf("%6d ", mat->data[i][j]);}printf("\n");}printf("\n");
}// 初始化矩阵为随机值
void init_matrix_random(Matrix* mat, int max_value) {if (mat == NULL) return;for (int i = 0; i < mat->rows; i++) {for (int j = 0; j < mat->cols; j++) {mat->data[i][j] = rand() % max_value;}}
}// 初始化矩阵为特定值
void init_matrix_value(Matrix* mat, int value) {if (mat == NULL) return;for (int i = 0; i < mat->rows; i++) {for (int j = 0; j < mat->cols; j++) {mat->data[i][j] = value;}}
}// 矩阵加法
Matrix* matrix_add(Matrix* A, Matrix* B) {if (A == NULL || B == NULL || A->rows != B->rows || A->cols != B->cols) {return NULL;}Matrix* result = create_matrix(A->rows, A->cols);for (int i = 0; i < A->rows; i++) {for (int j = 0; j < A->cols; j++) {result->data[i][j] = A->data[i][j] + B->data[i][j];}}return result;
}// 矩阵减法
Matrix* matrix_subtract(Matrix* A, Matrix* B) {if (A == NULL || B == NULL || A->rows != B->rows || A->cols != B->cols) {return NULL;}Matrix* result = create_matrix(A->rows, A->cols);for (int i = 0; i < A->rows; i++) {for (int j = 0; j < A->cols; j++) {result->data[i][j] = A->data[i][j] - B->data[i][j];}}return result;
}// 获取矩阵的子矩阵
Matrix* get_submatrix(Matrix* mat, int start_row, int start_col, int rows, int cols) {if (mat == NULL || start_row + rows > mat->rows || start_col + cols > mat->cols) {return NULL;}Matrix* sub = create_matrix(rows, cols);for (int i = 0; i < rows; i++) {for (int j = 0; j < cols; j++) {sub->data[i][j] = mat->data[start_row + i][start_col + j];}}return sub;
}// 设置子矩阵
void set_submatrix(Matrix* dest, Matrix* src, int start_row, int start_col) {if (dest == NULL || src == NULL || start_row + src->rows > dest->rows || start_col + src->cols > dest->cols) {return;}for (int i = 0; i < src->rows; i++) {for (int j = 0; j < src->cols; j++) {dest->data[start_row + i][start_col + j] = src->data[i][j];}}
}// 找到大于等于n的最小的2的幂
int next_power_of_two(int n) {int power = 1;while (power < n) {power *= 2;}return power;
}// 扩展矩阵到2的幂次方大小
Matrix* pad_matrix(Matrix* mat, int new_size) {if (mat == NULL) return NULL;Matrix* padded = create_matrix(new_size, new_size);for (int i = 0; i < mat->rows; i++) {for (int j = 0; j < mat->cols; j++) {padded->data[i][j] = mat->data[i][j];}}return padded;
}// 从填充矩阵中提取原始大小的子矩阵
Matrix* unpad_matrix(Matrix* padded, int original_rows, int original_cols) {if (padded == NULL) return NULL;Matrix* result = create_matrix(original_rows, original_cols);for (int i = 0; i < original_rows; i++) {for (int j = 0; j < original_cols; j++) {result->data[i][j] = padded->data[i][j];}}return result;
}// 传统矩阵乘法(用于基准情况和较小矩阵)
Matrix* traditional_multiply(Matrix* A, Matrix* B) {if (A == NULL || B == NULL || A->cols != B->rows) {return NULL;}Matrix* result = create_matrix(A->rows, B->cols);for (int i = 0; i < A->rows; i++) {for (int j = 0; j < B->cols; j++) {int sum = 0;for (int k = 0; k < A->cols; k++) {sum += A->data[i][k] * B->data[k][j];}result->data[i][j] = sum;}}return result;
}// Strassen矩阵乘法递归实现
Matrix* strassen_multiply_recursive(Matrix* A, Matrix* B) {// 基准情况:当矩阵较小时使用传统乘法if (A->rows <= 64) {  // 阈值可以根据实际情况调整return traditional_multiply(A, B);}int n = A->rows;int half = n / 2;// 划分矩阵为4个子矩阵Matrix* A11 = get_submatrix(A, 0, 0, half, half);Matrix* A12 = get_submatrix(A, 0, half, half, half);Matrix* A21 = get_submatrix(A, half, 0, half, half);Matrix* A22 = get_submatrix(A, half, half, half, half);Matrix* B11 = get_submatrix(B, 0, 0, half, half);Matrix* B12 = get_submatrix(B, 0, half, half, half);Matrix* B21 = get_submatrix(B, half, 0, half, half);Matrix* B22 = get_submatrix(B, half, half, half, half);// 计算7个Strassen乘积Matrix* M1 = strassen_multiply_recursive(matrix_add(A11, A22), matrix_add(B11, B22));Matrix* M2 = strassen_multiply_recursive(matrix_add(A21, A22), B11);Matrix* M3 = strassen_multiply_recursive(A11, matrix_subtract(B12, B22));Matrix* M4 = strassen_multiply_recursive(A22, matrix_subtract(B21, B11));Matrix* M5 = strassen_multiply_recursive(matrix_add(A11, A12), B22);Matrix* M6 = strassen_multiply_recursive(matrix_subtract(A21, A11), matrix_add(B11, B12));Matrix* M7 = strassen_multiply_recursive(matrix_subtract(A12, A22), matrix_add(B21, B22));// 计算结果子矩阵Matrix* C11 = matrix_add(matrix_subtract(matrix_add(M1, M4), M5), M7);Matrix* C12 = matrix_add(M3, M5);Matrix* C21 = matrix_add(M2, M4);Matrix* C22 = matrix_add(matrix_subtract(matrix_add(M1, M3), M2), M6);// 组合结果矩阵Matrix* result = create_matrix(n, n);set_submatrix(result, C11, 0, 0);set_submatrix(result, C12, 0, half);set_submatrix(result, C21, half, 0);set_submatrix(result, C22, half, half);// 释放中间矩阵free_matrix(A11); free_matrix(A12); free_matrix(A21); free_matrix(A22);free_matrix(B11); free_matrix(B12); free_matrix(B21); free_matrix(B22);free_matrix(M1); free_matrix(M2); free_matrix(M3); free_matrix(M4);free_matrix(M5); free_matrix(M6); free_matrix(M7);free_matrix(C11); free_matrix(C12); free_matrix(C21); free_matrix(C22);return result;
}// Strassen矩阵乘法主函数
Matrix* strassen_multiply(Matrix* A, Matrix* B) {if (A == NULL || B == NULL || A->cols != B->rows) {return NULL;}// 如果矩阵大小不是2的幂,需要填充int max_dim = (A->rows > A->cols) ? A->rows : A->cols;max_dim = (max_dim > B->cols) ? max_dim : B->cols;int padded_size = next_power_of_two(max_dim);// 如果已经是2的幂且大小合适,直接计算if (padded_size == A->rows && padded_size == A->cols && padded_size == B->rows && padded_size == B->cols) {return strassen_multiply_recursive(A, B);}// 否则需要填充Matrix* A_padded = pad_matrix(A, padded_size);Matrix* B_padded = pad_matrix(B, padded_size);Matrix* result_padded = strassen_multiply_recursive(A_padded, B_padded);Matrix* result = unpad_matrix(result_padded, A->rows, B->cols);free_matrix(A_padded);free_matrix(B_padded);free_matrix(result_padded);return result;
}

5. 优化版本(减少内存分配)

// 优化的Strassen实现,减少内存分配
Matrix* strassen_multiply_optimized(Matrix* A, Matrix* B) {if (A == NULL || B == NULL || A->cols != B->rows) {return NULL;}int n = A->rows;// 基准情况if (n <= 64) {return traditional_multiply(A, B);}int half = n / 2;// 预分配所有需要的矩阵Matrix* M[7];  // 7个Strassen乘积Matrix* temp1, *temp2;// M1 = (A11 + A22) × (B11 + B22)temp1 = matrix_add(get_submatrix(A, 0, 0, half, half), get_submatrix(A, half, half, half, half));temp2 = matrix_add(get_submatrix(B, 0, 0, half, half), get_submatrix(B, half, half, half, half));M[0] = strassen_multiply_optimized(temp1, temp2);free_matrix(temp1); free_matrix(temp2);// M2 = (A21 + A22) × B11temp1 = matrix_add(get_submatrix(A, half, 0, half, half), get_submatrix(A, half, half, half, half));M[1] = strassen_multiply_optimized(temp1, get_submatrix(B, 0, 0, half, half));free_matrix(temp1);// M3 = A11 × (B12 - B22)temp1 = matrix_subtract(get_submatrix(B, 0, half, half, half), get_submatrix(B, half, half, half, half));M[2] = strassen_multiply_optimized(get_submatrix(A, 0, 0, half, half), temp1);free_matrix(temp1);// M4 = A22 × (B21 - B11)temp1 = matrix_subtract(get_submatrix(B, half, 0, half, half), get_submatrix(B, 0, 0, half, half));M[3] = strassen_multiply_optimized(get_submatrix(A, half, half, half, half), temp1);free_matrix(temp1);// M5 = (A11 + A12) × B22temp1 = matrix_add(get_submatrix(A, 0, 0, half, half), get_submatrix(A, 0, half, half, half));M[4] = strassen_multiply_optimized(temp1, get_submatrix(B, half, half, half, half));free_matrix(temp1);// M6 = (A21 - A11) × (B11 + B12)temp1 = matrix_subtract(get_submatrix(A, half, 0, half, half), get_submatrix(A, 0, 0, half, half));temp2 = matrix_add(get_submatrix(B, 0, 0, half, half), get_submatrix(B, 0, half, half, half));M[5] = strassen_multiply_optimized(temp1, temp2);free_matrix(temp1); free_matrix(temp2);// M7 = (A12 - A22) × (B21 + B22)temp1 = matrix_subtract(get_submatrix(A, 0, half, half, half), get_submatrix(A, half, half, half, half));temp2 = matrix_add(get_submatrix(B, half, 0, half, half), get_submatrix(B, half, half, half, half));M[6] = strassen_multiply_optimized(temp1, temp2);free_matrix(temp1); free_matrix(temp2);// 计算结果子矩阵Matrix* C11 = matrix_add(matrix_subtract(matrix_add(M[0], M[3]), M[4]), M[6]);Matrix* C12 = matrix_add(M[2], M[4]);Matrix* C21 = matrix_add(M[1], M[3]);Matrix* C22 = matrix_add(matrix_subtract(matrix_add(M[0], M[2]), M[1]), M[5]);// 组合结果Matrix* result = create_matrix(n, n);set_submatrix(result, C11, 0, 0);set_submatrix(result, C12, 0, half);set_submatrix(result, C21, half, 0);set_submatrix(result, C22, half, half);// 释放所有中间矩阵for (int i = 0; i < 7; i++) {free_matrix(M[i]);}free_matrix(C11); free_matrix(C12); free_matrix(C21); free_matrix(C22);return result;
}

6. 编译和运行

gcc -o strassen strassen.c -lm
./strassen

7. 算法分析

时间复杂度比较

算法时间复杂度乘法次数实际性能
传统算法O(n³)小矩阵优秀
Strassen算法O(n^log₂7) ≈ O(n^2.807)~n^2.807大矩阵优秀

空间复杂度

  • Strassen算法:O(n²) + 递归栈空间

  • 传统算法:O(n²)

实际应用考虑

  1. 阈值选择:小矩阵使用传统算法,大矩阵使用Strassen

  2. 内存开销:Strassen需要更多临时存储

  3. 数值稳定性:Strassen可能在浮点运算中有精度问题

  4. 缓存友好性:传统算法通常有更好的缓存性能

关键要点总结:

  1. 分治策略:将大问题分解为小问题

  2. 乘法次数优化:7次乘法代替8次

  3. 递归基准:小矩阵时切换到传统算法

  4. 内存管理:仔细处理矩阵划分和组合

  5. 边界处理:处理非2的幂次矩阵大小

Strassen算法虽然理论复杂度更低,但由于常数因子较大和内存开销,在实际应用中通常只在矩阵很大时才比传统算法有优势。

http://www.dtcms.com/a/537031.html

相关文章:

  • 网站开发新闻怎么写相应式手机网站建设
  • [C++][windows]C++类成员函数默认参数和成员变量初始化问题
  • Vue 动态路由复制标签页失效?彻底解决新标签页路由空白问题
  • 扁平化网站特效张家港网站建设培训班
  • 【GaussDB】深入剖析Insert Select慢的定位全过程
  • 面向智能体与大语言模型的 AI 基础设施:选项、工具与优化
  • 招商网站建设服务商湖南专业网站建设服务
  • 从0到1:易趋驱动产品研发项目全流程管理效能跃升
  • 巴彦淖尔市百家姓网站建设文昌市规划建设管理局网站
  • JAX 高性能机器学习的新选择 - 从NumPy到变换编译
  • 能盈利的网站网站首页description标签
  • Geoserver修行记-安装CSS插件避坑
  • O(1) 时间获取最小值的巧妙设计——力扣155.最小栈
  • 韩国网站建设wordpress安装博客步骤
  • dbpystream webapi: 一次clickhouse数据从系统盘迁至数据盘的尝试
  • 大数据-136 - ClickHouse 集群 表引擎详解 选型实战:TinyLog/Log/StripeLog/Memory/Merge
  • 高效的项目构建和优化之前端构建工具
  • 网站建设公司宣传文案如何通过cpa网站做推广
  • windows环境,设置git 默认提交信息
  • 电商平台网站建设合同宁波seo优化报价多少
  • 哪里找人做网站系统设计
  • 做一个网站需要多少钱大概费用商贸有限公司注销流程
  • OpenVLA-OFT+ 在真实世界 ALOHA 机器人任务中的应用
  • 网站调用字体四网合一网站建设
  • 网站优化包括整站优化吗公司管理体系
  • Spring—Springboot篇
  • 《拆解一封网络信:HTTP 报文详解》
  • wordpress仿站网桌子seo关键词
  • 如何判断服务器是否遭受攻击?
  • DGX A100服务器常见故障解析与维修攻略