bzoj 3992 [SDOI2015]序列统计——NTT(循环卷积&&快速幂)

时间:2023-03-09 19:26:43
bzoj 3992 [SDOI2015]序列统计——NTT(循环卷积&&快速幂)

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3992

有转移次数、模M余数、方案数三个值,一看就是系数的地方放一个值、指数的地方放一个值、做卷积的次数表示一个值(应该是表示转移次数)。

可以余数和方案数都要求相乘,指数只能相加,怎么办?

然后看题解,原来可以用M的原根的幂来表示余数那个信息!因为原根的几次幂和%M剩余类可以一一对应(除了%M==0!!!),所以用原根的幂表示%M余几,两个余数相乘就变成原根的指数相加了!把该余数对应的原根的指数放在多项式指数的位置,就可以NTT啦!

原根是 1~P-1 次幂的值%P各不相同的,所以可以用 0次项~M-2次项 或者 1次项~M-1 次项来表示。

n的范围要求快速幂。但不是把点值拿出来之后对点值快速幂一番再用点值还原回系数,因为每次卷积那个多项式的长度都要翻倍,所以最后n个点的个数就不够了。

所以要快速幂中每次卷积了一下后把它翻倍的长度手动循环一番变回原长M。这样就行啦!

注意数据范围!!!求的那个 x 不能为0,而给出的元素可以为0!而原根的那些幂都不会为0!(仔细一想,只有0或M的倍数才会%M==0)考虑到那个 x 不会为0、而数列中放入一个0的话值就变成0了,所以给出0以后要认为没有那个元素!!!!!

快速幂时,ans的初值应该像1一样;也就是一个多项式卷积它之后还是该多项式本身。想一想,就是在0次项赋1、其他项赋0即可。

发现>(M<<1)的项的值一定是0;所以循环的时候可以直接减掉1个(M-1)而不用模什么的。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=; const ll mod=;
int n,m,M,pn,s[N],zb[N],pri[N],len,r[N<<];
int a[N<<],ans[N<<];
int rdn()
{
int ret=;bool fx=;char ch=getchar();
while(ch>''||ch<''){if(ch=='-')fx=;ch=getchar();}
while(ch>=''&&ch<='') ret=(ret<<)+(ret<<)+ch-'',ch=getchar();
return fx?ret:-ret;
}
void upd(int &x,int md){x>=md?x-=md:;}
int pw(int x,int k,int md)
{int ret=;while(k){if(k&)ret=(ll)ret*x%md;x=(ll)x*x%md;k>>=;}return ret;}
int gtrt()
{
int k=M-,tot=;
for(int i=;i*i<=k;i++)
if(k%i==){pri[++tot]=i;while(k%i==)k/=i;}
if(k>)pri[++tot]=k;
for(int g=;;g++)
{
int i;
for(i=;i<=tot;i++)
if(pw(g,(M-)/pri[i],M)==)break;
if(i>tot)return g;
}
}
void ntt(int *a,bool fx)
{
for(int i=;i<len;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(int R=;R<=len;R<<=)
{
int m=R>>;
int Wn=pw(,(mod-)/R,mod);
fx?Wn=pw(Wn,mod-,mod):;
for(int i=;i<len;i+=R)
for(int j=,w=;j<m;j++,w=(ll)w*Wn%mod)
{
int tmp=(ll)w*a[i+m+j]%mod;
a[i+m+j]=a[i+j]+mod-tmp; upd(a[i+m+j],mod);
a[i+j]=a[i+j]+tmp; upd(a[i+j],mod);
}
}
if(!fx)return; int inv=pw(len,mod-,mod);
for(int i=;i<len;i++)a[i]=(ll)a[i]*inv%mod;
}
int main()
{
n=rdn(); M=rdn(); pn=rdn(); m=rdn();
for(int i=;i<=m;i++)s[i]=rdn();
int rt=gtrt();
for(int i=,k=rt;i<M;i++,k=k*rt%M) zb[k]=i;
len=;
for(;len<=M<<;len<<=);
for(int i=;i<len;i++)r[i]=(r[i>>]>>)+((i&)?len>>:); for(int i=;i<=m;i++)if(s[i])a[zb[s[i]]]=;////if
ans[]=;///
while(n)
{
ntt(a,);
if(n&)
{
ntt(ans,);
for(int i=;i<len;i++)ans[i]=(ll)ans[i]*a[i]%mod;
ntt(ans,);
for(int i=;i<M;i++)//pos which >(M<<1) surely have no value
ans[i]+=ans[i+M-],ans[i+M-]=,upd(ans[i],mod);
}
for(int i=;i<len;i++)a[i]=(ll)a[i]*a[i]%mod;
ntt(a,);
for(int i=;i<M;i++)
a[i]+=a[i+M-],a[i+M-]=,upd(a[i],mod);
n>>=;
}
printf("%d\n",ans[zb[pn]]);
return ;
}