【JZOJ5363】【NOIP2017提高A组模拟9.14】生命之树 Trie+启发式合并

时间:2022-02-13 10:18:17

题面

【JZOJ5363】【NOIP2017提高A组模拟9.14】生命之树 Trie+启发式合并
【JZOJ5363】【NOIP2017提高A组模拟9.14】生命之树 Trie+启发式合并

45

在比赛中,我只想到了45分的暴力。
对于一个树中点对,相当于在他们的LCA及其祖先加上这个点对的贡献。
那么这个可以用dfs序+树状数组来维护。

100

想法

我想到了可能要用trie树来维护这个字符串的公共前缀。
然后这就面临了两个很严重的问题。
1.我对于每个子树都要建一个trie,所以这是\(O(n^2)\)的复杂度。
我想到了要合并儿子的信息,但是这个合并似乎是无法存储。
2.我还要处理xor的问题,我的想法是在trie上的每个结点上维护一个蜜汁容器。
可能这要用到xor的某些运算法则,但我并不知道如何实现。

然后正解就恰好解决了我这两个问题。

zrO lhy Orz

1.trie数可以使用启发式合并,那么时间复杂度就降为\(O(nlogn)\)
合并的时候,可以抛弃掉子树的信息,所以空间复杂度不会超过\(O(n)\)
2.xor我们考虑按位分治,那么我们给trie上的每个结点维护一个\(cnt[i][j][0/1]\)
表示这个结点\(i\)为根的子树内,有多少个数的二进制下第\(j\)位为\(0/1\)的个数。
这个在trie合并时可以简单合并。同时在合并的时候就能利用这个\(cnt\)统计答案。
具体就不展开,也就是\(cnt(*)(*)[0]*cnt(*)(*)[1]\)之类的。

Code

#include<bits/stdc++.h>
#define ll long long
#define fo(i,x,y) for(int i=x;i<=y;i++)
#define fd(i,x,y) for(int i=x;i>=y;i--)
#define ln(x,y) int(log(x)/log(y))
using namespace std;
const char* fin="1.in";
const char* fout="1.out";
const int inf=0x7fffffff;
int read(){
int x=0;
char ch=getchar();
while (ch<'0' || ch>'9') ch=getchar();
while (ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
return x;
}
const int maxn=100007,maxm=maxn*2,maxt=600007;
int fi[maxn],la[maxm],ne[maxm],tot;
void add_line(int a,int b){
tot++;
ne[tot]=fi[a];
la[tot]=b;
fi[a]=tot;
}
void add(int a,int b){add_line(a,b);add_line(b,a);}
int n,a[maxn],rt[maxn],si[maxn],num;
ll ans[maxn];
struct node{
int ne[26],cnt[17][2],cn[17][2];
}ac[maxt];
int b[maxn][2],hd,tl;
void dfs(int p,int _p,int de,ll &z){
fo(i,0,25){
int x=ac[p].ne[i],y=ac[_p].ne[i];
if (x){
fo(j,0,16) z+=1ll*ac[x].cnt[j][0]*(ac[_p].cnt[j][1]-ac[y].cnt[j][1])*(1<<j)*de,z+=1ll*ac[x].cnt[j][1]*(ac[_p].cnt[j][0]-ac[y].cnt[j][0])*(1<<j)*de;
if (y) dfs(x,y,de+1,z);
}
}
fo(j,0,16) z+=1ll*ac[p].cn[j][0]*ac[_p].cnt[j][1]*de*(1<<j),z+=1ll*ac[p].cn[j][1]*ac[_p].cnt[j][0]*de*(1<<j);
}
void link(int p,int _p){
fo(i,0,16){
ac[p].cn[i][0]+=ac[_p].cn[i][0];
ac[p].cn[i][1]+=ac[_p].cn[i][1];
ac[p].cnt[i][0]=ac[p].cn[i][0];
ac[p].cnt[i][1]=ac[p].cn[i][1];
}
fo(i,0,25){
int x=ac[p].ne[i],y=ac[_p].ne[i];
if (x && y) link(x,y);
else if (y) ac[p].ne[i]=y;
if (ac[p].ne[i]){
int x=ac[p].ne[i];
fo(i,0,16){
ac[p].cnt[i][0]+=ac[x].cnt[i][0];
ac[p].cnt[i][1]+=ac[x].cnt[i][1];
}
}
}
}
void merge(int x,int y,ll &z){
dfs(rt[x],rt[y],0,z);
link(rt[x],rt[y]);
si[x]+=si[y];
}
int main(){
freopen(fin,"r",stdin);
freopen(fout,"w",stdout);
scanf("%d",&n);
fo(i,1,n) scanf("%d",&a[i]);
fo(i,1,n){
char ch=getchar();
while (ch<'a' || ch>'z') ch=getchar();
rt[i]=++num;
int x=rt[i];
while (ch>='a' && ch<='z'){
fo(k,0,16) ac[x].cnt[k][a[i]>>k&1]++;
int y=ch-'a';
si[i]++;
x=ac[x].ne[y]=++num;
ch=getchar();
}
fo(k,0,16) ac[x].cnt[k][a[i]>>k&1]++,ac[x].cn[k][a[i]>>k&1]++;
}
fo(i,1,n-1) add(read(),read());
hd=tl=0;
b[++tl][0]=1;
while (hd++<tl){
int v=b[hd][0],from=b[hd][1];
for(int k=fi[v];k;k=ne[k])
if (la[k]!=from) b[++tl][0]=la[k],b[tl][1]=v;
}
fd(i,tl,1){
int v=b[i][0],from=b[i][1];
int mx=v;
for(int k=fi[v];k;k=ne[k])
if (la[k]!=from){
ans[v]+=ans[la[k]];
if (!mx || si[mx]<si[la[k]]) mx=la[k];
}
if (mx!=v) merge(mx,v,ans[v]);
for(int k=fi[v];k;k=ne[k])
if (la[k]!=from && la[k]!=mx){
merge(mx,la[k],ans[v]);
}
rt[v]=rt[mx];
}
fo(i,1,n) printf("%lld\n",ans[i]);
return 0;
}