51nod 1387 移数字

时间:2023-03-09 02:07:24
51nod 1387 移数字

任意门

  回来拉模版的时候意外发现这个题还没写题解,所以就随便补点吧。

  题意其实就是要你求n的阶乘在模意义下的值。

  首先找出来一个最大的$m$满足$m^2<=n$,对于大于$m^2$部分的数我们直接暴力求就行了,问题是求$m^2$以内的答案。

  先构造一个多项式$f(x)=(x+1)(x+2)(x+3)……(x+m)$,然后求它在$x=0、x=m……x=m(m-1)$位置的值,然后求个值全部乘起来就行了。

  

  稍微说下怎么做多点求值,构造两个多项式

  $$G_1(x)=(x-x_1)(x-x_2)……(x-x_{\left\lfloor\frac{m}{2}\right\rfloor})=x(x-m)……(x-\left\lfloor\frac{m}{2}\right\rfloor m)$$

  $$G_2(x)=(x-x_{\left\lfloor\frac{m}{2}\right\rfloor+1})……(x-x_m)=(x-\left\lfloor\frac{m}{2}\right\rfloor m-m)……(x-m^2+m)$$

  然后拿$f(x)$对$G_1(x)$取模得到一个$\left\lfloor\frac{m}{2}\right\rfloor$次的多项式,这个多项式在$x_1、x_2……x_{\left\lfloor\frac{m}{2}\right\rfloor}$位置的值跟$f(x)$是一样的(这是因为构造出来的式子在这些位置都等于0,而我们可以把多项式除法看成很多次减法,所以这个值不会变),后半部分同理用$G_2(x)$处理,这个时候问题规模就减半了,由此递归即可。

  题目最后复杂度是$O(\sqrt{n}log^2\sqrt{n})$

  

  

