HDU3341 Lost's revenge(AC自动机&&dp)

时间:2023-03-09 06:31:40
HDU3341 Lost's revenge(AC自动机&&dp)

一看到ACGT就会想起AC自动机上的dp,这种奇怪的联想可能是源于某道叫DNA什么的题的。

题意,给你很多个长度不大于10的小串,小串最多有50个,然后有一个长度<40的串,然后让你将这个这个长度<40的串经过重新排列之后,小串在里面出现的次数总和最大。譬如如果我的小串是AA,AAC,长串是CAAA,我们重新排列成AAAC之后,AA在里面出现了2次,AAC出现了1次,总和是3次,这个数字就是我们要求的。

思路:思路跟HDU4758 walk through squares很像的,首先对每个小串插Trie树,建自动机,然后要做一下预处理,对于每个状态预处理出到达该状态时匹配了多少个小串,方法就是沿着失配边将cnt加起来。然后对于每个状态,如果它不存在某个字母的后继,就沿着失配边走找到存在该字母的后继,这样预处理后,后面的状态转移起来就比较方便。然后定义状态dp[A][C][G][T][sta]表示已经匹配的A,C,G,T对应为A,C,G,T个,在自动机上的状态为sta时所能匹配到的最大的状态数。然后转移就好。

Trick的部分是,虽然A,C,G,T所能产生的状态数最大是11*11*11*11(即40平均分的时候的情况),但是因为有可能有些字母出现40次,所以开的时候要dp[41][41][41][41][550],想到这里我就不知道怎么写了- -0。后来发现其实可以先hash一下,对于sta[i][j][k][t]=用一个数字s代表其状态,然后开一个数组p[s][0~3]存的是该状态对应的A,C,G,T数,然后再转移就好。

不过貌似跑的有点慢,3s多,感觉挺容易TLE的。

#pragma warning(disable:4996)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cmath>
#include<iostream>
#include<queue>
#define maxn 1500
using namespace std; char str[50][15];
char T[50];
int n; void convert(char *s){
int len = strlen(s);
for (int i = 0; i < len; i++){
if (s[i] == 'A') s[i] = 'a';
else if (s[i] == 'C') s[i] = 'b';
else if (s[i] == 'G') s[i] = 'c';
else s[i] = 'd';
}
} struct Trie{
Trie *fail, *go[4];
int cnt; bool flag;
void init(){
memset(go, 0, sizeof(go)); fail = NULL; cnt = 0; flag = false;
}
}pool[maxn],*root;
int tot; void insert(char *c){
int len = strlen(c); Trie *p = root;
for (int i = 0; i < len; i++){
if (p->go[c[i] - 'a'] != 0) p = p->go[c[i] - 'a'];
else{
pool[tot].init();
p->go[c[i] - 'a'] = &pool[tot++];
p = p->go[c[i] - 'a'];
}
}
p->cnt++;
} void getFail()
{
queue<Trie*> que;
que.push(root);
root->fail = NULL;
while (!que.empty()){
Trie *temp = que.front(); que.pop();
Trie *p = NULL;
for (int i = 0; i < 4; i++){
if (temp->go[i] != NULL){
if (temp == root) temp->go[i]->fail = root;
else{
p = temp->fail;
while (p != NULL){
if (p->go[i] != NULL){
temp->go[i]->fail = p->go[i]; break;
}
p = p->fail;
}
if (p == NULL) temp->go[i]->fail = root;
}
que.push(temp->go[i]);
}
}
}
} int dfs(Trie *p){
if (p == root) return 0;
if (p->flag == true) return p->cnt;
p->cnt += dfs(p->fail); p->flag = true;
return p->cnt;
} int sta[45][45][45][45];
int p[15000][4];
int stanum;
int dp[15000][520];
int A, B, C, D; int main()
{
int ca = 0;
while (cin >> n&&n)
{
tot = 0; root = &pool[tot++]; root->init();
for (int i = 0; i < n; i++){
scanf("%s", str[i]); convert(str[i]);
insert(str[i]);
}
scanf("%s", T); A = B = C = D = 0; int len = strlen(T);
for (int i = 0; i < len; i++){
if (T[i] == 'A') A++;
else if (T[i] == 'C') B++;
else if (T[i] == 'G') C++;
else D++;
}
getFail();
for (int i = 0; i < tot; i++) dfs(&pool[i]);
for (int i = 0; i < tot; i++){
Trie *p = &pool[i];
for (int k = 0; k < 4; k++){
if (p->go[k] == NULL){
Trie *temp = p; temp = temp->fail;
while (temp != NULL){
if (temp->go[k] != NULL) {
p->go[k] = temp->go[k]; break;
}
temp = temp->fail;
}
if (temp == NULL) p->go[k] = root;
}
}
}
stanum = 0;
for (int i = 0; i <= A; i++){
for (int j = 0; j <= B; j++){
for (int k = 0; k <= C; k++){
for (int t = 0; t <= D; t++){
sta[i][j][k][t] = stanum;
p[stanum][0] = i; p[stanum][1] = j;
p[stanum][2] = k; p[stanum][3] = t; stanum++;
}
}
}
}
memset(dp, -1, sizeof(dp)); int a, b, c, d;
dp[0][0] = 0;
for (int i = 0; i < stanum; i++){
a = p[i][0]; b = p[i][1]; c = p[i][2]; d = p[i][3];
for (int j = 0; j < tot; j++){
if (dp[i][j] == -1) continue;
if (a + 1 <= A) dp[sta[a + 1][b][c][d]][pool[j].go[0] - pool] =
max(dp[sta[a + 1][b][c][d]][pool[j].go[0] - pool], dp[i][j] + pool[j].go[0]->cnt); if (b + 1 <= B) dp[sta[a][b + 1][c][d]][pool[j].go[1] - pool] =
max(dp[sta[a][b + 1][c][d]][pool[j].go[1] - pool], dp[i][j] + pool[j].go[1]->cnt); if (c + 1 <= C) dp[sta[a][b][c + 1][d]][pool[j].go[2] - pool] =
max(dp[sta[a][b][c + 1][d]][pool[j].go[2] - pool], dp[i][j] + pool[j].go[2]->cnt); if (d + 1 <= D) dp[sta[a][b][c][d + 1]][pool[j].go[3] - pool] =
max(dp[sta[a][b][c][d + 1]][pool[j].go[3] - pool], dp[i][j] + pool[j].go[3]->cnt);
}
}
int ans = 0; int fin = sta[A][B][C][D];
for (int i = 0; i < tot; i++){
ans = max(ans, dp[fin][i]);
}
printf("Case %d: %d\n", ++ca, ans);
}
return 0;
}