BZOJ 3572: [Hnoi2014]世界树 [虚树 DP 倍增]

时间:2023-03-09 20:50:58
BZOJ 3572: [Hnoi2014]世界树 [虚树 DP 倍增]

传送门

题意:

一棵树,多次询问,给出$m$个点,求有几个点到给定点最近


写了一晚上...

当然要建虚树了,但是怎么$DP$啊

大爷题解传送门

我们先求出到虚树上某个点最近的关键点

然后枚举所有的边$(f,x)$,讨论一下边上的点的子树应该靠谁更近

倍增求出分界点

注意有些没出现在虚树上的子树

注意讨论的时候只讨论链上的不包括端点,否则$f$的子树会被贡献多次

学到的一些$trick:$

1.$pair$的妙用

2.不需要建出虚树只要求虚树的$dfs$序(拓扑序)和$fa$就可以$DP$了

注意$DP$的时候必须先用儿子更新父亲再用父亲更新儿子,因为父亲的最优值有可能在其他儿子

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
#define pii pair<int,int>
#define MP make_pair
#define fir first
#define sec second
typedef long long ll;
const int N=3e5+,INF=1e9;
inline int read(){
char c=getchar();int x=,f=;
while(c<''||c>''){if(c=='-')f=-;c=getchar();}
while(c>=''&&c<=''){x=x*+c-'';c=getchar();}
return x*f;
} int n,Q;
struct Edge{
int v,ne,w;
}e[N<<];
int cnt,h[N];
inline void ins(int u,int v){
cnt++;
e[cnt].v=v;e[cnt].ne=h[u];h[u]=cnt;
cnt++;
e[cnt].v=u;e[cnt].ne=h[v];h[v]=cnt;
}
int fa[N][],deep[N],dfn[N],dfc,size[N],All;
void dfs(int u){
dfn[u]=++dfc;
size[u]=;
for(int i=;(<<i)<=deep[u];i++)
fa[u][i]=fa[ fa[u][i-] ][i-];
for(int i=h[u];i;i=e[i].ne)
if(e[i].v!=fa[u][]){
fa[e[i].v][]=u;
deep[e[i].v]=deep[u]+;
dfs(e[i].v);
size[u]+=size[e[i].v];
}
}
inline int lca(int x,int y){
if(deep[x]<deep[y]) swap(x,y);
int bin=deep[x]-deep[y];
for(int i=;i>=;i--)
if((<<i)&bin) x=fa[x][i];
for(int i=;i>=;i--)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return x==y ? x :fa[x][];
} int a[N],st[N],par[N],dis[N],t[N],m,ans[N];
int remain[N];
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
inline void ins2(int x,int y){par[y]=x;dis[y]=deep[y]-deep[x];}
pii g[N];
void dp(int m){
for(int i=m;i>;i--){
int x=t[i],f=par[x];
g[f]=min(g[f],MP(g[x].fir+dis[x],g[x].sec));
}
for(int i=;i<=m;i++){
int x=t[i],f=par[x];
g[x]=min(g[x],MP(g[f].fir+dis[x],g[f].sec));
}
}
inline int jump1(int x,int tar){
for(int i=;i>=;i--)
if(deep[ fa[x][i] ]>=tar) x=fa[x][i];
return x;
}
inline int jump(int x,int tar){
int bin=deep[x]-tar;
for(int i=;i>=;i--)
if((<<i)&bin) x=fa[x][i];
return x;
}
int ora[N];
void solve(){
int n=read(),m=;
for(int i=;i<=n;i++)
ora[i]=a[i]=read(),t[++m]=a[i],g[a[i]]=MP(,a[i]);
sort(a+,a++n,cmp); int top=;
for(int i=;i<=n;i++){
if(!top) {st[++top]=a[i];continue;}
int x=a[i],f=lca(x,st[top]);
while(dfn[f]<dfn[st[top]]){
if(dfn[f]>=dfn[st[top-]]){
ins2(f,st[top--]);
if(f!=st[top]) st[++top]=f,t[++m]=f,g[f]=MP(INF,);
break;
}else ins2(st[top-],st[top]),top--;
}
st[++top]=x;
}
while(top>) ins2(st[top-],st[top]),top--; sort(t+,t++m,cmp);
dp(m);
for(int i=;i<=m;i++) remain[t[i]]=size[t[i]]; ans[ g[t[]].sec ]+=All-size[t[]];
for(int i=;i<=m;i++){
int x=t[i],f=par[x];par[x]=;
int t=jump(x,deep[f]+);
remain[f]-=size[t];
if(g[x].sec == g[f].sec) ans[ g[x].sec ]+=size[t]-size[x];
else{
int len=g[x].fir + g[f].fir + dis[x], mid=deep[x]-(len/-g[x].fir);
if( !(len&) && g[f].sec<g[x].sec ) mid++;
int y=jump(x,mid);
ans[ g[f].sec ]+=size[t]-size[y];
ans[ g[x].sec ]+=size[y]-size[x];
}
}
for(int i=;i<=m;i++) ans[ g[t[i]].sec ]+=remain[t[i]];
for(int i=;i<=n;i++) printf("%d%c",ans[ora[i]],i==n?'\n':' '),ans[ora[i]]=;
}
int main(){
//freopen("in","r",stdin);
n=read();All=n;
for(int i=;i<n;i++) ins(read(),read());
dfs();
Q=read();
while(Q--) solve();
}