[NOIP2018]保卫王国(树形dp+倍增)

时间:2024-01-08 15:03:44

我的倍增解法吊打动态 \(dp\) 全局平衡二叉树没学过

先讲 \(NOIP\) 范围内的倍增解法。

我们先考虑只有一个点取/不取怎么做。

\(f[x][0/1]\) 表示取/不取 \(x\) 后,\(x\) 子树内的最小权覆盖集,\(g[x][0/1]\) 表示取/不取 \(x\) 后,除 \(x\) 子树的最小权覆盖集。那么这两个数组可以 \(O(n)\) 预处理出来。

\[f[x][0]+=f[y][1]
\]

\[f[x][1]+=min(f[y][0],f[y][1])
\]

\[g[y][0]=g[x][1]+f[x][1]-min(f[y][0],f[y][1])
\]

\[g[y][1]=min(g[y][0],g[x][0]+f[x][0]-f[y][1])
\]

那么我们可以 \(a\) 表示 \(x\) 结点的状态,那么 \(ans=f[x][a]+g[x][a]\)

现在我们考虑两个点取/不取怎么做。

我们发现每次影响的只有两点 \(lca\) 的子树内,所以考虑倍增。

我们用 \(anc\) 表示 \(x\) 结点上跳 \(2^i\) 层的祖先,那么 \(w[x][i][0/1][0/1]\) 表示 \(x\) 取/不取,\(anc\) 取/不取,\(anc\) 子树 \(-\) \(x\) 子树的最小权覆盖集,这个数组我们可以 \(O(n\log n)\) 预处理出来。

我们每次枚举 \(x\) 和 \(anc\) 的四种状态,然后再枚举 \(x\) 结点上跳 \(2^{i-1}\) 层的祖先的状态,然后直接取个 \(min\) 就可以了。

for(int u=0;u<2;u++)
for(int v=0;v<2;v++){
w[i][j][u][v]=inf;
for(int k=0;k<2;k++)
w[i][j][u][v]=min(w[i][j][u][v],w[i][j-1][u][k]+w[tmp][j-1][k][v]);
}

然后再倍增。我们每次想处理 \(w\) 数组一样一直将 \(x\) 结点和 \(y\) 结点向上跳,然后统计答案。

时间复杂度 \(O(n\log n)\)

