AtCoder Regular Contest 086 E - Smuggling Marbles(树形迭屁)

时间:2023-03-08 22:09:36

  好强的题。

  方案不好算,改成算概率,注意因为是模意义下的概率所以直接乘法逆元就好不要傻傻地开double。

  设$f[i][d][0]$为第i个节点离d层的球球走到第i个点时第i个点没有球的概率, $f[i][d][1]$为有1个球的概率, $f[i][d][2]$为有2个球及以上的概率。

  我们可以把$f[i]$看成一个队列, 然后从儿子转移的时候, 就是把儿子的队列一个一个合并起来,最后在队列头加上一个$f[i][0]$, 并且把队列里的所有$f[i][0$~$d][2]$加上$f[i][0$~$d][0]$,并且$f[i][0$~$d][2]$变成0就好了。

  合并的时候转移为:

  $f[i][d][0]=f[i][d][0]*f[j][d][0]$

  $f[i][d][1]=f[i][d][1]*f[j][d][0]+f[i][d][0]*f[j][d][1]$

  $f[i][d][2]=f[i][d][0]*f[j][d][2]+f[i][d][1]*f[j][d][2]+f[i][d][1]*f[j][d][1]+f[i][d][2]*f[j][d][2]+f[i][d][2]*f[j][d][1]+f[i][d][2]*f[j][d][0]$

  复杂度为O(N),因为每层元素只加1,交集最多为N。

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#define ll long long
#define MOD(x) ((x)>=mod?(x)-mod:(x))
using namespace std;
const int maxn=, mod=1e9+;
struct tjm{int too, pre;}e[maxn<<];
struct poi{int f[];};
int n, x, ans, tot, tott;
int last[maxn], root[maxn];
vector<poi>q[maxn];
inline void read(int &k)
{
int f=; k=; char c=getchar();
while(c<'' || c>'') c=='-' && (f=-), c=getchar();
while(c<='' && c>='') k=k*+c-'', c=getchar();
k*=f;
}
inline void add(int x, int y){e[++tot]=(tjm){y, last[x]}; last[x]=tot;}
inline int merge(int x, int y)
{
if(q[x].size()<q[y].size()) swap(x, y);
int nx=q[x].size()-, ny=q[y].size()-;
for(int i=;i<=ny;i++)
{
int sum0=, sum1=, sum2=;
sum0=1ll*q[x][nx-i].f[]*q[y][ny-i].f[]%mod;
sum1=(1ll*q[x][nx-i].f[]*q[y][ny-i].f[]+1ll*q[x][nx-i].f[]*q[y][ny-i].f[])%mod;
for(int j=;j<;j++)
for(int k=;j+k>=;k--)
sum2=(1ll*sum2+1ll*q[x][nx-i].f[j]*q[y][ny-i].f[k])%mod;
q[x][nx-i].f[]=sum0; q[x][nx-i].f[]=sum1; q[x][nx-i].f[]=sum2;
}
q[y].clear(); return x;
}
void dfs(int x, int fa)
{
if(!last[x]) root[x]=++tott; int dep=;
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa)
{
dfs(too, x);
if(!root[x]) root[x]=root[too];
else dep=max(dep, (int)min(q[root[x]].size(), q[root[too]].size())), root[x]=merge(root[x], root[too]);
}
int nx=q[root[x]].size()-;
for(int i=;i<dep;i++)
q[root[x]][nx-i].f[]=MOD(q[root[x]][nx-i].f[]+q[root[x]][nx-i].f[]), q[root[x]][nx-i].f[]=;
poi tmp; tmp.f[]=tmp.f[]=(mod+)>>; tmp.f[]=; q[root[x]].push_back(tmp);
}
inline int power(int a, int b)
{
int ans=;
for(;b;b>>=, a=1ll*a*a%mod)
if(b&) ans=1ll*ans*a%mod;
return ans;
}
int main()
{
read(n);
for(int i=;i<=n;i++) read(x), add(x, i);
dfs(, -);
for(int i=;i<q[root[]].size();i++) ans=MOD(ans+q[root[]][i].f[]);
printf("%lld\n", 1ll*ans*power(, n+)%mod);
}