2018.11.24 poj3415Common Substrings(后缀数组+单调栈)

时间:2022-11-26 18:11:25

传送门

常数实在压不下来(蒟蒻开O(3)都过不了)。

但有正确性233.

首先肯定得把两个字符串接在一起。

相当于heightheightheight数组被height&lt;kheight&lt;kheight<k的分成了几段,统计每段的贡献。

考虑段中每个heightheightheight作为最小值出现的次数就行了。

于是我们用单调栈求出每个位置向左右分别能延展到的最远下标然后统计答案就行了。

细节有点多。

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#define ri register int
using namespace std;
typedef long long ll;
const int N=2e5+5;
int n,m,len,ban,sa[N],sa2[N],ht[N],rk[N],L[N],R[N],stk[N],top=0;
ll ans=0;
char s[N],t[N];
inline void Sort(){
	static int cnt[N];
	for(ri i=1;i<=m;++i)cnt[i]=0;
	for(ri i=1;i<=n;++i)++cnt[rk[i]];
	for(ri i=2;i<=m;++i)cnt[i]+=cnt[i-1];
	for(ri i=n;i;--i)sa[cnt[rk[sa2[i]]]--]=sa2[i];
}
inline void getsa(){
	for(ri i=1;i<=n;++i)rk[i]=(int)s[i],sa2[i]=i;
	m=130,Sort();
	for(ri w=1,p=0;m^n;w<<=1,p=0){
		for(ri i=n-w+1;i<=n;++i)sa2[++p]=i;
		for(ri i=1;i<=n;++i)if(sa[i]>w)sa2[++p]=sa[i]-w;
		Sort(),swap(rk,sa2),rk[sa[1]]=p=1;
		for(ri i=2;i<=n;++i)rk[sa[i]]=(sa2[sa[i]]==sa2[sa[i-1]]&&sa2[sa[i]+w]==sa2[sa[i-1]+w])?p:++p;
		m=p;
	}
	for(ri i=1,j,k=0;i<=n;ht[rk[i++]]=k)for(k?--k:k,j=sa[rk[i]-1];s[i+k]==s[j+k];++k);
}
inline int min(int a,int b){return a<b?a:b;}
inline void calc(int l,int r){
	static int ca[N],cb[N];
	for(ri i=min(l-1,0);i<=r;++i)L[i]=R[i]=i,ca[i]=cb[i]=0;
	for(ri i=l;i<=r;++i)ca[i]=ca[i-1],cb[i]=cb[i-1],sa[i]<=len?++ca[i]:++cb[i];
	for(ri i=l+1;i<=r;++i){
		while(top&&ht[stk[top]]>ht[i])R[stk[top--]]=i-1;
		stk[++top]=i;
	}
	while(top)R[stk[top--]]=r;
	for(ri i=r;i>l;--i){
		while(top&&ht[stk[top]]>=ht[i])L[stk[top--]]=i+1;
		stk[++top]=i;
	}
	while(top)L[stk[top--]]=l;
	for(ri i=l+1;i<=r;++i)ans+=(ll)(ht[i]-ban+1)*((ll)(ca[R[i]]-ca[i-1])*(cb[i-1]-cb[L[i]-2])+(ll)(cb[R[i]]-cb[i-1])*(ca[i-1]-ca[L[i]-2]));
}
inline void solve(){
	for(ri l=1,r;l<=n;l=r+1){
		while(ht[l]<ban&&l<=n)++l;
		if(l>n)break;
		r=l;
		while(ht[r+1]>=ban&&r<n)++r;
		calc(l-1,r);
	}
	cout<<ans<<'\n';
}
int main(){
	while(scanf("%d",&ban)&&ban){
		ans=0,scanf("%s%s",s+1,t+1),n=strlen(s+1),len=strlen(t+1),s[++n]='@';
		for(ri i=1;i<=len;++i)s[++n]=t[i];
		len=n-len-1,getsa();
		solve();
	}
	return 0;
}