HDU 6035 - Colorful Tree | 2017 Multi-University Training Contest 1

时间:2024-04-04 13:37:15
/*
HDU 6035 - Colorful Tree [ DFS,分块 ]
题意:
n个节点的树,每个节点有一种颜色(1~n),一条路径的权值是这条路上不同的颜色的数量,问所有路径(n*(n-1)/2条) 权值之和是多少?
分析:
考虑单种颜色,这种颜色的贡献是 至少经过一次这种颜色的路径数 = 总路径数(n*(n-1)/2) - 没有经过这种颜色的路径数
求没有经过这种颜色的路径数,即这种颜色的点将整棵树分块,每个分块中的总路径数
*/
#include <bits/stdc++.h>
using namespace std;
#define LL long long
const int N = 200005;
struct Edge {
int to, next;
}edge[N<<1];
int tot, head[N];
void init() {
memset(head, -1, sizeof(head));
tot = 0;
}
void addedge(int u, int v) {
edge[tot].to = v; edge[tot].next = head[u];
head[u] = tot++;
}
int n;
int c[N], last[N], rem[N], cut[N];
LL ans;
LL sum2(LL x) {
return x*(x-1)/2;
}
int dfs(int u, int pre)
{
int su = 1, fa = last[c[u]];
last[c[u]] = u;
for (int i = head[u]; i != -1; i = edge[i].next)
{
int v = edge[i].to;
if (v == pre) continue;
cut[u] = 0;
int sv = dfs(v, u);
su += sv;
ans -= sum2(sv-cut[u]);
}
(fa ? cut[fa] : rem[c[u]]) += su;
last[c[u]] = fa;
return su;
}
int main()
{
int tt = 0;
while (~scanf("%d", &n))
{
init();
for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
for (int i = 1; i < n; i++)
{
int x, y; scanf("%d%d", &x, &y);
addedge(x, y); addedge(y, x);
}
memset(last, 0, sizeof(last));
memset(cut, 0, sizeof(cut));
memset(rem, 0, sizeof(rem));
ans = n*sum2(n);
dfs(1, 1);
for (int i = 1; i <= n; i++)
ans -= sum2(n-rem[i]);
printf("Case #%d: %lld\n", ++tt, ans);
}
}
//----------------------------------------------------------------------
#include <bits/stdc++.h>
using namespace std;
#define LL long long
const int N = 200005;
vector<int> c[N], G[N];
int n;
int L[N], R[N], s[N], f[N];
void dfs(int u, int pre, int&& ncnt)
{
f[u] = pre;
L[u] = ++ncnt;
s[u] = 1;
for (auto& v : G[u])
{
if (v == pre) continue;
dfs(v, u, move(ncnt));
s[u] += s[v];
}
R[u] = ncnt;
}
bool cmp(int a, int b) {
return L[a] < L[b];
}
int main()
{
int tt = 0;
while (~scanf("%d", &n))
{
for (int i = 0; i <= n; i++) c[i].clear(), G[i].clear();
for (int i = 1; i <= n; i++)
{
int x; scanf("%d", &x);
c[x].push_back(i);
}
for (int i = 1; i < n; i++)
{
int x, y; scanf("%d%d", &x, &y);
G[x].push_back(y);
G[y].push_back(x);
}
G[0].push_back(1);
dfs(0, 0, 0);
LL ans = (LL)n * n * (n-1)/2;
for (int i = 1; i <= n; i++)
{
if (c[i].empty()) {
ans -= (LL)n*(n-1)/2;
continue;
}
c[i].push_back(0);
sort(c[i].begin(), c[i].end(), cmp);
for (auto& x : c[i])
for (auto& y : G[x])
{
if (y == f[x]) continue;
int size = s[y];
int k = L[y];
while (1)
{
L[n+1] = k;
auto it = lower_bound(c[i].begin(), c[i].end(), n+1, cmp);
if (it == c[i].end() || L[*it] > R[y]) break;
size -= s[*it];
k = R[*it]+1;
}
ans -= (LL)size * (size-1)/2;
}
}
printf("Case #%d: %lld\n", ++tt, ans);
}
}