HDU 4609 3-idiots (FFT-快速傅立叶变换)

时间:2023-03-09 19:28:05
HDU 4609 3-idiots (FFT-快速傅立叶变换)

【题意】给定N个树枝,求从中取出三个可以围成三角形的概率

【思路】

2013多校训练第一场比赛1010题。

一开始就想到了O(n^2)枚举前两个树枝和的算法,赛后群里大牛说计算所有两个树枝和的情况可以用FFT在O(NlogN)时间内做到,于是剩下的问题就便简单了,于是就滚去学FFT了~

FFT可以在O(NlogN)时间内计算点值将多项式A = a1•x1 + a2•x2 + …… + an•xn, B = b1•x1 + b2•x2 + …… + bn•xn由系数表示法( 系数向量a = (a1, a2, ……, an) )转换成点值表示法从而可以在O(N)时间内计算出多项式乘积然后再在O(NlogN)时间内把乘积的点值表示法转换成系数表示法(这一过程叫插值),达到加速多项式乘法的作用。

那么我们怎么通过多项式乘法来快速计算出两个数组任意一项和的结果呢?应该想到乘法和加法间的联系:指数。所以我们可以令数组中的项为多项式的系数,比如(1, 3, 3, 4)就可以表示成P = 1*x1+2*x3+1*x4,这样两个多项式的乘积对应指数的系数就表示两个多项式指数和为该数的方案数。

在此题中我们令num[i]表示树枝长度为i的树枝个数,我们把它当作多项式的系数向量,对自身做一次乘法,结果的系数向量就对应和为某个长度的个数。

举个例子,num = {0 1 0 2 1}, 则num * num = {0 0 1 0 4 2 4 4 1},

这个结果的意义如下:

从{1 3 3 4}取一个数,从{1 3 3 4}再取一个数

取两个数和为 2 的取法是一种:1+1

和为 4 的取法有四种:1+3, 1+3  ,3+1 ,3+1

和为 5 的取法有两种:1+4 ,4+1;

和为 6的取法有四种:3+3,3+3,3+3,3+3,3+3

和为 7 的取法有四种: 3+4,3+4,4+3,4+3

和为 8 的取法有 一种:4+4

当然这样的结果是任意两个数相加,而题目中要求本身不能重复使用,所以要把取同一个的组合的情况删掉:for (int i = 0; i < n; i ++)      num[a[i]+a[i]] --;

然后我们统计方案采用无序的组合方法(即1\2\3和3\2\1等价),并假定x1<x2<x3,所以在总方案中要减一半:for (int i = 1; i < maxn; i ++)   num[i] /= 2;

最后对数组求前缀和就求出了所有x+y<=z的情况:for (int i = 1; i < maxn; i ++)   sum[i] = sum[i-1] + num[i];

枚举第三个树枝的长度把结果加起来计算概率,用1减后就是最后的结果了。

重要:这里为什么要先求x+y<=z而不直接求x+y>z呢?因为前面说了我们计算过程中是按照组合统计的,并且假定第三个树枝数最大,那么统计x+y>z的方案就比较麻烦,因为需要去除掉z<x\y的情况,但是统计x+y<=z就简单了~因为这样z必定大于x\y。

【代码】
这道题内存和时间卡的紧,FFT要用迭代的算法,递归算法会超内存。

#include
#include
#include
#include
#include
#include
#define MID(x,y) ((x+y)/2)
#define mem(a,b) memset(a,b,sizeof(a))
using namespace std;

const double eps = 1e-8;
const double Pi = acos(-1.0);
struct complex{
double r,i;
complex(double _r = 0,double _i = 0){
r = _r; i = _i;
}
complex operator +(const complex &b){
return complex(r+b.r,i+b.i);
}
complex operator -(const complex &b){
return complex(r-b.r,i-b.i);
}
complex operator *(const complex &b){
return complex(r*b.r-i*b.i,r*b.i+i*b.r);
}
};
struct FastFourierTransform{
inline int dcmp(double a){ if (aeps); }
void ReverseBits(complex *y, int len){
int i,j,k;
for(i = 1, j = len>>1; i > 1;
while(j >= k){
j -= k;
k >>= 1;
}
if(j nb) ? na : nb;
nc = 1;
while(nc 0 && dcmp(c[nc-1]) == 0; nc--);
// 这句加上时间还多了,不知道算不算优化……
delete ya; delete yb; delete yc;
}
//Convolution: r(k) = sigma(a[i]*b[i-k]){i=0~n-1}
//N must be power of 2
void Convolution(int *a, int *b, int *r, int n){
complex *d1 = new complex[n], *d2 = new complex[n], *y = new complex[n];
d1[0] = b[0];
int nc = 1;
while(nc