用 FFT 实现多项式乘法
一点前置知识
e e e 是一个实数,其中 e π i = − 1 e^{\pi i}=-1 e π i = − 1 。
复数
指形如 a + b i a+bi a + b i 的数,其中 i = − 1 i=\sqrt{-1} i = − 1 ,a , b a,b a , b 为实数。
( a + b i ) + ( c + d i ) = ( a + c ) + ( b + d ) i ( a + b i ) − ( c + d i ) = ( a − c ) + ( b − d ) i ( a + b i ) × ( c + d i ) = ( a c − b d ) + ( a d + b c ) i (a+bi)+(c+di)=(a+c)+(b+d)i\\
(a+bi)-(c+di)=(a-c)+(b-d)i\\
(a+bi)\times(c+di)=(ac-bd)+(ad+bc)i
( a + b i ) + ( c + d i ) = ( a + c ) + ( b + d ) i ( a + b i ) − ( c + d i ) = ( a − c ) + ( b − d ) i ( a + b i ) × ( c + d i ) = ( a c − b d ) + ( a d + b c ) i
单位根
如果有一个数 ω n \omega_n ω n ,满足 ω n n = 1 \omega_n^n=1 ω n n = 1 ,称这个数为 n n n 次单位根。
根据代数基本定理,ω n \omega_n ω n 有且只有 n n n 个,而且显然各不相同。
如果一个单位根 x x x ,它的 1 , 2 , … , n 1,2,\dots,n 1 , 2 , … , n 次方分别为 n n n 个单位根,我们称 x x x 为 n n n 次本原单位根。
e 2 π i n e^{\frac {2\pi i}n} e n 2 π i 是一个 n n n 次本原单位根。为了方便,下文中 ω n = e 2 i π n \omega_n = e^{\frac {2i\pi}n} ω n = e n 2 i π 。
单位根的一些性质(0 ≤ k < n 0 \le k < n 0 ≤ k < n )
ω n 0 = ω n n = 1 \omega_n^0=\omega_n^n=1 ω n 0 = ω n n = 1
ω n k = ω n k + n = − ω n k + n 2 = ω 2 n 2 k \omega_n^k=\omega_n^{k+n}=-\omega_n^{k+\frac n2}=\omega_{2n}^{2k} ω n k = ω n k + n = − ω n k + 2 n = ω 2 n 2 k
ω n k = cos ( 2 k π n ) + sin ( 2 k π n ) i \omega_n^k=\cos(\frac{2k\pi}{n})+\sin(\frac{2k\pi}{n})i ω n k = cos ( n 2 k π ) + sin ( n 2 k π ) i
ω n − k = ω n n − k = cos ( 2 k π n ) − sin ( 2 k π n ) i \omega_n^{-k}=\omega_n^{n-k}=\cos(\frac{2k\pi}{n})-\sin(\frac{2k\pi}{n})i ω n − k = ω n n − k = cos ( n 2 k π ) − sin ( n 2 k π ) i
多项式的表示方式
系数表示法
对于一个 n n n 次多项式可以表示为 A ( x ) = ∑ i = 0 n a i x i A(x)=\sum_{i=0}^{n}a_ix^i A ( x ) = ∑ i = 0 n a i x i 。
注意,n n n 次多项式有 n + 1 n+1 n + 1 项。
点值表示法
对于一个集合 { ( x 0 , A ( x 0 ) ) , ( x 1 , A ( x 1 ) ) , … , ( x n , A ( x n ) ) } \{(x_0,A(x_0)),(x_1,A(x_1)),\dots, (x_n,A(x_n))\} { ( x 0 , A ( x 0 ) ) , ( x 1 , A ( x 1 ) ) , … , ( x n , A ( x n ) ) } ,可以确定一个 n n n 次多项式 A ( x ) A(x) A ( x ) 。
点值表示法虽然不适合给人看,但是可以加快乘法。
设 n n n 次多项式 G ( x ) G(x) G ( x ) 和 m m m 次多项式 H ( x ) H(x) H ( x ) 。
我们设 F ( x ) = G ( x ) ⋅ H ( x ) F(x)=G(x)\cdot H(x) F ( x ) = G ( x ) ⋅ H ( x ) ,则 ( a , F ( a ) ) = ( a , G ( a ) ⋅ H ( a ) ) (a, F(a)) = (a, G(a)\cdot H(a)) ( a , F ( a ) ) = ( a , G ( a ) ⋅ H ( a ) ) 。
我们取 n + m + 1 n+m+1 n + m + 1 个 a a a 就能得出 F ( x ) F(x) F ( x ) 。
而且 a a a 可以随便选,那么我们选有一定关系的 a a a ,就能加快运算了。
快速傅里叶变换
快速傅里叶变换(FFT)的大致流程:
将两个相乘的多项式转化成点值表示法。
将点值相乘得到答案的点值表示法。
还原出答案。
这里,我们称操作 1 1 1 为 DFT,操作 3 3 3 为 IDFT。
我们用分治的思想,把大问题变为小问题,将时间复杂度降到 O ( n log n ) O(n\log n) O ( n log n ) 。
DFT
我们有两种分割方式,DIT 和 DIF。
DIT(Decimation in Time,按时域抽取)
令 f ( x ) = ∑ i = 0 n − 1 a i x i f(x)=\sum_{i=0}^{n-1}a_ix^i f ( x ) = ∑ i = 0 n − 1 a i x i ,其中 n n n 为 2 2 2 的整数次幂(原因等下会讲)。
我们根据 i i i 的奇偶性进行分组。
f ( x ) = ∑ i = 0 n − 1 a i x i = ( ∑ i = 0 n 2 − 1 a 2 i x 2 i ) + ( ∑ i = 0 n 2 − 1 a 2 i + 1 x 2 i + 1 ) = ( ∑ i = 0 n 2 − 1 a 2 i x 2 i ) + x ( ∑ i = 0 n 2 − 1 a 2 i + 1 x 2 i ) \begin{aligned}
f(x)&=\sum_{i=0}^{n-1}a_ix^i\\
&=(\sum_{i=0}^{\frac n2-1}a_{2i}x^{2i})+(\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{2i+1})\\
&=(\sum_{i=0}^{\frac n2-1}a_{2i}x^{2i})+x(\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{2i})\\
\end{aligned}
f ( x ) = i = 0 ∑ n − 1 a i x i = ( i = 0 ∑ 2 n − 1 a 2 i x 2 i ) + ( i = 0 ∑ 2 n − 1 a 2 i + 1 x 2 i + 1 ) = ( i = 0 ∑ 2 n − 1 a 2 i x 2 i ) + x ( i = 0 ∑ 2 n − 1 a 2 i + 1 x 2 i )
令 g ( x ) = ∑ i = 0 n 2 − 1 a 2 i x i , h ( x ) = ∑ i = 0 n 2 − 1 a 2 i + 1 x i g(x) =\sum_{i=0}^{\frac n2-1}a_{2i}x^{i}, h(x)=\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{i} g ( x ) = ∑ i = 0 2 n − 1 a 2 i x i , h ( x ) = ∑ i = 0 2 n − 1 a 2 i + 1 x i 则 f ( x ) = g ( x 2 ) + x × h ( x 2 ) f(x)=g(x^2)+x\times h(x^2) f ( x ) = g ( x 2 ) + x × h ( x 2 )
将 ω n k \omega_n^k ω n k 和 ω n k + n 2 \omega_n^{k+\frac n2} ω n k + 2 n 分别代入,由单位根的性质得:
f ( ω n k ) = g ( ω n 2 k ) + ω n k h ( ω n 2 k ) = g ( ω n 2 k ) + ω n k h ( ω n 2 k ) f ( ω n k + n 2 ) = g ( ω n 2 k + n ) + ω n k + n 2 h ( ω n 2 k + n ) = g ( ω n 2 k ) − ω n k h ( ω n 2 k ) \begin{aligned}
f(\omega_n^k)=g(\omega_n^{2k})+\omega_n^kh(\omega_n^{2k})
&=g(\omega_{\frac n2}^{k})+\omega_n^kh(\omega_{\frac n2}^k)\\
f(\omega_n^{k+\frac n2})=g(\omega_n^{2k+n})+\omega_n^{k+\frac n2}h(\omega_n^{2k+n})
&=g(\omega_{\frac n2}^{k})-\omega_n^kh(\omega_{\frac n2}^k)
\end{aligned}
f ( ω n k ) = g ( ω n 2 k ) + ω n k h ( ω n 2 k ) f ( ω n k + 2 n ) = g ( ω n 2 k + n ) + ω n k + 2 n h ( ω n 2 k + n ) = g ( ω 2 n k ) + ω n k h ( ω 2 n k ) = g ( ω 2 n k ) − ω n k h ( ω 2 n k )
这样只需要代入一半的单位根的幂就可以了。对于 h ( x ) h(x) h ( x ) 和 g ( x ) g(x) g ( x ) 显然可以继续递归。
即我们每次将 a i a_i a i 用 i i i 的奇偶性分组,ω n k \omega_n^k ω n k 前后分成两段。
因为每次都需要严格的将多项式分成相等长度的两部分,所以 n n n 必须为 2 2 2 的整次幂。
DIF(Decimation in Frequency,按频域抽取)
这次,我们对将 a i a_i a i 分成前后两段,ω n k \omega_n^k ω n k 按奇偶分组。
f ( x ) = ∑ i = 0 n − 1 a i x i = ( ∑ i = 0 n 2 − 1 a i x i ) + ∑ i = 0 n 2 − 1 a i + n 2 x i + n 2 = ( ∑ i = 0 n 2 − 1 a i x i ) + x n 2 ∑ i = 0 n 2 − 1 a i + n 2 x i \begin{aligned}
f(x)&=\sum_{i=0}^{n-1}a_ix^i\\
&=(\sum_{i=0}^{\frac n2-1}a_ix^i)+\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}x^{i+\frac n2}\\
&=(\sum_{i=0}^{\frac n2-1}a_ix^i)+x^{\frac n2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}x^i
\end{aligned}
f ( x ) = i = 0 ∑ n − 1 a i x i = ( i = 0 ∑ 2 n − 1 a i x i ) + i = 0 ∑ 2 n − 1 a i + 2 n x i + 2 n = ( i = 0 ∑ 2 n − 1 a i x i ) + x 2 n i = 0 ∑ 2 n − 1 a i + 2 n x i
设 2 ∣ k 2\mid k 2 ∣ k ,将 ω n k \omega_n^k ω n k 和 ω n k + 1 \omega_n^{k+1} ω n k + 1 带入。
f ( ω n k ) = ( ∑ i = 0 n 2 − 1 a i ω n i k ) + ω n k n 2 ∑ i = 0 n 2 − 1 a i + n 2 ω n i k = ( ∑ i = 0 n 2 − 1 a i ω n i k ) + ∑ i = 0 n 2 − 1 a i + n 2 ω n i k = ∑ i = 0 n 2 − 1 ( a i + a i + n 2 ) ω n i k f ( ω n k + 1 ) = ( ∑ i = 0 n 2 − 1 a i ω n i ( k + 1 ) ) + ω n ( k + 1 ) n 2 ∑ i = 0 n 2 − 1 a i + n 2 ω n i ( k + 1 ) = ( ∑ i = 0 n 2 − 1 a i ω n i ( k + 1 ) ) − ∑ i = 0 n 2 − 1 a i + n 2 ω n i ( k + 1 ) = ∑ i = 0 n 2 − 1 ω n i ( a i − a i + n 2 ) ω n i k \begin{aligned}
f(\omega_n^k)&=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{ik})+\omega_n^{\frac {kn}2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{ik}\\
&=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{ik})+\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{ik}\\
&=\sum_{i=0}^{\frac n2-1}(a_i+a_{i+\frac n2})\omega_n^{ik}\\
f(\omega_n^{k+1})&=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{i(k+1)})+\omega_n^{\frac {(k+1)n}2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{i(k+1)}\\
&=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{i(k+1)})-\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{i(k+1)}\\
&=\sum_{i=0}^{\frac n2-1}\omega_n^i(a_i-a_{i+\frac n2})\omega_n^{ik}\\
\end{aligned}
f ( ω n k ) f ( ω n k + 1 ) = ( i = 0 ∑ 2 n − 1 a i ω n i k ) + ω n 2 k n i = 0 ∑ 2 n − 1 a i + 2 n ω n i k = ( i = 0 ∑ 2 n − 1 a i ω n i k ) + i = 0 ∑ 2 n − 1 a i + 2 n ω n i k = i = 0 ∑ 2 n − 1 ( a i + a i + 2 n ) ω n i k = ( i = 0 ∑ 2 n − 1 a i ω n i ( k + 1 ) ) + ω n 2 ( k + 1 ) n i = 0 ∑ 2 n − 1 a i + 2 n ω n i ( k + 1 ) = ( i = 0 ∑ 2 n − 1 a i ω n i ( k + 1 ) ) − i = 0 ∑ 2 n − 1 a i + 2 n ω n i ( k + 1 ) = i = 0 ∑ 2 n − 1 ω n i ( a i − a i + 2 n ) ω n i k
我们同样只需要带入一半即可。
IDFT
考虑怎么变回去。
设 G ( x ) = ∑ i = 0 n − 1 F ( ω n i ) x i G(x)=\sum_{i=0}^{n-1}F(\omega_n^i)x^i G ( x ) = ∑ i = 0 n − 1 F ( ω n i ) x i ,其中 F ( x ) F(x) F ( x ) 为最终的答案。
结论一:对 G ( x ) G(x) G ( x ) 做 DFT,但用 ω n − k \omega_n^{-k} ω n − k 代替 ω n k \omega_n^k ω n k ,结果的每一项除以 n n n 后为 F ( x ) F(x) F ( x ) 。
结论二:对 G ( x ) G(x) G ( x ) 做 DFT,然后将后 n − 1 n-1 n − 1 项翻转,结果的每一项除以 n n n 后为 F ( x ) F(x) F ( x ) 。
我代码中均使用结论一。
结论一证明
G ( ω n − k ) = ∑ i = 0 n − 1 F ( ω n i ) ω n − k i = ∑ i = 0 n − 1 ω n − k i ∑ j = 0 n − 1 a j ω n i j = ∑ i = 0 n − 1 ∑ j = 0 n − 1 a j ω n i j − k i = ∑ j = 0 n − 1 a j ∑ i = 0 n − 1 ( ω n j − k ) i \begin{aligned}G(\omega_n^{-k})&=\sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{-ki}\\&=\sum_{i=0}^{n-1}\omega_n^{-ki}\sum_{j=0}^{n-1}a_j\omega_n^{ij}\\&=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{ij-ki}\\&=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\end{aligned} G ( ω n − k ) = i = 0 ∑ n − 1 F ( ω n i ) ω n − k i = i = 0 ∑ n − 1 ω n − k i j = 0 ∑ n − 1 a j ω n i j = i = 0 ∑ n − 1 j = 0 ∑ n − 1 a j ω n i j − k i = j = 0 ∑ n − 1 a j i = 0 ∑ n − 1 ( ω n j − k ) i
记 S ( ω n a ) = ∑ i = 0 n − 1 ( ω n a ) i S(\omega_n^a)=\sum_{i=0}^{n-1}(\omega_n^a)^i S ( ω n a ) = ∑ i = 0 n − 1 ( ω n a ) i 。
当 a ≡ 0 ( m o d n ) a\equiv 0 \pmod n a ≡ 0 ( m o d n ) 时,S ( ω n a ) = n S(\omega_n^a)=n S ( ω n a ) = n 。 否则,我们错位相减
S ( ω n a ) = ∑ i = 0 n − 1 ( ω n a ) i ω n a S ( ω n a ) = ∑ i = 1 n ( ω n a ) i S ( ω n a ) = ( ω n a ) n − ( ω n a ) 0 ω n a − 1 = 0 \begin{aligned}S(\omega_n^a)&=\sum_{i=0}^{n-1}(\omega_n^a)^i\\\omega_n^aS(\omega_n^a)&=\sum_{i=1}^{n}(\omega_n^a)^i\\S(\omega_n^a)&=\frac{(\omega_n^a)^n-(\omega_n^a)^0}{\omega_n^a-1}=0\\\end{aligned} S ( ω n a ) ω n a S ( ω n a ) S ( ω n a ) = i = 0 ∑ n − 1 ( ω n a ) i = i = 1 ∑ n ( ω n a ) i = ω n a − 1 ( ω n a ) n − ( ω n a ) 0 = 0
也就是说
S ( ω n a ) = { n , n ∣ a 0 , n ∤ a S(\omega_n^a)=\begin{cases}n,&{n\mid a}\\0,&{n\nmid a}\end{cases} S ( ω n a ) = { n , 0 , n ∣ a n ∤ a
那么代回原式G ( ω n − k ) = ∑ j = 0 n − 1 a j S ( ω n j − k ) = n a k G(\omega_n^{-k})=\sum_{j=0}^{n-1}a_jS(\omega_n^{j-k})=na_k G ( ω n − k ) = ∑ j = 0 n − 1 a j S ( ω n j − k ) = n a k
综上所述,将 ω n k \omega_n^k ω n k 换成 ω n − k \omega_n^{-k} ω n − k 对 G ( x ) G(x) G ( x ) 跑一遍 DFT,然后除以 n n n 即可。
结论二证明
G ( ω n k ) = ∑ i = 0 n − 1 F ( ω n i ) ω n k i = ∑ i = 0 n − 1 ω n k i ∑ j = 0 n − 1 a j ω n i j = ∑ i = 0 n − 1 ∑ j = 0 n − 1 a j ω n i j + k i = ∑ j = 0 n − 1 a j ∑ i = 0 n − 1 ( ω n j + k ) i \begin{aligned}G(\omega_n^k)&=\sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{ki}\\&=\sum_{i=0}^{n-1}\omega_n^{ki}\sum_{j=0}^{n-1}a_j\omega_n^{ij}\\&=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{ij+ki}\\&=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j+k})^i\end{aligned} G ( ω n k ) = i = 0 ∑ n − 1 F ( ω n i ) ω n k i = i = 0 ∑ n − 1 ω n k i j = 0 ∑ n − 1 a j ω n i j = i = 0 ∑ n − 1 j = 0 ∑ n − 1 a j ω n i j + k i = j = 0 ∑ n − 1 a j i = 0 ∑ n − 1 ( ω n j + k ) i
记 S ( ω n a ) = ∑ i = 0 n − 1 ( ω n a ) i S(\omega_n^a)=\sum_{i=0}^{n-1}(\omega_n^a)^i S ( ω n a ) = ∑ i = 0 n − 1 ( ω n a ) i 。
同上得知
S ( ω n a ) = { n , n ∣ a 0 , n ∤ a S(\omega_n^a)=\begin{cases}n,&{n\mid a}\\0,&{n\nmid a}\end{cases} S ( ω n a ) = { n , 0 , n ∣ a n ∤ a
带回原式得:
G ( ω n k ) = ∑ j = 0 n − 1 a j S ( ω n j + k ) = { n a 0 , k = 0 n a n − k , k ≠ 0 G(\omega_n^k)=\sum_{j=0}^{n-1}a_jS(\omega_n^{j+k})=\begin{cases}na_0,&{k=0}\\na_{n-k},&{k\neq0}\end{cases} G ( ω n k ) = j = 0 ∑ n − 1 a j S ( ω n j + k ) = { n a 0 , n a n − k , k = 0 k = 0
所以,直接对 G ( x ) G(x) G ( x ) 跑一遍 DFT,然后将后 n − 1 n-1 n − 1 位翻转,最后除以 n n n 即可。
实现
我们用 洛谷P3803 为示例。
先给出一个用递归的简单实现。
我们先以 DIT 为例。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 #include <bits/stdc++.h> using namespace std;const double Pi = acos (-1 ), Pi2 = 2 * Pi;const int N = 4e6 + 10 ;struct cp { double x, y; cp (double x = 0 , double y = 0 ):x (x), y (y){} cp operator +(const cp&i)const {return cp (x + i.x, y + i.y);} cp operator -(const cp&i)const {return cp (x - i.x, y - i.y);} cp operator *(const cp&i)const {return cp (x * i.x - y * i.y, y * i.x + x * i.y);} cp&operator *=(const cp&i){return *this = *this * i;} }tmp[N], a[N], b[N]; void FFT (cp*f, int n, int on) { if (n == 1 )return ; cp *fl = f, *fr = f + n / 2 ; for (int i = 0 ; i < n; i++)tmp[i] = f[i]; for (int i = 0 ; i < n / 2 ; i++) fl[i] = tmp[i * 2 ], fr[i] = tmp[i * 2 + 1 ]; FFT (fl, n / 2 , on), FFT (fr, n / 2 , on); cp wn (cos(Pi2 / n), on * sin(Pi2 / n)) , w (1 , 0 ) ; for (int i = 0 ; i < n / 2 ; i++) tmp[i] = fl[i] + w * fr[i], tmp[i + n / 2 ] = fl[i] - w * fr[i], w *= wn; for (int i = 0 ; i < n; i++)f[i] = tmp[i]; } int main () { cin.tie (0 )->sync_with_stdio (0 ); int n, m; cin >> n >> m; for (int i = 0 ; i <= n; i++)cin >> a[i].x; for (int i = 0 ; i <= m; i++)cin >> b[i].x; int len = 1 ; while (len <= n + m)len <<= 1 ; FFT (a, len, 1 ), FFT (b, len, 1 ); for (int i = 0 ; i < len; i++)a[i] *= b[i]; FFT (a, len, -1 ); for (int i = 0 ; i <= n + m; i++)cout << (long long )(a[i].x / len + 0.49 ) << " " ; }
每次我们都要反复复制,常数太大了。那么提前把每个元素放在最后的位置上,就好了。
我们打表可以发现,把每个位置的二进制翻转得到的数 就是最后的下标。
1 2 3 for (int i = 0 ; i < n; i++) rev[i] = (rev[i >> 1 ] >> 1 ) | ((i & 1 ) ? n >> 1 : 0 );
另外 f l k fl_k f l k 复制在 f k f_k f k ,f r k fr_k f r k 复制在 f k + n 2 f_{k+\frac n2} f k + 2 n ,没有重叠,那么可以直接共用。
那么可以变成
递归 DIT
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 #include <bits/stdc++.h> using namespace std;const double Pi = acos (-1 ), Pi2 = 2 * Pi;const int N = 4e6 + 10 ;struct cp { double x, y; cp (double x = 0 , double y = 0 ):x (x), y (y){} cp operator +(const cp&i)const {return cp (x + i.x, y + i.y);} cp operator -(const cp&i)const {return cp (x - i.x, y - i.y);} cp operator *(const cp&i)const {return cp (x * i.x - y * i.y, y * i.x + x * i.y);} cp&operator *=(const cp&i){return *this = *this * i;} }a[N], b[N]; void FFT (cp*f, int n, int on) { if (n == 1 )return ; int half = n >> 1 ; FFT (f, half, on), FFT (f + half, half, on); cp wn (cos(Pi2 / n), on * sin(Pi2 / n)) , w (1 , 0 ) , z ; for (int i = 0 ; i < half; i++) z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn; } int rev[N];int main () { cin.tie (0 )->sync_with_stdio (0 ); int n, m; cin >> n >> m; for (int i = 0 ; i <= n; i++)cin >> a[i].x; for (int i = 0 ; i <= m; i++)cin >> b[i].x; int len = 1 ; while (len <= n + m)len <<= 1 ; for (int i = 0 ; i < len; i++) rev[i] = (rev[i >> 1 ] >> 1 ) | ((i & 1 ) ? len >> 1 : 0 ); for (int i = 0 ; i < len; i++)if (i < rev[i])swap (a[i], a[rev[i]]); for (int i = 0 ; i < len; i++)if (i < rev[i])swap (b[i], b[rev[i]]); FFT (a, len, 1 ), FFT (b, len, 1 ); for (int i = 0 ; i < len; i++)a[i] *= b[i]; for (int i = 0 ; i < len; i++)if (i < rev[i])swap (a[i], a[rev[i]]); FFT (a, len, -1 ); for (int i = 0 ; i <= n + m; i++)cout << (long long )(a[i].x / len + 0.49 ) << " " ; }
同样的,我们可以改写为循环迭代。
迭代 DIT
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 #include <bits/stdc++.h> using namespace std;const double Pi = acos (-1 ), Pi2 = 2 * Pi;const int N = 4e6 + 10 ;struct cp { double x, y; cp (double x = 0 , double y = 0 ):x (x), y (y){} cp operator +(const cp&i)const {return cp (x + i.x, y + i.y);} cp operator -(const cp&i)const {return cp (x - i.x, y - i.y);} cp operator *(const cp&i)const {return cp (x * i.x - y * i.y, y * i.x + x * i.y);} cp&operator *=(const cp&i){return *this = *this * i;} }a[N], b[N]; int rev[N];void FFT (cp*f, int n, int on) { for (int i = 0 ; i < n; i++)if (i < rev[i])swap (f[i], f[rev[i]]); for (int len, p = 2 ; len = p >> 1 , p <= n; p <<= 1 ){ cp wn = cp (cos (Pi2 / p), on * sin (Pi2 / p)); for (int l = 0 ; l < n; l += p){ cp w = cp (1 , 0 ), z; for (int i = l; i < l + len; i++) z = f[i + len] * w, f[i + len] = f[i] - z, f[i] = f[i] + z, w *= wn; } } } int main () { cin.tie (0 )->sync_with_stdio (0 ); int n, m; cin >> n >> m; for (int i = 0 ; i <= n; i++)cin >> a[i].x; for (int i = 0 ; i <= m; i++)cin >> b[i].x; int len = 1 ; while (len <= n + m)len <<= 1 ; for (int i = 0 ; i < len; i++) rev[i] = (rev[i >> 1 ] >> 1 ) | ((i & 1 ) ? len >> 1 : 0 ); FFT (a, len, 1 ), FFT (b, len, 1 ); for (int i = 0 ; i < len; i++)a[i] *= b[i]; FFT (a, len, -1 ); for (int i = 0 ; i <= n + m; i++)cout << (long long )(a[i].x / len + 0.49 ) << " " ; }
由于 DIF 是按 k k k 的奇偶性分组,所以我们要给答案做 rev。
递归 DIF
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 #include <bits/stdc++.h> using namespace std;const double Pi = acos (-1 ), Pi2 = 2 * Pi;const int N = 4e6 + 10 ;struct cp { double x, y; cp (double x = 0 , double y = 0 ):x (x), y (y){} cp operator +(const cp&i)const {return cp (x + i.x, y + i.y);} cp operator -(const cp&i)const {return cp (x - i.x, y - i.y);} cp operator *(const cp&i)const {return cp (x * i.x - y * i.y, y * i.x + x * i.y);} cp&operator *=(const cp&i){return *this = *this * i;} }a[N], b[N]; void FFT (cp*f, int n, int on) { if (n == 1 )return ; int half = n >> 1 ; cp wn (cos(Pi2 / n), on * sin(Pi2 / n)) , w (1 , 0 ) , x, y ; for (int i = 0 ; i < half; w *= wn, i++) x = f[i], y = f[i + half], f[i] = x + y, f[i + half] = (x - y) * w; FFT (f, half, on), FFT (f + half, half, on); } int rev[N];int main () { cin.tie (0 )->sync_with_stdio (0 ); int n, m; cin >> n >> m; for (int i = 0 ; i <= n; i++)cin >> a[i].x; for (int i = 0 ; i <= m; i++)cin >> b[i].x; int len = 1 ; while (len <= n + m)len <<= 1 ; for (int i = 0 ; i < len; i++) rev[i] = (rev[i >> 1 ] >> 1 ) | ((i & 1 ) ? len >> 1 : 0 ); FFT (a, len, 1 ), FFT (b, len, 1 ); for (int i = 0 ; i < len; i++)if (i < rev[i])swap (a[i], a[rev[i]]); for (int i = 0 ; i < len; i++)if (i < rev[i])swap (b[i], b[rev[i]]); for (int i = 0 ; i < len; i++)a[i] *= b[i]; FFT (a, len, -1 ); for (int i = 0 ; i < len; i++)if (i < rev[i])swap (a[i], a[rev[i]]); for (int i = 0 ; i <= n + m; i++)cout << (long long )(a[i].x / len + 0.49 ) << " " ; }
迭代 DIF
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 #include <bits/stdc++.h> using namespace std;const double Pi = acos (-1 ), Pi2 = 2 * Pi;const int N = 4e6 + 10 ;struct cp { double x, y; cp (double x = 0 , double y = 0 ):x (x), y (y){} cp operator +(const cp&i)const {return cp (x + i.x, y + i.y);} cp operator -(const cp&i)const {return cp (x - i.x, y - i.y);} cp operator *(const cp&i)const {return cp (x * i.x - y * i.y, y * i.x + x * i.y);} cp&operator *=(const cp&i){return *this = *this * i;} }a[N], b[N]; int rev[N];void FFT (cp*f, int n, int on) { for (int len, p = n; len = p >> 1 , p > 1 ; p >>= 1 ){ cp wn = cp (cos (Pi2 / p), on * sin (Pi2 / p)); for (int l = 0 ; l < n; l += p){ cp w = cp (1 , 0 ), x, y; for (int i = l; i < l + len; w *= wn, i++) x = f[i], y = f[i + len], f[i] = x + y, f[i + len] = (x - y) * w; } } for (int i = 0 ; i < n; i++)if (i < rev[i])swap (f[i], f[rev[i]]); } int main () { cin.tie (0 )->sync_with_stdio (0 ); int n, m; cin >> n >> m; for (int i = 0 ; i <= n; i++)cin >> a[i].x; for (int i = 0 ; i <= m; i++)cin >> b[i].x; int len = 1 ; while (len <= n + m)len <<= 1 ; for (int i = 0 ; i < len; i++) rev[i] = (rev[i >> 1 ] >> 1 ) | ((i & 1 ) ? len >> 1 : 0 ); FFT (a, len, 1 ), FFT (b, len, 1 ); for (int i = 0 ; i < len; i++)a[i] *= b[i]; FFT (a, len, -1 ); for (int i = 0 ; i <= n + m; i++)cout << (long long )(a[i].x / len + 0.49 ) << " " ; }
优化
简单优化
众所周知 ( a + b i ) 2 = a 2 − b 2 + 2 a b i (a+bi)^2 = a^2-b^2+2abi ( a + b i ) 2 = a 2 − b 2 + 2 a b i ,故我们可以将两个多项式分别放在同一多项式的实部和虚部,FFT 后再计算结果的平方,虚部的一半就是答案。
DIF 需要最后 rev,DIT 需要开始时 rev,我们用 DIF 做 DFT,DIT 做 IDFT。就可以避免用 rev。
递归的常数太大了,我们可以用模版递归,像这样。
模版递归
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 #include <bits/stdc++.h> using namespace std;const double Pi = acos (-1 ), Pi2 = 2 * Pi;const int N = 4e6 + 10 ;struct cp { double x, y; cp (double x = 0 , double y = 0 ):x (x), y (y){} cp operator +(const cp&i)const {return cp (x + i.x, y + i.y);} cp operator -(const cp&i)const {return cp (x - i.x, y - i.y);} cp operator *(const cp&i)const {return cp (x * i.x - y * i.y, y * i.x + x * i.y);} cp&operator *=(const cp&i){return *this = *this * i;} }a[N], b[N]; template <const int n>void DFT (cp*f) { const int half = n >> 1 ; DFT <half>(f), DFT <half>(f + half); cp wn (cos(Pi2 / n), sin(Pi2 / n)) , w (1 , 0 ) , z ; for (int i = 0 ; i < half; i++) z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn; } template <>void DFT <1 >(cp*f){}template <>void DFT <0 >(cp*f){}#define Case(x, fu) case x: fu<x> (f);break; #define Runfft(x) switch(n){\ Case(1<<1,x)Case(1<<2,x)Case(1<<3,x)Case(1<<4,x)\ Case(1<<5,x)Case(1<<6,x)Case(1<<7,x)Case(1<<8,x)\ Case(1<<9,x)Case(1<<10,x)Case(1<<11,x)Case(1<<12,x)\ Case(1<<13,x)Case(1<<14,x)Case(1<<15,x)Case(1<<16,x)\ Case(1<<17,x)Case(1<<18,x)Case(1<<19,x)Case(1<<20,x)Case(1<<21,x)} int rev[N];void rundft (cp *f,const int & n) { for (int i = 0 ; i < n; i++)if (i < rev[i])swap (f[i], f[rev[i]]); Runfft (DFT); } template <const int n>void IDFT (cp*f) { const int half = n >> 1 ; IDFT <half>(f), IDFT <half>(f + half); cp wn (cos(Pi2 / n), -sin(Pi2 / n)) , w (1 , 0 ) , z ; for (int i = 0 ; i < half; i++) z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn; } template <>void IDFT <1 >(cp*f){}template <>void IDFT <0 >(cp*f){}void runidft (cp *f,const int & n) { for (int i = 0 ; i < n; i++)if (i < rev[i])swap (f[i], f[rev[i]]); Runfft (IDFT); } int main () { cin.tie (0 )->sync_with_stdio (0 ); int n, m; cin >> n >> m; for (int i = 0 ; i <= n; i++)cin >> a[i].x; for (int i = 0 ; i <= m; i++)cin >> b[i].x; int len = 1 ; while (len <= n + m)len <<= 1 ; for (int i = 0 ; i < len; i++) rev[i] = (rev[i >> 1 ] >> 1 ) | ((i & 1 ) ? len >> 1 : 0 ); rundft (a, len), rundft (b, len); for (int i = 0 ; i < len; i++)a[i] *= b[i]; runidft (a, len); for (int i = 0 ; i <= n + m; i++)cout << (long long )(a[i].x / len + 0.49 ) << " " ; }
分裂基
我们称每次一分为二为基2 ,相应的我们也有基4,基8。
这里只介绍 DIT 的做法,但两种代码都会给出,相信读者能想明白 DIF 的做法。
分裂基每次把基2 的第 2 2 2 项用同样办法拆开,变成:
h ( x ) = h 1 ( x 2 ) + x h 2 ( x 2 ) f ( x ) = g ( x 2 ) + x h 1 ( x 4 ) + x 3 h 2 ( x 4 ) h(x) = h_1(x^2) + xh_2(x^2)\\
f(x) = g(x^2) + xh_1(x^4) + x^3h_2(x^4)
h ( x ) = h 1 ( x 2 ) + x h 2 ( x 2 ) f ( x ) = g ( x 2 ) + x h 1 ( x 4 ) + x 3 h 2 ( x 4 )
我们分别将 ω n k , ω n k + n 4 , ω n k + n 2 , ω n k + 3 n 4 \omega_n^k,\omega_n^{k+\frac n4},\omega_n^{k+\frac n2},\omega_n^{k+\frac{3n}4} ω n k , ω n k + 4 n , ω n k + 2 n , ω n k + 4 3 n 带入。
f ( ω n k ) = g ( ω n 2 k ) + ω n k h 1 ( ω n 4 k ) + ω n 3 k h 2 ( ω n 4 k ) f ( ω n k + n 4 ) = g ( ω n 2 k + n 2 ) − ( ω n k h 1 ( ω n 4 k ) − ω n 3 k h 2 ( ω n 4 k ) ) i f ( ω n k + n 2 ) = g ( ω n 2 k ) − ω n k h 1 ( ω n 4 k ) − ω n 3 k h 2 ( ω n 4 k ) f ( ω n k + 3 n 4 ) = g ( ω n 2 k + n 2 ) + ( ω n k h 1 ( ω n 4 k ) − ω n 3 k h 2 ( ω n 4 k ) ) i \begin{aligned}
f(\omega_n^k)&=g(\omega_n^{2k})+\omega_n^kh_1(\omega_n^{4k})+\omega_n^3kh_2(\omega_n^{4k})\\
f(\omega_n^{k+\frac n4})&=g(\omega_n^{2k+\frac n2})-(\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h_2(\omega_n^{4k}))i\\
f(\omega_n^{k+\frac n2})&=g(\omega_n^{2k})-\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h2(\omega_n^{4k})\\
f(\omega_n^{k+\frac {3n}4})&=g(\omega_n^{2k+\frac n2})+(\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h_2(\omega_n^{4k}))i\\
\end{aligned}
f ( ω n k ) f ( ω n k + 4 n ) f ( ω n k + 2 n ) f ( ω n k + 4 3 n ) = g ( ω n 2 k ) + ω n k h 1 ( ω n 4 k ) + ω n 3 k h 2 ( ω n 4 k ) = g ( ω n 2 k + 2 n ) − ( ω n k h 1 ( ω n 4 k ) − ω n 3 k h 2 ( ω n 4 k ) ) i = g ( ω n 2 k ) − ω n k h 1 ( ω n 4 k ) − ω n 3 k h 2 ( ω n 4 k ) = g ( ω n 2 k + 2 n ) + ( ω n k h 1 ( ω n 4 k ) − ω n 3 k h 2 ( ω n 4 k ) ) i
分裂基适用于序列长度为 2 n 2^n 2 n 的 FFT,并且是运算次数比基8更少,并且更灵活。
运算次数的比较见 ooura 的博客 。
由于迭代版本过于复杂,所以只给出递归版本。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 #include <bits/stdc++.h> using namespace std;#define Case(x, fu) case x: fu<x> (f);break; #define Runfft(x) switch(n){\ Case(1<<1,x)Case(1<<2,x)Case(1<<3,x)Case(1<<4,x)\ Case(1<<5,x)Case(1<<6,x)Case(1<<7,x)Case(1<<8,x)\ Case(1<<9,x)Case(1<<10,x)Case(1<<11,x)Case(1<<12,x)\ Case(1<<13,x)Case(1<<14,x)Case(1<<15,x)Case(1<<16,x)\ Case(1<<17,x)Case(1<<18,x)Case(1<<19,x)Case(1<<20,x)Case(1<<21,x)} const double Pi = acos (-1 ), Pi2 = 2 * Pi;const int N = 4e6 + 10 ;struct cp { double x, y; cp (double x = 0 , double y = 0 ):x (x), y (y){} cp operator +(const cp&i)const {return cp (x + i.x, y + i.y);} cp operator -(const cp&i)const {return cp (x - i.x, y - i.y);} cp operator *(const cp&i)const {return cp (x * i.x - y * i.y, y * i.x + x * i.y);} cp&operator +=(const cp&i){return this ->x += i.x, this ->y += i.y, *this ;} cp&operator -=(const cp&i){return this ->x -= i.x, this ->y -= i.y, *this ;} cp&operator *=(const cp&i){return *this = *this * i;} }a[N]; template <const int n>void DFT (cp*f) { const int half = n >> 1 , quar = n >> 2 ; cp w (1 , 0 ) , wn (cos(Pi2 / n), sin(Pi2 / n)) , w3, x, y ; cp *a1 = &f[0 ], *a2 = a1 + quar, *a3 = a2 + quar, *a4 = a3 + quar; for (int i = 0 ; i < quar; i++){ w3 = w * w * w, x = *a1 - *a3, y = *a2 - *a4, y = cp (y.y, -y.x); *a1 += *a3, *a2 += *a4, *a3 = (x - y) * w, *a4 = (x + y) * w3; a1++, a2++, a3++, a4++, w *= wn; } DFT <half>(f), DFT <quar>(f + half), DFT <quar>(f + half + quar); } template <>void DFT <2 >(cp*f){cp x = f[0 ], y = f[1 ];f[0 ] = x + y, f[1 ] = x - y;}template <>void DFT <1 >(cp*f){}template <>void DFT <0 >(cp*f){}void rundft (cp *f,const int & n) {Runfft (DFT);}template <const int n>void IDFT (cp*f) { const int half = n >> 1 , quar = n >> 2 ; IDFT <half>(f), IDFT <quar>(f + half), IDFT <quar>(f + half + quar); cp wn (cos(Pi2 / n), -sin(Pi2 / n)) , w (1 , 0 ) , w3, tmp1, tmp2, x, y ; cp *a1 = &f[0 ], *a2 = a1 + quar, *a3 = a2 + quar, *a4 = a3 + quar; for (int i = 0 ; i < quar; i++){ w3 = w * w * w, tmp1 = w * *a3, tmp2 = w3 * *a4; x = tmp1 + tmp2, y = tmp1 - tmp2, y = cp (y.y, -y.x); *a4 = *a2 - y, *a3 = *a1 - x, *a2 += y, *a1 += x; a1++, a2++, a3++, a4++, w *= wn; } } template <>void IDFT <2 >(cp*f){cp x = f[0 ], y = f[1 ];f[0 ] = x + y, f[1 ] = x - y;}template <>void IDFT <1 >(cp*f){}template <>void IDFT <0 >(cp*f){}void runidft (cp *f,const int & n) {Runfft (IDFT);}int main () { cin.tie (0 )->sync_with_stdio (0 ); int n, m; cin >> n >> m; for (int i = 0 ; i <= n; i++)cin >> a[i].x; for (int i = 0 ; i <= m; i++)cin >> a[i].y; int len = 1 ; while (len <= n + m)len <<= 1 ; rundft (a, len); for (int i = 0 ; i < len; i++)a[i] *= a[i]; runidft (a, len); double inv = 0.5 / len; for (int i = 0 ; i <= n + m; i++)cout << (long long )(a[i].y * inv + 0.49 ) << " " ; }
参考资料
https://oi-wiki.org/math/poly/fft/
https://www.bilibili.com/opus/785022478912061446
https://www.luogu.com.cn/article/rj58c2eb
https://charleswu.site/archives/3065