Strassen算法详解实现
1. 算法原理详解
1.1 传统矩阵乘法的问题
传统矩阵乘法(三层循环)的时间复杂度是O(n³)。对于n×n矩阵,需要n³次乘法和n²(n-1)次加法。
1.2 Strassen算法的核心思想
Strassen算法通过分治策略将矩阵乘法的时间复杂度降低到O(n^log₂7) ≈ O(n^2.807)。核心思想是将矩阵分块,通过7次而不是8次乘法来完成计算。
1.3 数学原理
矩阵分块
对于两个n×n矩阵A和B(n为2的幂),我们将其分为4个(n/2)×(n/2)的子矩阵:
A = | A11 A12 | B = | B11 B12 | | A21 A22 | | B21 B22 |
传统方法需要计算:
C11 = A11×B11 + A12×B21 C12 = A11×B12 + A12×B22 C21 = A21×B11 + A22×B21 C22 = A21×B12 + A22×B22
这需要8次乘法和4次加法。
Strassen的7次乘法
Strassen定义了7个中间矩阵:
-
M1 = (A11 + A22) × (B11 + B22)
-
M2 = (A21 + A22) × B11
-
M3 = A11 × (B12 - B22)
-
M4 = A22 × (B21 - B11)
-
M5 = (A11 + A12) × B22
-
M6 = (A21 - A11) × (B11 + B12)
-
M7 = (A12 - A22) × (B21 + B22)
然后通过加减法组合得到结果:
C11 = M1 + M4 - M5 + M7 C12 = M3 + M5 C21 = M2 + M4 C22 = M1 - M2 + M3 + M6
1.4 算法流程示意图
输入: 矩阵A, B 输出: 矩阵C = A × B 1. 如果矩阵足够小,使用传统方法 2. 否则将A, B各分为4个子矩阵 3. 计算7个中间矩阵M1-M7(递归调用) 4. 通过M1-M7组合得到C的4个子矩阵 5. 合并子矩阵得到最终结果
2. 算法实现细节
2.1 边界处理
-
当矩阵大小≤阈值(通常为64)时,切换回传统算法
-
处理非2的幂次矩阵:填充0使其成为2的幂次
2.2 内存优化
-
避免不必要的矩阵拷贝
-
使用原地操作减少内存分配
3. 多语言实现
3.1 C语言实现
#include <stdio.h> #include <stdlib.h> #include <string.h> // 传统矩阵乘法 void standard_multiply(int **A, int **B, int **C, int n) { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { C[i][j] = 0; for (int k = 0; k < n; k++) { C[i][j] += A[i][k] * B[k][j]; } } } } // 矩阵加法 void matrix_add(int **A, int **B, int **C, int n) { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { C[i][j] = A[i][j] + B[i][j]; } } } // 矩阵减法 void matrix_sub(int **A, int **B, int **C, int n) { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { C[i][j] = A[i][j] - B[i][j]; } } } // 分配矩阵内存 int** allocate_matrix(int n) { int **matrix = (int**)malloc(n * sizeof(int*)); for (int i = 0; i < n; i++) { matrix[i] = (int*)malloc(n * sizeof(int)); } return matrix; } // 释放矩阵内存 void free_matrix(int **matrix, int n) { for (int i = 0; i < n; i++) { free(matrix[i]); } free(matrix); } // Strassen算法核心实现 void strassen_multiply(int **A, int **B, int **C, int n) { // 阈值,小于等于64使用传统算法 if (n <= 64) { standard_multiply(A, B, C, n); return; } int new_size = n / 2; // 分配子矩阵内存 int **A11 = allocate_matrix(new_size); int **A12 = allocate_matrix(new_size); int **A21 = allocate_matrix(new_size); int **A22 = allocate_matrix(new_size); int **B11 = allocate_matrix(new_size); int **B12 = allocate_matrix(new_size); int **B21 = allocate_matrix(new_size); int **B22 = allocate_matrix(new_size); // 分割矩阵 for (int i = 0; i < new_size; i++) { for (int j = 0; j < new_size; j++) { A11[i][j] = A[i][j]; A12[i][j] = A[i][j + new_size]; A21[i][j] = A[i + new_size][j]; A22[i][j] = A[i + new_size][j + new_size]; B11[i][j] = B[i][j]; B12[i][j] = B[i][j + new_size]; B21[i][j] = B[i + new_size][j]; B22[i][j] = B[i + new_size][j + new_size]; } } // 分配中间矩阵内存 int **M1 = allocate_matrix(new_size); int **M2 = allocate_matrix(new_size); int **M3 = allocate_matrix(new_size); int **M4 = allocate_matrix(new_size); int **M5 = allocate_matrix(new_size); int **M6 = allocate_matrix(new_size); int **M7 = allocate_matrix(new_size); int **temp1 = allocate_matrix(new_size); int **temp2 = allocate_matrix(new_size); // 计算M1 = (A11 + A22) × (B11 + B22) matrix_add(A11, A22, temp1, new_size); matrix_add(B11, B22, temp2, new_size); strassen_multiply(temp1, temp2, M1, new_size); // 计算M2 = (A21 + A22) × B11 matrix_add(A21, A22, temp1, new_size); strassen_multiply(temp1, B11, M2, new_size); // 计算M3 = A11 × (B12 - B22) matrix_sub(B12, B22, temp2, new_size); strassen_multiply(A11, temp2, M3, new_size); // 计算M4 = A22 × (B21 - B11) matrix_sub(B21, B11, temp2, new_size); strassen_multiply(A22, temp2, M4, new_size); // 计算M5 = (A11 + A12) × B22 matrix_add(A11, A12, temp1, new_size); strassen_multiply(temp1, B22, M5, new_size); // 计算M6 = (A21 - A11) × (B11 + B12) matrix_sub(A21, A11, temp1, new_size); matrix_add(B11, B12, temp2, new_size); strassen_multiply(temp1, temp2, M6, new_size); // 计算M7 = (A12 - A22) × (B21 + B22) matrix_sub(A12, A22, temp1, new_size); matrix_add(B21, B22, temp2, new_size); strassen_multiply(temp1, temp2, M7, new_size); // 计算C的子矩阵 int **C11 = allocate_matrix(new_size); int **C12 = allocate_matrix(new_size); int **C21 = allocate_matrix(new_size); int **C22 = allocate_matrix(new_size); // C11 = M1 + M4 - M5 + M7 matrix_add(M1, M4, temp1, new_size); matrix_sub(temp1, M5, temp2, new_size); matrix_add(temp2, M7, C11, new_size); // C12 = M3 + M5 matrix_add(M3, M5, C12, new_size); // C21 = M2 + M4 matrix_add(M2, M4, C21, new_size); // C22 = M1 - M2 + M3 + M6 matrix_sub(M1, M2, temp1, new_size); matrix_add(temp1, M3, temp2, new_size); matrix_add(temp2, M6, C22, new_size); // 合并结果 for (int i = 0; i < new_size; i++) { for (int j = 0; j < new_size; j++) { C[i][j] = C11[i][j]; C[i][j + new_size] = C12[i][j]; C[i + new_size][j] = C21[i][j]; C[i + new_size][j + new_size] = C22[i][j]; } } // 释放内存 free_matrix(A11, new_size); free_matrix(A12, new_size); free_matrix(A21, new_size); free_matrix(A22, new_size); free_matrix(B11, new_size); free_matrix(B12, new_size); free_matrix(B21, new_size); free_matrix(B22, new_size); free_matrix(M1, new_size); free_matrix(M2, new_size); free_matrix(M3, new_size); free_matrix(M4, new_size); free_matrix(M5, new_size); free_matrix(M6, new_size); free_matrix(M7, new_size); free_matrix(temp1, new_size); free_matrix(temp2, new_size); free_matrix(C11, new_size); free_matrix(C12, new_size); free_matrix(C21, new_size); free_matrix(C22, new_size); } // 测试函数 int main() { int n = 4;&nbs