Codeforces 600E. Lomsat gelral(Dsu on tree学习)

时间:2023-03-09 08:51:30
Codeforces 600E. Lomsat gelral(Dsu on tree学习)

题目链接:http://codeforces.com/problemset/problem/600/E


n个点的有根树,以1为根,每个点有一种颜色。我们称一种颜色占领了一个子树当且仅当没有其他颜色在这个子树中出现得比它多。求占领每个子树的所有颜色之和

我们都知道可以$BST$启发式合并从而完美${O(nlogn^{2})}$,这太丑陋了。

那么$Dsu~~on~~tree$是在干啥呢?

找出树中每一个节点的重儿子,统计答案的时候优先进入每一个点的所有轻儿子,之后再进入重儿子,目的是保留重儿子所在子树的信息。

处理完当前点的所有儿子的子树之后开始处理自己。

先统计以当前点为根的子树不经过重儿子的所有点的影响${O(size[x]-size[hson[x]])}$

如果这个点是轻儿子则需要暴力删除这棵子树所带来的影响${O(size[x])}$,这也正是先进入轻儿子的原因,可以保留重儿子的信息。

考虑复杂度为什么正确:

  不妨想想一个点被反复计算了多少次?是不是${O(这个点到根的轻边条树)}$,这个东西是${O(logn)}$级别的。

最终复杂度:${O(nlogn)}$


 #include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cstdlib>
#include<cmath>
#include<cstring>
using namespace std;
#define maxn 300100
#define llg long long
#define yyj(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
llg n,m,hson[maxn],size[maxn],ans[maxn],c[maxn],cnt[maxn],sum,mx,S;
vector<llg>a[maxn]; inline llg getint()
{
llg w=,q=; char c=getchar();
while((c<'' || c>'') && c!='-') c=getchar(); if(c=='-') q=,c=getchar();
while (c>='' && c<='') w=w*+c-'', c=getchar(); return q ? -w : w;
} void init()
{
llg x,y;
cin>>n;
for (llg i=;i<=n;i++) c[i]=getint();
for (llg i=;i<n;i++)
{
x=getint(),y=getint();
a[x].push_back(y),a[y].push_back(x);
}
} void find_hson(llg x,llg fa)
{
size[x]=;
llg w=a[x].size(),v;
for (llg i=;i<w;i++)
{
v=a[x][i];
if (v==fa) continue;
find_hson(v,x);
size[x]+=size[v];
if (size[v]>size[hson[x]]) hson[x]=v;
}
} void calc(llg x,llg fa,llg val)
{
cnt[c[x]]+=val;
if (cnt[c[x]]>mx) sum=c[x],mx=cnt[c[x]];
else
{
if (cnt[c[x]]==mx) sum+=c[x];
}
llg w=a[x].size(),v;
for (llg i=;i<w;i++)
{
v=a[x][i];
if (v==fa || v==S) continue;
calc(v,x,val);
}
} void dfs(llg x,llg fa,llg t)
{
llg w=a[x].size(),v;
for (llg i=;i<w;i++)
{
v=a[x][i];
if (v==fa || v==hson[x]) continue;
dfs(v,x,-);
}
if (hson[x]) dfs(hson[x],x,),S=hson[x];
calc(x,fa,); S=;
ans[x]=sum;
if (t==-) calc(x,fa,-),mx=sum=;
} int main()
{
yyj("tree");
init();
find_hson(,);
dfs(,,);
for (llg i=;i<=n;i++) printf("%lld ",ans[i]);
return ;
}