线段树区间合并就挺麻烦了,再套个树链就更加鬼畜,不过除了代码量大就没什么其他的了。。
一些细节:线段树每个结点用结构体保存,pushup等合并函数改成返回一个结构体,这样好写一些
struct Seg{
int lc,rc,tot;
Seg(){lc=rc=-;tot=;}
};
Seg seg[maxn<<];int lazy[maxn<<];
Seg pushup(Seg a,Seg b){
if(!a.tot)return b;
if(!b.tot)return a;
Seg res;
res.lc=a.lc,res.rc=b.rc;
res.tot=a.tot+b.tot;
if(a.rc==b.lc)res.tot--;
return res;
}
向上爬时更新操作不用变,但是询问操作需要改变
同样有一些值得注意的地方:向上爬的两条链是有顺序的,合并时顺序不能搞反,也不能像普通树链剖分那样直接swap
int Query(int x,int y){
Seg A,B;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]){
B=pushup(query(id[top[y]],id[y],,n,),B);
y=f[top[y]];
}
else {
A=pushup(query(id[top[x]],id[x],,n,),A);
x=f[top[x]];
}
}
if(id[x]>id[y])
A=pushup(query(id[y],id[x],,n,),A);
else
B=pushup(query(id[x],id[y],,n,),B);
if(A.lc==B.lc)return A.tot+B.tot-;
else return A.tot+B.tot;
}
最后是完整代码
#include<bits/stdc++.h>
using namespace std;
#define maxn 100005
struct Edge{int to,nxt;}edge[maxn<<];
int c[maxn],head[maxn],tot,n; int f[maxn],son[maxn],d[maxn],size[maxn];
int cnt,id[maxn],rk[maxn],top[maxn];
void dfs1(int x,int pre,int deep){
size[x]=,d[x]=deep;
for(int i=head[x];i!=-;i=edge[i].nxt){
int y=edge[i].to;
if(y==pre)continue;
f[y]=x;dfs1(y,x,deep+);size[x]+=size[y];
if(size[son[x]]<size[y])son[x]=y;
}
}
void dfs2(int x,int tp){
top[x]=tp;id[x]=++cnt;rk[cnt]=x;
if(son[x])dfs2(son[x],tp);
for(int i=head[x];i!=-;i=edge[i].nxt){
int y=edge[i].to;
if(y!=son[x] && y!=f[x])dfs2(y,y);
}
} #define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
struct Seg{
int lc,rc,tot;
Seg(){lc=rc=-;tot=;}
};
Seg seg[maxn<<];int lazy[maxn<<];
Seg pushup(Seg a,Seg b){
if(!a.tot)return b;
if(!b.tot)return a;
Seg res;
res.lc=a.lc,res.rc=b.rc;
res.tot=a.tot+b.tot;
if(a.rc==b.lc)res.tot--;
return res;
}
void pushdown(int rt){
if(lazy[rt]<)return;
lazy[rt<<]=lazy[rt<<|]=lazy[rt];
seg[rt<<].lc=seg[rt<<].rc=lazy[rt];
seg[rt<<].tot=;
seg[rt<<|].lc=seg[rt<<|].rc=lazy[rt];
seg[rt<<|].tot=;
lazy[rt]=-;
}
void build(int l,int r,int rt){
if(l==r){
seg[rt].lc=seg[rt].rc=c[rk[l]];
seg[rt].tot=;return;
}
int m=l+r>>;
build(lson);build(rson);
seg[rt]=pushup(seg[rt<<],seg[rt<<|]);
}
void update(int L,int R,int c,int l,int r,int rt){
if(L<=l && R>=r){
lazy[rt]=c;seg[rt].lc=seg[rt].rc=c;
seg[rt].tot=;return;
}
pushdown(rt);
int m=l+r>>;
if(L<=m)update(L,R,c,lson);
if(R>m)update(L,R,c,rson);
seg[rt]=pushup(seg[rt<<],seg[rt<<|]);
}
Seg query(int L,int R,int l,int r,int rt){
if(L<=l && R>=r)return seg[rt];
pushdown(rt);
int m=l+r>>;
Seg res;
if(L<=m)res=pushup(res,query(L,R,lson));
if(R>m)res=pushup(res,query(L,R,rson));
return res;
} void Update(int x,int y,int c){
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]])swap(x,y);
update(id[top[x]],id[x],c,,n,);
x=f[top[x]];
}
if(id[x]>id[y])swap(x,y);
update(id[x],id[y],c,,n,);
}
int Query(int x,int y){
Seg A,B;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]){
B=pushup(query(id[top[y]],id[y],,n,),B);
y=f[top[y]];
}
else {
A=pushup(query(id[top[x]],id[x],,n,),A);
x=f[top[x]];
}
}
if(id[x]>id[y])
A=pushup(query(id[y],id[x],,n,),A);
else
B=pushup(query(id[x],id[y],,n,),B);
if(A.lc==B.lc)return A.tot+B.tot-;
else return A.tot+B.tot;
} void init(){
memset(head,-,sizeof head);
memset(lazy,-,sizeof lazy);
tot=;
}
void addedge(int u,int v){
edge[tot].to=v;edge[tot].nxt=head[u];head[u]=tot++;
}
int main(){
init();int q;
cin>>n>>q;
for(int i=;i<=n;i++)cin>>c[i];
for(int i=;i<n;i++){
int x,y;cin>>x>>y;
addedge(x,y);addedge(y,x);
}
cnt=;dfs1(,,),dfs2(,);
build(,n,);
char op[];int x,y,z;
while(q--){
scanf("%s",op);
if(op[]=='Q'){scanf("%d%d",&x,&y);
cout<<Query(x,y)<<'\n';}
if(op[]=='C'){scanf("%d%d%d",&x,&y,&z);Update(x,y,z);}
}
}