#include<cstdio>
#include<cstring>
#include<algorithm>
#define lp (p<<1)
#define rp ((p<<1)|1)
#define ll long long
#define MN 200200
using namespace std;
int read_p,read_ca;
inline int read(){
read_p=;read_ca=getchar();
while(read_ca<''||read_ca>'') read_ca=getchar();
while(read_ca>=''&&read_ca<='') read_p=read_p*+read_ca-,read_ca=getchar();
return read_p;
}
int _n,n,m,t,e[MN],_e[MN],Mmh=,D[MN],C_a[MN],C_b[MN],C_c[MN],N_c[MN],D_a[MN],D_b[MN],D_c[MN],tot,gg=,MMH[MN],L[MN*];
int rt[MN*],B[MN*],sz=;
int MOD=;
inline void M(int &x){while(x>=MOD)x-=MOD;}
inline int mi(int a,int b){
int mmh=;
while (b){
if (b&) mmh=1LL*mmh*a%MOD;
b>>=;a=1LL*a*a%MOD;
}
return mmh;
}
inline void inv(){
int base=mi(gg,(MOD-)/tot),_base=mi(base,MOD-);
e[]=_e[]=;
for (register int i=;i<=tot;i++) e[i]=1LL*e[i-]*base%MOD,_e[i]=1LL*_e[i-]*_base%MOD;
}
inline void NTT(int N,int a[],int w[]){
register int i,j,k,m,z;
for (i=j=;i<N;i++){
if (i>j) swap(a[i],a[j]);
for (k=N>>;(j^=k)<k;k>>=);
}
for (i=;i<=N;i<<=)
for (m=i>>,j=;j<N;j+=i)
for (k=;k<m;k++){
z=1LL*a[j+k+m]*w[tot/i*k]%MOD;
a[j+k+m]=a[j+k]>z?a[j+k]-z:MOD-z+a[j+k];
a[j+k]=a[j+k]-MOD+z;if (a[j+k]<) a[j+k]+=MOD;
}
}
inline void cc(int N,int a[],int b[],int c[]){
memcpy(C_a,a,N<<);memcpy(C_b,b,N<<);
NTT(N,C_a,e);NTT(N,C_b,e);
for (register int i=;i<N;i++) c[i]=1LL*C_a[i]*C_b[i]%MOD;
NTT(N,c,_e);
int w=mi(N,MOD-);
for (register int i=;i<N;i++) c[i]=1LL*c[i]*w%MOD;
}
inline void cc(int n,int m,int a[],int b[],int c[]){
int N;
for (N=;N<(n+m);N<<=);
memcpy(C_a,a,n<<);memcpy(C_b,b,m<<);
fill(C_a+n,C_a+N,);fill(C_b+m,C_b+N,);
NTT(N,C_a,e);NTT(N,C_b,e);
for (register int i=;i<N;i++) c[i]=1LL*C_a[i]*C_b[i]%MOD;
NTT(N,c,_e);
int w=mi(N,MOD-);
for (register int i=;i<N;i++) c[i]=1LL*c[i]*w%MOD;
}
inline void ny(int p,int a[],int b[]){
if (p==) b[]=mi(a[],MOD-);else{
ny((p+)>>,a,b);
int N=;
while (N<(p<<))N<<=;
copy(a,a+p,N_c);fill(N_c+p,N_c+N,);
NTT(N,N_c,e);NTT(N,b,e);
for (register int i=;i<N;i++) b[i]=(2LL-1LL*N_c[i]*b[i]%MOD+MOD)*b[i]%MOD;
NTT(N,b,_e);
int w=mi(N,MOD-);
for (register int i=;i<N;i++) b[i]=1LL*b[i]*w%MOD;
fill(b+p,b+N,);
}
}
inline void re_copy(int n,int a[],int b[]){for (register int i=;i<n;i++) b[i]=a[n-i-];}
inline void div(int n,int m,int a[],int b[],int d[],int r[]){
int N=,t=n-m+,i;
while (N<t<<)N<<=;
memset(D_a,,N<<);
memset(D_b,,N<<);
memset(D_c,,N<<);
memset(d,,N<<);
re_copy(m,b,D_b);
re_copy(n,a,D_a);
ny(t,D_b,D_c);
for (N=;N<(n<<);N<<=);
cc(n,t,D_a,D_c,D_b);
re_copy(t,D_b,d);
fill(d+t,d+N,);
cc(t,m,d,b,D_a);
for (i=;i<m;i++) r[i]=(1LL*a[i]-D_a[i]+MOD)%MOD;
fill(r+m,r+N,);
}
inline bool ju(int x){
int u=MOD-;
for (register int i=;i*i<=u;i++)
if (u%i==) if (mi(x,u/i)==) return ;
return ;
}
int mmh=;
inline void Mmhp(int p,int l,int r){
if (l==r){
L[p]=sz;
rt[sz]=l;
rt[sz+]=;
sz+=;
return;
}
int mid=l+r>>;
Mmhp(lp,l,mid);Mmhp(rp,mid+,r);
cc(mid-l+,r-mid+,rt+L[lp],rt+L[rp],rt+sz);
L[p]=sz;
sz+=r-l+;
}
inline void Mmhrt(int p,int l,int r){
if (l==r){
L[p]=sz;
rt[sz]=(MOD-1LL*m*l%MOD)%MOD;
rt[sz+]=;
sz+=;
return;
}
int mid=l+r>>;
Mmhrt(lp,l,mid);Mmhrt(rp,mid+,r);
cc(mid-l+,r-mid+,rt+L[lp],rt+L[rp],rt+sz);
L[p]=sz;
sz+=r-l+;
}
inline void _Mmh(int p,int l,int r,int fi,int LL){
div(LL,r-l+,B+fi,rt+L[p],D,B+sz);
int mid=l+r>>,s=sz;
sz+=r-l+;
if (l==r) mmh=1LL*B[s]*mmh%MOD;else _Mmh(lp,l,mid,s,r-l+),_Mmh(rp,mid+,r,s,r-l+);
}
int main(){
register int i;
n=read();
MOD=read();
if (n>=MOD) return printf("0\n"),;
while(ju(gg))gg++;
for (m=;(m+)*(m+)<=n;m++);
for (tot=;tot<((m+)<<);tot<<=);inv();
for (i=m*m+;i<=n;i++) mmh=1LL*mmh*i%MOD;
sz=;Mmhp(,,m);
for (i=L[];i<=L[]+m;i++) B[i-L[]]=rt[i];
sz=;Mmhrt(,,m-);
sz=m+;_Mmh(,,m-,,m+);
if (n&) mmh=1LL*mmh*mi(,MOD-)%MOD;
printf("%d\n",mmh);
}