Petrozavodsk Winter Camp, Day 8, 2014, Second Trip

时间:2023-03-08 17:03:05

给你一棵树,每次询问一个(a,b),问有多少有路径与a-b没有交集

找lca

#include <bits/stdc++.h>
using namespace std;
#define rep(i, j, k) for (int i = int(j); i <= int(k); ++ i)
#define dwn(i, j, k) for (int i = int(j); i >= int(k); -- i)
typedef long long LL;
typedef pair<int, int> P;
const int N = 1e5 + ;
vector<int> g[N];
int dep[N], fa[N][], sz[N]; LL f1[N], f2[N]; int n, q;
LL calc(LL x) {
return x * (x + ) / 2LL;
}
void dfs(int u, int f) {
dep[u] = dep[f] + ;
fa[u][] = f;
sz[u] = ;
for (int v: g[u])
if (v != f) {
dfs(v, u);
sz[u] += sz[v];
}
}
void dfs2(int u, int f) {
for (int v: g[u])
if (v != f) {
dfs2(v, u);
f1[u] += calc(sz[v]);
}
}
void dfs3(int u, int f) {
if (u != ) f2[u] = f2[f] + f1[f] - calc(sz[u]);
for (int v: g[u])
if (v != f) dfs3(v, u);
}
int lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
for (int i = ; i >= ; --i)
if (dep[u] - ( << i) >= dep[v]) u = fa[u][i];
if (u == v) return u;
for (int i = ; i >= ; --i)
if (fa[u][i] && fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][];
}
int find(int u, int d) {
dwn(i, , ) if (d >= ( << i)) u = fa[u][d]; return u;
}
LL solve(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
int lc = lca(x, y); // dep[x] >= dep[y]
LL ret = ;
if (y == lc) {
ret = f1[x] + f2[x] - f2[y] + calc(n - sz[y]);
}
else {
ret = f1[x] + f1[y] + calc(n - sz[lc]);
ret += f2[x] - f2[lc];
int t1 = find(y, dep[lc] - dep[y] - );
ret -= calc(sz[t1]);
ret += f2[y] - f2[lc];
int t2 = find(x, dep[lc] - dep[x] - );
ret -= calc(sz[t2]);
ret -= f1[lc] - calc(sz[t1]) - calc(sz[t2]);
}
return ret;
}
int main() {
scanf("%d%d", &n, &q);
rep(i, , n - ) {
int x, y;
scanf("%d%d", &x, &y);
g[x].push_back(y);
g[y].push_back(x);
}
dfs(, );
rep(j, , ) rep(i, , n) fa[i][j] = fa[fa[i][j - ]][j - ];
dfs2(, );
dfs3(, );
// rep(i, 1, n) cout << f2[i] << ' '; cout << '\n';
while (q --) {
int x, y;
scanf("%d%d", &x, &y);
cout << solve(x, y) << '\n';
}
}
/*
6 2
1 2
3 2
3 4
3 5
6 3
5 4
1 6
*/