地址:http://www.spoj.com/problems/COT/en/
题目:
COT - Count on a tree
You are given a tree with N nodes.The tree nodes are numbered from 1 to N.Each node has an integer weight.
We will ask you to perform the following operation:
- u v k : ask for the kth minimum weight on the path from node u to node v
Input
In the first line there are two integers N and M.(N,M<=100000)
In the second line there are N integers.The ith integer denotes the weight of the ith node.
In the next N-1 lines,each line contains two integers u v,which describes an edge (u,v).
In the next M lines,each line contains three integers u v k,which means an operation asking for the kth minimum weight on the path from node u to node v.
Output
For each operation,print its result.
Example
Input:
8 5
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
2 5 2
2 5 3
2 5 4
7 8 2
Output:
2
8
9
105
7 思路: 这是一道求树上两点uv路径上的第k大权值的题。
求第k大很容易想到主席树,但是以前的主席树的都是线性的,这题却是树型的。
其实只需要把建树的方式改成根据父亲节点建树即可,查询时:sum[u]+sum[v]-sum[lca(u,v)]-sum[fa[lca(u,v)]]。
这还是一道模板题。
#include <bits/stdc++.h> using namespace std; #define MP make_pair
#define PB push_back
typedef long long LL;
typedef pair<int,int> PII;
const double eps=1e-;
const double pi=acos(-1.0);
const int K=2e6+;
const int mod=1e9+; int tot,ls[K],rs[K],rt[K],sum[K];
int v[K],b[K],up[K][],deep[K],fa[K];
vector<int>mp[K];
//sum[o]记录的是该节点区间内出现的数的个数
//区间指的是将数离散化后的区间
void build(int &o,int l,int r)
{
o=++tot,sum[o]=;
int mid=l+r>>;
if(l!=r)
build(ls[o],l,mid),build(rs[o],mid+,r);
}
void update(int &o,int l,int r,int last,int x)
{
o=++tot,sum[o]=sum[last]+;
ls[o]=ls[last],rs[o]=rs[last];
if(l==r) return ;
int mid=l+r>>;
if(x<=mid) update(ls[o],l,mid,ls[last],x);
else update(rs[o],mid+,r,rs[last],x);
}
int query(int ra,int rb,int rc,int rd,int l,int r,int k)
{
if(l==r) return b[l];
int cnt=sum[ls[ra]]+sum[ls[rb]]-sum[ls[rc]]-sum[ls[rd]],mid=l+r>>;
if(k<=cnt) return query(ls[ra],ls[rb],ls[rc],ls[rd],l,mid,k);
return query(rs[ra],rs[rb],rs[rc],rs[rd],mid+,r,k-cnt);
}
void dfs(int x,int f,int sz)
{
update(rt[x],,sz,rt[f],v[x]);
up[x][]=f,deep[x]=deep[f]+,fa[x]=f;
for(int i=;i<=;i++) up[x][i]=up[up[x][i-]][i-];
for(int i=;i<mp[x].size();i++)
if(mp[x][i]!=f)
dfs(mp[x][i],x,sz);
}
int lca(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
for(int i=;i>=;i--) if( deep[up[x][i]]>=deep[y]) x=up[x][i];
if(x==y) return x;
for(int i=;i>=;i--) if(up[x][i]!=up[y][i]) x=up[x][i],y=up[y][i];
return up[x][];
}
int main(void)
{
tot=;
int n,m,sz;
scanf("%d%d",&n,&m);
for(int i=;i<=n;i++)
scanf("%d",v+i),b[i]=v[i];
for(int i=,x,y;i<n;i++)
scanf("%d%d",&x,&y),mp[x].PB(y),mp[y].PB(x);
sort(b+,b++n);
sz=unique(b+,b++n)-b-;
for(int i=;i<=n;i++)
v[i]=lower_bound(b+,b++sz,v[i])-b;
build(rt[],,sz);
dfs(,,sz);
int x,y,k;
while(m--)
scanf("%d%d%d",&x,&y,&k),printf("%d\n",query(rt[x],rt[y],rt[lca(x,y)],rt[fa[lca(x,y)]],,sz,k));
return ;
}