BZOJ 4516: [Sdoi2016]生成魔咒——后缀数组、并查集

时间:2023-03-09 16:43:52
BZOJ 4516: [Sdoi2016]生成魔咒——后缀数组、并查集

传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4516

题意

一开始串为空,每次往串后面加一个字符,求本质不同的子串的个数,可以离线。即长度为N的字符串,对于每一个前缀,求本质不同的子串的个数。(字符集为int)

做法

首先,我们把所有的数字离散化。然后考虑后缀数组,我们把字符串倒过来,于是很神奇地,往最后加字符变成了添加一个后缀。

我们知道,在求出SA之后,一个字符串的本质不同的子串的个数等于(子串的个数)-(重复计数的个数)等于\(\frac{N*(N+1)}{2}-\sum height[i]\)。

那么,我们可以先把整个(倒过来的)字符串的SA求出来,然后尝试模拟插入后缀这个过程。

每插入一个后缀(插入的位置是这个后缀的rank),在所有已经插入的后缀中,找到它的前驱和后继。又因为两个后缀的LCP,即重复计数的子串个数等于height上的区间最小值,所以我们可以在\(O(1)\)的时间去用height上的最小值更新答案。

实现

这个当然是可以用ST表+线段树/平衡树维护的,但是因为可以离线(求后缀数组本来就要求离线...),我们有更简单的做法。

考虑从后往前做,将所有插入操作变成删除操作。当我们删除一个后缀时,只需要维护它在未删除的后缀中的前驱、后继到它之间的最小值即可(实际上并不需要显式地求出前驱后继)。

可以用双向链表/并查集实现。

我写了并查集的做法:将每个后缀看成一条边,连接的点代表两个后缀之间的LCP(height),删除后缀时连上相应的边,在unite()的时候维护min值。

其他

并查集也是可以离线求前驱后继的,在unite()的时候维护该段最左、最右点即可。

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN=100050, INF=0x3f3f3f3f;
void rd(int &x){
x=0; int ch=getchar();
while(ch<'0'||'9'<ch) ch=getchar();
while('0'<=ch&&ch<='9') x=x*10+ch-'0', ch=getchar();
}
int N, M, up;
int w[MAXN], rk[MAXN], ht[MAXN], sa[MAXN], c[MAXN];
int fa[MAXN], sz[MAXN], mn[MAXN], tmp[MAXN];
ll sum, ans[MAXN];
inline void chkmn(int &x, int y){if(x>y)x=y;}
void init(){
for(int i=0; i<=N; ++i) fa[i]=i, mn[i]=ht[i];
}
int find(int x){return x==fa[x]?x:(fa[x]=find(fa[x]));}
void unite(int x, int y){
x=find(x); y=find(y);
if(x==y) return;
if(sz[x]>sz[y]) swap(x,y);
fa[x]=y;
sum+=max(mn[x],mn[y]);
chkmn(mn[y],mn[x]);
}
inline int wcmp(int *x, int a, int b, int k){
return x[a]==x[b]&&x[a+k]==x[b+k];
}
inline void rsort(int *x, int *y){
memset(c, 0, sizeof(c));
for(int i=0; i<N; ++i) c[x[i]]++;
for(int i=1; i<up; ++i) c[i]+=c[i-1];
for(int i=N-1; i>=0; --i) sa[--c[x[y[i]]]]=y[i];
}
void getsa(){
int *x=rk, *y=ht;
for(int i=0; i<N; ++i) x[i]=w[i], y[i]=i;
rsort(x,y);
for(int k=1, p=0; p<N; k<<=1, up=p){
p=0;
for(int i=N-k; i<N; ++i) y[p++]=i;
for(int i=0; i<N; ++i) if(sa[i]>=k) y[p++]=sa[i]-k;
rsort(x,y); swap(x,y); p=0; x[sa[0]]=p++;
for(int i=1; i<N; ++i)
if(wcmp(y,sa[i],sa[i-1],k)) x[sa[i]]=p-1;
else x[sa[i]]=p++;
}
for(int i=0; i<N; ++i) rk[sa[i]]=i;
ht[0]=0;
for(int i=0, j, p=0; i<N-1; ++i){
for((p?p--:0),j=sa[rk[i]-1];w[i+p]==w[j+p];++p);
ht[rk[i]]=p;
}
}
int main(){
rd(N);
for(int i=1; i<=N; ++i) rd(w[N-i]), tmp[M++]=w[N-i];
sort(tmp,tmp+M); M=unique(tmp,tmp+M)-tmp;
for(int i=0; i<N; ++i) w[i]=lower_bound(tmp,tmp+M,w[i])-tmp+1;
M=N++; up=N+1; getsa();
sum=(ll)M*(M+1)/2;
for(int i=0; i<N; ++i) sum-=ht[i];
ans[N]=sum; init();
for(int i=0; i<M; ++i){
sum-=M-i;
unite(rk[i],rk[i]+1);
ans[M-i]=sum;
}
for(int i=2; i<=N; ++i) printf("%lld\n", ans[i]);
return 0;
}