BZOJ2243[SDOI2011]染色——树链剖分+线段树

时间:2023-03-09 05:28:13
BZOJ2243[SDOI2011]染色——树链剖分+线段树

题目描述

给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。

输入

第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

输出

对于每个询问操作,输出一行答案。

样例输入

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

样例输出

3
1
2

提示

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。

  一道很有思维含量且细节挺多的题。首先想如何求序列上一个区间有几个颜色段?线段树每个点维护这个点所代表区间的左右端点颜色及这个区间的颜色段数就好了。在树上操作就直接dfs出树剖序然后架在线段树上。但要注意一点,每次查询时两点在跳lca的过程中每次跳的是一条重链,两条重链的相接的两个点颜色如果相同就要把答案数减1,所以还要维护每次跳的链头的颜色。当最后两个点跳到同一个重链上时这两点间的链的两头要分别判断一下。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int n,m;
int tot;
int num;
int x,y,z;
char ch[2];
int s[100010];
int d[100010];
int a[800010];
int f[100010];
int v[100010];
int q[100010];
int ls[800010];
int rs[800010];
int to[200010];
int sum[800010];
int son[100010];
int top[100010];
int size[100010];
int head[100010];
int next[200010];
void add(int x,int y)
{
tot++;
next[tot]=head[x];
head[x]=tot;
to[tot]=y;
}
void dfs(int x,int fa)
{
size[x]=1;
f[x]=fa;
d[x]=d[fa]+1;
for(int i=head[x];i;i=next[i])
{
if(to[i]!=fa)
{
dfs(to[i],x);
size[x]+=size[to[i]];
if(size[to[i]]>size[son[x]])
{
son[x]=to[i];
}
}
}
}
void dfs2(int x,int tp)
{
s[x]=++num;
q[num]=x;
top[x]=tp;
if(son[x])
{
dfs2(son[x],tp);
}
for(int i=head[x];i;i=next[i])
{
if(to[i]!=f[x]&&to[i]!=son[x])
{
dfs2(to[i],to[i]);
}
}
}
void pushup(int rt)
{
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
if(rs[rt<<1]==ls[rt<<1|1])
{
sum[rt]--;
}
ls[rt]=ls[rt<<1];
rs[rt]=rs[rt<<1|1];
}
void pushdown(int rt)
{
if(a[rt]!=-1)
{
a[rt<<1]=a[rt];
a[rt<<1|1]=a[rt];
ls[rt<<1]=a[rt];
rs[rt<<1]=a[rt];
ls[rt<<1|1]=a[rt];
rs[rt<<1|1]=a[rt];
sum[rt<<1]=1;
sum[rt<<1|1]=1;
a[rt]=-1;
}
}
void build(int rt,int l,int r)
{
a[rt]=-1;
if(l==r)
{
ls[rt]=rs[rt]=v[q[l]];
sum[rt]=1;
return ;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void change(int rt,int l,int r,int L,int R,int v)
{
if(L<=l&&r<=R)
{
sum[rt]=1;
ls[rt]=v;
rs[rt]=v;
a[rt]=v;
return ;
}
pushdown(rt);
int mid=(l+r)>>1;
if(L<=mid)
{
change(rt<<1,l,mid,L,R,v);
}
if(R>mid)
{
change(rt<<1|1,mid+1,r,L,R,v);
}
pushup(rt);
}
int query(int rt,int l,int r,int L,int R)
{
if(L<=l&&r<=R)
{
return sum[rt];
}
pushdown(rt);
int mid=(l+r)>>1;
if(R<=mid)
{
return query(rt<<1,l,mid,L,R);
}
else if(L>mid)
{
return query(rt<<1|1,mid+1,r,L,R);
}
if(rs[rt<<1]==ls[rt<<1|1])
{
return query(rt<<1,l,mid,L,R)+query(rt<<1|1,mid+1,r,L,R)-1;
}
else
{
return query(rt<<1,l,mid,L,R)+query(rt<<1|1,mid+1,r,L,R);
}
}
int find(int rt,int l,int r,int k)
{
if(l==r)
{
return ls[rt];
}
pushdown(rt);
int mid=(l+r)>>1;
if(k<=mid)
{
return find(rt<<1,l,mid,k);
}
else
{
return find(rt<<1|1,mid+1,r,k);
}
}
void color(int x,int y,int v)
{
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]])
{
swap(x,y);
}
change(1,1,n,s[top[x]],s[x],v);
x=f[top[x]];
}
if(d[x]>d[y])
{
swap(x,y);
}
change(1,1,n,s[x],s[y],v);
}
int ask(int x,int y)
{
int t1=-1;
int s1=-1;
int res=0;
while(top[x]!=top[y])
{
if(d[top[x]]>d[top[y]])
{
res+=query(1,1,n,s[top[x]],s[x]);
if(s1==find(1,1,n,s[x]))
{
res--;
}
s1=find(1,1,n,s[top[x]]);
x=f[top[x]];
}
else
{
res+=query(1,1,n,s[top[y]],s[y]);
if(t1==find(1,1,n,s[y]))
{
res--;
}
t1=find(1,1,n,s[top[y]]);
y=f[top[y]];
}
}
if(x==y)
{
res++;
int miku=find(1,1,n,s[x]);
if(miku==t1)
{
res--;
}
if(miku==s1)
{
res--;
}
}
else
{
if(d[x]>d[y])
{
res+=query(1,1,n,s[y],s[x]);
}
else
{
res+=query(1,1,n,s[x],s[y]);
}
if(find(1,1,n,s[x])==s1)
{
res--;
}
if(find(1,1,n,s[y])==t1)
{
res--;
}
}
return res;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d",&v[i]);
}
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1,0);
dfs2(1,1);
build(1,1,n);
for(int i=1;i<=m;i++)
{
scanf("%s",ch);
if(ch[0]=='C')
{
scanf("%d%d%d",&x,&y,&z);
color(x,y,z);
}
else
{
scanf("%d%d",&x,&y);
printf("%d\n",ask(x,y));
}
}
}