LG4238 【【模板】多项式求逆】

时间:2023-03-09 01:59:39
LG4238 【【模板】多项式求逆】

前言

学习了Great_Influence的递推实现,我给大家说一下多项式求逆严格的边界条件,因为我发现改动一些很小的边界条件都会使程序出错。怎么办,背代码吗?背代码是不可能,这辈子都不会背代码的。理解了边界条件就不会出错了。

分析

理论基础
\[A \times B \equiv 1 \qquad (\mod{x^n})\]
\[A \times B' \equiv 1 \qquad (\mod{x^{\frac{n}{2}}})\]
\[A \times (B-B') \equiv 0 \qquad (\mod{x^{\frac{n}{2}}})\]
\[B-B' \equiv 0 \qquad(\mod{x^{\frac{n}{2}}})\]
\[(B-B')^2 \equiv 0 \qquad(\mod{x^n})\]
\[B^2-2BB'+B'^2 \equiv 0 \qquad (\mod{x^n})\]
\[A(B^2-2BB'+B'^2) \equiv 0 \qquad (\mod{x^n})\]
\[B-2B'+AB'^2 \equiv 0 \qquad (\mod{x^n})\]
\[B \equiv 2B'-AB'^2 \qquad (\mod{x^n})\]

根据这个式子就可以倍增求多项式逆元了。但是如何倍增呢?或许你已兴冲冲地打出了NTT的板子,然后感觉无从下手。

代码

前置是NTT

inline void FFT(int*t,int lim,int type)
{
    for(rg int i=0;i<lim;++i)
        if(i<rev[i])
            swap(t[i],t[rev[i]]);
    for(rg int i=1;i<lim;i<<=1)
    {
        int gn=qpow(g,(mod-1)/(i<<1));
        if(type==-1)
            gn=qpow(gn,mod-2);
        for(rg int j=0;j<lim;j+=(i<<1))
        {
            int gi=1;
            for(rg int k=0;k<i;++k,gi=(ll)gi*gn%mod)
            {
                int x=t[j+k],y=(ll)gi*t[j+i+k]%mod;
                t[j+k]=module(x,y);
                t[j+i+k]=module(x,mod-y);
            }
        }
    }
    if(type==-1)
    {
        int inv=qpow(lim,mod-2);
        for(rg int i=0;i<lim;++i)
            t[i]=(ll)t[i]*inv%mod;
    }
}
  • 为什么i从1开始小于lim?因为i是下层长度,这也是qpow里面i要乘2的原因。
  • 为什么j从0开始小于lim?因为j是当前层的起始位置,而数组是base 0的。
  • 为什么k从0开始小于i?因为k是当前合并区间的下标,为方便计算从0开始小于i。

首先要封装多项式柯西乘法(卷积),减少代码量以及出错的可能性。

int X[MAXN],Y[MAXN];

inline void mul(int*x,int*y,int lim)
{
    memset(X,0,sizeof(X));
    memset(Y,0,sizeof(Y));
    for(rg int i=0;i<(lim>>1);++i) // edit 2 lim>>1
        X[i]=x[i],Y[i]=y[i];
    FFT(X,lim,1);
    FFT(Y,lim,1);
    for(rg int i=0;i<lim;++i)
        X[i]=(ll)X[i]*Y[i]%mod;
    FFT(X,lim,-1);
    for(rg int i=0;i<lim;++i)
        x[i]=X[i];
}
  • 为什么要把数组复制到XY上?因为避免自乘出错,自己乘自己会导致NTT了两次。

  • 为什么第一层for循环边界为i<lim/2?因为乘出来度数是lim,乘之前度数是lim/2.

然后实现三个辅助函数,分别是 快速幂、 快速模加 和 计算rev 。

inline int qpow(int x,int k)
{
    int ans=1;
    while(k)
    {
        if(k&1)
            ans=(ll)ans*x%mod;
        x=(ll)x*x%mod,k>>=1;
    }
    return ans;
}

inline int module(int x,int y)
{
    x+=y;
    if(x>=mod)
        x-=mod;
    return x;
}

int rev[MAXN];

inline void calrev(int lim,int l)
{
    for(rg int i=1;i<lim;++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
}
  • 这三个东西的作用不用细说,主要是计算rev数组下表i可以从1开始的原因是rev[0]恒等于0.

最后就是主菜求逆元了,可以用滚动数组减少空间

int a[MAXN],b[2][MAXN];

int main()
{
//  freopen(".in","r",stdin);
//  freopen(".out","w",stdout);
    read(n);
    --n;
    for(rg int i=0;i<=n;++i)
        read(a[i]);
    int cur=0;
    b[cur][0]=qpow(a[0],mod-2);
    int bas=1,lim=2,len=1; // bas:下层长度(当前计算的) lim:上层长度(计算出来的) len:log_2(lim)
    calrev(lim,len);
    while(bas<=(n<<1)) // edit 1 <=
    {
        cur^=1;
        memset(b[cur],0,sizeof(b[cur]));
        for(int i=0;i<bas;++i)
            b[cur][i]=module(b[cur^1][i]<<1,0);
        mul(b[cur^1],b[cur^1],lim);
        mul(b[cur^1],a,lim);
        for(int i=0;i<bas;++i)
            b[cur][i]=module(b[cur][i],mod-b[cur^1][i]);
        bas<<=1,lim<<=1,++len;
        if(bas<=(n<<1))
            calrev(lim,len);
    }
    for(rg int i=0;i<=n;++i)
        printf("%d ",b[cur][i]);
//  fclose(stdin);
//  fclose(stdout);
    return 0;
}
  • 由于这道题就是裸的求逆元,所以我就没封装求逆。其实封装也很简单,加个void inv(int*a,int**b,int&cur,int n)就行了cur要引用是因为需要知道算出来的结果是b数组的哪个
  • 为什么第一个for循环边界是bas,第二个是lim?因为长度倍增了。另外mul里面也得用lim也是这个原因。
  • 为什么bas<=2*n?因为...当前层的长度得覆盖、大于n...吗?其实不是。

为什么bas<=2*n?

考虑我们的数组范围。长度看似倍增了,实则不然。数组下标为0~bas-1,代表\(\sum_{i=0}^{bas-1}a_ix^i\)的各项系数,然而自乘之后最高项次数应变为2(bas-1)而不是程序里面认为的2bas-1,所以我们求出来的多项式其实有虚假成分。怎么处理?程序里面是不好更改的,没必要为此增加代码量和出错性。那么求大一点就好了,比如求到2*n,这样虚假成分不会影响到最终答案。