题解-AtCoder Code-Festival2017 Final-J Tree MST

时间:2023-03-09 15:28:53
题解-AtCoder Code-Festival2017 Final-J Tree MST

Problem

\(\mathrm{Code~Festival~2017~Final~J}\)

题意概要:一棵 \(n\) 个节点有点权边权的树。构建一张完全图,对于任意一对点 \((x,y)\),连一条长度为 \(w[x] + w[y]+ dis(x, y)\) 的边。求这张图的最小生成树。

\(n\leq 2\times 10^5\)

Solution

在操场上晒太阳时想到的做法,求 \(\mathrm{MST}\) 可以使用另一种贪心算法:每次找到每个点连出去的最短的边,并将其合并,一次是 \(O(n)\),由于每次点数至少减半,所以总共不超过 \(\log n\)次,总复杂度 \(O(n\log n)\)

使用这种贪心算法后,只需每次找到离每个点最近的点。

可以使用点分治,设已经合并的点为同一连通块。考虑分治中心为 \(x\),只考虑过分治中心的路径,求出 \(dep+w\) 最小的点,对于每棵子树内的点,只有非子树内的点可能做贡献,而对于每个点,只有非同连通块的点可做贡献。所以需要维护四个值,这样较麻烦,或者是只维护两个值加上处理前后缀(具体可以看代码)。复杂度 \(O(n\log^2n)\)

然后搜了一波题解,发现一群人在同一天使用了同一个做法(可能是他们在讲课后统一发的题解):同样考虑上述贪心,只是点分治时不用考虑是否在同一子树内,而是都连过去,这样保证结果不会低于答案,稍加分析发现能得到最优解。

又看了看官方正解,发现不需要点分治,直接换根Dp即可……可能是老年选手已经开始老年痴呆了

Code

哦,这样常数有点大,我的代码跑极限数据 \(\mathrm {5.01s}\),会 T 三个点,预处理点分树即可

