HYSBZ 4034 【树链剖分】+【线段树 】

时间:2021-09-23 16:25:58

<题目链接>

题目大意:

有一棵点数为 N 的树,以点 1 为根,且树点有权值。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

Input

第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1 
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。

Output

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。

Sample Input

5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3

Sample Output

6

9

13

Hint

 对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。

解题分析:

很明显本题用树链剖分解决,重点是解决第二个操作,以x为根的子树所有的节点权值都增加a。其实,在明白了树链剖分原理后,不难发现,以x为根的所有节点必然在线段树上是一段连续的区间,并且x在线段树上的编号为区间的左端点,下面我们就可以直接通过该节点的子节点数量找出以该节点为根的在线段树上编号最大的子节点,即该区间的右端点,找到了对应的区间,然后进行线段树的区间更新就行。

下面的代码WA了。。。先记录一下

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std; #define ll long long
#define Lson l,mid,rt<<1
#define Rson mid+1,r,rt<<1|1
const int M =1e6+;
int n,m,cnt,tot,endloc;
int head[M],sz[M],son[M],dep[M],f[M],id[M],rnk[M],top[M];
int arr[M]; struct EDGE{
int to;
int next;
}edge[M<<];
struct Tree{
int lazy;
ll sum;
}tr[M<<];
void init(){
cnt=tot=;
memset(head,-,sizeof(head));
}
void Pushup(int rt){
tr[rt].sum=tr[rt<<].sum+tr[rt<<|].sum;
}
void Pushdown(int rt,int len){
if(tr[rt].lazy){
int tmp=tr[rt].lazy;
tr[rt].lazy=;
tr[rt<<].lazy+=tmp;
tr[rt<<|].lazy+=tmp;
tr[rt<<].sum+=tmp*(len-(len>>));
tr[rt<<|].sum+=tmp*(len>>);
}
}
void add(int u,int v){
edge[++cnt].to=v,edge[cnt].next=head[u];
head[u]=cnt;
} void dfs(int u,int fa,int d){
sz[u]=,f[u]=fa,dep[u]=d,son[u]=-;
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;
if(v==fa)continue;
dfs(v,u,d+);
sz[u]+=sz[v];
if(son[u]==-||sz[v]>sz[son[u]])son[u]=v;
}
} void dfs1(int u,int t){
id[u]=++tot;
rnk[tot]=u;
top[u]=t;
if(son[u]==-)return;
dfs1(son[u],t);
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;
if(v==f[u]||v==son[u])continue;
dfs1(v,v);
}
} void dfsrot(int u,int fa){
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;
if(v==fa)continue;
endloc=max(endloc,id[v]); //在以rot为根的所有子节点,在线段树上是一段连续的区间,所以只需要记录区间的右端点就行
dfsrot(v,u);
}
} void build(int l,int r,int rt){
tr[rt].lazy=;
if(l==r){
tr[rt].sum=arr[rnk[l]];
return;
}
int mid=(l+r)>>;
build(Lson);
build(Rson);
Pushup(rt);
} void update1(int loc,int val,int l,int r,int rt){ //单点更新
if(l==r){
tr[rt].sum+=val;
return;
}
Pushdown(rt,r-l+);
int mid=(l+r)>>;
if(loc<=mid)update1(loc,val,Lson);
if(loc>mid)update1(loc,val,Rson);
Pushup(rt);
} void update2(int L,int R,int val,int l,int r,int rt){ //线段树区间修改
if(L<=l&&r<=R){
tr[rt].lazy+=val;
tr[rt].sum+=val*(r-l+);
return;
}
Pushdown(rt,r-l+);
int mid=(l+r)>>;
if(L<=mid)
update2(L,R,val,Lson);
if(R>mid)
update2(L,R,val,Rson);
Pushup(rt);
} ll query(int L,int R,int l,int r,int rt){
if(L<=l&&r<=R){
return tr[rt].sum;
}
Pushdown(rt,r-l+);
int mid=(l+r)>>;
ll ans=;
if(L<=mid)
ans+=query(L,R,Lson);
if(R>mid)
ans+=query(L,R,Rson);
return ans;
} void Query(int x){
ll ans=;
int fx=top[x];
while(fx!=){ //(因为1为整棵树的根)当x点不在以1为链首的重链上时
ans+=query(id[fx],id[x],,n,);
x=f[fx],fx=top[x];
}
ans+=query(id[],id[x],,n,);
printf("%lld\n",ans);
} int main(){
while(scanf("%d%d",&n,&m)!=EOF){
init();
for(int i=;i<=n;i++)scanf("%d",&arr[i]);
for(int i=;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
} dfs(,-,);
dfs1(,);
build(,n,);
while(m--){
int cal;scanf("%d",&cal);
if(cal==){
int loc,val;
scanf("%d%d",&loc,&val);
update1(id[loc],val,,n,);
}
else if(cal==){
int rot,val;
scanf("%d%d",&rot,&val);
int start=id[rot]; //根据树链剖分的原理,根节点就是该区间的左端点
update2(start,start+sz[rot]-,val,,n,);
}
else{
int loc;
scanf("%d",&loc);
Query(loc);
}
}
}
return ;
}

