初探 插头DP

时间:2023-03-09 16:20:44
初探 插头DP

因为这题,气得我火冒三丈!

这数据是不是有问题啊!我用cin代替scanf后居然就AC了(本来一直卡在Test 18)!导致我调(对)试(排)了一个小时!!

UPD:后来细细想想,会不会是因为scanf的读入,数组要开大一点点呢?比如读一个长为\(n\)的字符串,需要一个\(str[n + 1]\)?

题目

就是找出有多少条经过所有可行格子的回路。

插头DP

一直没有时间学习,然后最近膜拜了一下cdq的《基于连通性状态压缩的动态规划问题》,然后写了一裸题。

其实也很好写嘛,不过在转移的时候要万分小心,还有要注意的是记录下第一个可行点最后一个可行点

代码

我的写法是把竖直的那一条轮廓线放在set的最后一位(就是最大那一位)。

//#define debug
//#define local #include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <assert.h>
using namespace std; #ifdef debug
#define ep(...) fprintf(stderr, __VA_ARGS__)
#else
#define ep(...) assert(true)
#endif typedef long long i64; const int MaxN = 12; int n, m;
//char A[MaxN][MaxN];
string A[MaxN]; const int MaxHashTable = 30001; #define getbit(x, i) ((x) >> ((i) << 1) & 3)
#define copybit(x, i, b) ((x) | ((b) << ((i) << 1)))
#define clrbit(x, i) ((x) & ~(3 << ((i) << 1)))
#define clrbit2(x, i, j) (clrbit( clrbit(x, j), i))
#define revbit(x, i) ((x) ^ (1 << ((i) << 1))) struct Hash
{
pair<int, i64> A[MaxHashTable];
int n; void update(bool lastone)
{
int i = 0;
if (lastone)
{
while (i < n)
{
if (A[i].first)
{
n --;
A[i] = A[n];
}
else i ++;
}
}
else
{
while (i < n)
{
int s = A[i].first;
if (getbit(s, m)) // assert clrbit(s, m)
{
n --;
A[i] = A[n];
}
else i ++;
}
}
} i64 total()
{
i64 ret = 0;
for (int i = 0; i < n; i ++)
ret += A[i].second;
return ret;
} struct Link
{
int to;
Link *next;
}pool[MaxHashTable], *pool_cur, *info[MaxHashTable];
int pool_counter, pool_mark[MaxHashTable]; void clear()
{
pool_counter ++;
n = 0;
pool_cur = pool;
} #ifdef debug
void print()
{
for (int i = 0; i < n; i ++)
ep("%d %lld\n", A[i].first, A[i].second);
ep("\n");
}
#endif void push(const int &st, const i64 &value)
{
int hash = st % MaxHashTable;
if (pool_mark[hash] != pool_counter)
{
pool_mark[hash] = pool_counter;
info[hash] = NULL;
}
for (Link *p = info[hash]; p; p = p->next)
{
if (A[p->to].first == st)
{
A[p->to].second += value;
return;
}
}
pool_cur->to = n;
pool_cur->next = info[hash];
info[hash] = pool_cur ++;
A[n ++] = make_pair(st, value);
#ifdef debug
assert(n <= MaxHashTable);
#endif
}
}; int getbracket0(const int &s, const int &i)
{
int cnt = 1;
for (int j = i + 1; j < m; j ++)
{
int t = getbit(s, j);
if (t)
{
if (t & 1) cnt --;
else cnt ++;
}
if (! cnt) return j;
}
assert(false);
return -1;
} int getbracket1(const int &s, const int &i)
{
int cnt = -1;
for (int j = i - 1; j >= 0; j --)
{
int t = getbit(s, j);
if (t)
{
if (t & 1) cnt --;
else cnt ++;
}
if (! cnt) return j;
}
assert(false);
return -1;
} int main()
{
#if defined(debug) || defined(local)
freopen("a.in", "r", stdin);
freopen("a.out", "w", stdout);
#endif #ifndef debug
while (true)
#else
for (int a = 0; a == 0; a = 1)
#endif
{
//if (scanf("%d%d\n", &n, &m) != 2) break;
if (! (cin >> n >> m)) break;
int lastrow = -1, lastcol = -1, firstrow = -1;
for (int i = 0; i < n; i ++)
{
//scanf("%s\n", A[i]);
cin >> A[i];
for (int j = 0; j < m; j ++)
if (A[i][j] == '.')
{
if (firstrow == -1) firstrow = i;
lastrow = i;
lastcol = j;
}
} if (lastrow == -1)
{
printf("0\n");
continue;
} static Hash dp[2];
dp[0].clear(), dp[1].clear();
int cur = 0, next = 1;
dp[cur].push(0, 1); for (int i = firstrow; i <= lastrow; i ++)
{
for (int j = 0; j < m; j ++)
{
for (int k = 0; k < dp[cur].n; k ++)
{
int s = dp[cur].A[k].first;
i64 val = dp[cur].A[k].second; int U = getbit(s, j);
int L = getbit(s, m); if (A[i][j] == '.')
{
if (L && U)
{
L &= 1, U &= 1;
if (!L && !U)
{
dp[next].push(revbit( clrbit2(s, j, m), getbracket0(s, j)), val);
}
else if (L ^ U)
{
if (L || (i == lastrow && j == lastcol))
dp[next].push(clrbit2(s, j, m), val);
}
else // assert L && U
{
dp[next].push(revbit( clrbit2(s, j, m), getbracket1(s, j)), val);
}
} else if (L)
{
dp[next].push(copybit(s, m, L), val);
dp[next].push(clrbit( copybit(s, j, L), m), val);
}
else if (U)
{
dp[next].push(s, val);
dp[next].push(clrbit( copybit(s, m, U), j), val);
}
else
{
dp[next].push(copybit( copybit(s, m, 3), j, 2), val);
}
} else if (!U && !L)
{
dp[next].push(s, val);
}
}
swap(cur, next);
dp[next].clear();
ep("for %d %d\n", i, j);
#ifdef debug
dp[cur].print();
#endif
} dp[cur].update(i == lastrow);
} ep("final:\n");
//dp[cur].update(true);
//dp[cur].print();
//printf("%I64d\n", dp[cur].total());
//printf("%lld\n", dp[cur].total());
ep("%lld\n", dp[cur].total());
cout << dp[cur].total() << endl;
} return 0;
}