BZOJ3129 SDOI2013方程(容斥原理+扩展lucas)

时间:2021-11-02 11:30:49

  没有限制的话算一个组合数就好了。对于不小于某个数的限制可以直接减掉,而不大于某个数的限制很容易想到容斥,枚举哪些超过限制即可。

  一般情况下n、m、p都是1e9级别的组合数没办法算。不过可以发现模数已经被给出,并且这些模数的最大质因子幂都不是很大,那么扩展lucas就可以了。

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
int read()
{
int x=,f=;char c=getchar();
while (c<''||c>'') {if (c=='-') f=-;c=getchar();}
while (c>=''&&c<='') x=(x<<)+(x<<)+(c^),c=getchar();
return x*f;
}
int T,P,n,n1,n2,m,ans,a[];
int p[],b[],c[],s[],t,f[][];
void inc(int &x,int y,int p){x+=y;if (x>=p) x-=p;}
void exgcd(int a,int b,int &x,int &y)
{
if (b==)
{
x=,y=;
return;
}
exgcd(b,a%b,x,y);
int t=x;x=y;y=t-a/b*x;
}
int inv(int a,int p)
{
int x,y;
exgcd(a,p,x,y);
return (x+p)%p;
}
int ksm(int a,int k,int p)
{
if (k==) return ;
int tmp=ksm(a,k>>,p);
if (k&) return 1ll*tmp*tmp%p*a%p;
else return 1ll*tmp*tmp%p;
}
int fac(int n,int i)
{
if (n==) return ;
return 1ll*fac(n/p[i],i)*ksm(f[i][c[i]],n/c[i],c[i])%c[i]*f[i][n%c[i]]%c[i];
}
int C(int n,int m,int i)
{
int s=;
for (long long j=p[i];j<=n;j*=p[i]) s+=n/j;
for (long long j=p[i];j<=m;j*=p[i]) s-=m/j;
for (long long j=p[i];j<=n-m;j*=p[i]) s-=(n-m)/j;
if (s>=b[i]) return ;
return 1ll*fac(n,i)*inv(fac(m,i),c[i])%c[i]*inv(fac(n-m,i),c[i])%c[i]*ksm(p[i],s,c[i])%c[i];
}
int crt()
{
int ans=;
for (int i=;i<=t;i++)
inc(ans,1ll*s[i]*(P/c[i])%P*inv(P/c[i],c[i])%P,P);
return ans;
}
int calc(int n,int m)
{
if (n<m) return ;
for (int i=;i<=t;i++)
s[i]=C(n,m,i);
return crt();
}
void dfs(int k,int s,int m)
{
if (k>n1)
{
if (s&) inc(ans,(P-calc(m-,n-))%P,P);
else inc(ans,calc(m-,n-),P);
return;
}
dfs(k+,s+,m-a[k]);
dfs(k+,s,m);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("bzoj3129.in","r",stdin);
freopen("bzoj3129.out","w",stdout);
const char LL[]="%I64d";
#else
const char LL[]="%lld";
#endif
T=read(),P=read();
if (P==) t=,p[]=,b[]=,c[]=;
else if (P==)
{
t=;
p[]=,p[]=,p[]=,p[]=,p[]=;
b[]=,b[]=,b[]=,b[]=,b[]=;
c[]=,c[]=,c[]=,c[]=,c[]=;
}
else
{
t=;
p[]=,p[]=,p[]=;
b[]=,b[]=,b[]=;
c[]=,c[]=,c[]=;
}
for (int i=;i<=t;i++)
{
f[i][]=;
for (int j=;j<=c[i];j++)
if (j%p[i]==) f[i][j]=f[i][j-];
else f[i][j]=1ll*f[i][j-]*j%c[i];
}
while (T--)
{
n=read(),n1=read(),n2=read(),m=read();
for (int i=;i<=n1;i++) a[i]=read();
for (int i=;i<=n2;i++) m-=read()-;
ans=;
if (m>) dfs(,,m);
cout<<ans<<endl;
}
return ;
}