【学习笔记】dsu on tree

时间:2023-03-09 05:15:31
【学习笔记】dsu on tree

我也不知道为啥这要起这名,完完全全没看到并查集的影子啊……

实际上原理就是一个树上的启发式合并。

特点是可以在$O(nlogn)$的时间复杂度内完成对无修改的子树的统计,复杂度优于莫队算法。

局限性也很明显:1.不能支持修改  2.只能支持子树统计,不能链上统计。(链上统计你不能直接树剖吗?)

那么它是怎么实现的呢?首先有一个例子:
树上每个节点都有一个颜色(那么一定是蓝色),

求每个节点的子树上有多少颜色为k的节点。(每个节点的k不一定相同)

$O(n^2)$的算法非常好想,以每个点为起点dfs一下就没了。

当然也有不那么暴力的做法,dfs序一下再主席树或者莫队随便搞搞也行。

那么我们先看看暴力是怎么做的。

每次统计x节点前,暴力将x的子树的贡献加入,统计结束后,再暴力删除贡献,消除影响。

我们发现这有很多无用的删除操作,考虑优化?

那么我们怎么用dsu上树优雅的解决这个问题呢?我们想到了树链剖分(轻重链剖分)。

具体的做法是,我们先统计一个点的轻儿子,再把它的影响消除。再统计重儿子,此时不必消除影响。

为了完成统计,最后再统计一遍轻儿子。

可以这么考虑:只有dfs到轻边时,才会将轻边的子树中合并到上一级的重链,

树链剖分将一棵树分割成了不超过$logn$条重链。
每一个节点最多向上合并$logn$次,单次修改复杂度$O(1)$。
所以整体复杂度是$O(nlogn)$的。

所以大概的模版是这样的:

 void dfs2(int u,int f,int k){
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;if(v==f||v==wson[u])continue;
dfs2(v,u,);
}
if(wson[u])dfs(wson[u],u,),now=wson[u];
calc(u,f,);
now=;ans[u]=sum;
if(k==)calc(u,f,-),sumv=,maxv=;
}

下面是两道烂大街的例题:

1. Lomsat gelral(cf600E)

n个点的有根树,以1为根,每个点有一种颜色。我们称一种颜色占领了一个子树当且仅当没有其他颜色在这个子树中出现得比它多。求占领每个子树的所有颜色之和。

就是刚才的裸题啊。

 #include<bits/stdc++.h>
#define N 700010
using namespace std;
struct Edge{int u,v,next;}G[*N];
typedef long long ll;
int n,c[N],val[N],size[N],wson[N],fa[N];
ll ans[N];
int head[*N],tot=;
void addedge(int u,int v){
G[++tot].u=u;G[tot].v=v;G[tot].next=head[u];head[u]=tot;
G[++tot].u=v;G[tot].v=u;G[tot].next=head[v];head[v]=tot;
}
void dfs1(int u,int f=){
size[u]=;
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;if(v==f)continue;
if(v==f)continue;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[wson[u]])wson[u]=v;
}
}
bool vis[N];int maxv=;ll sum=;
void change(int u,int f,int k){
c[val[u]]+=k;
if(k>&&c[val[u]]>=maxv){
if(c[val[u]]>maxv)sum=,maxv=c[val[u]];
sum+=val[u];
}
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;if(v==f||vis[v])continue;
change(v,u,k);
}
}
void dfs2(int u,int f=,bool used=){
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;if(v==f||v==wson[u])continue;
dfs2(v,u);
}
if(wson[u])dfs2(wson[u],u,),vis[wson[u]]=;
change(u,f,);ans[u]=sum;
if(wson[u])vis[wson[u]]=;
if(!used)change(u,f,-),maxv=sum=;
}
inline int read(){
int f=,x=;char ch;
do{ch=getchar();if(ch=='-')f=-;}while(ch<''||ch>'');
do{x=x*+ch-'';ch=getchar();}while(ch>=''&&ch<='');
return f*x;
}
int main(){
n=read();
for(int i=;i<=n;i++)val[i]=read();
for(int i=;i<n;i++){
int u=read(),v=read();
addedge(u,v);
}
dfs1();dfs2();
for(int i=;i<=n;i++)printf("%I64d ",ans[i]);
}

当然这题也有不这么做的做法,随便从cf上粘了一个,大家自行意会……

 #include<bits/stdc++.h>
