UOJ#395. 【NOI2018】你的名字 字符串,SAM,线段树合并

时间:2023-03-08 17:13:18
UOJ#395. 【NOI2018】你的名字   字符串,SAM,线段树合并

原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ395.html

题解

记得同步赛的时候这题我爆0了,最暴力的暴力都没调出来。

首先我们看看 68 分怎么做

——求两个串的本质不同的公共子串个数。

  它是一个模板题,然而我当时并不会,甚至连SAM都忘了怎么写QAQ。

再简化一下:如何求一个串的本质不同的子串个数。

  给串建一个SAM,把所有节点代表的字符串个数(也就是 Max(x) - Max(fa(x)) 加起来就好了。

回到上一个问题。

假设这两个串分别是 S,T 。对 T 建个SAM。

对于T的SAM,考虑对于它的任何一个节点 x ,算出 x 的 Right 集合代表的所有前缀与 S 的所有前缀的 LCS 的最大值(也就是这个节点代表的状态能在 S 上匹配的最长长度),设为 val(x)。然后对于所有 x 把 $(1,val(x)] \cap (Max(fa(x)),Max(x)]$ 的长度加起来就好了。

那么如何求那个最长的匹配长度?对 S 建一个 SAM,然后用 T 在 S 的 SAM 上走一遍,找到 T 的每一个前缀的 最长的是 S 的子串的后缀  然后 T 的 SAM 上的一个节点的 val 就是他在 parent 树上的所有后代节点的 Max 。

由于 S 的 SAM 可以预先建好,所以询问一个 T 串的复杂度是 $O(|T|)$ 的。

那么 S 有 [L,R] 的限制呢?

线段树合并预处理一下 S 的 SAM 的每一个节点的 Right 集合。

修改一下求最长的匹配长度的过程,保证走转移边的时候在 [L,R] 中有匹配。

注意这里有一个易错点:我们匹配失败跳 father 的时候,不能直接 len' = Max(father) ,只能不断减一。原因是在 len 不断减一的过程中可能会找到匹配,而直接跳 father 会漏过这个匹配。然而出题人数据出的很水,没注意到这个东西还是有96分!

至此,我们得到了一个 $O((|S|+\sum |T|)\log |S|)$ 的做法。

但是,由于在 SAM 上遍历节点暴力跳祖先的复杂度是 $O(n\sqrt n)$ 的,然后加个线段树合并多个 $\log$ ,总复杂度 $O(n\sqrt n \log  n)$ 的可以通过原题数据……wft??(UOJ Hack数据过不去的)

代码

#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof (x))
using namespace std;
typedef long long LL;
LL read(){
LL x=0,f=0;
char ch=getchar();
while (!isdigit(ch))
f|=ch=='-',ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
}
const int N=500005*4;
int n,m,q;
char s[N];
struct Node{
int Next[26],fa,Max,pos;
};
namespace Seg{
const int S=N*35;
int ls[S],rs[S],cnt=0;
void Ins(int &rt,int L,int R,int x){
if (!rt)
rt=++cnt;
if (L==R)
return;
int mid=(L+R)>>1;
if (x<=mid)
Ins(ls[rt],L,mid,x);
else
Ins(rs[rt],mid+1,R,x);
}
int Merge(int a,int b,int L,int R){
if (!a||!b)
return a+b;
int rt=++cnt;
if (L<R){
int mid=(L+R)>>1;
ls[rt]=Merge(ls[a],ls[b],L,mid);
rs[rt]=Merge(rs[a],rs[b],mid+1,R);
}
return rt;
}
int Query(int rt,int L,int R,int xL,int xR){
if (!rt||R<xL||L>xR||xL>xR)
return 0;
if (xL<=L&&R<=xR)
return 1;
int mid=(L+R)>>1;
return Query(ls[rt],L,mid,xL,xR)
|Query(rs[rt],mid+1,R,xL,xR);
}
}
namespace SAM{
Node t[N];
int root,last,size;
int rt[N],id[N];
void Init(){
while (size){
clr(t[size].Next);
t[size].fa=t[size].Max=t[size].pos=rt[size]=0;
size--;
}
root=last=size=1;
}
void extend(int c,int ps){
int p=last,np=++size,q,nq;
t[np].Max=t[p].Max+1,t[np].pos=ps;
Seg::Ins(rt[np],1,n,ps);
for (;p&&!t[p].Next[c];p=t[p].fa)
t[p].Next[c]=np;
if (!p)
t[np].fa=1;
else {
q=t[p].Next[c];
if (t[p].Max+1==t[q].Max)
t[np].fa=q;
else {
nq=++size;
t[nq]=t[q],t[nq].Max=t[p].Max+1,t[nq].pos=ps;
t[np].fa=t[q].fa=nq;
for (;p&&t[p].Next[c]==q;p=t[p].fa)
t[p].Next[c]=nq;
}
}
last=np;
}
void Sort(){
static int tax[N];
for (int i=0;i<=size;i++)
tax[i]=0;
for (int i=1;i<=size;i++)
tax[t[i].Max]++;
for (int i=1;i<=size;i++)
tax[i]+=tax[i-1];
for (int i=1;i<=size;i++)
id[tax[t[i].Max]--]=i;
}
void build(){
Sort();
for (int i=size;i>1;i--){
int x=id[i],f=t[x].fa;
rt[f]=Seg::Merge(rt[f],rt[x],1,n);
}
}
}
namespace sam{
Node t[N];
int root,last,size;
int id[N],val[N];
void Init(){
while (size){
clr(t[size].Next);
t[size].fa=t[size].Max=t[size].pos=val[size]=0;
size--;
}
root=last=size=1;
}
void extend(int c,int ps){
int p=last,np=++size,q,nq;
t[np].Max=t[p].Max+1,t[np].pos=ps;
for (;p&&!t[p].Next[c];p=t[p].fa)
t[p].Next[c]=np;
if (!p)
t[np].fa=1;
else {
q=t[p].Next[c];
if (t[p].Max+1==t[q].Max)
t[np].fa=q;
else {
nq=++size;
t[nq]=t[q],t[nq].Max=t[p].Max+1,t[nq].pos=ps;
t[np].fa=t[q].fa=nq;
for (;p&&t[p].Next[c]==q;p=t[p].fa)
t[p].Next[c]=nq;
}
}
last=np;
}
void Sort(){
static int tax[N];
for (int i=0;i<=size;i++)
tax[i]=0;
for (int i=1;i<=size;i++)
tax[t[i].Max]++;
for (int i=1;i<=size;i++)
tax[i]+=tax[i-1];
for (int i=1;i<=size;i++)
id[tax[t[i].Max]--]=i;
}
LL solve(){
Sort();
LL ans=0;
for (int i=size;i>1;i--){
int x=id[i],f=t[x].fa;
val[f]=max(val[x],val[f]);
ans+=max(0,t[x].Max-max(t[f].Max,val[x]));
}
return ans;
}
}
int main(){
scanf("%s",s+1);
n=strlen(s+1),q=read();
SAM::Init();
for (int i=1;i<=n;i++)
SAM::extend(s[i]-'a',i);
SAM::build();
SAM::t[0].Max=-1;
while (q--){
scanf("%s",s+1);
m=strlen(s+1);
sam::Init();
int L=read(),R=read();
int x=1,len=0;
for (int i=1;i<=m;i++){
int c=s[i]-'a',nowx=sam::size+1;
sam::extend(c,i);
while (x){
int nx=SAM::t[x].Next[c];
if (nx&&Seg::Query(SAM::rt[nx],1,n,L+len,R))
break;
if ((--len)==SAM::t[SAM::t[x].fa].Max)
x=SAM::t[x].fa;
}
if (!x)
x=1,len=0;
else {
x=SAM::t[x].Next[c];
sam::val[nowx]=++len;
}
}
printf("%lld\n",sam::solve());
}
return 0;
}