BZOJ4543 [POI2014]Hotel加强版

时间:2021-04-01 13:38:14

Hotel加强版

有一个树形结构,每条边的长度相同,任意两个节点可以相互到达。选3个点。两两距离相等。有多少种方案?

数据范围:n<=100000

yyb的题解

我们先考虑一个\(O(n^2)\)的dp,也就是原题的做法。

我们考虑一下,三个点两两的距离相同是什么情况,

  1. 存在一个三个点公共的LCA,所以我们在LCA统计答案即可。

  2. 存在一个点,使得这个点到另外两个子树中距离它为d的点以及这个点的d次祖先。

所以,设计DP状态为

  • \(f[i][j]\)表示以\(i\)为根的子树中,距离当前点为\(j\)的点数。

  • \(g[i][j]\)表示以\(i\)为根的子树中,两个点到LCA的距离为\(d\),并且他们的LCA到\(i\)的距离为\(d−j\)的点对数,简单来说就是\(i\)往其他地方走\(j\)步就能找到一组解。

考虑合并的时候的转移:

\[ans+=g[i][0],ans+=g[i][j]∗f[son][j−1],f[i][j]+=f[son][j−1],g[i][j]+=g[son][j+1]
\]

转移的正确性比较显然,不在多讲了,并不是这里的重点。这样子的复杂度是\(O(n^2)\)的。

我们观察一下转移的时候有这样两步:

\[f[i][j]+=f[son][j−1],g[i][j]+=g[son][j+1]
\]

如果我们钦定一个儿子的话,那么这个数组是可以直接赋值的,并不需要再重复计算。

所以我们用指针来写,也就是:\(f[i]=f[son]−1,g[i]=g[son]+1\)。

如果整棵树是链我们发现复杂度可以做到O(n),既然如此,我们推广到树。我们进行长链剖分,每次钦定从重儿子直接转移,那么我们还需要从轻儿子进行转移。不难证明所有轻儿子都是一条重链的顶部,转移时的复杂度是重链长度。

那么,复杂度拆分成两个部分:直接从重儿子转移\(O(1)\),从轻儿子转移\(O(∑len)\)。发现每个点有且仅有一个父亲,因此一条重链算且仅被一个点暴力转移,而每次转移复杂度是链长。所以全局复杂度是∑链长,也就是\(O(n)\),因此总复杂度就是\(O(n)\)。

这样子写下来,发现长链剖分之后,我们的复杂度变为了线性。但是注意到复杂度证明中的一点:转移和链长相关。而链长和什么相关呢?深度。所以说对于这一类与深度相关的、可以快速合并的信息,使用长链剖分可以优化到一个非常完美的复杂度。如果需要维护的与深度无关的信息的话,或许dsu on tree是一个更好的选择。

代码

DP是的for是在用相对深度,比较简单的实现方法是之前统计重儿子的时候用高度代替深度。

然后tmp必须开到4倍是因为g数组指针给儿子的时候在前移。

co int N=1e5+1;
int n,head[N],to[N*2],nx[N*2],tot;
void add(int x,int y){to[++tot]=y,nx[tot]=head[x],head[x]=tot;}
int dep[N],md[N],son[N];
void dfs1(int x,int fa){
for(int i=head[x];i;i=nx[i]){
int y=to[i];if(y==fa) continue;
dfs1(y,x),md[x]=std::max(md[x],md[y]);
if(md[y]>md[son[x]]) son[x]=y;
}
md[x]=md[son[x]]+1;
}
ll*f[N],*g[N],tmp[N*4],*id=tmp,ans;
void dfs2(int x,int fa){
if(son[x]) f[son[x]]=f[x]+1,g[son[x]]=g[x]-1,dfs2(son[x],x);
f[x][0]=1,ans+=g[x][0];
for(int i=head[x];i;i=nx[i]){
int y=to[i]; if(y==fa||y==son[x]) continue;
f[y]=id,id+=md[y]*2,g[y]=id,id+=md[y]*2;
dfs2(y,x);
for(int j=0;j<md[y];++j){
if(j)ans+=f[x][j-1]*g[y][j];
ans+=g[x][j+1]*f[y][j];
}
for(int j=0;j<md[y];++j){
g[x][j+1]+=f[x][j+1]*f[y][j];
if(j)g[x][j-1]+=g[y][j];
f[x][j+1]+=f[y][j];
}
}
}
int main(){
// freopen("BZOJ4543.in","r",stdin);
// freopen(".out","w",stdout);
read(n);
for(int i=1,x,y;i<n;++i){
read(x),read(y);
add(x,y),add(y,x);
}
dfs1(1,0);
f[1]=id,id+=md[1]*2,g[1]=id,id+=md[1]*2;
dfs2(1,0);
printf("%lld\n",ans);
return 0;
}