最近公共祖先 LCA 倍增法

时间:2023-03-09 14:12:36
最近公共祖先 LCA 倍增法

【简介】

解决LCA问题的倍增法是一种基于倍增思想的在线算法。

【原理】

原理和同样是使用倍增思想的RMQ-ST 算法类似,比较简单,想清楚后很容易实现。

对于每个节点u , ancestors[u][k] 表示 u 的第2k个祖先是谁。很容易就想到递推式: ancestors[j][i] = ancestors[ancestors[j][i - 1]][i - 1];  根据二进制原理,理论上 u 的所有祖先都可以根据ancestors数组多次跳转得到,这样就间接地记录了每个节点的祖先信息。
     查询LCA(u,v)的时候:
         (一)u和v所在的树的层数如果一样,令u'=u。否则需要平衡操作(假设u更深),先找到u的一个祖先u', 使得u'的层数和v一样,此时LCA(u,v)=LCA(u',v) 。证明很简单:如果LCA(u,v)=v , 那么u'一定等于v ;如果LCA(u,v)=k ,k!=v ,那么k 的深度一定小于 v , u、u'、v 一定在k的子树中;综上所述,LCA(u,v)=LCA(u',v)一定成立。

(二)此时u' 和 v 的祖先序列中一开始的部分一定有所重叠,重叠部分的最后一个元素(也就是深度最深,与u'、v最近的元素)就是所求的LCA(u,v)。这里ancestors数组就可以派上用场了。找到第一个不重叠的节点k,LCA(u,v)=ancestors[k][0] 。 找k的过程利用二进制贪心思想,先尽可能跳到最上层的祖先,如果两祖先相等,说明完全可以跳小点,跳的距离除2,这样一步步跳下去一定可以找到k。

【hdu 2586】

需要注意的是超界的处理。

 #pragma comment(linker, "/STACK:1024000000,1024000000")
#include <stdio.h>
#include <string.h>
#include <vector>
#include <cmath>
#include <iostream>
using namespace std;
int n,m;
struct edge
{
int d,v,next;
edge(){}
edge(int _d,int _v,int _next)
{
d=_d;v=_v;next=_next;
}
}data[];
int map[];
int pool;
void addedge(int s,int e,int v)
{
int t=map[s];
data[pool++]=edge(e,v,t);
map[s]=pool-;
}
int ANCLOG;
int depth[];
int ifv[];
int dis[];
int anc[][];
void dfs(int cur,int dep)
{
ifv[cur]=;
depth[cur]=dep;
int p=map[cur];
while (p!=-)
{
if (!ifv[data[p].d])
{
dis[data[p].d]=dis[cur]+data[p].v;
anc[data[p].d][]=cur;
dfs(data[p].d,dep+);
}
p=data[p].next;
}
}
void initLCA()
{
for (int k=;k<ANCLOG;++k)
for (int i=;i<n;++i)
{
if (anc[i][k-]==-) continue;
anc[i][k]=anc[anc[i][k-]][k-];
}
}
int getLCA(int u,int v)
{
if (depth[u]<depth[v]) swap(u,v);
for (int k=ANCLOG;k>=;--k)
{
if (anc[u][k]==-) continue;
if (depth[anc[u][k]]>=depth[v])
{
u=anc[u][k];
if (depth[u]==depth[v]) break;
}
}
if (u==v) return u;
for (int k=ANCLOG;k>=;--k)
{
if (anc[u][k]==-) continue;
if (anc[u][k]!=anc[v][k])
{
u=anc[u][k];
v=anc[v][k];
}
}
return anc[u][];
}
int main()
{
int T;
scanf("%d",&T);
while (T--)
{
pool=;
memset(anc,-,sizeof anc);
memset(map,-,sizeof map);
memset(ifv,,sizeof ifv);
scanf("%d%d",&n,&m);
ANCLOG=(int)(log(n)/log(2.0));
int s,e,v;
for (int i=;i<n-;++i)
{
scanf("%d%d%d",&s,&e,&v);
addedge(s-,e-,v);
addedge(e-,s-,v);
}
dis[]=;
dfs(,);
initLCA();
for (int i=;i<m;++i)
{
int u,v;
scanf("%d%d",&u,&v);
--u;--v;
int k=getLCA(u,v);
k=dis[u]+dis[v]-*dis[k];
printf("%d\n",k);
}
}
}