Codeforces 1097G Vladislav and a Great Legend [树形DP,斯特林数]

时间:2021-09-25 11:38:09

洛谷

Codeforces


这题真是妙的很。

通过看题解,终于知道了\(\sum_n f(n)^k​\)这种东西怎么算。

update:经过思考,我对这题有了更深的理解,现将更新内容放在原题解下方。


思路

发现\(\sum_S (f(S))^k\)这东西因为有个\(k\)次方,所以特别难算,考虑拆开:

\[x^k=\sum_{i=0}^k {x \choose i} \times i! \times S(k,i)
\]

其中\(S(n,m)\)是第二类斯特林数,即\(n\)个元素放进\(m​\)个集合里的方案数。(不会打那个大括号,会的请在评论区教我一下,谢谢)

于是式子化为:

\[\begin{align*}
ans&=\sum_S (f(S))^k\\
&=\sum_S \sum_{i=0}^k {f(S) \choose i} i!S(k,i)\\
&=\sum_{i=0}^k i!S(k,i)\sum_S {f(S) \choose i}
\end{align*}
\]

外面可以枚举\(i\),考虑里面那一堆东西怎么搞。

考虑它的组合意义:对于每一个点集组成的生成树,在树里面选出\(i\)条边涂色的方案数。

听着像是个背包,实际就是背包。

设\(dp_{x,i}\)表示\(x\)的子树内选出一棵有\(i\)条涂了色的边的生成树的方案数,答案在每一棵生成树中深度最小的点上统计,则有

\[dp_{x,i}+=dp_{x,j}dp_{v,i-j}
\]

也就是个背包。

但实际的DP比这恶心很多,有很多细节。

复杂度看似是\(O(nk^2)\),但据说是\(O(nk)\),不过我不会证。

感觉对这题的理解还不够透彻,但网上也没什么题解,就这样吧。


代码

#include<bits/stdc++.h>
namespace my_std{
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
#define MP make_pair
#define rep(i,x,y) for (int i=(x);i<=(y);i++)
#define drep(i,x,y) for (int i=(x);i>=(y);i--)
#define go(x) for (int i=head[x];i;i=edge[i].nxt)
#define sz 101010
typedef long long ll;
typedef double db;
const ll mod=1e9+7;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<typename T>inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
template<typename T>inline void read(T& t)
{
t=0;char f=0,ch=getchar();double d=0.1;
while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
t=(f?-t:t);
}
template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
void file()
{
#ifndef ONLINE_JUDGE
freopen("a.txt","r",stdin);
#endif
}
// inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std; int n,K;
struct hh{int t,nxt;}edge[sz<<1];
int head[sz],ecnt;
void make_edge(int f,int t)
{
edge[++ecnt]=(hh){t,head[f]};
head[f]=ecnt;
edge[++ecnt]=(hh){f,head[t]};
head[t]=ecnt;
}
inline ll add(ll x,ll y){x+=y;if (x>=mod) x-=mod;return x;} ll dp[sz][233],f[sz],ans[233];
int size[sz];
#define v edge[i].t
void dfs(int x,int fa)
{
size[x]=1;dp[x][0]=2; // 选了自己和没选自己都合法
go(x) if (v!=fa)
{
dfs(v,x);
rep(i,0,K) f[i]=0;
rep(j,0,min(K,size[x])) // 一定要取min!!
rep(k,0,min(K-j,size[v]))
f[j+k]=add(f[j+k],dp[x][j]*dp[v][k]%mod);
rep(j,0,K) dp[x][j]=f[j];
rep(j,0,K) ans[j]=add(ans[j],mod-dp[v][j]); // 减去只在一棵子树有的,即不经过x的
size[x]+=size[v];
}
rep(i,0,K) ans[i]=add(ans[i],dp[x][i]);
drep(i,K,1) dp[x][i]=add(dp[x][i],dp[x][i-1]); // 考虑上通往祖先的那条边
dp[x][1]=add(dp[x][1],mod-1); // 减去子树为空但连了祖先边的一种情况
}
#undef v
void solve(){dfs(1,0);} ll S[233][233]; int main()
{
file();
int x,y;
read(n,K);
rep(i,1,n-1) read(x,y),make_edge(x,y);
S[1][1]=1;
rep(i,2,K)
rep(j,1,K)
S[i][j]=add(S[i-1][j-1],S[i-1][j]*j%mod);
int fac=1,Ans=0;
solve();
rep(i,1,K)
{
fac=1ll*fac*i%mod;
Ans=add(Ans,1ll*fac*S[K][i]%mod*ans[i]%mod);
}
cout<<Ans;
return 0;
}

更新

原来的那种DP方法有很多特殊情况需要考虑,感觉不是很好,这里给出另外一种。

设\(dp_{x,i}\)表示\(x\)的子树内选出一棵有\(i\)条涂了色的边的非空生成树的方案数,方程也差不多,但多了一些分类讨论。

考虑原来的\(x\)加入了新的子树\(v\),那么:

