SPOJ QTREE Query on a tree [树链剖分+线段树]

时间:2022-06-18 12:35:58

Description

You are given a tree (an acyclic undirected connected graph) with N nodes, and edges numbered 1, 2, 3...N-1.

We will ask you to perfrom some instructions of the following form:

  • CHANGE i ti : change the cost of the i-th edge to ti
    or
  • QUERY a b : ask for the maximum edge cost on the path from node a to node b

题意: 给一棵树,边有权,两种操作:1)修改一条边的权值  2) 询问某条路径上的最大边权

解法:树链剖分入门模板题,剖分后,对于每个询问,变成了求logN条线段上的最大值,很明显可以用线段树解决,复杂度为OlogNlogN

代码:280MS

#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<math.h>
#include<iostream>
#include<stdlib.h>
#include<set>
#include<map>
#include<queue>
#include<vector>
#include<bitset>
#pragma comment(linker, "/STACK:1024000000,1024000000")
template <class T>
bool scanff(T &ret){ //Faster Input
    char c; int sgn; T bit=0.1;
    if(c=getchar(),c==EOF) return 0;
    while(c!='-'&&c!='.'&&(c<'0'||c>'9')) c=getchar();
    sgn=(c=='-')?-1:1;
    ret=(c=='-')?0:(c-'0');
    while(c=getchar(),c>='0'&&c<='9') ret=ret*10+(c-'0');
    if(c==' '||c=='\n'){ ret*=sgn; return 1; }
    while(c=getchar(),c>='0'&&c<='9') ret+=(c-'0')*bit,bit/=10;
    ret*=sgn;
    return 1;
}
#define inf 1073741823
#define llinf 4611686018427387903LL
#define PI acos(-1.0)
#define lth (th<<1)
#define rth (th<<1|1)
#define rep(i,a,b) for(int i=int(a);i<=int(b);i++)
#define drep(i,a,b) for(int i=int(a);i>=int(b);i--)
#define gson(i,root) for(int i=ptx[root];~i;i=ed[i].next)
#define tdata int testnum;scanff(testnum);for(int cas=1;cas<=testnum;cas++)
#define mem(x,val) memset(x,val,sizeof(x))
#define mkp(a,b) make_pair(a,b)
#define findx(x) lower_bound(b+1,b+1+bn,x)-b
#define pb(x) push_back(x)
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;

#define NN 100010
int ptx[NN],lnum;
struct edge{
    int v,next,w;
    edge(){}
    edge(int v,int next,int w){
        this->v=v;
        this->next=next;
        this->w=w;
    }
}ed[NN*2];
void addline(int x,int y,int w){
    ed[lnum]=edge(y,ptx[x],w);
    ptx[x]=lnum++;
}

int sz[NN],son[NN],f[NN],dep[NN];
int tid[NN],tn;
int top[NN];
int a[NN],b[NN];

int getson(int x,int fa,int d){
    int maxval=0;
    //这些数据千万不要忘记初始化,错了一万遍
    son[x]=0;
    sz[x]=1;
    f[x]=fa;
    dep[x]=dep[fa]+1;
    gson(i,x){
        int y=ed[i].v;
        if(y==fa)continue;
        sz[x]+=getson(y,x,d+1);
        a[y]=ed[i].w;
        if(sz[y]>maxval)
            maxval=sz[y],son[x]=y;
    }
    return sz[x];
}
void getchain(int r,int x,int fa){
    tid[x]=++tn;
    top[x]=r;
    if(son[x])getchain(r,son[x],x);
    gson(i,x){
        int y=ed[i].v;
        if(y==fa||y==son[x])continue;
        getchain(y,y,x);
    }
}

struct segtree{
    int val[NN*16],m;
    void init(int n){
        for(m=1;m<n+3;m<<=1);
        rep(i,1,m<<1)val[i]=0;
        rep(i,1,n)val[i+m]=b[i];
        drep(i,m-1,1)val[i]=max(val[i<<1],val[i<<1|1]);
    }
    void update(int pos,int v){
        val[pos+m]=v;
        for(int i=(pos+m)>>1;i;i>>=1)
            val[i]=max(val[i<<1],val[i<<1|1]);
    }
    int query(int l,int r){
        if(l>r)return -inf;
        int ans=-inf;
        for(l=l+m-1,r=r+m+1;l^r^1;l>>=1,r>>=1){
            if(~l&1)ans=max(ans,val[l^1]);
            if(r&1) ans=max(ans,val[r^1]);
        }
        return ans;
    }
}st;
char op[11];
int main(){
    int x,y,w,n;
    tdata{
        if(cas>1)printf("\n");
        scanff(n);
        lnum=tn=0; //tn记得初始化
        rep(i,1,n)ptx[i]=-1;
        rep(i,1,n-1){
            scanff(x);scanff(y);scanff(w);
            addline(x,y,w);addline(y,x,w);
        }
        getson(1,0,0);
        getchain(1,1,0);
        rep(i,1,n)b[tid[i]]=a[i];
        st.init(n);
        while(scanf("%s",op)!=EOF){
            if(op[0]=='C'){
                scanff(x);scanff(y);
                int u=ed[x*2-2].v;//需要找到第I条边所连着的更深的点
                int v=ed[x*2-1].v;
                if(dep[u]<dep[v])swap(u,v);
                st.update(tid[u],y);
            }
            if(op[0]=='Q'){
                scanff(x);scanff(y);
                int ans=-inf;
                while(top[x]!=top[y]){
                    if(dep[top[x]]<dep[top[y]])swap(x,y);
                    ans=max(ans,st.query(tid[top[x]],tid[x]));
                    x=f[top[x]];
                }
                if(x!=y){
                    if(dep[x]>dep[y])swap(x,y);//x即lca,统计边时不要计入
                    ans=max(ans,st.query(tid[x]+1,tid[y]));
                }
                printf("%d\n",ans);
            }
            if(op[0]=='D')break;
        }
    }
}