BZOJ - 2243 染色 (树链剖分+线段树+区间合并)

时间:2023-03-09 22:52:48
BZOJ - 2243 染色 (树链剖分+线段树+区间合并)

题目链接

线段树维护区间连续段个数即可。设lc为区间左端点颜色,rc为区间右端点颜色,则合并两区间的时候,如果左区间右端点和右区间左端点颜色相同,则连续段个数-1。

在树链上的区间合并可以定义一个结构体作为线段,分成左右两条链暴力合并。也可以考虑到树上的路径中每两个树链“断开”的地方必然有一个结点是另一个结点的祖先,因此如果top[u]的颜色与fa[top[u]]的颜色相同时答案-1即可。

树剖和线段树结合真容易把人搞晕啊,什么时候要用l,r,什么时候要用u,什么时候要用dfn[u],一定要分清楚~~

 #include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+,inf=0x3f3f3f3f;
int hd[N],ne,n,k,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,a[N],cnt[N<<],mk[N<<],lc[N<<],rc[N<<];
struct E {int v,nxt;} e[N<<];
void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
void dfs1(int u,int f,int d) {
fa[u]=f,fa[u]=f,siz[u]=,dep[u]=d;
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u])continue;
dfs1(v,u,d+),siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int tp) {
top[u]=tp,dfn[u]=++tot,rnk[dfn[u]]=u;
if(!son[u])return;
dfs2(son[u],top[u]);
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
#define ls (u<<1)
#define rs (u<<1|1)
#define mid ((l+r)>>1)
void pu(int u) {lc[u]=lc[ls],rc[u]=rc[rs],cnt[u]=cnt[ls]+cnt[rs]; if(rc[ls]==lc[rs])cnt[u]--;}
void pd(int u) {if(mk[u])lc[u]=rc[u]=mk[u],cnt[u]=,mk[ls]=mk[rs]=mk[u],mk[u]=;}
void build(int u=,int l=,int r=tot) {
if(l==r) {lc[u]=rc[u]=a[rnk[l]],cnt[u]=; return;}
build(ls,l,mid),build(rs,mid+,r),pu(u);
}
void upd(int L,int R,int x,int u=,int l=,int r=tot) {
pd(u);
if(l>=L&&r<=R) {mk[u]=x,pd(u); return;}
if(l>R||r<L)return;
upd(L,R,x,ls,l,mid),upd(L,R,x,rs,mid+,r),pu(u);
}
int getcol(int p,int u=,int l=,int r=tot) {
pd(u);
if(l==r)return lc[u];
return p<=mid?getcol(p,ls,l,mid):getcol(p,rs,mid+,r);
}
int qry(int L,int R,int u=,int l=,int r=tot) {
pd(u);
if(l>=L&&r<=R)return cnt[u];
if(l>R||r<L)return ;
int t1=qry(L,R,ls,l,mid),t2=qry(L,R,rs,mid+,r);
int ret=t1+t2;
if(t1&&t2&&rc[ls]==lc[rs])ret--;
return ret;
}
void change(int u,int v,int x) {
for(; top[u]!=top[v]; u=fa[top[u]]) {
if(dep[top[u]]<dep[top[v]])swap(u,v);
upd(dfn[top[u]],dfn[u],x);
}
if(dep[u]<dep[v])swap(u,v);
upd(dfn[v],dfn[u],x);
}
int ask(int u,int v) {
int ret=;
for(; top[u]!=top[v]; u=fa[top[u]]) {
if(dep[top[u]]<dep[top[v]])swap(u,v);
ret+=qry(dfn[top[u]],dfn[u]);
if(getcol(dfn[top[u]])==getcol(dfn[fa[top[u]]]))ret--;
}
if(dep[u]<dep[v])swap(u,v);
ret+=qry(dfn[v],dfn[u]);
return ret;
}
int main() {
memset(hd,-,sizeof hd),ne=;
scanf("%d%d",&n,&k);
for(int i=; i<=n; ++i)scanf("%d",&a[i]),a[i]++;
for(int i=; i<n; ++i) {
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
tot=,dfs1(,,),dfs2(,),build();
while(k--) {
char ch;
int a,b,c;
scanf(" %c%d%d",&ch,&a,&b);
if(ch=='Q')printf("%d\n",ask(a,b));
else scanf("%d",&c),c++,change(a,b,c);
}
return ;
}