hdu5730 Shell Necklace(CDQ分治+FFT|多项式求逆)

时间:2021-06-22 21:47:05

题目链接

题目描述:长为 i 的项链有 a [ i ] 种装饰方法,问长度为 n 的项链有多少种装饰方式

分析:
说实话我感觉这个题目描述有点模棱两可

显然,用不同的方式分割这个序列,就会产生一定数量的装饰方法
一开始想dp方程有点懵了,实际上非常简单,我们枚举分割出来的一部分 i

f [ n ] = i = 0 n f [ n i ] a [ i ]

暴力转移时间复杂度 O ( n 2 )
观察式子,存在一个类似卷积的形式: i = 0 n f [ n i ] a [ i ] ,可以考虑用FFT/NTT优化

dalao们表示这是一道很好的CDQ+FFT的模板题

这道题我们不能做n次FFT,会重复计算,最终导致TLE。
我们采取CDQ分治的方式做加速

所谓CDQ分治,利用了这道题目中的——想求 f [ i ] ,需要加进所有 f [ j j < i ] f [ i ] 的贡献
我们把区间分成了以下两块 [ l , m i d ] , [ m i d + 1 , r ]
我们先处理 [ l , m i d ] , 然后把 [ l , m i d ] 的影响向 [ m i d + 1 , r ] 转移, 再去处理 [ m i d + 1 , r ]

那么我们怎么用FFT维护 [ m i d + 1 , r ] 的答案呢?
我们回忆一下FFT的原理(多项式乘法)

                           4   3   2   1
                           4   3   2   1
------------------------------------------
                          1*4 1*3 1*2 1*1
                      2*4 2*3 2*2 2*1
                  3*4 3*3 3*2 3*1
              4*4 4*3 4*2 4*1

一个多项式为 f [ l ] , f [ l + 1 ] , f [ l + 2 ] , f [ l + 3 ] , . . . , f [ m i d 1 ] , f [ m i d ]
另一个多项式为 a [ 0 ] , a [ 1 ] , a [ 2 ] , . . . , a [ r l ]
我们乘起来就变成了——
( f [ l ] a [ 0 ] )

( f [ l ] a [ m i d + 1 l ] + f [ l + 1 ] a [ m i d + 1 l 1 ] + . . . ) = f [ m i d + 1 ]

即,我们需要 f [ l ] f [ m i d ] 中的每个数
同时需要 a [ 0 ] a [ r l ] 中的每个数
然后FFT就可以得到结果

tip

按理来说,FFT的序列长度应该是 2 ( r l + 1 ) ,但是这道题中 r l + 1 就够了

一开始写CDQ中的初始化写崩了(初始化少了),调了好久QwQ

#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>

using namespace std;

const int p=313;
const int N=400010;
const double Pi=acos(-1.0);

struct node{
    double x,y;
    node(double xx=0,double yy=0) {
        x=xx;y=yy;
    }
};
node A[N],B[N];

node operator +(const node &A,const node &B) {return node(A.x+B.x,A.y+B.y);}
node operator -(const node &A,const node &B) {return node(A.x-B.x,A.y-B.y);}
node operator *(const node &A,const node &B) {return node(A.x*B.x-A.y*B.y,A.x*B.y+A.y*B.x);}

int n,f[N],a[N];

void FFT(int n,node *a,int opt) {
    int i,j=0,k;
    for (i=0;i<n;i++) {
        if (i>j) swap(a[i],a[j]);
        for (int l=n>>1;(j^=l)<l;l>>=1);
    }
    for (int i=2;i<=n;i<<=1) {
        int m=i>>1;
        node wn(cos(2.0*opt*Pi/i),sin(2.0*opt*Pi/i));
        for (int j=0;j<n;j+=i) {
            node w(1,0);
            for (int k=0;k<m;k++,w=w*wn) {
                node z=a[j+m+k]*w;
                a[j+m+k]=a[j+k]-z;
                a[j+k]=a[j+k]+z;
            }
        }
    }
}

