P3803 【模板】多项式乘法(FFT)
题目描述
给定一个 nnn 次多项式 F(x)F(x)F(x),和一个 mmm 次多项式 G(x)G(x)G(x)。
请求出 F(x)F(x)F(x) 和 G(x)G(x)G(x) 的乘积。
输入格式
第一行两个整数 n,mn,mn,m。
接下来一行 n+1n+1n+1 个数字,从低到高表示 F(x)F(x)F(x) 的系数。
接下来一行 m+1m+1m+1 个数字,从低到高表示 G(x)G(x)G(x) 的系数。
输出格式
一行 n+m+1n+m+1n+m+1 个数字,从低到高表示 F(x)⋅G(x)F(x) \cdot G(x)F(x)⋅G(x) 的系数。
输入输出样例 #1
输入 #1
1 2
1 2
1 2 1
输出 #1
1 4 5 2
说明/提示
保证输入中的系数大于等于 000 且小于等于 999。
对于 100%100\%100% 的数据:1≤n,m≤1061 \le n, m \leq {10}^61≤n,m≤106。
思路
一、引言
快速傅立叶变换(FFT, Fast Fourier Transform)是一种高效计算离散傅立叶变换(DFT)的方法。
在多项式乘法、卷积、信号处理等领域中都有广泛应用。
二、DFT 定义
设多项式:
A(x)=a0+a1x+a2x2+⋯+an−1xn−1 A(x) = a_0 + a_1x + a_2x^2 + \cdots + a_{n-1}x^{n-1} A(x)=a0+a1x+a2x2+⋯+an−1xn−1
定义其离散傅立叶变换(DFT)为:
Ak=∑j=0n−1aj⋅ωnjk,k=0,1,2,…,n−1 A_k = \sum_{j=0}^{n-1} a_j \cdot \omega_n^{jk}, \quad k = 0,1,2,\dots,n-1 Ak=j=0∑n−1aj⋅ωnjk,k=0,1,2,…,n−1
其中:
ωn=e2πin \omega_n = e^{\frac{2\pi i}{n}} ωn=en2πi
为 nnn 次单位根。
三、逆变换 (IDFT)
由 DFT 的性质可以推出逆变换:
aj=1n∑k=0n−1Ak⋅ωn−jk a_j = \frac{1}{n} \sum_{k=0}^{n-1} A_k \cdot \omega_n^{-jk} aj=n1k=0∑n−1Ak⋅ωn−jk
四、核心思想:偶奇项分治
1. 偶奇项拆分
将 (A(x)A(x)A(x)) 拆分为偶项与奇项:
A(x)=A0(x2)+xA1(x2) A(x) = A_0(x^2) + xA_1(x^2) A(x)=A0(x2)+xA1(x2)
其中:
A0(x)=a0+a2x+a4x2+⋯ A_0(x) = a_0 + a_2x + a_4x^2 + \cdots A0(x)=a0+a2x+a4x2+⋯
A1(x)=a1+a3x+a5x2+⋯ A_1(x) = a_1 + a_3x + a_5x^2 + \cdots A1(x)=a1+a3x+a5x2+⋯
2. 在单位根处取值
令 (ωn=e2πi/n\omega_n = e^{2\pi i/n}ωn=e2πi/n),计算在单位根处的取值:
A(ωnk)=A0(ωn2k)+ωnkA1(ωn2k) A(\omega_n^k) = A_0(\omega_n^{2k}) + \omega_n^k A_1(\omega_n^{2k}) A(ωnk)=A0(ωn2k)+ωnkA1(ωn2k)
注意到:
ωn2k=ωn/2k \omega_n^{2k} = \omega_{n/2}^k ωn2k=ωn/2k
因此:
A(ωnk)=A0(ωn/2k)+ωnkA1(ωn/2k) A(\omega_n^k) = A_0(\omega_{n/2}^k) + \omega_n^k A_1(\omega_{n/2}^k) A(ωnk)=A0(ωn/2k)+ωnkA1(ωn/2k)
同理可得:
A(ωnk+n/2)=A0(ωn/2k)−ωnkA1(ωn/2k) A(\omega_n^{k + n/2}) = A_0(\omega_{n/2}^k) - \omega_n^k A_1(\omega_{n/2}^k) A(ωnk+n/2)=A0(ωn/2k)−ωnkA1(ωn/2k)
3. 分治递归结构
由上式我们可以看到,长度为 nnn 的 DFT 可拆为两个长度为 n/2n/2n/2 的 DFT:
{Ak=Ek+ωnkOk,[8pt]Ak+n/2=Ek−ωnkOk, \begin{cases} A_k = E_k + \omega_n^k O_k,[8pt] A_{k + n/2} = E_k - \omega_n^k O_k, \end{cases} {Ak=Ek+ωnkOk,[8pt]Ak+n/2=Ek−ωnkOk,
其中 (EkE_kEk)、(OkO_kOk) 分别表示偶数项与奇数项的 DFT 结果。
递推式的时间复杂度为:
T(n)=2T(n/2)+O(n)=O(nlogn) T(n) = 2T(n/2) + O(n) = O(n \log n) T(n)=2T(n/2)+O(n)=O(nlogn)
五、FFT 迭代实现
为了避免递归带来的额外开销,我们将其转化为迭代形式。
1. 位反转重排(Bit-reversal Permutation)
先将输入序列的下标按二进制反转排列。
例如 (n = 8) 时:
| 原下标 | 二进制 | 反转 | 新下标 |
|---|---|---|---|
| 0 | 000 | 000 | 0 |
| 1 | 001 | 100 | 4 |
| 2 | 010 | 010 | 2 |
| 3 | 011 | 110 | 6 |
| 4 | 100 | 001 | 1 |
| 5 | 101 | 101 | 5 |
| 6 | 110 | 011 | 3 |
| 7 | 111 | 111 | 7 |
这样在蝶形运算时每一层数据都能正确对应。
2. 蝶形合并(Butterfly Operation)
每一层合并时,计算如下:
{Ak′=Ak+w⋅Ak+len/2,[8pt]Ak+len/2′=Ak−w⋅Ak+len/2, \begin{cases} A_k' = A_k + w \cdot A_{k + len/2}, [8pt] A_{k + len/2}' = A_k - w \cdot A_{k + len/2}, \end{cases} {Ak′=Ak+w⋅Ak+len/2,[8pt]Ak+len/2′=Ak−w⋅Ak+len/2,
其中 (w=e2πi/lenw = e^{2\pi i / len}w=e2πi/len) 为旋转因子(逆变换时取共轭)。
3. 算法流程总结
- 确定长度:取最小的 n=2kn = 2^kn=2k 使 n≥nA+nBn \ge n_A + n_Bn≥nA+nB。
- 补零:将两个多项式系数扩展到长度 nnn。
- FFT 变换:分别对 A,BA,BA,B 做 FFT 得到点值形式。
- 点乘:对每个位置 iii 计算 A[i]×B[i]A[i] \times B[i]A[i]×B[i]。
- 逆 FFT:将乘积结果转回系数形式。
- 取整输出:四舍五入得到最终整数系数。
六、逆 FFT
逆 FFT 只需两步:
- 将旋转角度取负(即使用共轭单位根);
- 每个结果除以 (n)。
aj=1n∑k=0n−1Akωn−jk a_j = \frac{1}{n}\sum_{k=0}^{n-1}A_k \omega_n^{-jk} aj=n1k=0∑n−1Akωn−jk
七、复杂度分析
FFT 的递归关系为:
T(n)=2T(n/2)+O(n) T(n) = 2T(n/2) + O(n) T(n)=2T(n/2)+O(n)
解得:
T(n)=O(nlogn) T(n) = O(n \log n) T(n)=O(nlogn)
相比朴素的 O(n2)O(n^2)O(n2) 乘法,性能大幅提升。
八、FFT 多项式乘法完整示意
- 读入多项式系数;
- 扩展长度到 2 的幂;
- 分别做 FFT;
- 点值乘法;
- 逆 FFT;
- 输出系数。
九、总结
FFT 的核心思想是:
- 通过偶奇拆分降低复杂度;
- 通过单位根性质把指数分解;
- 通过bit-reversal + 蝶形合并实现高效迭代。
最终实现了从 (O(n2)O(n^2)O(n2)) 到 (O(nlogn)O(n\log n)O(nlogn)) 的突破。
题解
#include <bits/stdc++.h>
using namespace std;
const int N=4e6;
const double PI=acos(-1);struct cpx{double x, y;cpx operator+(const cpx& t)const{return {x+t.x, y+t.y};}cpx operator-(const cpx& t)const{return {x-t.x, y-t.y};}cpx operator*(const cpx& t)const{return {x*t.x-y*t.y, x*t.y+y*t.x};}
}A[N], B[N];
int R[N];void FFT(cpx A[],int n,int op){for(int i=0; i<n; ++i)R[i] = R[i/2]/2 + ((i&1)?n/2:0);for(int i=0; i<n; ++i)if(i<R[i]) swap(A[i],A[R[i]]); for(int i=2; i<=n; i<<=1){cpx w1({cos(2*PI/i),sin(2*PI/i)*op});for(int j=0; j<n; j+=i){cpx wk({1,0});for(int k=j; k<j+i/2; ++k){cpx x=A[k], y=A[k+i/2]*wk;A[k]=x+y; A[k+i/2]=x-y; wk=wk*w1;}}}
}
void solve(){int n,m; cin>>n>>m;for(int i=0; i<=n; i++)cin>>A[i].x;for(int i=0; i<=m; i++)cin>>B[i].x;for(m=n+m,n=1;n<=m;n<<=1);FFT(A,n,1); FFT(B,n,1);for(int i=0;i<n;++i)A[i]=A[i]*B[i];FFT(A,n,-1);for(int i=0;i<=m;++i){int res = (A[i].x/n+0.5);cout<<res<<' ';}cout<<endl;
}int main()
{ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);int t=1;// cin>>t;while(t--){solve();}return 0;
}
