BZOJ 2286 消耗战 (虚树+树形DP)

时间:2022-06-29 12:57:33

给出一个n节点的无向树,每条边都有一个边权,给出m个询问,
每个询问询问ki个点,问切掉一些边后使得这些顶点无法与顶点1连接。
最少的边权和是多少。
(n<=250000,sigma(ki)<=500000)

考虑树形DP,我们令mn[i]表示i节点无法与1节点相连切除的最小权值。
显然有mn[i]=min(E(fa,i),mn[fa]).
大致就是i到1的简单路径上的最小边。
我们对于每个询问。把询问的点不妨称为关键点。
令dp[i]表示i节点不能与子树的关键点连接切掉的最小权值。
那么有,如果son[i]是关键点,则dp[i]+=E(i,son(i)).
如果son[i]不是关键点,则dp[i]+=min(dp[son(i)],E(i,son(i))).

考虑最坏每次只询问一个点,则复杂度为O(n*sigma(ki)).显然无法承受。

我们观察到sigma(ki)有限制,这启发了我们构造一颗新树,这棵树称为虚树。
我们把每个节点和每对节点的lca单独拉出来模仿原来的树的形态构造一颗虚树。
这样再在这颗新树上进行树形DP。

构造这棵树的核心思想是每次维护一条最右边的链。
首先把关键点按dfs序排序。
然后相邻的点取lca。
再单调栈维护一下最右边的链就ok啦。

# include <stdio.h>
# include <string.h>
# include <stdlib.h>
# include <iostream>
# include <vector>
# include <queue>
# include <stack>
# include <map>
# include <math.h>
# include <algorithm>
using namespace std;
# define lowbit(x) ((x)&(-x))
# define pi acos(-1.0)
# define MAXN
# define eps 1e-
# define MAXM
# define MOD
# define INF
# define mem(a,b) memset(a,b,sizeof(a))
# define FOR(i,a,n) for(int i=a; i<=n; ++i)
# define FO(i,a,n) for(int i=a; i<n; ++i)
# define bug puts("H");
typedef long long LL;
typedef unsigned long long ULL;
int _MAX(int a, int b){return a>b?a:b;}
int _MIN(int a, int b){return a>b?b:a;}
int Scan() {
int res=, flag=;
char ch;
if((ch=getchar())=='-') flag=;
else if(ch>=''&&ch<='') res=ch-'';
while((ch=getchar())>=''&&ch<='') res=res*+(ch-'');
return flag?-res:res;
}
void Out(int a) {
if(a<) {putchar('-'); a=-a;}
if(a>=) Out(a/);
putchar(a%+'');
} struct Edge{int p, next, w;}edge[MAXN<<];
int head[MAXN], cnt=, bin[], ind;
int id[MAXN], dep[MAXN], fa[MAXN][], h[MAXN], st[MAXN], top;
LL ans[MAXN], dp[MAXN]; void add_edge(int u, int v, int w)
{
if (u==v) return ;
edge[cnt].p=v; edge[cnt].next=head[u]; edge[cnt].w=w; head[u]=cnt++;
}
void bin_init(){bin[]=; FO(i,,) bin[i]=bin[i-]<<;}
bool comp(int a, int b){return id[a]<id[b];}
void dfs(int x, int fat)
{
id[x]=++ind;
fa[x][]=fat;
for (int i=; bin[i]<=dep[x]; ++i) fa[x][i]=fa[fa[x][i-]][i-];
for (int i=head[x]; i; i=edge[i].next) {
int v=edge[i].p;
if (v==fat) continue;
dep[v]=dep[x]+;
ans[v]=min(ans[x],(LL)edge[i].w);
dfs(v,x);
}
}
int lca(int x, int y)
{
if (dep[x]<dep[y]) swap(x,y);
int t=dep[x]-dep[y];
for (int i=; bin[i]<=t; ++i) if (bin[i]&t) x=fa[x][i];
for (int i=; i>=; --i) if (fa[x][i]!=fa[y][i]) x=fa[x][i], y=fa[y][i];
if (x==y) return x;
else return fa[x][];
}
void dp_dfs(int x)
{
dp[x]=ans[x];
LL temp=;
for (int i=head[x]; i; i=edge[i].next) {
int v=edge[i].p;
dp_dfs(v);
temp+=dp[v];
}
head[x]=;
if (temp) dp[x]=min(dp[x],temp);
}
void sol()
{
int k, tot=;
cnt=;
scanf("%d",&k);
FOR(i,,k) h[i]=Scan();
sort(h+,h+k+,comp);
h[++tot]=h[];
FOR(i,,k) if (lca(h[tot],h[i])!=h[tot]) h[++tot]=h[i];
st[++top]=;
FOR(i,,tot) {
int f=lca(h[i],st[top]);
while (dep[f]<dep[st[top-]]) add_edge(st[top-],st[top],), top--;
add_edge(f,st[top--],);
if (f!=st[top]) st[++top]=f;
st[++top]=h[i];
}
while (top>) add_edge(st[top-],st[top],), top--;
dp_dfs();
printf("%lld\n",dp[]);
}
int main()
{
int n, m, u, v, w;
bin_init();
n=Scan();
FO(i,,n) u=Scan(), v=Scan(), w=Scan(), add_edge(u,v,w), add_edge(v,u,w);
ans[]=(LL)<<; dfs(,);
m=Scan();
mem(head,);
while (m--) sol();
return ;
}