抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

用 FFT 实现多项式乘法

一点前置知识

ee 是一个实数,其中 eπi=1e^{\pi i}=-1

复数

指形如 a+bia+bi 的数,其中 i=1i=\sqrt{-1}a,ba,b 为实数。

(a+bi)+(c+di)=(a+c)+(b+d)i(a+bi)(c+di)=(ac)+(bd)i(a+bi)×(c+di)=(acbd)+(ad+bc)i(a+bi)+(c+di)=(a+c)+(b+d)i\\ (a+bi)-(c+di)=(a-c)+(b-d)i\\ (a+bi)\times(c+di)=(ac-bd)+(ad+bc)i

单位根

如果有一个数 ωn\omega_n,满足 ωnn=1\omega_n^n=1,称这个数为 nn 次单位根。
根据代数基本定理,ωn\omega_n 有且只有 nn 个,而且显然各不相同。
如果一个单位根 xx,它的 1,2,,n1,2,\dots,n 次方分别为 nn 个单位根,我们称 xxnn 次本原单位根。
e2πine^{\frac {2\pi i}n} 是一个 nn 次本原单位根。为了方便,下文中 ωn=e2iπn\omega_n = e^{\frac {2i\pi}n}
单位根的一些性质(0k<n0 \le k < n

  • ωn0=ωnn=1\omega_n^0=\omega_n^n=1
  • ωnk=ωnk+n=ωnk+n2=ω2n2k\omega_n^k=\omega_n^{k+n}=-\omega_n^{k+\frac n2}=\omega_{2n}^{2k}
  • ωnk=cos(2kπn)+sin(2kπn)i\omega_n^k=\cos(\frac{2k\pi}{n})+\sin(\frac{2k\pi}{n})i
  • ωnk=ωnnk=cos(2kπn)sin(2kπn)i\omega_n^{-k}=\omega_n^{n-k}=\cos(\frac{2k\pi}{n})-\sin(\frac{2k\pi}{n})i

多项式的表示方式

系数表示法

对于一个 nn 次多项式可以表示为 A(x)=i=0naixiA(x)=\sum_{i=0}^{n}a_ix^i
注意,nn 次多项式有 n+1n+1 项。

点值表示法

对于一个集合 {(x0,A(x0)),(x1,A(x1)),,(xn,A(xn))}\{(x_0,A(x_0)),(x_1,A(x_1)),\dots, (x_n,A(x_n))\},可以确定一个 nn 次多项式 A(x)A(x)
点值表示法虽然不适合给人看,但是可以加快乘法。
nn 次多项式 G(x)G(x)mm 次多项式 H(x)H(x)
我们设 F(x)=G(x)H(x)F(x)=G(x)\cdot H(x),则 (a,F(a))=(a,G(a)H(a))(a, F(a)) = (a, G(a)\cdot H(a))
我们取 n+m+1n+m+1aa 就能得出 F(x)F(x)
而且 aa 可以随便选,那么我们选有一定关系的 aa,就能加快运算了。

快速傅里叶变换

快速傅里叶变换(FFT)的大致流程:

  1. 将两个相乘的多项式转化成点值表示法。
  2. 将点值相乘得到答案的点值表示法。
  3. 还原出答案。

这里,我们称操作 11 为 DFT,操作 33 为 IDFT。
我们用分治的思想,把大问题变为小问题,将时间复杂度降到 O(nlogn)O(n\log n)

DFT

我们有两种分割方式,DIT 和 DIF。

DIT(Decimation in Time,按时域抽取)

f(x)=i=0n1aixif(x)=\sum_{i=0}^{n-1}a_ix^i,其中 nn22 的整数次幂(原因等下会讲)。
我们根据 ii 的奇偶性进行分组。

f(x)=i=0n1aixi=(i=0n21a2ix2i)+(i=0n21a2i+1x2i+1)=(i=0n21a2ix2i)+x(i=0n21a2i+1x2i)\begin{aligned} f(x)&=\sum_{i=0}^{n-1}a_ix^i\\ &=(\sum_{i=0}^{\frac n2-1}a_{2i}x^{2i})+(\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{2i+1})\\ &=(\sum_{i=0}^{\frac n2-1}a_{2i}x^{2i})+x(\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{2i})\\ \end{aligned}

g(x)=i=0n21a2ixi,h(x)=i=0n21a2i+1xig(x) =\sum_{i=0}^{\frac n2-1}a_{2i}x^{i}, h(x)=\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{i}f(x)=g(x2)+x×h(x2)f(x)=g(x^2)+x\times h(x^2)

ωnk\omega_n^kωnk+n2\omega_n^{k+\frac n2} 分别代入,由单位根的性质得:

f(ωnk)=g(ωn2k)+ωnkh(ωn2k)=g(ωn2k)+ωnkh(ωn2k)f(ωnk+n2)=g(ωn2k+n)+ωnk+n2h(ωn2k+n)=g(ωn2k)ωnkh(ωn2k)\begin{aligned} f(\omega_n^k)=g(\omega_n^{2k})+\omega_n^kh(\omega_n^{2k}) &=g(\omega_{\frac n2}^{k})+\omega_n^kh(\omega_{\frac n2}^k)\\ f(\omega_n^{k+\frac n2})=g(\omega_n^{2k+n})+\omega_n^{k+\frac n2}h(\omega_n^{2k+n}) &=g(\omega_{\frac n2}^{k})-\omega_n^kh(\omega_{\frac n2}^k) \end{aligned}

这样只需要代入一半的单位根的幂就可以了。对于 h(x)h(x)g(x)g(x) 显然可以继续递归。
即我们每次将 aia_iii 的奇偶性分组,ωnk\omega_n^k 前后分成两段。
因为每次都需要严格的将多项式分成相等长度的两部分,所以 nn 必须为 22 的整次幂。

DIF(Decimation in Frequency,按频域抽取)

这次,我们对将 aia_i 分成前后两段,ωnk\omega_n^k 按奇偶分组。

f(x)=i=0n1aixi=(i=0n21aixi)+i=0n21ai+n2xi+n2=(i=0n21aixi)+xn2i=0n21ai+n2xi\begin{aligned} f(x)&=\sum_{i=0}^{n-1}a_ix^i\\ &=(\sum_{i=0}^{\frac n2-1}a_ix^i)+\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}x^{i+\frac n2}\\ &=(\sum_{i=0}^{\frac n2-1}a_ix^i)+x^{\frac n2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}x^i \end{aligned}

