Luogu4338 ZJOI2018 历史 LCT、贪心

时间:2023-12-16 13:11:14

传送门

题意:在$N$个点的$LCT$中,最开始每条边的虚实不定,给出每一个点的$access$次数,求一种$access$方案使得每条边的虚实变换次数之和最大,需要支持动态增加某个点的$access$次数。$N \leq 4 \times 10^5$


ZJOI2018真的都是大火题

首先一个小小的转化:对于每个非叶子节点,新开一个叶子节点,将当前非叶子节点的$access$次数转移到这些叶子节点上,这样所有的$access$操作都在叶子节点进行,可以少很多的判断。

接着我们需要考虑在每一个点上最大化方案总数。因为必须相邻两次$access$由不同子树的叶子节点发起才能够贡献$1$的答案,而我们希望每一次$access$都能尽可能多贡献答案,所以尽可能让相邻两次$access$来自不同子树。考虑树上某一个非叶子节点$x$,其所有儿子为集合$i$,$S$表示子树$access$次数总和,显然答案贡献的最大值与$\sum S_i$与$max\{S_i\}$相关,因为当$max\{S_i\}$占$\sum S_i$比例特别大的时候,则必定要有很多同一子树来的$access$操作被放在一起。

现讲结论吧,最大$access$贡献是

$$min\{S_i-1 , 2 \times (\sum S_i - max \{S_i\})\}$$

也就是在$max\{S_i\} > \frac{\sum S_i + 1}{2}$时总贡献数量会取右边一项

$min$中的左边一项表示的是任意两个$access$之间都产生$1$的贡献(最优的情况),而对于右边的项,因为只有取到最大值的子树的贡献次数变少了,那么我们考虑所有其他子树,它们每一次$access$都可以在这一次$access$的之前、之后的$access$中取到$2$的贡献,所以贡献总和就是右边那一项,树形$DP$计算一次就能获得$30pts$。

接着我们考虑修改操作。如果在某一个点的子树集合$i$上存在一个子树$x$满足$S_x > \frac{\sum S_i + 1}{2}$,那么如果我们在子树$x$上加上$access$次数$w$,和和最大值同时增加了$w$,也就是说贡献没有变化。我们考虑将满足$S_x > \frac{\sum S_i + 1}{2}$的点与其父亲连一条实边,表示这一条边连接的两个点之间无需转移,而其他的边就连为轻边。

观察一下条件:$S_x > \frac{\sum S_i + 1}{2}$,是不是很像重链剖分?其实实质就是重链剖分

然后我们就只需要考虑轻边上的转移了。可以知道每一个点到根的轻边的数量是$log \, \sum (access \text{次数})$级别的,复杂度也符合要求。所以可以使用$LCT$动态维护轻重边的划分,外部计算全局答案,每一次找到一条轻边的时候,看能否将其改为重边,去掉以前这个点的贡献,加上当前的贡献即可。

 #include<bits/stdc++.h>
#define int long long
#define lch Tree[x].ch[0]
#define rch Tree[x].ch[1]
#define mid ((Tree[x].sum + 1) >> 1)
//This code is written by Itst
using namespace std; inline int read(){
int a = ;
bool f = ;
char c = getchar();
while(c != EOF && !isdigit(c)){
if(c == '-')
f = ;
c = getchar();
}
while(c != EOF && isdigit(c)){
a = (a << ) + (a << ) + (c ^ '');
c = getchar();
}
return f ? -a : a;
} const int MAXN = ;
struct node{
int fa , ch[] , sum , non_sum;
bool type;
}Tree[MAXN << ];
struct Edge{
int end , upEd;
}Ed[MAXN << ];
int a[MAXN] , head[MAXN] , sum , N , cntEd; inline bool nroot(int x){
return Tree[Tree[x].fa].ch[] == x || Tree[Tree[x].fa].ch[] == x;
} inline bool son(int x){
return Tree[Tree[x].fa].ch[] == x;
} inline void pushup(int x){
Tree[x].sum = Tree[lch].sum + Tree[rch].sum + Tree[x].non_sum + (x > N ? a[x - N] : );
} inline void rotate(int x){
bool f = son(x);
int y = Tree[x].fa , z = Tree[y].fa , w = Tree[x].ch[f ^ ];
if(nroot(y))
Tree[z].ch[son(y)] = x;
Tree[x].fa = z;
Tree[x].ch[f ^ ] = y;
Tree[y].fa = x;
Tree[y].ch[f] = w;
if(w)
Tree[w].fa = y;
pushup(y);
} inline void Splay(int x){
while(nroot(x)){
if(nroot(Tree[x].fa))
rotate(son(x) == son(Tree[x].fa) ? Tree[x].fa : x);
rotate(x);
}
pushup(x);
} inline void access(int x , int w){
a[x - N] += w;
Splay(x);
while(Tree[x].fa){
Splay(Tree[x].fa);
int k = Tree[x].fa , t = Tree[k].sum - Tree[Tree[k].ch[]].sum;
if(Tree[k].ch[])
sum -= (t - Tree[Tree[k].ch[]].sum) << ;
else
sum -= t - ;
Tree[k].non_sum += w;
Tree[k].sum += w;
t += w;
if(Tree[Tree[k].ch[]].sum < Tree[x].sum){
Tree[k].non_sum = Tree[k].non_sum - Tree[x].sum + Tree[Tree[k].ch[]].sum;
Tree[k].ch[] = x;
}
if(((t + ) >> ) < Tree[Tree[k].ch[]].sum)
sum += (t - Tree[Tree[k].ch[]].sum) << ;
else{
sum += t - ;
Tree[k].non_sum += Tree[Tree[k].ch[]].sum;
Tree[k].ch[] = ;
}
x = k;
}
} inline void addEd(int a , int b){
Ed[++cntEd].end = b;
Ed[cntEd].upEd = head[a];
head[a] = cntEd;
} void dfs(int x , int fa){
Tree[x].fa = fa;
if(x > N)
return;
for(int i = head[x] ; i ; i = Ed[i].upEd)
if(Ed[i].end != fa){
dfs(Ed[i].end , x);
Tree[x].sum += Tree[Ed[i].end].sum;
}
for(int i = head[x] ; i ; i = Ed[i].upEd)
if(Ed[i].end != fa && mid < Tree[Ed[i].end].sum){
sum += (Tree[x].sum - Tree[Ed[i].end].sum) << ;
Tree[x].non_sum = Tree[x].sum - Tree[Ed[i].end].sum;
Tree[x].ch[] = Ed[i].end;
return;
}
sum += Tree[x].sum - ;
Tree[x].non_sum = Tree[x].sum;
} signed main(){
#ifndef ONLINE_JUDGE
freopen("4338.in" , "r" , stdin);
//freopen("4338.out" , "w" , stdout);
#endif
N = read();
int M = read();
for(int i = ; i <= N ; ++i)
Tree[i + N].sum = a[i] = read();
for(int i = ; i < N ; ++i){
int a = read() , b = read();
addEd(a , b);
addEd(b , a);
}
for(int i = ; i <= N ; ++i)
addEd(i , i + N);
dfs( , );
printf("%lld\n" , sum);
while(M--){
int a = read() , x = read();
access(a + N , x);
printf("%lld\n" , sum);
}
return ;
}