CF438E The Child and Binary Tree

时间:2020-12-20 20:41:51

思路

设F(x)的第x项系数为权值和为x的答案

题目中要求权值必须在集合中出现,这个不好处理,考虑再设一个C,C的第x项如果是1代表x出现在值域里,如果是0,代表x没有出现在值域里,然后由于二叉树可以分别对左右子树处理,所以

\[F_k=\sum_{i=1}^k C_i \sum_{j=0}^{k-i}F_j F_{k-i-j}
\]

\[F_0=1
\]

可以看出这是一个卷积的形式

\[F=1+C*F*F
\]

然后解一个一元二次方程

\[F=\frac{1 \pm \sqrt{1-4C}}{2C}=\frac{2}{1 \pm \sqrt{1-4C}}
\]

因为\(C_0=0\),\(F_0=1\),所以去掉负号

然后上多项式求逆和多项式开方即可

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
using namespace std;
const int MAXN = 300000;
const int MOD = 998244353;
const int G = 3;
const int invG = 332748118;
struct Poly{
int t,data[MAXN];
Poly(){};
};
int pow(int a,int b){
int ans=1;
while(b){
if(b&1)
ans=(1LL*ans*a)%MOD;
a=(1LL*a*a)%MOD;
b>>=1;
}
return ans;
}
void NTT(Poly &a,int opt,int n){
int lim=0;
while((1<<(lim))<n){
lim++;
}
n=(1<<lim);
for(int i=0;i<n;i++){
int t=0;
for(int j=0;j<lim;j++)
if((i>>j)&1)
t|=(1<<(lim-j-1));
if(i<t)
swap(a.data[i],a.data[t]);
}
for(int i=2;i<=n;i<<=1){
int len=i/2;
int tmp=pow((opt)?G:invG,(MOD-1)/i);
for(int j=0;j<n;j+=i){
int arr=1;
for(int k=j;k<j+len;k++){
int t=(1LL*a.data[k+len]*arr)%MOD;
a.data[k+len]=(1LL*a.data[k]-t+MOD)%MOD;
a.data[k]=(1LL*a.data[k]+t)%MOD;
arr=(1LL*arr*tmp)%MOD;
}
}
}
if(!opt){
int invN = pow(n,MOD-2);
for(int i=0;i<n;i++){
a.data[i]=(1LL*a.data[i]*invN)%MOD;
}
}
}
void save(Poly &a,int top){
for(int i=top+1;i<=a.t;i++)
a.data[i]=0;
a.t=top;
}
void mul(Poly &a,Poly b){//a=a*b
int num=a.t+b.t,lim=0;
while((1<<lim)<=(num+2))
lim++;
lim=(1<<lim);
NTT(a,1,lim);
NTT(b,1,lim);
for(int i=0;i<lim;i++)
a.data[i]=(1LL*a.data[i]*b.data[i])%MOD;
NTT(a,0,lim);
a.t=num;
for(int i=num+1;i<lim;i++)
a.data[i]=0;
}
void Inv(Poly a,Poly &b,int dep,int &midlen){
if(dep==1){
b.data[0]=pow(a.data[0],MOD-2);
b.t=dep-1;
return;
}
Inv(a,b,(dep+1)>>1,midlen);
static Poly tmp;
while((dep<<1)>midlen)
midlen<<=1;
for(int i=0;i<dep;i++)
tmp.data[i]=a.data[i];
for(int i=dep;i<midlen;i++)
tmp.data[i]=0;
NTT(tmp,1,midlen);
NTT(b,1,midlen);
for(int i=0;i<midlen;i++)
b.data[i]=1LL*b.data[i]*((2-1LL*tmp.data[i]*b.data[i])%MOD+MOD)%MOD;
NTT(b,0,midlen);
for(int i=dep;i<midlen;i++)
b.data[i]=0;
b.t=dep-1;
}
void sqrt(Poly a,Poly &b,int &midlen,int dep){
if(dep==1){
b.data[0]=1;
b.t=dep-1;
return;
}
sqrt(a,b,midlen,(dep+1)>>1);
while((dep<<1)>(midlen))
midlen<<=1;
static Poly tmp1,tmp2,tmp3;
tmp1=b;tmp3=b;
save(tmp1,dep-1);
save(tmp2,-1);
save(tmp3,dep-1);
int midlent=1;
for(int i=0;i<dep;i++)
tmp1.data[i]=(tmp1.data[i]*2)%MOD;
Inv(tmp1,tmp2,dep,midlent);
mul(b,tmp3);
for(int i=0;i<dep;i++)
b.data[i]=(b.data[i]+a.data[i])%MOD;
mul(b,tmp2);
for(int i=dep;i<midlen;i++)
b.data[i]=0;
b.t=dep-1;
}
Poly c,C;
int n,m;
signed main(){
scanf("%lld %lld",&n,&m);
c.t=100000;
for(int i=1;i<=n;i++){
int x;
scanf("%lld",&x);
c.data[x]=1;
}
for(int i=0;i<=c.t;i++)
c.data[i]=((-4LL*c.data[i])%MOD+MOD)%MOD;
c.data[0]=(1+c.data[0])%MOD;
int midlen=1;
sqrt(c,C,midlen,c.t+1);
C.data[0]=(1+C.data[0])%MOD;
midlen=1;
save(c,-1);
Inv(C,c,C.t+1,midlen);
for(int i=1;i<=m;i++)
printf("%lld\n",(c.data[i]*2)%MOD);
return 0;
}