\(Code\ Below:\)

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=100000+10;
const ll inf=0x7f7f7f7f7f7f;
int n,m,val[maxn],dep[maxn],fa[maxn][18],head[maxn],to[maxn<<1],nxt[maxn<<1],tot;
ll f[maxn][2],g[maxn][2],w[maxn][18][2][2];char op[10]; inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
} inline void addedge(int x,int y){
to[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
} void dfs1(int x,int Fa){
dep[x]=dep[Fa]+1;
fa[x][0]=Fa;f[x][1]=val[x];
for(int i=head[x],y;i;i=nxt[i]){
y=to[i];
if(y==Fa) continue;
dfs1(y,x);
f[x][0]+=f[y][1];
f[x][1]+=min(f[y][0],f[y][1]);
}
} void dfs2(int x){
for(int i=head[x],y;i;i=nxt[i]){
y=to[i];
if(y==fa[x][0]) continue;
g[y][0]=g[x][1]+f[x][1]-min(f[y][0],f[y][1]);
g[y][1]=min(g[y][0],g[x][0]+f[x][0]-f[y][1]);
dfs2(y);
}
} ll solve(int a,int x,int b,int y){
if(dep[x]<dep[y]) swap(x,y),swap(a,b);
ll nx[2],ny[2],tx[2]={inf,inf},ty[2]={inf,inf};
tx[a]=f[x][a];ty[b]=f[y][b];
for(int i=17;i>=0;i--)
if(dep[fa[x][i]]>=dep[y]){
nx[0]=nx[1]=inf;
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
nx[j]=min(nx[j],tx[k]+w[x][i][k][j]);
tx[0]=nx[0];tx[1]=nx[1];x=fa[x][i];
}
if(x==y) return tx[b]+g[y][b];
for(int i=17;i>=0;i--)
if(fa[x][i]!=fa[y][i]){
nx[0]=nx[1]=ny[0]=ny[1]=inf;
for(int j=0;j<2;j++)
for(int k=0;k<2;k++){
nx[j]=min(nx[j],tx[k]+w[x][i][k][j]);
ny[j]=min(ny[j],ty[k]+w[y][i][k][j]);
}
tx[0]=nx[0];tx[1]=nx[1];x=fa[x][i];
ty[0]=ny[0];ty[1]=ny[1];y=fa[y][i];
}
int lca=fa[x][0];
ll ans1=f[lca][0]-f[x][1]-f[y][1]+tx[1]+ty[1]+g[lca][0];
ll ans2=f[lca][1]-min(f[x][0],f[x][1])-min(f[y][0],f[y][1])+min(tx[0],tx[1])+min(ty[0],ty[1])+g[lca][1];
return min(ans1,ans2);
} int main()
{
n=read(),m=read();scanf("%s",op);
int a,x,b,y,tmp;
for(int i=1;i<=n;i++) val[i]=read();
for(int i=1;i<n;i++){
x=read(),y=read();
addedge(x,y);addedge(y,x);
}
dfs1(1,0);dfs2(1);
for(int i=1;i<=n;i++){
tmp=fa[i][0];
w[i][0][0][0]=inf;
w[i][0][0][1]=f[tmp][1]-min(f[i][0],f[i][1]);
w[i][0][1][0]=f[tmp][0]-f[i][1];
w[i][0][1][1]=w[i][0][0][1];
}
for(int j=1;j<=17;j++)
for(int i=1;i<=n;i++){
tmp=fa[i][j-1];
if(fa[tmp][j-1]){
fa[i][j]=fa[tmp][j-1];
for(int u=0;u<2;u++)
for(int v=0;v<2;v++){
w[i][j][u][v]=inf;
for(int k=0;k<2;k++)
w[i][j][u][v]=min(w[i][j][u][v],w[i][j-1][u][k]+w[tmp][j-1][k][v]);
}
}
}
while(m--){
x=read(),a=read(),y=read(),b=read();
if(!a&&!b&&(x==fa[y][0]||y==fa[x][0])){
printf("-1\n");
continue;
}
printf("%lld\n",solve(a,x,b,y));
}
return 0;
}

然后就是 \(O(8n\log^2 n)\) 的树剖+线段树维护矩阵的动态 \(dp\) 了。

发现取/不取我们可以用 \(inf\) 和 \(-inf\) 代替,转化为最大权独立集来做。

\(Code\ Below:\)

#include <bits/stdc++.h>
#define int long long
#define lson (rt<<1)
#define rson (rt<<1|1)
using namespace std;
const int maxn=100000+10;
const int inf=1e10;
int n,m,v[maxn],val[maxn],dp[maxn][2],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,num,ans;
int top[maxn],ed[maxn],siz[maxn],son[maxn],fa[maxn],id[maxn],mp[maxn],tim;
char op[5]; struct Matrix{
int mat[2][2];
Matrix(){
memset(mat,0,sizeof(mat));
}
};
Matrix operator * (const Matrix &a,const Matrix &b){
Matrix c;
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
c.mat[i][j]=max(c.mat[i][j],a.mat[i][k]+b.mat[k][j]);
return c;
}
Matrix a[maxn],sum[maxn<<2]; inline void read(int &x){
x=0;bool f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
if(!f) x=-x;
} void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9) print(x/10);
putchar(x%10+'0');
} inline void add(int x,int y){
to[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
} void dfs1(int x,int f){
siz[x]=1;fa[x]=f;
int maxson=-1;
for(int i=head[x],y;i;i=nxt[i]){
y=to[i];
if(y==f) continue;
dfs1(y,x);
siz[x]+=siz[y];
if(siz[y]>maxson){
maxson=siz[y];
son[x]=y;
}
}
} void dfs2(int x,int topf){
id[x]=++tim;
mp[tim]=x;
top[x]=topf;
ed[topf]=x;
if(son[x]) dfs2(son[x],topf);
for(int i=head[x],y;i;i=nxt[i]){
y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
} void treedp(int x){
dp[x][0]=0;dp[x][1]=val[x];
for(int i=head[x],y;i;i=nxt[i]){
y=to[i];
if(y==fa[x]) continue;
treedp(y);
dp[x][0]+=max(dp[y][0],dp[y][1]);
dp[x][1]+=dp[y][0];
}
} inline void pushup(int rt){
sum[rt]=sum[lson]*sum[rson];
} void build(int l,int r,int rt){
if(l == r){
int x=mp[l],b[2]={0,val[x]};
for(int i=head[x],y;i;i=nxt[i]){
y=to[i];
if(y==fa[x]||y==son[x]) continue;
b[0]+=max(dp[y][0],dp[y][1]);
b[1]+=dp[y][0];
}
sum[rt].mat[0][0]=sum[rt].mat[0][1]=b[0];
sum[rt].mat[1][0]=b[1];a[x]=sum[rt];
return ;
}
int mid=(l+r)>>1;
build(l,mid,lson);
build(mid+1,r,rson);
pushup(rt);
} void update(int x,int l,int r,int rt){
if(l == r){
sum[rt]=a[mp[l]];
return ;
}
int mid=(l+r)>>1;
if(x <= mid) update(x,l,mid,lson);
else update(x,mid+1,r,rson);
pushup(rt);
} Matrix query(int L,int R,int l,int r,int rt){
if(L <= l && r <= R){
return sum[rt];
}
int mid=(l+r)>>1;
if(L > mid) return query(L,R,mid+1,r,rson);
if(R <= mid) return query(L,R,l,mid,lson);
return query(L,R,l,mid,lson)*query(L,R,mid+1,r,rson);
} void modify(int x,int y){
Matrix u,v;
a[x].mat[1][0]+=y-val[x];val[x]=y;
while(x){
u=query(id[top[x]],id[ed[top[x]]],1,n,1);
update(id[x],1,n,1);
v=query(id[top[x]],id[ed[top[x]]],1,n,1);
x=fa[top[x]];
if(x){
a[x].mat[0][0]+=max(v.mat[0][0],v.mat[1][0])-max(u.mat[0][0],u.mat[1][0]);
a[x].mat[0][1]=a[x].mat[0][0];
a[x].mat[1][0]+=v.mat[0][0]-u.mat[0][0];
}
}
} signed main()
{
read(n),read(m);
scanf("%s",op+1);
int x,c,d,y;
for(int i=1;i<=n;i++){
read(val[i]);
v[i]=val[i];num+=val[i];
}
for(int i=1;i<n;i++){
read(x),read(y);
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);
treedp(1);build(1,n,1);
Matrix u;
for(int i=1;i<=m;i++){
read(x),read(c),read(y),read(d);
if(c==0&&d==0&&(x==fa[y]||y==fa[x])){
printf("-1\n");
continue;
}
ans=num;
if(c==0) ans+=inf-val[x];
if(d==0) ans+=inf-val[y];
modify(x,(c==0)?inf:-inf);
modify(y,(d==0)?inf:-inf);
u=query(id[1],id[ed[1]],1,n,1);
ans-=max(u.mat[0][0],u.mat[1][0]);
modify(x,v[x]);modify(y,v[y]);
print(ans);putchar('\n');
}
return 0;
}