题解:
就是按照常规的合并
期望有一点麻烦
首先计算全部的和
再减去有多少种
具体看看http://blog.****.net/PoPoQQQ/article/category/2542261这个博客吧
代码:
#include<bits/stdc++.h>
using namespace std;
#define pa t[x].fa
#define lc t[x].ch[0]
#define rc t[x].ch[1]
const int N=5e4+;
typedef long long ll;
int read()
{
char c=getchar();int x=,f=;
while(c<''||c>''){if (c=='-')f=-;c=getchar();}
while(c>=''&&c<=''){x=x*+c-'';c=getchar();}
return x*f;
}
struct node
{
int ch[],fa,rev;
ll add,lsum,rsum,sum,exp,w,size;
}t[N];
int wh(int x){return t[pa].ch[]==x;}
int isRoot(int x){return t[pa].ch[]!=x&&t[pa].ch[]!=x;}
void update(int x)
{
t[x].size=t[lc].size+t[rc].size+;
t[x].sum=t[lc].sum+t[rc].sum+t[x].w;
t[x].lsum=t[lc].lsum+t[x].w*(t[lc].size+)+t[rc].lsum+t[rc].sum*(t[lc].size+);
t[x].rsum=t[rc].rsum+t[x].w*(t[rc].size+)+t[lc].rsum+t[lc].sum*(t[rc].size+);
t[x].exp=t[lc].exp+t[rc].exp
+t[lc].lsum*(t[rc].size+)+t[rc].rsum*(t[lc].size+)
+t[x].w*(t[lc].size+)*(t[rc].size+);
}
ll cal1(ll x){return x*(x+)/;}
ll cal2(ll x){return x*(x+)*(x+)/;}
void paint(int x,ll d)
{
t[x].w+=d;
t[x].add+=d;
t[x].sum+=d*t[x].size;
t[x].lsum+=d*cal1(t[x].size);
t[x].rsum+=d*cal1(t[x].size);
t[x].exp+=d*cal2(t[x].size);
}
void rever(int x)
{
swap(lc,rc);
swap(t[x].lsum,t[x].rsum);
t[x].rev^=;
}
void pushDown(int x)
{
if (t[x].rev)
{
rever(lc);
rever(rc);
t[x].rev=;
}
if (t[x].add)
{
paint(lc,t[x].add);
paint(rc,t[x].add);
t[x].add=;
}
}
void rotate(int x)
{
int f=t[x].fa,g=t[f].fa,c=wh(x);
if (!isRoot(f)) t[g].ch[wh(f)]=x;t[x].fa=g;
t[f].ch[c]=t[x].ch[c^];t[t[f].ch[c]].fa=f;
t[x].ch[c^]=f;t[f].fa=x;
update(f);update(x);
}
int st[N],top;
void splay(int x)
{
top=;st[++top]=x;
for (int i=x;!isRoot(i);i=t[i].fa) st[++top]=t[i].fa;
for (int i=top;i>=;i--) pushDown(st[i]);
for (;!isRoot(x);rotate(x))
if (!isRoot(pa)) rotate(wh(x)==wh(pa)?pa:x);
}
void Access(int x)
{
for (int y=;x;y=x,x=pa)
{
splay(x);
rc=y;
update(x);
}
}
void MakeR(int x){Access(x);splay(x);rever(x);}
int FindR(int x){Access(x);splay(x);while(lc) x=lc;return x;}
void Link(int x,int y){MakeR(x);t[x].fa=y;}
void Cut(int x,int y)
{
MakeR(x);Access(y);splay(y);
t[y].ch[]=t[x].fa=;
update(y);
}
void Add(int x,int y,int d)
{
if (FindR(x)!=FindR(y)) return;
MakeR(x);Access(y);splay(y);
paint(y,d);
}
ll gcd(ll a,ll b){return b==?a:gcd(b,a%b);}
void Que(int x,int y)
{
if (FindR(x)!=FindR(y)){puts("-1");return;}
MakeR(x);Access(y);splay(y);
ll a=t[y].exp,b=t[y].size*(t[y].size+)/;
ll g=gcd(a,b);
printf("%lld/%lld\n",a/g,b/g);
}
int n,Q,a,op,x,y,d;
int main()
{
n=read();Q=read();
for (int i=;i<=n;i++)
{
a=read();
t[i].size=;
t[i].w=t[i].lsum=t[i].rsum=t[i].sum=t[i].exp=a;
}
for (int i=;i<=n-;i++) x=read(),y=read(),Link(x,y);
while(Q--)
{
op=read();x=read();y=read();
if (op==) if (FindR(x)==FindR(y)) Cut(x,y);
if (op==) if (FindR(x)!=FindR(y)) Link(x,y);
if (op==) d=read(),Add(x,y,d);
if (op==) Que(x,y);
}
}