hdu 4622 Reincarnation trie树+树状数组/dp

时间:2023-01-28 07:04:20

题意:给你一个字符串和m个询问,问你l,r这个区间内出现过多少字串。

连接:http://acm.hdu.edu.cn/showproblem.php?pid=4622

网上也有用后缀数组搞得、

思路(虎哥):用字典树把每一个字符串对应成一个整数 相同的字符串对应到相同的整数上

把所用的串对应的整数放在一个数组里 比如书字符串s[l...r]对应的整数是 k

那么二维数组 [l][r] 就等于k

假设一个对应好的二维数组  左下角是原点

3     4     5     2

2     3     4     0

1     6     0     0

2     0     0     0

这样求解 从l到r的不同字符串的个数 其实就是求 从[l][r] 到右下角所在的矩阵所包含不同整数的个数(不包括0)

这里需要一定的去重处理 处理后是

-1    0    1     1

0     1    1     0

1     1    0     0

1     0    0     0

然后一边dp就可以求出所有答案(因为是求一个矩形矩阵,所以我用了一个二维树状数组做的,感觉好慢- -。)

注意常用的next[26]写法的字典树有可能超内存 要优化

代码:

 #include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#include <queue>
#define loop(s,i,n) for(i = s;i < n;i++)
#define cl(a,b) memset(a,b,sizeof(a))
#define lowbit(x) x&-x
using namespace std; const int maxm = ;
const int maxn = ; int head[maxm]; struct node
{
int next,v;
}g[maxm]; int cnt;
int c[maxn][maxn];
int pos[maxm];
char s[maxn];
int len;
void add(int a,int b,int val)
{
int i,j;
for(i = a;i <= len;i += lowbit(i))
{
for(j = b;j <= len;j += lowbit(j))
c[i][j] += val;
}
} int sum(int a,int b)
{
int res = ;
int i,j;
for(i = a;i > ;i -= lowbit(i))
{
for(j = b;j > ;j -= lowbit(j))
res+=c[i][j];
}
return res;
}
void insert(int &u,int key)
{
int i;
for(i = head[u];i != -;i = g[i].next)
{
int v;
v = g[i].v;
if(v == key)
{
u = i;
return ;
}
}
cnt++;
g[cnt].next = head[u];
g[cnt].v = key;
head[u] = cnt;
u = cnt;
return ;
}
int main()
{
int t;
//freopen("out.txt","w",stdout);
scanf("%d",&t);
while(t--)
{
scanf("%s",s);
len = strlen(s);
cl(c,);
cl(pos,-);
cl(head,-);
int j,i,loc;
loc = ;
cnt = ;
loop(,j,len)
{
loc = ;
for(i = j;i >= ;--i)
{
int key = s[i]-'a'; insert(loc,key); if(pos[loc] == -)
{
add(i+,j+,);
pos[loc] = i+;
}
else if(pos[loc] < (i+))
{
add(i+,j+,);
add(pos[loc],j+,-);
pos[loc] = i+;
}
}
}
int m,l,r;
scanf("%d",&m);
while(m--)
{
scanf("%d %d",&l,&r);
printf("%d\n",sum(l-,)+sum(len,r)-sum(l-,r)-sum(len,));
} }
return ;
}

这个是dp的。

 #include<stdio.h>
#include<iostream>
#include<algorithm>
#include<string.h>
using namespace std;
struct list
{
int next;
int v;
}g[];
int head[];
int vis[];
int ct;
char str[];
int c[][];
void add(int x,int &u)
{
for(int i=head[u];i!=-;i=g[i].next)
{
if(x==g[i].v)
{
u=i;return ;
}
}
g[ct].next=head[u];
g[ct].v=x;
head[u]=ct;
u=ct++;
}
int main()
{
int n,p;
int T,i,j;
scanf("%d%*c",&T);
while(T--)
{
ct=;
memset(head,-,sizeof(head));
memset(vis,-,sizeof(vis));
memset(c,,sizeof(c));
gets(str);
n=strlen(str);
for(j=;j<n;j++)
{
p=;
for(i=j;i>=;i--)
{
add(str[i],p);
if(vis[p]==-)
{
vis[p]=i;
c[i][j]++;
}
else if(vis[p]<i)
{
c[i][j]++;
c[vis[p]][j]--;
vis[p]=i;
}
}
}
// for(j=1;j<n;j++)c[0][j]+=c[0][j-1];
for(j=;j<n;j++)
{
for(i=j;i>=;i--)
{
c[i][j]+=c[i+][j]+c[i][j-]-c[i+][j-];
}
}
int q,a,b;
cin>>q;
while(q--)
{
scanf("%d%d%*c",&a,&b);
printf("%d\n",c[a-][b-]);
}
}
return ;
}