BZOJ 3530: [Sdoi2014]数数 [AC自动机 数位DP]

时间:2023-03-08 20:35:39
BZOJ 3530: [Sdoi2014]数数 [AC自动机 数位DP]

3530: [Sdoi2014]数数

题意:\(\le N\)的不含模式串的数字有多少个,\(n=|N| \le 1200\)


考虑数位DP

对于长度\(\le n\)的,普通套路DP\(g[i][j]\)即可

对于长度\(=n\)的,需要考虑天际线,\(f[i][j][0/1]\)表示从高开始i位走到节点j,是否卡上界的方案数

需要注意的是前导0的处理,不能出现前导0,所以\(f[0]\)往外转移的时候不能走0

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int N=2005, P=1e9+7;
typedef long long ll;
inline int read(){
char c=getchar();int x=0,f=1;
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
return x*f;
} int n, m;
char a[N], s[N];
inline void mod(int &x) {if(x>=P) x-=P;}
namespace ac{
struct meow{int ch[10], fail, val;}t[N];
int sz;
void insert(char *s) {
int len=strlen(s+1), u=0;
for(int i=1; i<=len; i++) {
int c=s[i]-'0';
if(!t[u].ch[c]) t[u].ch[c] = ++sz;
u=t[u].ch[c];
}
t[u].val=1;
}
int q[N], head, tail;
void build() {
head=tail=1;
for(int i=0; i<10; i++) if(t[0].ch[i]) q[tail++]=t[0].ch[i];
while(head!=tail) {
int u=q[head++];
t[u].val |= t[t[u].fail].val;
for(int i=0; i<10; i++) {
int &v=t[u].ch[i];
if(!v) v = t[t[u].fail].ch[i];
else t[v].fail = t[t[u].fail].ch[i], q[tail++]=v;
}
}
}
int f[N][N][2], g[N][N], ans;
void dp() {
g[0][0]=1;
for(int i=0; i<n; i++)
for(int u=0; u<=sz; u++) if(!t[u].val) {
for(int k=0; k<10; k++) if(!t[t[u].ch[k]].val) {
if(i==0 && k==0) continue;
mod(g[i+1][ t[u].ch[k] ] += g[i][u]);
}
}
for(int i=1; i<n; i++) for(int j=0; j<=sz; j++) mod(ans += g[i][j]); f[0][0][1]=1; //f[0][0][0]=1;
for(int i=0; i<n; i++) { //printf("\niii %d %d\n",i, a[i+1]-'0');
for(int u=0; u<=sz; u++) if(!t[u].val) { //printf("uuu %d %d %d\n",u,f[i][u][0],f[i][u][1]);
for(int k=0; k<10; k++) if(!t[t[u].ch[k]].val) {
if(i==0 && k==0) continue;
int v=t[u].ch[k]; //printf("v %d %d\n",k,v);
mod(f[i+1][v][0] += f[i][u][0]);
if(k < a[i+1]-'0') mod(f[i+1][v][0] += f[i][u][1]);
if(k == a[i+1]-'0') mod(f[i+1][v][1] += f[i][u][1]);
}
}
}
//for(int i=1; i<=n; i++) for(int j=0; j<=sz; j++) printf("f %d %d %d %d\n",i,j,f[i][j][0],f[i][j][1]);
for(int i=0; i<=sz; i++) {
mod(ans += f[n][i][0]);
mod(ans += f[n][i][1]);
}
printf("%d", ans);
}
}
int main() {
freopen("in","r",stdin);
scanf("%s",a+1); n=strlen(a+1);
m=read();
for(int i=1; i<=m; i++) scanf("%s",s+1), ac::insert(s);
ac::build();
ac::dp();
}