codeforces 161D Distance in Tree 树形dp

时间:2023-03-08 18:10:49
codeforces 161D Distance in Tree 树形dp

题目链接:

http://codeforces.com/contest/161/problem/D

D. Distance in Tree

time limit per test 3 secondsmemory limit per test 512 megabytes
#### 问题描述
> A tree is a connected graph that doesn't contain any cycles.
>
> The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.
>
> You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u, v) are considered to be the same pair.
#### 输入
> The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.
>
> Next n - 1 lines describe the edges as "ai bi" (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.
#### 输出
> Print a single integer — the number of distinct pairs of the tree's vertices which have a distance of exactly k between them.
>
> Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.
#### 样例
> **sample input**
> 5 2
> 1 2
> 2 3
> 3 4
> 2 5
>
> **sample output**
> 4

题意

给你一颗树,每条边长为1,求所有距离为k的顶点对,(u,v)和(v,u)算一对。

题解

树形dp:

dp[i][j]表示与第i个节点距离为j的节点数。

两次dfs:

第一次求以i为根的子树中与i距离为j的节点数dp[i][j]。

第二次求i与不在i的子树中的节点金额距离为j的节点数。

两次加起来就是表示与i节点距离为j的所有的树上节点数。

答案就是sigma(dp[i][k])。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<map>
#define lson (o<<1)
#define rson ((o<<1)|1)
#define M (l+(r-l)/2)
using namespace std; const int maxn=5e4+10;
const int maxm=555; typedef __int64 LL; int n,k;
LL dp[maxn][maxm];
vector<int> G[maxn]; void dfs1(int u,int fa) {
dp[u][0]=1;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v==fa) continue;
dfs1(v,u);
for(int j=0;j+1<=k;j++){
dp[u][j+1]+=dp[v][j];
}
}
} LL tmp[maxm];
void dfs2(int u,int fa) {
if(fa!=-1){
tmp[0]=dp[fa][0];
for(int j=1;j<=k;j++){
tmp[j]=dp[fa][j]-dp[u][j-1];
}
for(int j=0;j+1<=k;j++){
dp[u][j+1]+=tmp[j];
}
}
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v==fa) continue;
dfs2(v,u);
}
} int main() {
scanf("%d%d",&n,&k);
memset(dp,0,sizeof(dp));
for(int i=0; i<n-1; i++) {
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1,-1);
dfs2(1,-1);
LL ans=0;
for(int i=1;i<=n;i++){
ans+=dp[i][k];
}
printf("%I64d\n",ans/2);
return 0;
}