  1. 只有原来的生成树
  2. 只有新的子树(加上\(x\rightarrow v​\)这条边)
  3. 由原来的和新的共同组成(并加上\(x\rightarrow v\)这条边)

加上分类讨论和非空的限制之后,思路清晰了很多(至少我是这么认为的),也避免了很多特殊情况。

至于复杂度的问题,似乎取个min之后变成\(n^2\)(\(nk\))是很多树形DP的trick,看来是我孤陋寡闻了。

#include<bits/stdc++.h>
namespace my_std{
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
#define MP make_pair
#define rep(i,x,y) for (int i=(x);i<=(y);i++)
#define drep(i,x,y) for (int i=(x);i>=(y);i--)
#define go(x) for (int i=head[x];i;i=edge[i].nxt)
#define sz 101010
typedef long long ll;
typedef double db;
const ll mod=1e9+7;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<typename T>inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
template<typename T>inline void read(T& t)
{
t=0;char f=0,ch=getchar();double d=0.1;
while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
t=(f?-t:t);
}
template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
void file()
{
#ifndef ONLINE_JUDGE
freopen("a.txt","r",stdin);
#endif
}
// inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std; int n,K;
struct hh{int t,nxt;}edge[sz<<1];
int head[sz],ecnt;
void make_edge(int f,int t)
{
edge[++ecnt]=(hh){t,head[f]};
head[f]=ecnt;
edge[++ecnt]=(hh){f,head[t]};
head[t]=ecnt;
}
inline ll add(ll x,ll y){x+=y;if (x>=mod) x-=mod;return x;} ll dp[sz][233],f[sz],ans[233];
int size[sz];
#define v edge[i].t
void dfs(int x,int fa)
{
size[x]=1;dp[x][0]=1;
go(x) if (v!=fa)
{
dfs(v,x);
rep(j,0,K) f[j]=dp[x][j];
rep(j,0,K) dp[x][j]=dp[x][j]; // 只有原来的
rep(j,1,min(K,size[v])) // 只有新的
dp[x][j]=add(dp[x][j],add(dp[v][j],dp[v][j-1]));
dp[x][0]=add(dp[x][0],dp[v][0]); // 只有新的
rep(j,0,min(K,size[x])) // 原来的和新的合并,顺便更新答案
rep(k,0,min(K-j,size[v]))
{
ll val=f[j]*dp[v][k]%mod;
dp[x][j+k]=add(dp[x][j+k],val);
ans[j+k]=add(ans[j+k],val);
dp[x][j+k+1]=add(dp[x][j+k+1],val);
ans[j+k+1]=add(ans[j+k+1],val);
}
size[x]+=size[v];
}
}
#undef v
void solve(){dfs(1,0);} ll S[233][233]; int main()
{
file();
int x,y;
read(n,K);
rep(i,1,n-1) read(x,y),make_edge(x,y);
S[1][1]=1;
rep(i,2,K)
rep(j,1,K)
S[i][j]=add(S[i-1][j-1],S[i-1][j]*j%mod);
int fac=1,Ans=0;
solve();
rep(i,1,K)
{
fac=1ll*fac*i%mod;
Ans=add(Ans,1ll*fac*S[K][i]%mod*ans[i]%mod);
}
cout<<Ans;
return 0;
}