#define N 100005
using namespace std;
vector<int>a[N];map<int,int>S[N];
int F[N],id[N],c[N],n,i,x,y;
long long G[N],ans[N];
void work(int x,int y,int color){
if (y>F[x]) F[x]=y,G[x]=;
if (y==F[x]) G[x]+=color;
}
void Union(int &x,int y){
if (S[x].size()<S[y].size()) swap(x,y);
for (map<int,int>::iterator it=S[y].begin();it!=S[y].end();it++)
work(x,S[x][it->first]+=it->second,it->first);
}
void DFS(int x,int fa){
id[x]=x;S[x][c[x]]=;
F[x]=;G[x]=c[x];
for (int i=,y;i<a[x].size();i++)
if ((y=a[x][i])!=fa)
DFS(y,x),Union(id[x],id[y]);
ans[x]=G[id[x]];
}
int main(){
scanf("%d",&n);
for (i=;i<=n;i++)
scanf("%d",&c[i]);
for (i=;i<n;i++)
scanf("%d%d",&x,&y),
a[x].push_back(y),
a[y].push_back(x);
DFS(,);
for (i=;i<=n;i++)
printf("%I64d ",ans[i]);
}

例2: Arpa's letter-marked tree and Mehrdad's Dokhtar-kosh paths(CF741D)

这题也很显然,如果重排后能形成回文串,那么出现奇数次的字符应该少于2个(即最多1个)如果只有a~v的话考虑把每个字符当做一个二进制位,把一个点i到根的路径异或值记为s[i],那么我们就是要对于每个x在子树中找到a和b,使得s[a]^s[b]为0或2的次幂,且dep[a]+dep[b]-dep[lca]*2最大。

 #include<bits/stdc++.h>
#define N 500005
using namespace std;
int size[N],head[*N],tot=,wson[N],s[N],f[*N],ans[N],d[N],a[N];
char c[N];
int maxv,n,inf;
struct Edge{int u,v,next;}G[*N];
void addedge(int u,int v){
G[++tot].u=u;G[tot].v=v;G[tot].next=head[u];head[u]=tot;
//G[++tot].u=v;G[tot].v=u;G[tot].next=head[v];head[v]=tot;
}
void dfs1(int u,int fa){
size[u]=;d[u]=d[fa]+;
if(u!=)s[u]=s[fa]^(<<a[u]);
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;
dfs1(v,u);
size[u]+=size[v];if(size[v]>size[wson[u]])wson[u]=v;
}
}
void calc(int rt,int u){
int now=s[u];
maxv=max(maxv,f[now]+d[u]-*d[rt]);
if((s[u]^s[rt])==)maxv=max(maxv,d[u]-d[rt]);
for(int i=;i<;i++){
now=(<<i)^s[u];
maxv=max(maxv,f[now]+d[u]-*d[rt]);
if((s[u]^s[rt])==(<<i))maxv=max(maxv,d[u]-d[rt]);
}
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;calc(rt,v);
}
}
void change(int u,int k){
if(k)f[s[u]]=max(f[s[u]],d[u]);
else f[s[u]]=inf;
for(int i=head[u];i;i=G[i].next)change(G[i].v,k);
}
void dfs2(int u,int k){
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;if(v==wson[u])continue;
dfs2(v,);
}
if(wson[u])dfs2(wson[u],);
maxv=;int now=s[u];
maxv=max(maxv,f[now]-d[u]);
for(int i=;i<;i++){
now=(<<i)^s[u];
maxv=max(maxv,f[now]-d[u]);
}
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;if(v==wson[u])continue;
calc(u,v);change(v,);
}
ans[u]=maxv;
if(!k){
for(int i=head[u];i;i=G[i].next)change(G[i].v,);
f[s[u]]=inf;
}else f[s[u]]=max(f[s[u]],d[u]);
}
void erase(int u){
for(int i=head[u];i;i=G[i].next){
int v=G[i].v;erase(v);
ans[u]=max(ans[u],ans[v]);
}
}
int main(){
scanf("%d",&n);
for(int i=;i<=n;i++){
int u;scanf("%d %c\n",&u,&c[i]);
addedge(u,i);a[i]=c[i]-'a';
}
dfs1(,);
memset(f,,sizeof(f));inf=f[];dfs2(,);
erase();
for (int i=;i<=n;++i)printf("%d%c",ans[i]," \n"[i==n]);
}

大概是这样。

参考:

http://blog.****.net/qq_35392050/article/details/64537364

http://www.cnblogs.com/zzqsblog/p/6146916.html