hdu 4679 (树形DP)

时间:2023-03-10 02:43:58
hdu 4679 (树形DP)

题意:给一棵树,边的权值都是1,摧毁每条边是有代价的,选择摧毁一条边,把一棵树分成两部分,求出两部分中距离最大的两点的距离,求出距离*代价最小的边,多条的话输出序号最小的。

刚开始理解错题意了,wrong了几次,一直在纠结摧毁一条边后上边的树的最远距离怎么求,儿子树的最远距离就是所有子树的最长边+次长边就可以了。当我们求到一个节点u时,肯定有一个祖先节点,该祖先节点在摧毁与u链接的边后剩余的子树中最长的边和次长边之和是最大的,如果摧毁u与子节点的边时,就要考虑那个祖先节点的位置了,可能就是u这个节点。如果摧毁u与子节点的一条边后,可以求出u的子树中的最长边和次长边,如果次长边要是大于祖先节点的最长边,祖先节点就更新为u节点,如果u的子树的最长边+到祖先节点的距离大于祖先的次长边的话,祖先节点也更新为u,u的最长边和次长边要更新。

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<stdio.h>
#include<string.h>
#define N 100001
int n,head[N],num,vis[N],dp[N][3],dp1[N][2],min,iid;
struct edge
{
int st,ed,next,w,id;
}E[N*2];
void addedge(int x,int y,int w,int id)
{
E[num].st=x;E[num].ed=y; E[num].w=w;E[num].id=id;E[num].next=head[x];head[x]=num++;
}
int dfs(int u)
{
int i,v;
vis[u]=1;
for(i=head[u];i!=-1;i=E[i].next)
{
v=E[i].ed;
if(vis[v]==1)continue;
int temp=dfs(v)+1;
if(temp>dp[u][2])//所有子树的最长边
{
dp[u][0]=dp[u][1];
dp[u][1]=dp[u][2];
dp[u][2]=temp;
}
else if(temp>dp[u][1])//次长边
{
dp[u][0]=dp[u][1];
dp[u][1]=temp;
}
else if(temp>dp[u][0])//第三长边
dp[u][0]=temp;
}
return dp[u][2];
}
void dfs1(int u,int father,int dis)
{
vis[u]=1;
int i,v,ans,temp,flag;
int y1,y0;//摧毁一条边后当前节点u链接的所有子树的最长边和次长边
for(i=head[u];i!=-1;i=E[i].next)
{
dp1[u][1]=dp1[father][1];dp1[u][0]=dp1[father][0];//祖先节点的最长和次长边之和最大的
v=E[i].ed;
flag=dis;//当前节点到最长和次长边之和最大的祖先节点的距离
if(vis[v]==1)continue;
if(dp[v][2]+1==dp[u][2])//当前边在父节点u的最长边上
{y1=dp[u][1];y0=dp[u][0];}
else if(dp[v][2]+1==dp[u][1])//当前边在父节点u的次长边上
{y1=dp[u][2];y0=dp[u][0];}
else {y1=dp[u][2];y0=dp[u][1];}
if(y0>dp1[u][1])//如果子树的次长边比祖先节点的最长边大。就更新当前节点的最长边和次长边
{
dp1[u][1]=y1;
dp1[u][0]=y0;
flag=0;//祖先节点变成当前节点,
}
else if(y1+dis>dp1[u][0])//如果子树的最长边+到祖先节点的距离大于祖先的次长边
{
dp1[u][1]+=dis;//当前节点的最长边加上到祖先的距离
dp1[u][0]=y1;//次长边=子树的最长边
flag=0;
}
if(dp1[u][1]<dp1[u][0]){temp=dp1[u][1];dp1[u][1]=dp1[u][0];dp1[u][0]=temp;}
ans=dp1[u][1]+dp1[u][0];//摧毁当前边后,u所在树的最远两点距离
if(ans<dp[v][2]+dp[v][1])ans=dp[v][2]+dp[v][1];//与儿子v所在树的最远两点距离比较
if(ans*E[i].w<min){min=ans*E[i].w;iid=E[i].id;}//更新最小值
else if(ans*E[i].w==min&&iid>E[i].id){min=ans*E[i].w;iid=E[i].id;}
dfs1(v,u,flag+1);
}
}
int main()
{
int i,x,y,w,t,op=1;
scanf("%d",&t);
while(t--)
{
memset(head,-1,sizeof(head));
num=0;
scanf("%d",&n);
for(i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&w);
addedge(x,y,w,i);
addedge(y,x,w,i);
}
min=1000000000;iid=1000000;
memset(dp,0,sizeof(dp));
memset(dp1,0,sizeof(dp1));
memset(vis,0,sizeof(vis));
dfs(1);
memset(vis,0,sizeof(vis));
dp1[0][1]=0;dp1[0][0]=0;
dfs1(1,0,0);
printf("Case #%d: ",op++);
printf("%d\n",iid);
}
return 0;
}