[血压游戏] (https://ac.nowcoder.com/acm/contest/5278/G)
神奇的tag数组...,巧妙弥补了高度损失。
方法一:dsu on tree
类似长链剖分,不过是用unordered_map 来维护高度相关信息,swap复杂度是O(1)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
#define dbg(x...) do { cout << "\033[32;1m" << #x <<" -> "; err(x); } while (0)
void err() { cout << "\033[39;0m" << endl; }
template<class T, class... Ts> void err(const T& arg,const Ts&... args) { cout << arg << " "; err(args...); }
const int N = 200000 + 5;
int head[N], ver[N<<1], nxt[N<<1], tot;
int dep[N];
int n, rt;
ll a[N], tag[N];
unordered_map<int, ll> mp[N];
void add(int x, int y){
ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
void ins(int x, int d, ll cnt){
if(!mp[x].count(d)){
mp[x][d] = cnt + tag[x]; // x 下面的边数
} else {
mp[x][d] = max(mp[x][d] - tag[x], 1ll) + cnt + tag[x];
}
}
void merge(int x, int y){
if(mp[x].size() < mp[y].size()){
swap(mp[x], mp[y]);
swap(tag[x], tag[y]);
}
for(auto t : mp[y]){
if(t.second){
ins(x, t.first, max(t.second - tag[y], 1ll));
}
}
}
void dfs(int x, int fa){
dep[x] = dep[fa] + 1;
for(int i=head[x];i;i=nxt[i]){
int y = ver[i];
if(y == fa) continue;
dfs(y, x);
merge(x, y);
}
if(a[x])
ins(x, dep[x], a[x]);
tag[x] ++;
}
int main(){
scanf("%d%d", &n, &rt);
for(int i=1;i<=n;i++){
scanf("%lld", &a[i]);
}
for(int i=1;i<n;i++){
int x, y;scanf("%d%d", &x, &y);
add(x, y);add(y, x);
}
dfs(rt, 0);
ll res = 0;
for(auto t : mp[rt]){
if(t.second) res += max(1ll, t.second - tag[rt]);
}
cout << res << endl;
return 0;
}
方法二:
按照深度分组,建立虚树,然后树形DP求解即可
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
#define dbg(x...) do { cout << "\033[32;1m" << #x <<" -> "; err(x); } while (0)
void err() { cout << "\033[39;0m" << endl; }
template<class T, class... Ts> void err(const T& arg,const Ts&... args) { cout << arg << " "; err(args...); }
const int N = 200000 + 5;
const int M = 2*N;
int head[N], ver[M], nxt[M];
int dfn[N], rnk[N], cnt;
int dep[N], f[N][20];
int st[N], top, inq[N];
ll a[N];
int n, rt, tot;
vector<int> node[N];
struct Graph{
int head[N], ver[M], nxt[M], tot;
void add(int x, int y){
ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
}G;
void add(int x, int y){
ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
void dfs(int x, int fa){
dfn[x] = ++cnt, rnk[cnt] = x;
for(int i=head[x];i;i=nxt[i]){
if(ver[i] == fa) continue;
f[ver[i]][0] = x;
dep[ver[i]] = dep[x] + 1;
dfs(ver[i], x);
}
}
int lca(int x, int y){
if(dep[x] > dep[y]) swap(x, y);
for(int i=19;i>=0;i--) if(dep[f[y][i]] >= dep[x]) y = f[y][i];
if(x == y) return x;
for(int i=19;i>=0;i--) if(f[y][i] != f[x][i]) y = f[y][i], x = f[x][i];
return f[x][0];
}
void insert(int x){
if(x == rt) return;
int t = lca(x, st[top]);
if(t != st[top]){
while(top > 1 && dfn[st[top-1]] > dfn[t]){
G.add(st[top-1], st[top]);
top --;
}
if(dfn[t] > dfn[st[top-1]]){
G.head[t] = 0;
G.add(t, st[top]);
st[top] = t;
} else {
G.add(t, st[top--]);
}
}
G.head[x] = 0, st[++top] = x;
}
ll dfs(int x){
if(inq[x]) return a[x];
ll res = 0;
for(int i=G.head[x];i;i=G.nxt[i]){
int y = G.ver[i];
ll val = dfs(y);
if(val) // 没有就不要加
res += max(val - dep[y] + dep[x], 1ll);
}
return res;
}
ll get(int x){
if(!node[x].size()) return 0;
sort(node[x].begin(), node[x].end(),[=](int a, int b){return dfn[a] < dfn[b];});
st[top = 1] = rt; G.tot = 0; G.head[rt] = 0;
for(auto t : node[x]) insert(t), inq[t] = 1;
for(int i=1;i<top;i++){
G.add(st[i], st[i+1]);
}
ll res = dfs(rt);
if(res >= 2) res --;
for(auto t : node[x]) inq[t] = 0;
return res;
}
int main(){
scanf("%d%d", &n, &rt);
for(int i=1;i<=n;i++){
scanf("%lld", &a[i]);
}
for(int i=1;i<n;i++){
int x, y;scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
dep[rt] = 1;
dfs(rt, 0);
for(int i=1;i<=n;i++){
node[dep[i]].push_back(i);
}
for(int j=1;j<20;j++){
for(int i=1;i<=n;i++){
f[i][j] = f[f[i][j-1]][j-1];
}
}
ll res = 0;
for(int i=1;i<=n;i++){
res += get(i);
}
cout << res <<endl;
return 0;
}