bzoj 2243: [SDOI2011]染色 线段树区间合并+树链剖分

时间:2023-03-09 03:50:45
bzoj 2243: [SDOI2011]染色 线段树区间合并+树链剖分

2243: [SDOI2011]染色

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 7925  Solved: 2975
[Submit][Status][Discuss]

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c;

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。

请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数n和m,分别表示节点数和操作数;

第二行包含n个正整数表示n个节点的初始颜色

下面 行每行包含两个整数x和y,表示xy之间有一条无向边。

下面 行每行描述一个操作:

“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;

“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample Input

6 5

2 2 1 2 1 1

1 2

1 3

2 4

2 5

2 6

Q 3 5

C 2 1 1

Q 3 5

C 5 1 2

Q 3 5

Sample Output

3

1

2

HINT

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。

Source

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<iostream>
#include<cstdio>
#include<cmath>
#include<string>
#include<queue>
#include<algorithm>
#include<stack>
#include<cstring>
#include<vector>
#include<list>
#include<set>
#include<map>
using namespace std;
#define LL long long
#define pi (4*atan(1.0))
#define eps 1e-8
#define bug(x) cout<<"bug"<<x<<endl;
const int N=3e5+,M=2e6+,inf=1e9+;
const LL INF=1e18+,mod=1e9+; struct edge
{
int v,next;
} edge[N<<];
int head[N<<],edg,id,n;
/// 树链剖分 int fa[N],dep[N],son[N],siz[N]; // fa父亲,dep深度,son重儿子,siz以该点为子树的节点个数
int a[N],ran[N],top[N],tid[N]; // tid表示边的标号,top通过重边可以到达最上面的点,ran表示标记tid
void init()
{
memset(son,-,sizeof(son));
memset(head,-,sizeof(head));
edg=;
id=;
} void add(int u,int v)
{
edg++;
edge[edg].v=v;
edge[edg].next=head[u];
head[u]=edg;
} void dfs1(int u,int fath,int deep)
{
fa[u]=fath;
siz[u]=;
dep[u]=deep;
for(int i=head[u]; i!=-; i=edge[i].next)
{
int v=edge[i].v;
if(v==fath)continue;
dfs1(v,u,deep+);
siz[u]+=siz[v];
if(son[u]==-||siz[v]>siz[son[u]])
son[u]=v;
}
} void dfs2(int u,int tp)
{
tid[u]=++id;
top[u]=tp;
ran[tid[u]]=u;
if(son[u]==-)return;
dfs2(son[u],tp);
for(int i=head[u]; i!=-; i=edge[i].next)
{
int v=edge[i].v;
if(v==fa[u])continue;
if(v!=son[u])
dfs2(v,v);
}
} struct SGT
{
int la[N<<],ra[N<<],ma[N<<],lazy[N<<];
void pushup(int pos)
{
if(ra[pos<<]==la[pos<<|])ma[pos]=ma[pos<<]+ma[pos<<|]-;
else ma[pos]=ma[pos<<|]+ma[pos<<];
la[pos]=la[pos<<];
ra[pos]=ra[pos<<|];
}
void pushdown(int pos)
{
if(lazy[pos])
{
la[pos<<]=la[pos<<|]=lazy[pos];
ra[pos<<]=ra[pos<<|]=lazy[pos];
ma[pos<<]=ma[pos<<|]=;
lazy[pos<<]=lazy[pos<<|]=lazy[pos];
lazy[pos]=;
}
}
pair<int,pair<int,int> > Union( pair<int,pair<int,int> > a, pair<int,pair<int,int> > b)
{
if(a.second.second==b.second.first)
return make_pair(a.first+b.first-,make_pair(a.second.first,b.second.second));
return make_pair(a.first+b.first,make_pair(a.second.first,b.second.second));
}
void build(int l,int r,int pos)
{
lazy[pos]=;
if(l==r)
{
la[pos]=ra[pos]=a[ran[l]];
ma[pos]=;
return;
}
int mid=(l+r)>>;
build(l,mid,pos<<);
build(mid+,r,pos<<|);
pushup(pos);
}
void update(int L,int R,int c,int l,int r,int pos)
{
if(L<=l&&r<=R)
{
lazy[pos]=c;
la[pos]=ra[pos]=c;
ma[pos]=;
return;
}
pushdown(pos);
int mid=(l+r)>>;
if(L<=mid)update(L,R,c,l,mid,pos<<);
if(R>mid)update(L,R,c,mid+,r,pos<<|);
pushup(pos);
}
pair<int,pair<int,int> > query(int L,int R,int l,int r,int pos)
{
if(L<=l&&r<=R)
return make_pair(ma[pos],make_pair(la[pos],ra[pos]));
pushdown(pos);
int mid=(l+r)>>;
if(L>mid)return query(L,R,mid+,r,pos<<|);
else if(R<=mid)return query(L,R,l,mid,pos<<);
else
{
pair<int,pair<int,int> > a=query(L,mid,l,mid,pos<<);
pair<int,pair<int,int> > b=query(mid+,R,mid+,r,pos<<|);
return Union(a,b);
}
}
}tree; int lca(int l,int r)
{
while(top[l]!=top[r])
{
if(dep[top[l]]<dep[top[r]])swap(l,r);
l=fa[top[l]];
}
if(dep[l]<dep[r])swap(l,r);
return r;
}
int up(int l,int r)
{
int pre=-,ans=;
while(top[l]!=top[r])
{
if(dep[top[l]]<dep[top[r]])swap(l,r);
pair<int,pair<int,int> > x=tree.query(tid[top[l]],tid[l],,n,);
//cout<<tid[top[l]]<<" "<<tid[l]<<" "<<x.first<<endl;
ans+=x.first;
if(pre==x.second.second)ans--;
pre=x.second.first;
l=fa[top[l]];
}
if(dep[l]<dep[r])swap(l,r);
pair<int,pair<int,int> > x=tree.query(tid[r],tid[l],,n,);
//cout<<tid[r]<<" "<<tid[l]<<" "<<x.first<<endl;
ans+=x.first;
if(pre==x.second.second)ans--;
return ans;
}
void go(int l,int r,int c)
{
while(top[l]!=top[r])
{
if(dep[top[l]]<dep[top[r]])swap(l,r);
tree.update(tid[top[l]],tid[l],c,,n,);
l=fa[top[l]];
}
if(dep[l]<dep[r])swap(l,r);
tree.update(tid[r],tid[l],c,,n,);
}
char ch[];
int main()
{
init();
int q;
scanf("%d%d",&n,&q);
for(int i=;i<=n;i++)
scanf("%d",&a[i]);
for(int i=;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs1(,-,);
dfs2(,);
tree.build(,n,);
while(q--)
{
int u,v;
scanf("%s%d%d",ch,&u,&v);
if(ch[]=='C')
{
int c;
scanf("%d",&c);
go(u,v,c);
}
else
{
int x=lca(u,v);
printf("%d\n",up(u,x)+up(v,x)-);
}
}
return ;
}