BZOJ 1977 次小生成树

时间:2023-02-23 17:31:11

TM终于过了。。。。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxv 300500
#define maxe 800500
#define inf 0x7fffffffffffffff
using namespace std;
struct edge
{
    long long v,w,nxt;
}e[maxe];
struct edge_mp
{
    long long u,v,w,flag;
}mp[maxe];
long long n,m,g[maxv],nume=0,father[maxv],anc[maxv][22],mx1[maxv][22],mx2[maxv][22],ans=0,dx=inf,dis[maxv];
long long r1=0,r2=0;
bool cmp(edge_mp x,edge_mp y)
{
    return x.w<y.w;
}
void addedge(long long u,long long v,long long w)
{
    e[++nume].v=v;
    e[nume].w=w;
    e[nume].nxt=g[u];
    g[u]=nume;
}
long long getfather(long long x)
{
    if (father[x]!=x) 
        father[x]=getfather(father[x]);
    return father[x];
}
void kruskal()
{
    for (long long i=1;i<=n;i++) father[i]=i;
    sort(mp+1,mp+m+1,cmp);
    for (long long i=1;i<=m;i++)
    {
        long long u=mp[i].u,v=mp[i].v,w=mp[i].w;
        if (getfather(u)!=getfather(v))
        {
            father[getfather(u)]=getfather(v);mp[i].flag=1;ans+=w;
            addedge(u,v,w);addedge(v,u,w);
        }
    }
}
void dfs(long long x,long long father)
{
    for (long long i=g[x];i;i=e[i].nxt)
    {
        long long v=e[i].v;
        if (v!=father)
        {
            anc[v][0]=x;mx1[v][0]=e[i].w;mx2[v][0]=0;
            dis[v]=dis[x]+1;
            dfs(v,x);
        }
    }
}
void get_table()
{
    for (long long e=1;e<=20;e++)
        for (long long i=1;i<=n;i++)
        {
            anc[i][e]=anc[anc[i][e-1]][e-1];
            long long regis[5];
            regis[1]=mx1[i][e-1];regis[2]=mx1[anc[i][e-1]][e-1];
            regis[3]=mx2[i][e-1];regis[4]=mx2[anc[i][e-1]][e-1];
            sort(regis+1,regis+5);
            mx1[i][e]=regis[4];
            for (long long j=3;j>=1;j--)
            {
                if (regis[j]==regis[j+1]) continue;
                else {mx2[i][e]=regis[j];break;}
            }
        }
}
 
void get_ans(long long x)
{
    long long u=mp[x].u,v=mp[x].v;r1=-1;r2=-1;
    long long k1=-1,k2=-1,k3=-1,k4=-1;
    if (dis[u]<dis[v]) swap(u,v);
    if (dis[u]!=dis[v])
    {
        for (long long e=20;e>=0;e--)
        {
            long long pos=anc[u][e];
            if ((dis[pos]>=dis[v]) && (pos>0))
            {
                long long regis[5];
                regis[1]=mx1[u][e];regis[2]=mx2[u][e];regis[3]=k1;regis[4]=k2;
                sort(regis+1,regis+5);
                k1=regis[4];
                for (long long i=3;i>=1;i--)
                {
                    if (regis[i]==regis[i+1]) continue;
                    else {k2=regis[i];break;}
                }
                u=pos;
            }
        }
    }
    if (u==v)
    {
        r1=k1;r2=k2;
        return;
    }
    for (long long e=20;e>=0;e--)
    {
        long long posu=anc[u][e],posv=anc[v][e];
        if (posu!=posv)
        {
            long long regis[5];
            regis[1]=mx1[u][e];regis[2]=mx2[u][e];regis[3]=k1;regis[4]=k2;
            sort(regis+1,regis+5);
            k1=regis[4];
            for (long long i=3;i>=1;i--)
            {
                if (regis[i]==regis[i+1]) continue;
                else {k2=regis[i];break;}
            }
            regis[1]=mx1[v][e];regis[2]=mx2[v][e];regis[3]=k3;regis[4]=k4;
            sort(regis+1,regis+5);
            k3=regis[4];
            for (long long i=3;i>=1;i--)
            {
                if (regis[i]==regis[i+1]) continue;
                else {k4=regis[i];break;}
            }
            u=posu;v=posv;
        }
    }
    long long regis[8];
    regis[1]=k1;regis[2]=k2;regis[3]=k3;regis[4]=k4;regis[5]=mx1[u][0];regis[6]=mx1[v][0];
    sort(regis+1,regis+7);
    r1=regis[6];
    for (long long i=5;i>=1;i--)
    {
        if (regis[i]==regis[i+1]) continue;
        else {r2=regis[i];break;}
    }
}
int main()
{
    scanf("%lld%lld",&n,&m);
    for (long long i=1;i<=m;i++)
    {
        scanf("%lld%lld%lld",&mp[i].u,&mp[i].v,&mp[i].w);
        mp[i].flag=0;
    }
    kruskal();
    memset(mx1,0,sizeof(mx1));
    memset(mx2,0,sizeof(mx2));
    dfs(1,1);
    get_table();
    for (long long i=1;i<=m;i++)
    {
        if (mp[i].flag) continue;
        long long u=mp[i].u,v=mp[i].v,w=mp[i].w;
        get_ans(i);
        if (r1==mp[i].w) 
        {
            if (r2==-1) continue;
            dx=min(dx,mp[i].w-r2);
        }
        else if (r1<mp[i].w) dx=min(dx,mp[i].w-r1);
    }
    printf("%lld\n",ans+dx);
    return 0;
}