【bzoj1937】 Shoi2004—Mst 最小生成树

时间:2023-03-10 02:49:54
【bzoj1937】 Shoi2004—Mst 最小生成树

http://www.lydsy.com/JudgeOnline/problem.php?id=1937 (题目链接)

题意

  一个无向图,给出一个生成树,可以修改每条边的权值,问最小修改多少权值使得给出的生成树是最小生成树。

Solution

  好神!!!!!

  首先,由贪心可知,生成树上的边我们肯定是减小它的权值,非树边我们肯定是增大它的权值。假设树边$i$的权值$w_i$,修改后的权值$w_i-d_i$;非树边$j$的权值$w_j$,修改后的权值$w_j+d_j$。如果$j$有可能代替$i$,那么它们必须满足式子$w_i-d_i<=w_j+d_j$,移下项$w_i-w_j<=d_i+d_j$,是不是很像KM里面的顶标,所以我们把边当做点,边权为两个有制约关系的边的权值差,跑KM求最大权完美匹配就可以了。

  纠结了好久,蛋疼死了。我们的确是要求最小的$\sum d_i$,但是$w_i-w_j<=d_i+d_j$的意义是要求对所有的$i,j$都得满足。我们需要在满足条件的情况下不断缩小$\sum d_i$,所以完美匹配以后我们可以使$\sum d_i$最小。

细节

  边权非负。可能不会完美匹配,需要加点加边。

代码

// bzoj1937
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#define LL long long
#define inf (1ll<<30)
#define MOD 1000000007
#define Pi acos(-1.0)
#define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std; const int maxn=1010;
int head[maxn],deep[maxn],vis[maxn],fa[maxn],id[maxn][maxn],n,m,cnt;
int slack[maxn],vx[maxn],vy[maxn],lx[maxn],ly[maxn],p[maxn],mp[maxn][maxn];
struct data {int u,v,w;}a[maxn];
struct edge {int to,next;}e[maxn<<1]; void link(int u,int v) {
e[++cnt]=(edge){v,head[u]};head[u]=cnt;
e[++cnt]=(edge){u,head[v]};head[v]=cnt;
}
void dfs(int x) {
for (int i=head[x];i;i=e[i].next) if (e[i].to!=fa[x]) {
deep[e[i].to]=deep[x]+1;
fa[e[i].to]=x;
dfs(e[i].to);
}
}
bool match(int x) {
vx[x]=cnt;
for (int y=1;y<=m;y++) if (vy[y]!=cnt) {
int t=lx[x]+ly[y]-mp[x][y];
if (!t) {
vy[y]=cnt;
if (!p[y] || match(p[y])) {p[y]=x;return 1;}
}
else slack[y]=min(slack[y],t);
}
return 0;
}
int KM() {
for (int i=1;i<=m;i++) {
lx[i]=-inf;
for (int j=1;j<=m;j++) lx[i]=max(lx[i],mp[i][j]);
}
cnt=0;
for (int x=1;x<=m;x++) {
for (int i=1;i<=m;i++) slack[i]=inf;
while (1) {
int d=inf;cnt++;
if (match(x)) break;
for (int i=1;i<=m;i++) if (vy[i]!=cnt) d=min(d,slack[i]);
for (int i=1;i<=m;i++) {
if (vx[i]==cnt) lx[i]-=d;
if (vy[i]==cnt) ly[i]+=d;
}
}
}
int ans=0;
for (int i=1;i<=m;i++) ans+=mp[p[i]][i];
return ans;
}
int main() {
scanf("%d%d",&n,&m);
for (int i=1;i<=m;i++) {
scanf("%d%d%d",&a[i].u,&a[i].v,&a[i].w);
id[a[i].u][a[i].v]=id[a[i].v][a[i].u]=i;
}
for (int u,v,i=1;i<n;i++) {
scanf("%d%d",&u,&v);
link(u,v);vis[id[u][v]]=1;
}
dfs(1);memset(head,0,sizeof(head));cnt=0;
for (int i=1;i<=m;i++) if (!vis[i]) {
int x=a[i].u,y=a[i].v,w=a[i].w;
if (deep[x]<deep[y]) swap(x,y);
int t=deep[x]-deep[y];
while (t--) mp[id[x][fa[x]]][i]=max(0,a[id[x][fa[x]]].w-w),x=fa[x];
while (x!=y) {
mp[id[x][fa[x]]][i]=max(0,a[id[x][fa[x]]].w-w);
mp[id[y][fa[y]]][i]=max(0,a[id[y][fa[y]]].w-w);
x=fa[x],y=fa[y];
}
}
printf("%d",KM());
return 0;
}