spoj COT - Count on a tree (树上第K小 LCA+主席树)

时间:2021-05-25 20:30:44

链接:

https://www.spoj.com/problems/COT/en/

思路:

首先看到求两点之前的第k小很容易想到用主席树去写,但是主席树处理的是线性结构,而这道题要求的是树形结构,我们可以用dfs跑出所有点离根的距离-dep[i](根为1,dep[1]也为1)在dfs的过程

中,我们对每一个节点建一棵线段树,那么【a,b】就是:root[a] + root[b] - root[lca(a,b)] - root[f[lca(a,b)]]; (因为a-b的路径上的权值还要算上lca(a,b)这个点,所以不是减2*root[lca(a,b)]);

实现代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mid int m = (l + r) >> 1
const int M = 2e5 + ;
int p[M][],dep[M],head[M],sum[M*],ls[M*],rs[M*],root[M*];
int cnt1,n,idx,cnt,f[M]; struct node {
int to,next;
}e[M]; void add(int u,int v){
e[++cnt1].to=v;e[cnt1].next=head[u];head[u]=cnt1;
e[++cnt1].to=u;e[cnt1].next=head[v];head[v]=cnt1;
} int lca(int a,int b){
if(dep[a] > dep[b]) swap(a,b);
int h = dep[b] - dep[a]; //h为高度差
for(int i = ;(<<i)<=h;i++){ //(1<<i)&f找到h化为2进制后1的位置,移动到相应的位置
if((<<i)&h) b = p[b][i];
//比如h = 5(101),先移动2^0祖先,然后再移动2^2祖先
}
//cout<<a<<" "<<b<<endl;
if(a!=b){
for(int i = ;i >= ;i --){
if(p[a][i]!=p[b][i]){ //从最大祖先开始,判断a,b祖先,是否相同
a = p[a][i]; b = p[b][i]; //如不相同,a,b,同时向上移动2^j
}
}
a = p[a][]; //这时a的father就是LCA
}
return a;
} void build(int l,int r,int &rt){
rt = ++idx;
sum[rt] = ;
if(l == r) return;
mid;
build(l,m,ls[rt]);
build(m+,r,rs[rt]);
} void update(int p,int l,int r,int old,int &rt){
rt = ++idx;
ls[rt] = ls[old]; rs[rt] = rs[old]; sum[rt] = sum[old] + ;
if(l == r) return ;
mid;
if(p <= m) update(p,l,m,ls[old],ls[rt]);
else update(p,m+,r,rs[old],rs[rt]);
} int query(int a,int b,int lc,int cl,int l,int r,int k){
if(l == r) return l;
mid;
int cnt = sum[ls[a]] + sum[ls[b]] - sum[ls[lc]] - sum[ls[cl]];
if(k <= cnt)
return query(ls[a],ls[b],ls[lc],ls[cl],l,m,k);
else
return query(rs[a],rs[b],rs[lc],rs[cl],m+,r,k-cnt);
}
int a[M],b[M]; void dfs(int u,int fa){
f[u] = fa;
dep[u] = dep[fa] + ;
p[u][] = fa;
for(int i = ;i < ;i ++) p[u][i] = p[p[u][i-]][i-];
update(a[u],,cnt,root[fa],root[u]);
for(int i = head[u];i!=-;i = e[i].next){
int v = e[i].to;
if(v == fa) continue;
dfs(v,u);
}
} int main()
{
int m;
while(scanf("%d%d",&n,&m)!=EOF){
cnt1 = ;
memset(head,-,sizeof(head));
memset(dep,,sizeof(dep));
memset(p,,sizeof(p));
memset(f,,sizeof(f));
for(int i = ; i <= n;i ++){
scanf("%d",&a[i]);
b[i] = a[i];
}
idx = ;
int l,r,c;
sort(b+,b+n+);
cnt = unique(b+,b++n)-b-;
for(int i = ;i <= n;i ++)
a[i] = lower_bound(b+,b+cnt+,a[i]) - b;
for(int i = ;i <= n-;i ++){
scanf("%d%d",&l,&r);
add(l,r);
}
build(,cnt,root[]);
dfs(,);
for(int i = ;i <= m;i ++){
scanf("%d%d%d",&l,&r,&c);
int lc = lca(l,r);
int id = query(root[l],root[r],root[lc],root[f[lc]],,cnt,c);
printf("%d\n",b[id]);
}
}
return ;
}