2k2\mid k,将 ωnk\omega_n^kωnk+1\omega_n^{k+1} 带入。

f(ωnk)=(i=0n21aiωnik)+ωnkn2i=0n21ai+n2ωnik=(i=0n21aiωnik)+i=0n21ai+n2ωnik=i=0n21(ai+ai+n2)ωnikf(ωnk+1)=(i=0n21aiωni(k+1))+ωn(k+1)n2i=0n21ai+n2ωni(k+1)=(i=0n21aiωni(k+1))i=0n21ai+n2ωni(k+1)=i=0n21ωni(aiai+n2)ωnik\begin{aligned} f(\omega_n^k)&=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{ik})+\omega_n^{\frac {kn}2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{ik}\\ &=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{ik})+\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{ik}\\ &=\sum_{i=0}^{\frac n2-1}(a_i+a_{i+\frac n2})\omega_n^{ik}\\ f(\omega_n^{k+1})&=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{i(k+1)})+\omega_n^{\frac {(k+1)n}2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{i(k+1)}\\ &=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{i(k+1)})-\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{i(k+1)}\\ &=\sum_{i=0}^{\frac n2-1}\omega_n^i(a_i-a_{i+\frac n2})\omega_n^{ik}\\ \end{aligned}

我们同样只需要带入一半即可。