void CDQ(int l,int r) {
    if (l==r) {
        f[l]=(f[l]+a[l])%p;
        return;
    }
    int mid=(l+r)>>1;
    CDQ(l,mid);

    int fn=1;
    while (fn<=r-l+1) fn<<=1;

    for (int i=l;i<mid+1;i++) A[i-l]=node(f[i],0);
    for (int i=mid-l+1;i<fn;i++) A[i]=node(0,0);     //一定要预处理到fn
    for (int i=0;i<r-l+1;i++) B[i]=node(a[i],0);
    for (int i=r-l+1;i<fn;i++) B[i]=node(0,0);
    FFT(fn,A,1); FFT(fn,B,1);
    for (int i=0;i<fn;i++) A[i]=A[i]*B[i];
    FFT(fn,A,-1);
    for (int i=0;i<fn;i++) A[i].x/=fn;
    for (int i=mid+1;i<=r;i++) f[i]=f[i]+(int)(A[i-l].x+0.5)%p;

    CDQ(mid+1,r);
}

int main()
{
    while (scanf("%d",&n)!=EOF&&n) {
        memset(f,0,sizeof(f));
        memset(a,0,sizeof(a));

        for (int i=1;i<=n;i++) scanf("%d",&a[i]),a[i]%=p;
        CDQ(1,n);
        printf("%d\n",f[n]%p);
    }
    return 0;
}

这道题当然还有简单一点的生成函数解法
F f 的生成函数( x i 的系数为 f [ i ] ), A a 的生成函数( x i 的系数为 a i

F = F A + 1

+1是因为 f [ 0 ] = 1 , a [ 0 ] = 0 (其实我也不是很明白)

F = 1 1 A

多项式 1 A 求逆即可

找到一篇很好的博客

tip

式子中的 1 A 中的1其实就是常数项

多项式求逆得到的就是系数表达式

一开始我直接%313,但是总是得不到正确答案
(后来问了一下舒老师:NTT需要一些特殊的模数
因为模313,最终各个系数 <= n(m-1)*(m-1) = 9734400000, 故选了一个大素数206158430209(原根为22)
但是由于模数太大,乘一下就会爆ll,很难过

所以网上此解法的AC代码都是FFT。。。(然而我并不是很清楚FFT的求逆。。。先坑着吧)

//NTT未AC代码
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define ll long long

using namespace std;

const int N=400010;
const ll p=998244353;
const ll o=3;
int n;
ll a[N],b[N],c[N];

ll KSM(ll a,ll b) {
    ll t=1;
    while (b) {
        if (b&1) t=(t*a)%p;
        b>>=1;
        a=(a*a)%p;
    }
    return t%p;
}

void NTT(int n,ll *a,int opt) {
    int i,j=0,k;
    for (i=0;i<n;i++) {
        if (i>j) swap(a[i],a[j]);
        for (int l=n>>1;(j^=l)<l;l>>=1);
    }
    for (i=2;i<=n;i<<=1) {
        int m=i>>1;
        ll wn=KSM(10,(p-1)/i);
        for (j=0;j<n;j+=i) {
            ll w=1;
            for (k=0;k<m;k++,w=(w*wn)%p) {
                ll z=(a[j+k+m]*w)%p;
                a[j+k+m]=(a[j+k]-z+p)%p;
                a[j+k]=(a[j+k]+z)%p;
            }
        }
    }
    if (opt==-1) reverse(a+1,a+n);
}

void inv(int n,ll *a,ll *b,ll *c) {
    if (n==1) {
        b[0]=KSM(a[0],p-2);
        return;
    }
    inv(n>>1,a,b,c);
    int k=n<<1;
    for (int i=0;i<n;i++) c[i]=a[i];
    for (int i=n;i<k;i++) c[i]=0;
    NTT(k,c,1); NTT(k,b,1);
    for (int i=0;i<k;i++) b[i]=(2-b[i]*c[i]%p+p)%p*b[i]%p;
    NTT(k,b,-1);
    ll _inv=KSM(k,p-2);
    for (int i=0;i<n;i++) b[i]=(b[i]*_inv)%p;
    for (int i=n;i<k;i++) b[i]=0; 
}

int main()
{
    while (scanf("%d",&n)!=EOF&&n) {
        memset(a,0,sizeof(a));
        memset(b,0,sizeof(b));

        for (int i=1;i<=n;i++) {
            ll x; scanf("%lld",&x);
            a[i]=((-x)%p+p)%p;
        }
        a[0]=1;
        int fn=1;
        while (fn<=n) fn<<=1;
        inv(fn,a,b,c);
        printf("%lld\n",b[n]%313);
    }
    return 0;
}