AHOI2013 差异 【后缀数组】

时间:2022-01-22 10:00:12

题目分析:

求出height以后很明显跨越最小height的一定贡献是最小height,所以对于区间找出最小height再将区间对半分。

代码:

 #include<bits/stdc++.h>
using namespace std; const int maxn = ;
const int N = ; int n;
char str[maxn]; int sa[maxn],rk[maxn],X[maxn],Y[maxn];
int height[maxn],h[maxn],RMQ[maxn][],pos[maxn][]; int chk(int x,int k){
return rk[sa[x]]==rk[sa[x-]]&&rk[sa[x]+(<<k)]==rk[sa[x-]+(<<k)];
} void getsa(){
for(int i=;i<n;i++) X[str[i]]++;
for(int i=;i<=N;i++) X[i] += X[i-];
for(int i=n-;i>=;i--) sa[X[str[i]]--] = i;
for(int i = , num = ;i <= n;i++)
rk[sa[i]] = (str[sa[i]] == str[sa[i-]]?num:++num);
rk[sa[]] = ;
for(int k=;(<<k-)<=n;k++){
for(int i=;i<=N;i++) X[i] = ;
for(int i=n-(<<k-);i<n;i++) Y[i-n+(<<k-)+]=i;
for(int i=,j=(<<k-)+;i<=n;i++)
if(sa[i]>=(<<k-))Y[j++]=sa[i]-(<<k-);
for(int i=;i<n;i++) X[rk[i]]++;
for(int i=;i<=N;i++) X[i]+=X[i-];
for(int i=n;i>=;i--) sa[X[rk[Y[i]]]--] = Y[i];
int num = ; Y[sa[]] = ;
for(int i=;i<=n;i++) Y[sa[i]] = (chk(i,k-)?num:++num);
for(int i=;i<n;i++) rk[i] = Y[i];
if(num == n) break;
}
}
void getheight(){
for(int i=;i<n;i++){
if(i) h[i] = max(,h[i-]-); else h[i] = ;
if(rk[i] == ) continue;
int comp = sa[rk[i]-];
while(str[comp+h[i]] == str[i+h[i]])h[i]++;
}
for(int i=;i<n;i++) height[rk[i]] = h[i];
for(int i=;i<=n;i++) RMQ[i][] = height[i],pos[i][] = i;
for(int k=;(<<k)<=n;k++){
for(int i=;i<=n;i++){
if(i+(<<k-)>n) RMQ[i][k]=RMQ[i][k-],pos[i][k]=pos[i][k-];
else {
if(RMQ[i][k-]<RMQ[i+(<<k-)][k-]) pos[i][k] = pos[i][k-];
else pos[i][k] = pos[i+(<<k-)][k-];
RMQ[i][k] = min(RMQ[i][k-],RMQ[i+(<<k-)][k-]);
}
}
}
}
int getLCP(int L,int R){
if(L > R) swap(L,R);
if(L == R) return n-sa[L];
L++;
int k = ; while((<<k+)<=R-L+)k++;
if(RMQ[L][k]<RMQ[R-(<<k)+][k]) return pos[L][k];
else return pos[R-(<<k)+][k];
} long long ans = ; void divide(int l,int r){
if(l == r) return;
int ps = getLCP(l,r);
ans -= 2ll*(ps-l)*(r-ps+)*height[ps];
divide(l,ps-); divide(ps,r);
} void work(){
n = strlen(str);
getsa();
getheight();
for(int i=;i<=n;i++) ans += 1ll*i*i-i;
for(int i=;i<=n;i++) ans += 1ll*i*(n-i);
divide(,n);
printf("%lld\n",ans);
} int main(){
scanf("%s",str);
work();
return ;
}