IDFT

考虑怎么变回去。
G(x)=i=0n1F(ωni)xiG(x)=\sum_{i=0}^{n-1}F(\omega_n^i)x^i,其中 F(x)F(x) 为最终的答案。
结论一:对 G(x)G(x) 做 DFT,但用 ωnk\omega_n^{-k} 代替 ωnk\omega_n^k,结果的每一项除以 nn 后为 F(x)F(x)
结论二:对 G(x)G(x) 做 DFT,然后将后 n1n-1 项翻转,结果的每一项除以 nn 后为 F(x)F(x)

我代码中均使用结论一。

结论一证明

G(ωnk)=i=0n1F(ωni)ωnki=i=0n1ωnkij=0n1ajωnij=i=0n1j=0n1ajωnijki=j=0n1aji=0n1(ωnjk)i\begin{aligned}G(\omega_n^{-k})&=\sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{-ki}\\&=\sum_{i=0}^{n-1}\omega_n^{-ki}\sum_{j=0}^{n-1}a_j\omega_n^{ij}\\&=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{ij-ki}\\&=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\end{aligned}

S(ωna)=i=0n1(ωna)iS(\omega_n^a)=\sum_{i=0}^{n-1}(\omega_n^a)^i

a0(modn)a\equiv 0 \pmod n 时,S(ωna)=nS(\omega_n^a)=n
否则,我们错位相减

S(ωna)=i=0n1(ωna)iωnaS(ωna)=i=1n(ωna)iS(ωna)=(ωna)n(ωna)0ωna1=0\begin{aligned}S(\omega_n^a)&=\sum_{i=0}^{n-1}(\omega_n^a)^i\\\omega_n^aS(\omega_n^a)&=\sum_{i=1}^{n}(\omega_n^a)^i\\S(\omega_n^a)&=\frac{(\omega_n^a)^n-(\omega_n^a)^0}{\omega_n^a-1}=0\\\end{aligned}

也就是说

S(ωna)={n,na0,naS(\omega_n^a)=\begin{cases}n,&{n\mid a}\\0,&{n\nmid a}\end{cases}

那么代回原式
G(ωnk)=j=0n1ajS(ωnjk)=nakG(\omega_n^{-k})=\sum_{j=0}^{n-1}a_jS(\omega_n^{j-k})=na_k

综上所述,将 ωnk\omega_n^k 换成 ωnk\omega_n^{-k}G(x)G(x) 跑一遍 DFT,然后除以 nn 即可。

结论二证明

G(ωnk)=i=0n1F(ωni)ωnki=i=0n1ωnkij=0n1ajωnij=i=0n1j=0n1ajωnij+ki=j=0n1aji=0n1(ωnj+k)i\begin{aligned}G(\omega_n^k)&=\sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{ki}\\&=\sum_{i=0}^{n-1}\omega_n^{ki}\sum_{j=0}^{n-1}a_j\omega_n^{ij}\\&=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{ij+ki}\\&=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j+k})^i\end{aligned}

S(ωna)=i=0n1(ωna)iS(\omega_n^a)=\sum_{i=0}^{n-1}(\omega_n^a)^i

同上得知

