NTT算法小结

时间:2023-03-09 00:15:08
NTT算法小结

从理论上说,经过人们优化的FFT已经十分优秀,能够处理大部分的多项式乘法,但是有的时候仍然会出现下面的情况:

1)常数仍然比较大

2)在进行与整数有关的FFT时,发现得到的结果是一堆诡异的数,你需要不停的和精度搏斗

那么在这时,你就需要学会快速数论变换(NTT)

前置芝士

快速傅里叶变换

你可以上网百度,或者看我的博客

阶与原根

我们由欧拉定理可以知道,对于任意的正整数\(a、m\),如果满足\(gcd(a,m)=1\),就有\(a^{\varphi(m)}\equiv 1(mod\ m)\)

但我们发现,还有一些数满足\(a^p\equiv 1(mod\ m)\)且\(p< \varphi(m)\),因此人们定义了阶

设\(m>1\),且\(gcd(a,m)=1\),则满足\(a^p\equiv 1(mod\ m)\)的最小正整数\(p\)成为\(a\)对模\(m\)的阶,记作\(\delta_m(a)\)

于是就会有\(\delta_m(a)|p\),充分性很显然,我们证一下必要性

我们设\(p=\delta_m(a)·q+r\)(其中\(0\leq r < \delta_m(a)\))

那么\(a^p=a^{\delta_m(a)·q}·a^r\equiv a^r\equiv 1(mod m)\)

由\(p\)是最小正整数知\(r=0\),所以\(\delta_m(a)|p\)

然后你就有了\(\delta_m(a)|\varphi(p)\)

好吧这个性质并没有什么用

由上面的欧拉定理,我们不难理解数学家们为什么搞出这么一个蛋疼的定义——原根

如果\(\delta_m(a)=\varphi(m)\),那么称\(a\)是模\(m\)的一个原根,为了下面表述的方便我们将它记作\(g\)

我们发现原根有一个这样的性质:\(g^0,g^1,g^2,\cdots,g^{\varphi(m)-1}\)构成了一个模\(m\)的完全剩余系

证明:考虑反证法,即假设c存在\(i,j\)(\(i>j\))满足\(g^i\equiv g^j(mod\ m)\)

​ 两边同时除以\(g^j\),有\(g^{i-j}\equiv 1\),而很明显\(i-j<\varphi(m)\),这与原根的定义相矛盾

​ 于是性质得证

性质验证

在FFT中,我们使用单位根的原因就是单位根满足的一些性质可以加速计算,如果原根也满足的话,那么我们在计算时可以直接替换

性质1

\(w_n^0,w_n^1,w_n^2,\cdots,w_n^{n-1}\)两两不同

这在上面已经得到了证明

性质2

\(w_{2n}^{2p}=w_n^p\)

如果\(w_{n}=g^p\),那么就应该有\(w_{2n}=g^{\frac{p}{2}}\),他们在乘上两倍的次幂之后值相等

性质3

\(w_{n}^{\frac{n}{2}+p}=-w_n^p\)

因为有\((g^{\frac{n}{2}})^2\equiv 1\),为了保持原根的定义,就会有\(g^{\frac{n}{2}}\equiv -1(mod \ n)\)

带回去运算即可

综上所述,原根满足原来的单位根所具有的性质,因此我们可以考虑用原根来代替单位根

实际运用

一点注意事项

1、原根的话在模数不确定的情况下需要自己求,不过如果模数是\(998244353\)或者\(1004535809\)的话,它们的原根是3

2、注意在IDFT的时候,原来直接除以的地方要换做求逆元

代码

#include<iostream>
#include<string>
#include<string.h>
#include<stdio.h>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<map>
using namespace std;
#define rep(i,a,b) for (i=a;i<=b;i++)
typedef long long ll;
#define maxd 998244353
const double pi=acos(-1.0);
#define int long long
ll n,m,a[5005000],b[5005000];
int lim=1,r[5005000]; int qpow(int x,int y)
{
int ans=1,sum=x;
while (y)
{
int tmp=y%2;y/=2;
if (tmp) ans=(1ll*ans*sum)%maxd;
sum=(1ll*sum*sum)%maxd;
}
return ans;
} void ntt(int lim,ll *a,int typ)
{
int i;
for (i=0;i<lim;i++)
if (i<r[i]) swap(a[i],a[r[i]]);
int mid;
for (mid=1;mid<lim;mid<<=1)
{
int gn=qpow(3,(maxd-1)/(mid<<1));
int sta,len=mid<<1,j;
for (sta=0;sta<lim;sta+=len)
{
int g=1;
for (j=0;j<mid;j++,g=(g*gn)%maxd)
{
int x1=a[j+sta],y1=(g*a[j+sta+mid])%maxd;
a[j+sta]=(x1+y1)%maxd;
a[j+sta+mid]=(x1-y1+maxd)%maxd;
}
}
}
if (typ==-1) reverse(&a[1],&a[lim]);
} int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
} signed main()
{
n=read();m=read();int i,cnt=0;
for (i=0;i<=n;i++) a[i]=read();
for (i=0;i<=m;i++) b[i]=read();
while (lim<=n+m) {lim<<=1;cnt++;}
for (i=0;i<=lim;i++)
r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1)));
ntt(lim,a,1);
ntt(lim,b,1);
for (i=0;i<=lim;i++) a[i]=(a[i]*b[i])%maxd;
ntt(lim,a,-1);
int tmp=qpow(lim,maxd-2);
for (i=0;i<=n+m;i++)
{
a[i]=(a[i]*tmp)%maxd;
printf("%lld ",a[i]);
}
return 0;
}