hdu 5877 Weak Pair dfs序+树状数组+离散化

时间:2021-07-11 10:37:41

Weak Pair

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)

Problem Description
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
  (1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
  (2) au×av≤k.

Can you find the number of weak pairs in the tree?

 
Input
There are multiple cases in the data set.
  The first line of input contains an integer T denoting number of test cases.
  For each case, the first line contains two space-separated integers, N and k, respectively.
  The second line contains N space-separated integers, denoting a1 to aN.
  Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.

Constrains:
  
  1≤N≤105
  
  0≤ai≤109
  
  0≤k≤1018

 
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
 
Sample Input
1
2 3
1 2
1 2
 
Sample Output
1
 
Source
题意:给你一颗树,给点的权值,和树的边,求每个点v的祖先u,并且a[u]*a[v]<=K;
思路:利用dfs序可以快速的得到每个点的祖先,每次更新a[u],找到k/a[v]>=a[u]的个数树状数组优化;
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pi (4*atan(1.0))
const int N=1e5+10,M=4e6+10,inf=1e9+10,mod=1e9+7;
const ll INF=1e18+10;
vector<int>v[N];
int du[N];
ll l[N<<1],a[N];
int n,len;
ll k,ans;
int tree[N<<1];
void init(int n)
{
for(int i=1;i<=n;i++)
v[i].clear();
memset(tree,0,sizeof(tree));
memset(du,0,sizeof(du));
ans=0;
}
int getpos(ll x)
{
int pos=lower_bound(l,l+len,x)-l;
return pos+1;
}
int lowbit(int x)
{
return x&(-x);
}
void update(int x,int c)
{
while(x<(N<<1))
{
tree[x]+=c;
x+=lowbit(x);
}
}
ll query(int x)
{
ll ans=0;
while(x)
{
ans+=tree[x];
x-=lowbit(x);
}
return ans;
}
void dfs(int u)
{
int p,q;
if(a[u])
p=getpos(k/a[u]);
else
p=(N<<1)-1;
if(a[u])
q=getpos(a[u]);
else
q=1;
ans+=query(p);
update(q,1);
for(int i=0;i<v[u].size();i++)
{
dfs(v[u][i]);
}
update(q,-1);
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
int flag=0;
scanf("%d%lld",&n,&k);
init(n);
for(int i=1;i<=n;i++)
scanf("%lld",&a[i]),l[flag++]=k/a[i],l[flag++]=a[i];
sort(l,l+flag);
len=unique(l,l+flag)-l;
for(int i=1;i<n;i++)
{
int u,w;
scanf("%d%d",&u,&w);
v[u].push_back(w);
du[w]++;
}
for(int i=1;i<=n;i++)
if(du[i]==0)
dfs(i);
printf("%lld\n",ans);
}
return 0;
}