HDU 1402 A * B Problem Plus (FFT模板题)

时间:2023-12-04 16:07:14

FFT模板题,求A*B。

用次FFT模板需要注意的是,N应为2的幂次,不然二进制平摊反转置换会出现死循环。

取出结果值时注意精度,要加上eps才能A。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long ll;
const double pi = acos(-1.0);
const int maxn = 50000 + 5;
const double eps = 1e-6; struct Complex {
double a, b;
Complex() {
}
Complex(double a, double b) :
a(a), b(b) {
}
Complex operator +(const Complex& t) const {
return Complex(a + t.a, b + t.b);
}
Complex operator -(const Complex& t) const {
return Complex(a - t.a, b - t.b);
}
Complex operator *(const Complex& t) const {
return Complex(a * t.a - b * t.b, a * t.b + b * t.a);
}
}; // 二进制平摊反转置换
void brc(Complex *x, int n) {
int i, j, k;
for (i = 1, j = n >> 1; i < n - 1; i++) {
if (i < j)
swap(x[i], x[j]); k = n >> 1;
while (j >= k) {
j -= k;
k >>= 1;
}
if (j < k)
j += k;
}
} // FFT,其中on==1时为DFT,on==-1时为IDFT
void FFT(Complex *x, int n, int on) {
int h, i, j, k, p;
double r;
Complex u, t;
brc(x, n);
for (h = 2; h <= n; h <<= 1) { // 控制层数
r = on * 2.0 * pi / h;
Complex wn(cos(r), sin(r));
p = h >> 1;
for (j = 0; j < n; j += h) {
Complex w(1, 0);
for (k = j; k < j + p; k++) {
u = x[k];
t = w * x[k + p];
x[k] = u + t;
x[k + p] = u - t;
w = w * wn;
}
}
}
if (on == -1) // IDFT
for (i = 0; i < n; i++)
x[i].a = x[i].a / n + eps;
} int n, ma, N;
Complex x1[maxn<<2], x2[maxn<<2];
char sa[maxn], sb[maxn];
int ans[maxn<<1]; void solve() {
int n1 = strlen(sa), n2 = strlen(sb);
int N = 1, tmpn = max(n1, n2) << 1;
// N应为2的幂次
while(N < tmpn) N <<= 1;
for(int i = 0;i < N; i++)
x1[i].a = x1[i].b = x2[i].a = x2[i].b = 0;
for(int i = 0;i < n1; i++)
x1[i].a = sa[n1-i-1] - '0';
for(int i = 0;i < n2; i++)
x2[i].a = sb[n2-i-1] - '0';
FFT(x1, N, 1); FFT(x2, N, 1);
for(int i = 0;i < N; i++)
x1[i] = x1[i]*x2[i];
FFT(x1, N, -1);
int pre = 0, top = 0;
for(int i = 0;i < n1+n2; i++) {
// 不加epsA不了~
int cur = (int)(x1[i].a + eps);
ans[++top] = (cur + pre)%10;
pre = (pre + cur)/10;
}
while(!ans[top] && top > 1) top--;
for(int i = top;i >= 1; i--)
printf("%d", ans[i]);
puts("");
} int main() {
while(scanf("%s%s", sa, &sb) != -1) {
solve();
}
return 0;
}