P2633 Count on a tree

时间:2023-03-09 15:35:01
P2633 Count on a tree

思路

运用树上差分的思想,转化成一个普通的主席树模型即可求解

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
struct Node{
int lson,rson,sz;
}pt[100100*30];
const int MAXlog=19;
int dep[100100],jump[100100][MAXlog],lastans=0,n,m,u[100100<<1],Nodecnt,v[100100<<1],w_p[100100],ax[100100],nx,fir[100100],nxt[100100<<1],cnt,root[100100];
void addedge(int ui,int vi){
++cnt;
u[cnt]=ui;
v[cnt]=vi;
nxt[cnt]=fir[ui];
fir[ui]=cnt;
}
void insert(int L,int R,int pos,int &o){
pt[++Nodecnt]=pt[o];
o=Nodecnt;
pt[o].sz++;
if(L==R)
return;
int mid=(L+R)>>1;
if(pos<=mid)
insert(L,mid,pos,pt[o].lson);
else
insert(mid+1,R,pos,pt[o].rson);
}
int lca(int x,int y){
if(dep[x]<dep[y])
swap(x,y);
for(int i=MAXlog-1;i>=0;i--)
if(dep[x]-(1<<i)>=dep[y])
x=jump[x][i];
if(x==y)
return x;
for(int i=MAXlog-1;i>=0;i--)
if(jump[x][i]!=jump[y][i])
x=jump[x][i],y=jump[y][i];
return jump[x][0];
}
int query(int L,int R,int k,int rootx,int rooty,int rootlca,int falca){
if(L==R)
return L;
int lch=pt[pt[rootx].lson].sz+pt[pt[rooty].lson].sz-pt[pt[rootlca].lson].sz-pt[pt[falca].lson].sz;
int mid=(L+R)>>1;
if(lch<k)
return query(mid+1,R,k-lch,pt[rootx].rson,pt[rooty].rson,pt[rootlca].rson,pt[falca].rson);
else
return query(L,mid,k,pt[rootx].lson,pt[rooty].lson,pt[rootlca].lson,pt[falca].lson);
}
int query(int u,int v,int k){
u^=lastans;
int Lca=lca(u,v);
return lastans=ax[query(1,n,k,root[u],root[v],root[Lca],root[jump[Lca][0]])];
}
void init(void){
sort(ax+1,ax+n+1);
nx=unique(ax+1,ax+n+1)-(ax+1);
for(int i=1;i<=n;i++)
w_p[i]=lower_bound(ax+1,ax+nx+1,w_p[i])-ax;
}
void dfs(int u,int fa){
dep[u]=dep[fa]+1;
jump[u][0]=fa;
for(int i=1;i<MAXlog;i++)
jump[u][i]=jump[jump[u][i-1]][i-1]; root[u]=root[fa];
insert(1,n,w_p[u],root[u]); for(int i=fir[u];i;i=nxt[i]){
if(v[i]==fa)
continue;
dfs(v[i],u);
}
}
int main(){
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&w_p[i]),ax[i]=w_p[i];
init();
for(int i=1;i<=n-1;i++){
int a,b;
scanf("%d %d",&a,&b);
addedge(a,b);
addedge(b,a);
}
dfs(1,0);
for(int i=1;i<=m;i++){
int ux,vx,kx;
scanf("%d %d %d",&ux,&vx,&kx);
printf("%d\n",query(ux,vx,kx));
}
return 0;
}