[EZOJ1007] 神奇的三角形

时间:2023-03-09 15:12:54
[EZOJ1007] 神奇的三角形

Description

求 \(\sum\limits_{i=0}^{n-1}\sum\limits_{j=0}^{i}C(i,j)\times (j+1)^m\operatorname{mod}998244353\)

\(n\leq10^9,m\leq 100000\)

Solution

傻逼推式子题...

首先 \(\sum\limits_{i=0}^nC(i,j)=C(n+1,j+1)\),所以原式可化为

\[\sum_{i=1}^nC(n,i)\times i^m
\]

斯特林展开 \(n^k=\sum\limits_{i=0}^nS(k,i)\times i!\times C(n,i)\)

\[\sum_{i=1}^nC(n,i)\times \sum_{k=0}^mC(i,k)\times k!\times S(m,k)
\]

因为 \(S(i,j)=0(i<j)\),所以将 \(k\) 的枚举提前

\[\sum_{k=0}^mS(m,k)\times k!\times \sum_{i=1}^nC(n,i)\times C(i,k)
\]

观察 \(\sum\limits_{i=1}^nC(n,i)\times C(i,k)\) 的组合意义,即先从 \(n\) 个球中选 \(i\) 个,再从 \(i\) 个球中选 \(k\) 个。这和从 \(n\) 个球中先取 \(k\) 个,剩下的球随意拿是等价的。所以 \(\sum\limits_{i=1}^nC(n,i)\times C(i,k)=C(n,k)\times 2^{n-k}\)

\[\sum_{k=0}^mS(m,k)\times k!\times C(n,k)\times 2^{n-k}
\]

将组合数拆开

\[\sum_{k=0}^mS(m,k)\times \frac{n!\times 2^{n-k}}{(n-k)!}
\]

这是个卷积的形式,那么就先 \(NTT\) 一遍求出第二类斯特林数,再 \(NTT\) 求答案就行了。

因为 \(n\) 很大但是 \(k\) 很小,所以 \(\frac{n!}{(n-k)!}\) 是可以算的,数组下标再平移一下就好了。

Code

#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
#define inv(x) ksm(x,mod-2)
const int N=4e5+5;
const int mod=998244353; int fac[N],ifac[N];
int rev[N],a[N],b[N];
int n,m,lim,c[N],d[N]; int ksm(int a,int b,int ans=1){
while(b){
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;b>>=1;
} return ans;
} int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
} void ntt(int *f,int opt){
for(int i=0;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int mid=1;mid<lim;mid<<=1){
int tmp=ksm(3,(mod-1)/(mid<<1));
if(opt<0) tmp=inv(tmp);
for(int R=mid<<1,j=0;j<lim;j+=R){
int w=1;
for(int k=0;k<mid;k++,w=1ll*w*tmp%mod){
int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
}
}
} if(opt<0){
for(int in=inv(lim),i=0;i<lim;i++)
f[i]=1ll*f[i]*in%mod;
}
} void mul(int *a,int *b){
ntt(a,1),ntt(b,1);
for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
ntt(a,-1);
} signed main(){
n=getint(),m=getint();
fac[0]=ifac[0]=1;
for(int i=1;i<=m;i++) fac[i]=1ll*fac[i-1]*i%mod;
ifac[m]=inv(fac[m]);
for(int i=m-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
if(n<m){
int ans=0;
for(int i=1;i<=n;i++)
ans=(ans+1ll*fac[n]%mod*ifac[i]%mod*ifac[n-i]%mod*ksm(i,m)%mod)%mod;
printf("%d\n",ans);return 0;
}
lim=1;while(lim<=m+m) lim<<=1;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
for(int i=0;i<=m;i++){
a[i]=1ll*(i&1?mod-1:1)*ifac[i]%mod;
b[i]=1ll*ksm(i,m)*ifac[i]%mod;
} mul(a,b);int now=1;
for(int i=n;i>=n-m;i--){
c[i-n+m]=1ll*ksm(2,i)*now%mod;
now=1ll*now*i%mod;
} mul(a,c);
printf("%d\n",a[m]);
return 0;
}