bzoj 4034: [HAOI2015]T2

时间:2023-03-09 23:03:33
bzoj 4034: [HAOI2015]T2

4034: [HAOI2015]T2

Description

有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个

操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

Input

第一行包含两个整数 N, M 。表示点数和操作数。

接下来一行 N 个整数,表示树中节点的初始权值。

接下来 N-1 行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。
再接下来 M 行,每行分别表示一次操作。其中第一个数表示该操
作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。

Output

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。

Sample Input

5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3

Sample Output

6
9
13

HINT

对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不

会超过 10^6 。

Source

鸣谢bhiaibogf提供

题解:

树链剖分模板题。

刚学树剖,贴个模板吧。。

#include<stdio.h>
#include<iostream>
using namespace std;
const int N=30005;
#define p1 (p<<1)
#define p2 (p<<1|1)
char c[5];
int Q,n,i,x,y,k,id[N],fa[N],top[N],dep[N],s[N],t[N<<2],T[N<<2];
int tot,head[N],to[N<<1],Next[N<<1];
inline void read(int &v){
char ch,fu=0;
for(ch='*'; (ch<'0'||ch>'9')&&ch!='-'; ch=getchar());
if(ch=='-') fu=1, ch=getchar();
for(v=0; ch>='0'&&ch<='9'; ch=getchar()) v=v*10+ch-'0';
if(fu) v=-v;
}
void add(int x,int y)
{
to[tot]=y;
Next[tot]=head[x];
head[x]=tot++;
}
inline void dfs(int x,int pre)
{
s[x]=1;
for(int i=head[x];i!=-1;i=Next[i])
if(to[i]!=pre)
{
dep[to[i]]=dep[x]+1;
fa[to[i]]=x;
dfs(to[i],x);
s[x]+=s[to[i]];
}
}
inline void Dfs(int x,int v)
{
id[x]=++k;
top[x]=v;
int son=0,i;
for(i=head[x];i!=-1;i=Next[i])
if(dep[to[i]]>dep[x]&&s[to[i]]>s[son]) son=to[i];
if(!son) return;
Dfs(son,v);
for(i=head[x];i!=-1;i=Next[i])
if(dep[to[i]]>dep[x]&&to[i]!=son) Dfs(to[i],to[i]);
}
void update(int l,int r,int x,int y,int p)
{
if(l==r)
{
t[p]=T[p]=y;
return;
}
int mid=(l+r)>>1;
if(x<=mid) update(l,mid,x,y,p1);else update(mid+1,r,x,y,p2);
t[p]=max(t[p1],t[p2]);
T[p]=T[p1]+T[p2];
}
int Max(int l,int r,int x,int y,int p)
{
if(x<=l&&r<=y) return t[p];
int mid=(l+r)>>1,ans=-1e9;
if(x<=mid) ans=Max(l,mid,x,y,p1);
if(y>mid) ans=max(ans,Max(mid+1,r,x,y,p2));
return ans;
}
int Sum(int l,int r,int x,int y,int p)
{
if(x<=l&&r<=y) return T[p];
int mid=(l+r)>>1,ans=0;
if(x<=mid) ans=Sum(l,mid,x,y,p1);
if(y>mid) ans+=Sum(mid+1,r,x,y,p2);
return ans;
}
int solvemax(int x,int y)
{
int ans=-1e9;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=max(ans,Max(1,n,id[top[x]],id[x],1));
x=fa[top[x]];
}
if(id[x]>id[y]) swap(x,y);
ans=max(ans,Max(1,n,id[x],id[y],1));
return ans;
}
int solvesum(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=Sum(1,n,id[top[x]],id[x],1);
x=fa[top[x]];
}
if(id[x]>id[y]) swap(x,y);
ans+=Sum(1,n,id[x],id[y],1);
return ans;
}
int main()
{
read(n);
for(i=1;i<=n;i++) head[i]=-1;
for(i=1;i<=n<<2;i++) t[i]=-1e9;
for(i=1;i<n;i++)
{
read(x),read(y);
add(x,y);
add(y,x);
}
dfs(1,0);
Dfs(1,1);
for(i=1;i<=n;i++)
{
read(x);
update(1,n,id[i],x,1);
}
read(Q);
while(Q--)
{
scanf("%s",c);read(x),read(y);
if(c[1]=='M') printf("%d\n",solvemax(x,y));else
if(c[1]=='S') printf("%d\n",solvesum(x,y));else
update(1,n,id[x],y,1);
}
return 0;
}