bzoj1036: [ZJOI2008]树的统计Count(树链剖分+线段树维护)

时间:2021-07-15 09:53:20

bzoj1036: [ZJOI2008]树的统计Count

  Time Limit: 10 Sec
  Memory Limit: 162 MB

Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作:
  I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
  II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
 

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
  对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
 

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
 

Sample Input

  4
  1 2
  2 3
  4 1
  4 2 1 3
  12
  QMAX 3 4
  QMAX 3 3
  QMAX 3 2
  QMAX 2 3
  QSUM 3 4
  QSUM 2 1
  CHANGE 1 5
  QMAX 3 4
  CHANGE 3 6
  QMAX 3 4
  QMAX 2 4
  QSUM 3 4
 

Sample Output

  4
  1
  2
  2
  10
  6
  5
  6
  5
  16
  

题目地址:  bzoj1036: [ZJOI2008]树的统计Count

题目大意:

  题目已经很清楚了
  

题解:

  树链剖分裸题
  
  先把树轻重链剖分(两遍dfs)
  
  然后线段树维护区间最大值和区间和
  
  将剖出来的序列合并一下就好了
  


AC代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#define inf 0x7fffffff
#define N 30005
using namespace std;
int n,Q,cnt,sz;
int w[N],dep[N],size[N],head[N],fa[N];
int pos[N],top[N];
char ch[10];
struct edge{
    int to,next;
}e[N+N];
struct seg{
    int l,r,mx,sum;
}t[N<<2];
void add_edge(int u,int v){
    e[++cnt]=(edge){v,head[u]};head[u]=cnt;
    e[++cnt]=(edge){u,head[v]};head[v]=cnt;
}
void dfs1(int u){
    size[u]=1;
    for(int i=head[u];i;i=e[i].next){
        if(e[i].to==fa[u])continue;
        dep[e[i].to]=dep[u]+1;
        fa[e[i].to]=u;
        dfs1(e[i].to);
        size[u]+=size[e[i].to];
    }
}
void dfs2(int u,int chain){
    int k=0;sz++;
    pos[u]=sz;
    top[u]=chain;
    for(int i=head[u];i;i=e[i].next)
        if(dep[e[i].to]>dep[u]&&size[e[i].to]>size[k])
            k=e[i].to;
    if(k==0)return;
    dfs2(k,chain);
    for(int i=head[u];i;i=e[i].next)
        if(dep[e[i].to]>dep[u]&&k!=e[i].to)
            dfs2(e[i].to,e[i].to);
}
void build(int l,int r,int id){
    t[id].l=l;t[id].r=r;
    if(l==r)return;
    int mid=(l+r)>>1;
    build(l,mid,id<<1);
    build(mid+1,r,id<<1|1);
}
void change(int id,int k,int w){
    int l=t[id].l,r=t[id].r,mid=(l+r)>>1;
    if(l==r){
        t[id].sum=t[id].mx=w;
        return;
    }
    if(k<=mid)change(id<<1,k,w);
    else change(id<<1|1,k,w);
    t[id].sum=t[id<<1].sum+t[id<<1|1].sum;
    t[id].mx=max(t[id<<1].mx,t[id<<1|1].mx);
}
int querysum(int id,int L,int R){
    int l=t[id].l,r=t[id].r,mid=(l+r)>>1;
    if(l==L&&R==r)return t[id].sum;
    if(R<=mid)return querysum(id<<1,L,R);
    else if(L>mid)return querysum(id<<1|1,L,R);
    else return querysum(id<<1,L,mid)+querysum(id<<1|1,mid+1,R);
}
int querymx(int id,int L,int R){
    int l=t[id].l,r=t[id].r,mid=(l+r)>>1;
    if(l==L&&R==r)return t[id].mx;
    if(R<=mid)return querymx(id<<1,L,R);
    else if(L>mid)return querymx(id<<1|1,L,R);
    else return max(querymx(id<<1,L,mid),querymx(id<<1|1,mid+1,R));
}
int solvesum(int a,int b){
    int sum=0;
    while(top[a]!=top[b]){
        if(dep[top[a]]<dep[top[b]])swap(a,b);
        sum+=querysum(1,pos[top[a]],pos[a]);
        a=fa[top[a]];
    }
    if(pos[a]>pos[b])swap(a,b);
    sum+=querysum(1,pos[a],pos[b]);
    return sum;
}
int solvemx(int a,int b){
    int mx=-inf;
    while(top[a]!=top[b]){
        if(dep[top[a]]<dep[top[b]])swap(a,b);
        mx=max(mx,querymx(1,pos[top[a]],pos[a]));
        a=fa[top[a]];
    }
    if(pos[a]>pos[b])swap(a,b);
    mx=max(mx,querymx(1,pos[a],pos[b]));
    return mx;
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        add_edge(u,v);
    }
    for(int i=1;i<=n;i++)scanf("%d",&w[i]);
    dfs1(1);
    dfs2(1,1);
     
    build(1,n,1);
    for(int i=1;i<=n;i++)
        change(1,pos[i],w[i]);
    scanf("%d",&Q);
    while(Q--){
        int x,y;scanf("%s%d%d",ch+1,&x,&y);
        if(ch[1]=='C'){
            w[x]=y;
            change(1,pos[x],y);
        }else
            if(ch[2]=='M')
                printf("%d\n",solvemx(x,y));
            else
                printf("%d\n",solvesum(x,y));
    }
    return 0;
}