题目:https://www.luogu.org/problemnew/show/P4719
感觉这篇博客写得挺好:https://blog.****.net/litble/article/details/81038415
为了动态维护DP值,首先要把它转化成一个容易维护的形式,这道题中DP状态的转移就可以转化成矩阵乘法;
于是要快速算出一个DP值,就可以矩阵连乘,用线段树维护(此时求DP值已经完全变成求区间矩阵乘积了);
可以发现,如果修改一个点的值,影响到的只有它到根的一条链;
所以树剖+线段树维护矩阵,以重链为单位修改,复杂度据说是 23nlog2n ;
注意这里的 ed[x] 不是树剖常用的那个 ed,而是重链底端的 dfn 值,并且只记在 top 上,这样就可以在线段树上从 top 提取出一条重链。
代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define mid ((l+r)>>1)
#define ls (x<<1)
#define rs (x<<1|1)
using namespace std;
int const xn=1e5+,inf=1e9;
int n,m,a[xn],hd[xn],ct,to[xn<<],nxt[xn<<],dfn[xn],tim,id[xn],top[xn];
int fa[xn],siz[xn],son[xn],f[xn][],ed[xn];
struct N{
int a[][];
N(){a[][]=a[][]=; a[][]=a[][]=-inf;}
N operator * (const N &y) const
{
N ret;
for(int i=;i<;i++)
for(int k=;k<;k++)
for(int j=;j<;j++)
ret.a[i][j]=max(ret.a[i][j],a[i][k]+y.a[k][j]);
return ret;
}
}s[xn],t[xn<<];
int rd()
{
int ret=,f=; char ch=getchar();
while(ch<''||ch>''){if(ch=='-')f=; ch=getchar();}
while(ch>=''&&ch<='')ret=(ret<<)+(ret<<)+ch-'',ch=getchar();
return f?ret:-ret;
}
void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;}
int maxx(int a,int b){return a>b?a:b;}
void dfs(int x,int ff)
{
fa[x]=ff; siz[x]=; f[x][]=a[x];
for(int i=hd[x],u;i;i=nxt[i])
{
if((u=to[i])==ff)continue;
dfs(u,x); siz[x]+=siz[u];
if(siz[u]>siz[son[x]])son[x]=u;
f[x][]+=f[u][]; f[x][]+=maxx(f[u][],f[u][]);
}
}
void dfsx(int x)
{
dfn[x]=++tim; id[tim]=x;
if(son[x])top[son[x]]=top[x],dfsx(son[x]);
for(int i=hd[x],u;i;i=nxt[i])
if((u=to[i])!=fa[x]&&u!=son[x])top[u]=u,dfsx(u);
if(!son[x])ed[top[x]]=dfn[x];//!!链底
//ed[x]=tim;
}
void build(int x,int l,int r)
{
if(l==r)
{
int nw=id[l],g0=,g1=a[nw];
for(int i=hd[nw],u;i;i=nxt[i])
if((u=to[i])!=fa[nw]&&u!=son[nw])g0+=maxx(f[u][],f[u][]),g1+=f[u][];
t[x].a[][]=t[x].a[][]=g0; t[x].a[][]=g1;
s[l]=t[x]; return;//s[l]!
}
build(ls,l,mid); build(rs,mid+,r);
t[x]=t[ls]*t[rs];
}
N query(int x,int l,int r,int L,int R)
{
if(l>=L&&r<=R)return t[x];
if(mid>=R)return query(ls,l,mid,L,R);
if(mid<L)return query(rs,mid+,r,L,R);
return query(ls,l,mid,L,R)*query(rs,mid+,r,L,R);
}
N get(int x){return query(,,n,dfn[x],ed[x]);}
void chg(int x,int l,int r,int pos)
{
if(l==r){t[x]=s[l]; return;}//!s[x]!
if(pos<=mid)chg(ls,l,mid,pos);
else chg(rs,mid+,r,pos);
t[x]=t[ls]*t[rs];
}
void work(int x,int ss)
{
s[dfn[x]].a[][]+=ss-a[x]; a[x]=ss;//dfn[x]
N nw,pr;
while()
{
pr=get(top[x]); chg(,,n,dfn[x]); nw=get(top[x]);
x=fa[top[x]]; if(!x)return;
s[dfn[x]].a[][]+=maxx(nw.a[][],nw.a[][])-maxx(pr.a[][],pr.a[][]);
s[dfn[x]].a[][]=s[dfn[x]].a[][];
s[dfn[x]].a[][]+=nw.a[][]-pr.a[][];//dfn[x]
}
}
int main()
{
n=rd(); m=rd();
for(int i=;i<=n;i++)a[i]=rd();
for(int i=,x,y;i<n;i++)x=rd(),y=rd(),add(x,y),add(y,x);
dfs(,); top[]=; dfsx(); build(,,n);
for(int i=,x,y;i<=m;i++)
{
x=rd(); y=rd(); work(x,y); N tmp=get();
printf("%d\n",maxx(tmp.a[][],tmp.a[][]));
}
return ;
}