hdu5293 lca+dp+树状数组+时间戳

时间:2021-06-20 10:14:20

题意是给了 n 个点的树,会有m条链条 链接两个点,计算出他们没有公共点的最大价值,  公共点时这样计算的只要在他们 lca 这条链上有公共点的就说明他们相交

dp[i]为这个点包含的子树所能得到的最大价值

sum[i]表示这个点没有选择经过i这个点链条的总价值

两种选择

这个点没有被选择

dp[i]=sum[i]=sigma(dp[k])k为i的子树

选择了某个链

假设这条链 为(tyuijk)

那么dp[i]=(sum[i]-dp[u]-dp[j])+(sum[j]-dp[k])+dp[k] +(sum[u]-dp[y])+(sum[y]-dp[t])+sum[t];

整理后发现 dp[i]=sum[i] +(sum[j]-dp[j])+(sum[k]-dp[k])+(sum[u]-dp[u])+(sum[y]-dp[y])+(sum[t]-dp[t]);

使用lca计算出每条链的最近公共祖先,在这个最近公共祖先上判断是否使用这条链,还有我们可以使用时间戳加树状数组来求得sum和dp

#include <iostream>
#include <algorithm>
#include <string.h>
#include <cstdio>
#include <vector>
using namespace std;
const int maxn=+;
int to[maxn*],nx[maxn*],H[maxn*],numofedg,timoflook;
int fa[maxn][],first[maxn],last[maxn],depth[maxn];
void addedg(int u, int v)
{
numofedg++; to[numofedg]=v; nx[numofedg]=H[u]; H[u]=numofedg;
numofedg++; to[numofedg]=u; nx[numofedg]=H[v]; H[v]=numofedg;
}
void dfs(int cur, int per, int dep)
{
first[cur]=++timoflook;
depth[cur]=dep;
fa[cur][]=per;
for(int i=; i<; i++)
{
fa[cur][i]=fa[ fa[cur][i-] ][ i- ];
}
for(int i=H[cur]; i; i=nx[i])
{
if(to[i]==per)continue;
dfs(to[i],cur,dep+);
}
last[cur]=++timoflook;
}
int getlca(int u,int v)
{
if(depth[u]<depth[v])swap(u,v);
for(int i=; i>=; i--)
{
if(depth[fa[u][i]]>=depth[v])
u=fa[u][i];
if(u==v)return u;
}
for(int i=; i>=; i--)
{
if(fa[u][i]!=fa[v][i])
{
u=fa[u][i];
v=fa[v][i];
}
}
return fa[u][];
}
struct Edg
{
int u,v,lca,val;
}P[maxn];
vector<int>E[maxn];
int dp[maxn],sum[maxn],CS[maxn*],CD[maxn*];
int lowbit(int x)
{
return x&-x;
}
void add(int x, int d, int *C)
{
while(x<=timoflook)
{
C[x]+=d;
x+=lowbit(x);
}
}
int getsum(int x, int *C)
{
int ret=;
while(x>)
{
ret+=C[x];
x-=lowbit(x);
}
return ret;
}
void solve(int cur, int per)
{
dp[cur]=sum[cur]=;
for(int i=H[cur]; i; i=nx[i])
{
if(to[i]==per)continue;
solve(to[i],cur);
sum[cur]+=dp[to[i]];
}
dp[cur]=sum[cur];
for(int i=; i<E[cur].size(); i++)
{
int id=E[cur][i];
int u=P[id].u;
int v=P[id].v;
int t1=getsum(first[u],CS);
int t2=getsum(first[v],CS);
int t3=getsum(first[u],CD);
int t4=getsum(first[v],CD);
int tmp=t1+t2-t3-t4;
dp[cur]=max(dp[cur],tmp+P[id].val+sum[cur]);
}
add(first[cur],sum[cur],CS);
add(last[cur],-sum[cur],CS);
add(first[cur],dp[cur],CD);
add(last[cur],-dp[cur],CD); }
int main()
{
int cas;
scanf("%d",&cas);
for(int cc=; cc<=cas; cc++)
{
int n,m;
timoflook=numofedg=;
scanf("%d%d",&n,&m);
for(int i=; i<=n; i++)
{
CS[i*]=CS[i*+]=CD[i*]=CD[i*+]=;
H[i]=;E[i].clear(); } for(int i=; i<n; i++)
{
int u,v;
scanf("%d%d",&u,&v);
addedg(u,v);
}
fa[][]=;
dfs(,,);
for(int i=; i<m; i++)
{
scanf("%d%d%d",&P[i].u,&P[i].v,&P[i].val);
P[i].lca=getlca(P[i].u,P[i].v);
E[P[i].lca].push_back(i);
}
solve(,-);
printf("%d\n",dp[]);
}
return ;
}