HDU 6203 ping ping ping(dfs序+LCA+树状数组)

时间:2021-06-02 23:28:34

http://acm.hdu.edu.cn/showproblem.php?pid=6203

题意:

n+1 个点 n 条边的树(点标号 0 ~ n),有若干个点无法通行,导致 p 组 U V 无法连通。问无法通行的点最少有多少个。

思路:

贪心思维,破坏两个点的LCA是最佳的。那么怎么判断现在在(u,v)之间的路径上有没有被破坏的点呢,如果没有的话那么此时就要破坏这个lca点。一开始我们要把询问按照u和v的lca深度从大到小排序,如果某个点需要被破坏,那么它的所有子节点都可以不再需要破坏别的点了(因为它的子节点到别的子节点肯定是要经过该点的,要注意这个前提是lca是排好序的,自己脑补一下~)。

所以,用dfs序来维护子节点是最好的,记录in和out两个数组。然后用树状数组来维护。

 #include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<sstream>
#include<vector>
#include<stack>
#include<queue>
#include<cmath>
#include<map>
#include<set>
using namespace std;
typedef long long ll;
typedef pair<int,ll> pll;
const int INF = 0x3f3f3f3f;
const int maxn = 1e4+; int n;
int Log;
int dfs_clock;
int in[maxn],out[maxn];
int deep[maxn];
int p[maxn][];
int c[*maxn];
vector<int> G[maxn]; struct node
{
int u,v,lca;
}query[]; void dfs(int u, int fa, int d)
{
in[u]=++dfs_clock;
deep[u]=d;
p[u][]=fa;
for(int i=;i<G[u].size();i++)
{
int v=G[u][i];
if(v==fa) continue;
dfs(v,u,d+);
}
out[u]=++dfs_clock;
} bool cmp(node a, node b)
{
return deep[a.lca]>deep[b.lca];
} void init()
{
for(int j=;j<=Log;j++)
for(int i=;i<=n;i++)
p[i][j]=p[p[i][j-]][j-];
} int LCA(int x, int y)
{
if(x==y) return x;
if(deep[x]<deep[y]) swap(x,y);
for(int i=Log;i>=;i--)
{
if(deep[p[x][i]]>=deep[y])
x=p[x][i];
}
if(x==y) return x;
for(int i=Log;i>=;i--)
{
if(p[x][i]!=p[y][i])
{
x=p[x][i];y=p[y][i];
}
}
return p[x][];
} int lowbit(int x)
{
return x&(-x);
} int sum(int x)
{
int ret = ;
while(x>)
{
ret+=c[x];
x-=lowbit(x);
}
return ret;
} void add(int x, int d)
{
while(x<=*n)
{
c[x]+=d;
x+=lowbit(x);
}
} int main()
{
//freopen("in.txt","r",stdin);
while(~scanf("%d",&n))
{
dfs_clock=;
memset(c,,sizeof(c));
memset(p,,sizeof(p));
for(int i=;i<=n+;i++) G[i].clear();
for(int i=;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
u++;v++;
G[u].push_back(v);
G[v].push_back(u);
}
n++;
for(Log=;(<<Log)<=n;Log++);
Log--; dfs(,,);
init();
int q;
scanf("%d",&q);
for(int i=;i<=q;i++)
{
scanf("%d%d",&query[i].u,&query[i].v);
query[i].u++;query[i].v++;
query[i].lca=LCA(query[i].u,query[i].v);
}
sort(query+,query+q+,cmp);
int ans=;
for(int i=;i<=q;i++)
{
int u=query[i].u,v=query[i].v,lca=query[i].lca;
int tmp1=sum(in[u]),tmp2=sum(in[v]);
if(sum(in[u])+sum(in[v])) continue;
else
{
ans++;
add(in[lca],);
add(out[lca],-);
}
}
printf("%d\n",ans);
}
return ;
}