//Code Festival 2017 Final-J
#include <bits/stdc++.h>
using namespace std;
typedef long long ll; template <typename _tp> inline void cmax(_tp&A,const _tp&B){if(A < B) A = B;}
template <typename _tp> inline void cmin(_tp&A,const _tp&B){if(A > B) A = B;} template <typename _tp> inline void read(_tp&x){
char c11=getchar(),ob=0;x=0;
while(c11!='-'&&!isdigit(c11))c11=getchar();if(c11=='-')ob=1,c11=getchar();
while(isdigit(c11))x=x*10+c11-'0',c11=getchar();if(ob)x=-x;
} const ll Inf = 2e18;
const int N = 201000;
struct Edge{int v,nxt;ll w;}a[N+N+N];
int head[N],Head[N],vs[N],w[N],id[N];
int n,_; inline void ad(){
static int x,y,z; read(x), read(y), read(z);
a[++_].v = y, a[_].w = z, a[_].nxt = head[x], head[x] = _;
a[++_].v = x, a[_].w = z, a[_].nxt = head[y], head[y] = _;
} namespace dsu{
int dad[N];
int find(int x){return dad[x]? dad[x] = find(dad[x]): x;}
bool check(int x,int y){return find(x) == find(y);}
bool merge(int x,int y){
static int p1,p2;
if((p1 = find(x)) == (p2 = find(y))) return false;
dad[p1] = p2; return true;
}
} namespace TD{
int sz[N], rt, Mi, nn;
void get_rt(int x,int las){
sz[x] = 1;int mx = 0;
for(int i=head[x];i;i=a[i].nxt)
if(a[i].v!=las and !vs[a[i].v]){
get_rt(a[i].v,x);
sz[x] += sz[a[i].v];
cmax(mx, sz[a[i].v]);
}
cmax(mx, nn - sz[x]);
if(mx < Mi) Mi = mx, rt = x;
} void Get_rt(int x,int xn){rt = 0, nn = xn, Mi = 2e9; get_rt(x,0);} void build(int x,int las){
vs[x] = 1;
a[++_].v = x, a[_].nxt = Head[las], Head[las] = _;
get_rt(x,0);
for(int i=head[x];i;i=a[i].nxt)
if(!vs[a[i].v]){
Get_rt(a[i].v,sz[a[i].v]);
build(rt,x);
}
}
} struct node{
ll v;int id;
inline node(){}
inline node(const ll&V,const int&Id):v(V),id(Id){}
}tr[N], p[N], Mx, Mi, Fir[N], Sec[N]; node pre_fir[N], pre_sec[N];
node suf_fir[N], suf_sec[N]; inline void upd(node&A, node&B, node nw){
if(nw.v < A.v) {
if(nw.id == A.id) {A = nw; return ;}
B = A, A = nw; return ;
}
if(nw.v < B.v)
if(nw.id != A.id) B = nw;
} void get_val(int x,int las,ll dep){
upd(Mi, Mx, node(dep+w[x],id[x]));
for(int i=head[x];i;i=a[i].nxt)
if(a[i].v!=las and !vs[a[i].v])
get_val(a[i].v,x,dep+a[i].w);
} void cover(int x,int las,ll dep,node A,node B){
if(id[x] != A.id and dep + A.v < p[x].v)
p[x].v = dep + A.v, p[x].id = A.id;
if(id[x] != B.id and dep + B.v < p[x].v)
p[x].v = dep + B.v, p[x].id = B.id;
for(int i=head[x];i;i=a[i].nxt)
if(a[i].v!=las and !vs[a[i].v])
cover(a[i].v,x,dep+a[i].w,A,B);
} int to[N], to_w[N]; void work(int x){
vs[x] = 1;
int top = 0;
for(int i=head[x];i;i=a[i].nxt)
if(!vs[a[i].v]){
Mi = node(w[x],id[x]), Mx = node(Inf,0);
get_val(a[i].v,x,a[i].w);
++top, to[top] = a[i].v, to_w[top] = a[i].w;
Fir[top] = Mi, Sec[top] = Mx;
} pre_fir[1] = Fir[1];
pre_sec[1] = Sec[1];
for(int i=2;i<=top;++i){
pre_fir[i] = pre_fir[i-1];
pre_sec[i] = pre_sec[i-1];
upd(pre_fir[i], pre_sec[i], Fir[i]);
upd(pre_fir[i], pre_sec[i], Sec[i]);
}
suf_fir[top] = Fir[top];
suf_sec[top] = Sec[top];
for(int i=top-1;i>=1;--i){
suf_fir[i] = suf_fir[i+1];
suf_sec[i] = suf_sec[i+1];
upd(suf_fir[i], suf_sec[i], Fir[i]);
upd(suf_fir[i], suf_sec[i], Sec[i]);
} node A,B;
for(int i=1;i<=top;++i){
A = node(w[x],id[x]), B = node(Inf,0);
if(i!=1) upd(A,B,pre_fir[i-1]);
if(i!=1) upd(A,B,pre_sec[i-1]);
if(i!=top) upd(A,B,suf_fir[i+1]);
if(i!=top) upd(A,B,suf_sec[i+1]);
cover(to[i],x,to_w[i],A,B);
} if(top){
A = pre_fir[top], B = pre_sec[top];
if(id[x] != A.id and A.v < p[x].v)
p[x].v = A.v, p[x].id = A.id;
if(id[x] != B.id and B.v < p[x].v)
p[x].v = B.v, p[x].id = B.id;
} for(int i=Head[x];i;i=a[i].nxt)
work(a[i].v);
} int main(){
read(n);
for(int i=1;i<=n;++i)read(w[i]);
for(int i=1;i<n;++i)ad(); TD::build(1,0); int Tot = n; ll Ans = 0ll;
while(Tot > 1){
for(int i=1;i<=n;++i) id[i] = dsu::find(i), p[i].v = Inf, vs[i] = 0;
work(a[Head[0]].v);
for(int i=1;i<=n;++i) tr[i].v = Inf;
for(int i=1,t;i<=n;++i){
t = dsu::find(i);
if(tr[t].v > w[i] + p[i].v)
tr[t].v = p[i].v + w[i], tr[t].id = p[i].id;
}
for(int i=1;i<=n;++i)
if(dsu::find(i) == i){
if(dsu::check(i, tr[i].id)) continue;
Ans += tr[i].v, dsu::merge(i, tr[i].id);
--Tot;
}
}
printf("%lld\n",Ans);
return 0;
}