S(ωna)={n,na0,naS(\omega_n^a)=\begin{cases}n,&{n\mid a}\\0,&{n\nmid a}\end{cases}

带回原式得:

G(ωnk)=j=0n1ajS(ωnj+k)={na0,k=0nank,k0G(\omega_n^k)=\sum_{j=0}^{n-1}a_jS(\omega_n^{j+k})=\begin{cases}na_0,&{k=0}\\na_{n-k},&{k\neq0}\end{cases}

所以,直接对 G(x)G(x) 跑一遍 DFT,然后将后 n1n-1 位翻转,最后除以 nn 即可。

实现

我们用 洛谷P3803 为示例。
先给出一个用递归的简单实现。
我们先以 DIT 为例。

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
double x, y;
cp(double x = 0, double y = 0):x(x), y(y){}
cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
cp&operator*=(const cp&i){return *this = *this * i;}
}tmp[N], a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
if(n == 1)return;
cp *fl = f, *fr = f + n / 2;
for(int i = 0; i < n; i++)tmp[i] = f[i];
for(int i = 0; i < n / 2; i++)
fl[i] = tmp[i * 2], fr[i] = tmp[i * 2 + 1];
FFT(fl, n / 2, on), FFT(fr, n / 2, on);
cp wn(cos(Pi2 / n), on * sin(Pi2 / n)), w(1, 0);
for (int i = 0; i < n / 2; i++)
tmp[i] = fl[i] + w * fr[i],
tmp[i + n / 2] = fl[i] - w * fr[i], w *= wn;
for(int i = 0; i < n; i++)f[i] = tmp[i];
}
int main(){
cin.tie(0)->sync_with_stdio(0);
int n, m;
cin >> n >> m;
for(int i = 0; i <= n; i++)cin >> a[i].x;
for(int i = 0; i <= m; i++)cin >> b[i].x;
int len = 1;
while(len <= n + m)len <<= 1;
FFT(a, len, 1), FFT(b, len, 1);
for(int i = 0; i < len; i++)a[i] *= b[i];
FFT(a, len, -1);
for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

每次我们都要反复复制,常数太大了。那么提前把每个元素放在最后的位置上,就好了。
我们打表可以发现,把每个位置的二进制翻转得到的数就是最后的下标。

1
2
3
for(int i = 0; i < n; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? n >> 1 : 0);
//递推预处理

另外 flkfl_k 复制在 fkf_kfrkfr_k 复制在 fk+n2f_{k+\frac n2},没有重叠,那么可以直接共用。
那么可以变成

递归 DIT
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
double x, y;
cp(double x = 0, double y = 0):x(x), y(y){}
cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
if(n == 1)return;
int half = n >> 1;
FFT(f, half, on), FFT(f + half, half, on);
cp wn(cos(Pi2 / n), on * sin(Pi2 / n)), w(1, 0), z;
for (int i = 0; i < half; i++)
z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn;
}
int rev[N];
int main(){
cin.tie(0)->sync_with_stdio(0);
int n, m;
cin >> n >> m;
for(int i = 0; i <= n; i++)cin >> a[i].x;
for(int i = 0; i <= m; i++)cin >> b[i].x;
int len = 1;
while(len <= n + m)len <<= 1;
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
for(int i = 0; i < len; i++)if(i < rev[i])swap(a[i], a[rev[i]]);
for(int i = 0; i < len; i++)if(i < rev[i])swap(b[i], b[rev[i]]);
FFT(a, len, 1), FFT(b, len, 1);
for(int i = 0; i < len; i++)a[i] *= b[i];
for(int i = 0; i < len; i++)if(i < rev[i])swap(a[i], a[rev[i]]);
FFT(a, len, -1);
for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

同样的,我们可以改写为循环迭代。

迭代 DIT
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
double x, y;
cp(double x = 0, double y = 0):x(x), y(y){}
cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
int rev[N];
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
for(int i = 0; i < n; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
for(int len, p = 2; len = p >> 1, p <= n; p <<= 1){//枚举区间长度
cp wn = cp(cos(Pi2 / p), on * sin(Pi2 / p));
for(int l = 0; l < n; l += p){//枚举区间左端点
cp w = cp(1, 0), z;
for(int i = l; i < l + len; i++)
z = f[i + len] * w, f[i + len] = f[i] - z, f[i] = f[i] + z, w *= wn;
}
}
}
int main(){
cin.tie(0)->sync_with_stdio(0);
int n, m;
cin >> n >> m;
for(int i = 0; i <= n; i++)cin >> a[i].x;
for(int i = 0; i <= m; i++)cin >> b[i].x;
int len = 1;
while(len <= n + m)len <<= 1;
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
FFT(a, len, 1), FFT(b, len, 1);
for(int i = 0; i < len; i++)a[i] *= b[i];
FFT(a, len, -1);
for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

由于 DIF 是按 kk 的奇偶性分组,所以我们要给答案做 rev。

递归 DIF
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
double x, y;
cp(double x = 0, double y = 0):x(x), y(y){}
cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
if(n == 1)return;
int half = n >> 1;
cp wn(cos(Pi2 / n), on * sin(Pi2 / n)), w(1, 0), x, y;
for (int i = 0; i < half; w *= wn, i++)
x = f[i], y = f[i + half], f[i] = x + y, f[i + half] = (x - y) * w;
FFT(f, half, on), FFT(f + half, half, on);
}
int rev[N];
int main(){
cin.tie(0)->sync_with_stdio(0);
int n, m;
cin >> n >> m;
for(int i = 0; i <= n; i++)cin >> a[i].x;
for(int i = 0; i <= m; i++)cin >> b[i].x;
int len = 1;
while(len <= n + m)len <<= 1;
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
FFT(a, len, 1), FFT(b, len, 1);
for(int i = 0; i < len; i++)if(i < rev[i])swap(a[i], a[rev[i]]);
for(int i = 0; i < len; i++)if(i < rev[i])swap(b[i], b[rev[i]]);
for(int i = 0; i < len; i++)a[i] *= b[i];
FFT(a, len, -1);
for(int i = 0; i < len; i++)if(i < rev[i])swap(a[i], a[rev[i]]);
for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}
迭代 DIF
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
double x, y;
cp(double x = 0, double y = 0):x(x), y(y){}
cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
int rev[N];
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
for(int len, p = n; len = p >> 1, p > 1; p >>= 1){//从大到小枚举
cp wn = cp(cos(Pi2 / p), on * sin(Pi2 / p));
for(int l = 0; l < n; l += p){//枚举区间左端点
cp w = cp(1, 0), x, y;
for(int i = l; i < l + len; w *= wn, i++)
x = f[i], y = f[i + len], f[i] = x + y, f[i + len] = (x - y) * w;
}
}
for(int i = 0; i < n; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
}
int main(){
cin.tie(0)->sync_with_stdio(0);
int n, m;
cin >> n >> m;
for(int i = 0; i <= n; i++)cin >> a[i].x;
for(int i = 0; i <= m; i++)cin >> b[i].x;
int len = 1;
while(len <= n + m)len <<= 1;
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
FFT(a, len, 1), FFT(b, len, 1);
for(int i = 0; i < len; i++)a[i] *= b[i];
FFT(a, len, -1);
for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

优化

简单优化

众所周知 (a+bi)2=a2b2+2abi(a+bi)^2 = a^2-b^2+2abi,故我们可以将两个多项式分别放在同一多项式的实部和虚部,FFT 后再计算结果的平方,虚部的一半就是答案。

DIF 需要最后 rev,DIT 需要开始时 rev,我们用 DIF 做 DFT,DIT 做 IDFT。就可以避免用 rev。

递归的常数太大了,我们可以用模版递归,像这样。

模版递归
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
double x, y;
cp(double x = 0, double y = 0):x(x), y(y){}
cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
template<const int n>void DFT(cp*f){//要分开写常数小
const int half = n >> 1;
DFT<half>(f), DFT<half>(f + half);
cp wn(cos(Pi2 / n), sin(Pi2 / n)), w(1, 0), z;
for (int i = 0; i < half; i++)
z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn;
}
template<>void DFT<1>(cp*f){}template<>void DFT<0>(cp*f){}//特化
#define Case(x, fu) case x: fu<x>(f);break;
#define Runfft(x) switch(n){\
Case(1<<1,x)Case(1<<2,x)Case(1<<3,x)Case(1<<4,x)\
Case(1<<5,x)Case(1<<6,x)Case(1<<7,x)Case(1<<8,x)\
Case(1<<9,x)Case(1<<10,x)Case(1<<11,x)Case(1<<12,x)\
Case(1<<13,x)Case(1<<14,x)Case(1<<15,x)Case(1<<16,x)\
Case(1<<17,x)Case(1<<18,x)Case(1<<19,x)Case(1<<20,x)Case(1<<21,x)}
int rev[N];
void rundft(cp *f,const int& n){
for(int i = 0; i < n; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
Runfft(DFT);
}
template<const int n>void IDFT(cp*f){
const int half = n >> 1;
IDFT<half>(f), IDFT<half>(f + half);
cp wn(cos(Pi2 / n), -sin(Pi2 / n)), w(1, 0), z;
for (int i = 0; i < half; i++)
z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn;
}
template<>void IDFT<1>(cp*f){}template<>void IDFT<0>(cp*f){}//特化
void runidft(cp *f,const int& n){
for(int i = 0; i < n; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
Runfft(IDFT);
}
int main(){
cin.tie(0)->sync_with_stdio(0);
int n, m;
cin >> n >> m;
for(int i = 0; i <= n; i++)cin >> a[i].x;
for(int i = 0; i <= m; i++)cin >> b[i].x;
int len = 1;
while(len <= n + m)len <<= 1;
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
rundft(a, len), rundft(b, len);
for(int i = 0; i < len; i++)a[i] *= b[i];
runidft(a, len);
for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

分裂基

我们称每次一分为二为基2,相应的我们也有基4,基8。
这里只介绍 DIT 的做法,但两种代码都会给出,相信读者能想明白 DIF 的做法。
分裂基每次把基2的第 22 项用同样办法拆开,变成:

h(x)=h1(x2)+xh2(x2)f(x)=g(x2)+xh1(x4)+x3h2(x4)h(x) = h_1(x^2) + xh_2(x^2)\\ f(x) = g(x^2) + xh_1(x^4) + x^3h_2(x^4)

我们分别将 ωnk,ωnk+n4,ωnk+n2,ωnk+3n4\omega_n^k,\omega_n^{k+\frac n4},\omega_n^{k+\frac n2},\omega_n^{k+\frac{3n}4} 带入。

f(ωnk)=g(ωn2k)+ωnkh1(ωn4k)+ωn3kh2(ωn4k)f(ωnk+n4)=g(ωn2k+n2)(ωnkh1(ωn4k)ωn3kh2(ωn4k))if(ωnk+n2)=g(ωn2k)ωnkh1(ωn4k)ωn3kh2(ωn4k)f(ωnk+3n4)=g(ωn2k+n2)+(ωnkh1(ωn4k)ωn3kh2(ωn4k))i\begin{aligned} f(\omega_n^k)&=g(\omega_n^{2k})+\omega_n^kh_1(\omega_n^{4k})+\omega_n^3kh_2(\omega_n^{4k})\\ f(\omega_n^{k+\frac n4})&=g(\omega_n^{2k+\frac n2})-(\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h_2(\omega_n^{4k}))i\\ f(\omega_n^{k+\frac n2})&=g(\omega_n^{2k})-\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h2(\omega_n^{4k})\\ f(\omega_n^{k+\frac {3n}4})&=g(\omega_n^{2k+\frac n2})+(\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h_2(\omega_n^{4k}))i\\ \end{aligned}

分裂基适用于序列长度为 2n2^n 的 FFT,并且是运算次数比基8更少,并且更灵活。
运算次数的比较见 ooura 的博客
由于迭代版本过于复杂,所以只给出递归版本。

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include<bits/stdc++.h>
using namespace std;
#define Case(x, fu) case x: fu<x>(f);break;
#define Runfft(x) switch(n){\
Case(1<<1,x)Case(1<<2,x)Case(1<<3,x)Case(1<<4,x)\
Case(1<<5,x)Case(1<<6,x)Case(1<<7,x)Case(1<<8,x)\
Case(1<<9,x)Case(1<<10,x)Case(1<<11,x)Case(1<<12,x)\
Case(1<<13,x)Case(1<<14,x)Case(1<<15,x)Case(1<<16,x)\
Case(1<<17,x)Case(1<<18,x)Case(1<<19,x)Case(1<<20,x)Case(1<<21,x)}
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;

struct cp{
double x, y;
cp(double x = 0, double y = 0):x(x), y(y){}
cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
cp&operator+=(const cp&i){return this->x += i.x, this->y += i.y, *this;}
cp&operator-=(const cp&i){return this->x -= i.x, this->y -= i.y, *this;}
cp&operator*=(const cp&i){return *this = *this * i;}
}a[N];//实现复数类,x 代表实部,y 代表虚部。

template<const int n>void DFT(cp*f){//DIF
const int half = n >> 1, quar = n >> 2;
cp w(1, 0), wn(cos(Pi2 / n), sin(Pi2 / n)), w3, x, y;
cp *a1 = &f[0], *a2 = a1 + quar, *a3 = a2 + quar, *a4 = a3 + quar;
for(int i = 0; i < quar; i++){
w3 = w * w * w, x = *a1 - *a3, y = *a2 - *a4, y = cp(y.y, -y.x);
*a1 += *a3, *a2 += *a4, *a3 = (x - y) * w, *a4 = (x + y) * w3;
a1++, a2++, a3++, a4++, w *= wn;
}
DFT<half>(f), DFT<quar>(f + half), DFT<quar>(f + half + quar);
}
template<>void DFT<2>(cp*f){cp x = f[0], y = f[1];f[0] = x + y, f[1] = x - y;}
template<>void DFT<1>(cp*f){}template<>void DFT<0>(cp*f){}//特化
void rundft(cp *f,const int& n){Runfft(DFT);}

template<const int n>void IDFT(cp*f){//DIT
const int half = n >> 1, quar = n >> 2;
IDFT<half>(f), IDFT<quar>(f + half), IDFT<quar>(f + half + quar);
cp wn(cos(Pi2 / n), -sin(Pi2 / n)), w(1, 0), w3, tmp1, tmp2, x, y;
cp *a1 = &f[0], *a2 = a1 + quar, *a3 = a2 + quar, *a4 = a3 + quar;
for (int i = 0; i < quar; i++){
w3 = w * w * w, tmp1 = w * *a3, tmp2 = w3 * *a4;
x = tmp1 + tmp2, y = tmp1 - tmp2, y = cp(y.y, -y.x);
*a4 = *a2 - y, *a3 = *a1 - x, *a2 += y, *a1 += x;
a1++, a2++, a3++, a4++, w *= wn;
}
}
template<>void IDFT<2>(cp*f){cp x = f[0], y = f[1];f[0] = x + y, f[1] = x - y;}
template<>void IDFT<1>(cp*f){}template<>void IDFT<0>(cp*f){}//特化
void runidft(cp *f,const int& n){Runfft(IDFT);}

int main(){
cin.tie(0)->sync_with_stdio(0);
int n, m;
cin >> n >> m;
for(int i = 0; i <= n; i++)cin >> a[i].x;
for(int i = 0; i <= m; i++)cin >> a[i].y;
int len = 1;
while(len <= n + m)len <<= 1;
rundft(a, len);
for(int i = 0; i < len; i++)a[i] *= a[i];
runidft(a, len);
double inv = 0.5 / len;
for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].y * inv + 0.49) << " ";
}

参考资料

https://oi-wiki.org/math/poly/fft/
https://www.bilibili.com/opus/785022478912061446
https://www.luogu.com.cn/article/rj58c2eb
https://charleswu.site/archives/3065