BZOJ 4817 [SDOI2017]树点涂色 (LCT+线段树维护dfs序)

时间:2022-06-27 04:12:44

题目大意:略

涂色方式明显符合$LCT$里$access$操作的性质,相同颜色的节点在一条深度递增的链上

用$LCT$维护一个树上集合就好

因为它维护了树上集合,所以它别的啥都干不了了

发现树是静态的,可以用$dfs$序搞搞

把问题当成树上节点涂色会很麻烦

但只有相邻的不同颜色节点才会对答案产生影响

所以我们把涂色当成一种连边/断边操作

这样,问题就容易解决得多了

维护一个数组$f_{x}$表示$x$节点到根的路径上一共有$f_{x}$种颜色,$f_{x}-1$条断边

显然它的初始值就是节点x的深度

第一个操作,把这个位置到根打通

每断一条边,子树每个节点答案$+1$,连一条边,答案$-1$

在$access$操作中进行讨论即可

第二个操作,求链上不同颜色数量,即一个链的断边数量$-1$

显然,答案是$(f_{x}-1)+(f_{y}-1)-2*(f_{lca(x,y)}-1)+1=f_{x}+f_{y}-2*f_{lca(x,y)}+1$

即断边总数$+1$

不要把它当成节点的颜色去想

第三个操作,求子树内$f_{x}$最大值

以上操作皆可用$dfs$序+线段树实现!

随时保持清醒头脑,千万不要把这个$LCT$当成真正的$LCT$,它只是一个维护链集合的媒介!

尤其是断边/连边,进行区间修改操作时,需要找出开头/后继节点,而不是当前splay的根节点

 #include <queue>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N1 101000
#define S1 (N1<<1)
#define T1 (N1<<2)
#define ll long long
#define uint unsigned int
#define rint register int
#define ull unsigned long long
#define dd double
#define il inline
#define inf 1000000000
using namespace std; int gint()
{
int ret=,fh=;char c=getchar();
while(c<''||c>''){if(c=='-')fh=-;c=getchar();}
while(c>=''&&c<=''){ret=ret*+c-'';c=getchar();}
return ret*fh;
}
int n,m,T;
struct Edge{
int to[N1*],nxt[N1*],head[N1],cte;
void ae(int u,int v)
{cte++;to[cte]=v,nxt[cte]=head[u],head[u]=cte;}
}E; struct SEG{
int ma[N1<<],tag[N1<<];
void pushup(int rt){ma[rt]=max(ma[rt<<],ma[rt<<|]);}
void pushdown(int rt)
{
if(!tag[rt]) return;
ma[rt<<]+=tag[rt]; ma[rt<<|]+=tag[rt];
tag[rt<<]+=tag[rt]; tag[rt<<|]+=tag[rt];
tag[rt]=;
}
void build(int *a,int *id,int l,int r,int rt)
{
if(l==r) {ma[rt]=a[id[l]];return;}
int mid=(l+r)>>;
build(a,id,l,mid,rt<<);
build(a,id,mid+,r,rt<<|);
pushup(rt);
}
void update(int L,int R,int l,int r,int rt,int w)
{
if(!L||!R||L>R) return;
if(L<=l&&r<=R) {ma[rt]+=w,tag[rt]+=w;return;}
int mid=(l+r)>>; pushdown(rt);
if(L<=mid) update(L,R,l,mid,rt<<,w);
if(R>mid) update(L,R,mid+,r,rt<<|,w);
pushup(rt);
}
int query(int L,int R,int l,int r,int rt)
{
if(L<=l&&r<=R) return ma[rt];
int mid=(l+r)>>,ans=; pushdown(rt);
if(L<=mid) ans=max(ans,query(L,R,l,mid,rt<<));
if(R>mid) ans=max(ans,query(L,R,mid+,r,rt<<|));
return ans;
}
}s; namespace lct{
int ch[N1][],fa[N1];
int idf(int x){return ch[fa[x]][]==x?:;}
int isroot(int x){return (ch[fa[x]][]==x||ch[fa[x]][]==x)?:;}
//int stk[N1],tp;
void rot(int x)
{
int y=fa[x],ff=fa[y],px=idf(x),py=idf(y);
if(!isroot(y)) ch[ff][py]=x; fa[x]=ff;
fa[ch[x][px^]]=y,ch[y][px]=ch[x][px^];
ch[x][px^]=y,fa[y]=x;
//pushup(y),pushup(x);
}
void splay(int x)
{
int y=x; /*stk[++tp]=x;
while(!isroot(y)){stk[++tp]=fa[y],y=fa[y];}
while(tp){pushdown(stk[tp--]);}*/
while(!isroot(x))
{
y=fa[x];
if(isroot(y)) rot(x);
else if(idf(y)==idf(x)) rot(y),rot(x);
else rot(x),rot(x);
}
}
int First(int x){while(ch[x][]) x=ch[x][];return x;}
int upper(int x){x=ch[x][];while(ch[x][]) x=ch[x][];return x;}
void access(int x,int *st,int *ed)
{
for(int y=,z;x;)
{
splay(x);
z=upper(x);
if(z) s.update(st[z],ed[z],,n,,);
z=First(y);
s.update(st[z],ed[z],,n,,-);
ch[x][]=y; y=x; x=fa[x];
}
}
void init(int *ff){
for(int i=;i<=n;i++) fa[i]=ff[i];
}
}; int fa[N1],son[N1],sz[N1],tp[N1],dep[N1];
int st[N1],ed[N1],id[N1],tot;
void dfs1(int u,int ff)
{
for(int j=E.head[u];j;j=E.nxt[j])
{
int v=E.to[j];
if(v==ff) continue;
dep[v]=dep[u]+; dfs1(v,u); fa[v]=u;
sz[u]+=sz[v]; son[u]=sz[v]>sz[son[u]]?v:son[u];
}
sz[u]++;
}
void dfs2(int u)
{
st[u]=ed[u]=++tot; id[tot]=u;
if(son[u]) tp[son[u]]=tp[u], dfs2(son[u]), ed[u]=max(ed[u],ed[son[u]]);;
for(int j=E.head[u];j;j=E.nxt[j])
{
int v=E.to[j];
if(v==fa[u]||v==son[u]) continue;
tp[v]=v; dfs2(v);
ed[u]=max(ed[u],ed[v]);
}
}
int lca(int x,int y)
{
while(tp[x]!=tp[y])
{
if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
x=fa[tp[x]];
}
return dep[x]<dep[y]?x:y;
}
void mksame(int x)
{
lct::access(x,st,ed);
}
int split(int x,int y)
{
int sx,sy,sf,f=lca(x,y);
sx=s.query(st[x],st[x],,n,);
sy=s.query(st[y],st[y],,n,);
sf=s.query(st[f],st[f],,n,);
return sx+sy-*sf+;
}
int qmax(int x){return s.query(st[x],ed[x],,n,);}
void init()
{
dep[]=; dfs1(,-);
tp[]=; dfs2();
lct::init(fa);
s.build(dep,id,,n,);
} int qf(int x){return s.query(st[x],st[x],,n,);} int main()
{
scanf("%d%d",&n,&m);
int i,j,fl,x,y,cnt=,de;
for(i=;i<n;i++) x=gint(), y=gint(), E.ae(x,y), E.ae(y,x);
init();
for(j=;j<=m;j++)
{
fl=gint();
if(fl==){
x=gint();
mksame(x);
}else if(fl==){
x=gint(); y=gint();
printf("%d\n",split(x,y));
}else{
x=gint();
printf("%d\n",qmax(x));
}
}
return ;
}