UOJ 34: 多项式乘法(FFT模板题)

时间:2023-03-09 08:54:20
UOJ 34: 多项式乘法(FFT模板题)

关于FFT

这个博客的讲解超级棒

http://blog.miskcoo.com/2015/04/polynomial-multiplication-and-fast-fourier-transform

算法导论上的讲解也不错

模板就是抄一抄别人的啦

首先是递归版本

 #include <cstdio>
#include <complex>
#include <cmath>
using namespace std; const double pi = acos(-);
const int N = ( << ) + ;
typedef complex<double> cp;
cp A[N], B[N];
int n, m; void FFT(cp *y, int n, int o) {
if (n == ) return ;
cp l[n >> ], r[n >> ];
for (int i = ; i <= n; i++)
if (i & ) r[i >> ] = y[i];
else l[i >> ] = y[i];
FFT(l, n >> , o); FFT(r, n >> , o);
cp omegan(cos( * pi / n), sin( * pi * o / n)), omega(, );
for (int i = ; i < n >> ; i++) {
y[i] = l[i] + omega * r[i];
y[i + (n >> )] = l[i] - omega * r[i];
omega *= omegan;
}
} int main() {
scanf("%d %d", &n, &m);
for (int i = ; i <= n; i++)
scanf("%lf", &A[i].real());
for (int i = ; i <= m; i++)
scanf("%lf", &B[i].real());
m += n;
for (n = ; n <= m; n <<= );
FFT(A, n, ); FFT(B, n, );
for (int i = ; i <= n; i++)
A[i] *= B[i];
FFT(A, n, -);
for (int i = ; i <= m; i++)
printf("%d ", (int)(A[i].real() / n + 0.5));
return ;
}

迭代版本

 #include <cstdio>
#include <cmath>
#include <complex>
#include <iostream>
using namespace std; const int N = << ;
typedef complex<double> cp;
const double pi = acos(-1.0);
cp A[N], B[N];
bool flag;
int a[N], b[N], n, m, tar[N], bit; inline void read(int &ans) {
static char buf = getchar();
ans = ;
for (; !isdigit(buf); buf = getchar());
for (; isdigit(buf); buf = getchar())
ans = ans * + buf - '';
} inline int rev(int val) {
int rst = ;
for (int i = ; i < bit; i++) {
rst <<= ; rst |= val & ; val >>= ;
} return rst;
} inline void FFT(cp *y) {
for (int i = ; i <= bit; i++) {
int fac = << i;
cp omegan(cos( * pi / fac), sin( * pi / fac));
if (flag) omegan.imag() *= -;
for (int j = ; j < n; j += fac) {
cp omega(, );
for (int k = ; k < fac >> ; k++) {
cp t = omega * y[j + k + (fac >> )];
cp u = y[j + k]; y[j + k] = u + t;
y[j + k + (fac >> )] = u - t;
omega *= omegan;
}
}
}
}
int main() {
read(n); read(m); n++; m++;
for (int i = ; i < n; i++) read(a[i]);
for (int i = ; i < m; i++) read(b[i]);
m += n; for (n = ; n < m; n <<= ) bit++;
for (int i = ; i < n; i++) tar[i] = rev(i);
for (int i = ; i < n; i++) A[i].real() = a[tar[i]];
for (int i = ; i < n; i++) B[i].real() = b[tar[i]];
FFT(A); FFT(B);
for (int i = ; i < n; i++) A[i] *= B[i];
for (int i = ; i < n; i++) if (i < tar[i]) swap(A[i], A[tar[i]]);
flag = true; FFT(A);
for (int i = ; i < m - ; i++)
printf("%.0lf ", 0.0001 + A[i].real() / n);
puts("");
return ;
}