AC代码:转载于 >>>

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define maxn 100005
#define lson l,mid,node<<1
#define rson mid+1,r,node<<1|1
using namespace std;
typedef long long ll;
ll sum[maxn*],col[maxn*];
int n,m,q;
int cot=,cnt;
int tid[maxn],top[maxn],tim;
int dep[maxn],fa[maxn],sz[maxn],Rank[maxn];
int first[maxn],son[maxn];
int End[maxn];
ll w[maxn]; struct Edge{
int v,nxt;
Edge(int _v=,int _nxt=){
v=_v,nxt=_nxt;
}
}e[maxn*]; void init(){
cot=;
tim=;
memset(first,-,sizeof(first));
memset(son,-,sizeof(son));
} void add(int u,int v){
cot++;
e[cot]=Edge(v,first[u]);
first[u]=cot;
cot++;
e[cot]=Edge(u,first[v]);
first[v]=cot;
} //树链剖分部分 void dfs1(int u,int father,int d){
dep[u]=d;
fa[u]=father;
sz[u]=;
for(int i=first[u];i!=-;i=e[i].nxt){
int v=e[i].v;
if(v!=father){
dfs1(v,u,d+);
sz[u]+=sz[v];
if(son[u]==-||sz[v]>sz[son[u]])
son[u]=v;
}
}
} void dfs2(int u,int tp){
top[u]=tp;
tid[u]=++tim;
Rank[tid[u]]=u;
if(son[u]!=-) dfs2(son[u],tp);
for(int i=first[u];i!=-;i=e[i].nxt){
int v=e[i].v;
if(v!=son[u]&&v!=fa[u]) dfs2(v,v);
}
End[u]=tim;
} //线段树部分
inline void pushup(int node){
sum[node]=sum[node<<]+sum[node<<|];
} inline void pushdown(int node,int l,int r){
int mid=(l+r)>>;
if(col[node]){
col[node<<]+=col[node];
col[node<<|]+=col[node];
sum[node<<]+=(mid-l+)*col[node];
sum[node<<|]+=(r-mid)*col[node];
col[node]=;
}
} void build(int l,int r,int node){
col[node]=;
if(l==r){
sum[node]=w[Rank[l]];
return;
}
int mid=(l+r)>>;
build(lson);
build(rson);
pushup(node);
} void update(int l,int r,int node,int L,int R,int val){
if(L<=l&&R>=r){
col[node]+=1LL*val;
sum[node]+=1LL*val*(r-l+);
return;
}
pushdown(node,l,r);
int mid=(l+r)>>;
//只要在范围内都要考虑
if(L<=mid) update(lson,L,R,val);
if(R>mid) update(rson,L,R,val);
pushup(node);
} ll query(int l,int r,int node,int L,int R){
if(l>=L&&r<=R) return sum[node];
pushdown(node,l,r);
int mid=(l+r)>>;
ll ret=;
if(L<=mid) ret+=query(lson,L,R);
if(R>mid) ret+=query(rson,L,R);
pushup(node);
return ret;
} ll getans(int x,int y){
ll ans=;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=query(,n,,tid[top[x]],tid[x]);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=query(,n,,tid[x],tid[y]);
return ans;
} int main(){
int u,v;
int op;
int a,x;
while(~scanf("%d%d",&n,&m)) {
init();
for(int i=;i<=n;i++) scanf("%lld",&w[i]);
for(int i=;i<n;i++){
scanf("%d%d",&u,&v);
add(u,v);
}
dfs1(,-,);
dfs2(,);
build(,n,); for(int i=;i<=m;i++){
scanf("%d",&op);
if(op==){
scanf("%d",&x);
ll ans=getans(,x);
printf("%lld\n",ans);
}
else{
scanf("%d%d",&x,&a);
update(,n,,tid[x],op==?tid[x]:End[x],a);
}
}
}
return ;
}