poj 1077 hdu 1043 Eight 八数码问题 DBFS(双向广度优先搜索)a*算法 康拓展开

时间:2022-08-13 09:51:54

一,八数码问题简介

  1. 编号为1到8的8个正方形滑块被摆成3行3列(有一个格子留空),可以每次把与空格相邻(有公共边)的滑块移动到空格中,而它原来的位置就成了新的空格。给定局面,计算出从当前状态移动到目标状态的最少步数。如将八数码从左到有从上到下的数字列出来,没有空格用0表示(其实也可以用9表示),可以表示为:
    2 6 4 1 3 7 0 5 8 -> 8 1 5 7 3 6 4 0 2
    在这里我们的目标状态是1 2 3 4 5 6 7 8 0
  2. 预备知识
    康拓展开,因为八数码中有八个数,再加上一个用0表示的空格,可以看成是0到8的全排列,总共有362880个状态,幸运的是康拓展开可以讲一个1-n的排列对应到整数0到n!-1,其实就是这个排列在所有排列中出现的位置。
    如2 4 1 3,这个排列在1到4的所有全排列的位置可以这样计算,第一个位置比2小的全排列个数是1*(4-1)!=6;第一个位置是2,第二个位置比4小的全排列个数是(3-1)*(4-2)!=4,这里3-1的原因是比4小的个数有3个,但是2已经出现在4前面了,就不必要在计算了,这样就可以计算出2 4 1 3对应的数字是6+4+0+0=10。其实就是看这个数字后边有多少个比自己小的数字,乘上后边还有多少个位置的全排列值。代码:
int getcode(int perm[], int len)
{
int ret = 0;
for (int i = 0; i < len; ++i) {
int cnt = 0;
for (int j = i + 1; j < len; ++j) {
if (perm[i] > perm[j]) {
++cnt;
}
}
// fac是存储阶乘值的数组
ret += fac[len - 1 - i] * cnt;
}
return ret;
}

还有另外一个小优化,就是我们忽略0之后(如果空格用9表示,则忽略9),每一步移动之后的全排列的逆序数的奇偶性是不变的,逆序数就是当前数字之后小于这个数的个数,全部加起来。
二,一般解法
就是使用一般的bfs,对于每一个状态[2 6 4 1 3 7 0 5 8],扩展其可以转到的状态,使用康拓展开标示这个状态以前有没有出现过,用一个父节点数组存储父节点,对于当前节点用一个方向数组存放从父节点到当前节点的移动方向,找到解之后,就按照存储的父节点位置,得到整个路径,最后输出就行了。可以加上上边提到的小优化。但是这个效率不高,在poj上可以通过,在hdu上会超时。代码如下:

/*************************************************************************
> File Name: 1077.cpp
> Author: gwq
> Mail: gwq5210@qq.com
> Created Time: 2015年08月12日 星期三 10时35分19秒
************************************************************************/


#include <cmath>
#include <ctime>
#include <cctype>
#include <climits>
#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <map>
#include <set>
#include <queue>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <algorithm>

#define INF (INT_MAX / 10)
#define clr(arr, val) memset(arr, val, sizeof(arr))
#define pb push_back
#define sz(a) ((int)(a).size())

using namespace std;
typedef set<int> si;
typedef vector<int> vi;
typedef map<int, int> mii;
typedef pair<int, int> pii;
typedef long long ll;

const double esp = 1e-5;

#define N 5
#define M 400000

int head, tail, orilen;
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
int st[M][9];
int goal[9] = {1, 2, 3, 4, 5, 6, 7, 8, 0};
int vis[M];
char mm[] = "urdl";
char strtmp[100];
int oritmp[9];
int fac[20];
int fa[M];
char direct[M];

int getcode(int s[])
{
int res = 0;
for (int i = 0; i < 9; ++i) {
int cnt = 0;
for (int j = i + 1; j < 9; ++j) {
if (s[i] > s[j]) {
++cnt;
}
}
res += cnt * fac[8 - i];
}
return res;
}

int try_insert(int idx)
{
int code = getcode(st[idx]);
if (vis[code]) {
return 0;
} else {
return vis[code] = 1;
}
}

void print(int t[])
{
for (int i = 0; i < 9; ++i) {
printf("%d ", t[i]);
}
printf("\n");
}

int bfs(void)
{
clr(vis, 0);
head = 1;
tail = 2;
memcpy(st[head], oritmp, sizeof(oritmp));
clr(fa, -1);
fa[1] = -1;
direct[1] = '*';
vis[getcode(st[head])] = 1;
while (head < tail) {
//print(st[head]);
if (memcmp(st[head], goal, sizeof(goal)) == 0) {
return head;
}
int idx = 0;
for (int i = 0; i < 9; ++i) {
if (st[head][i] == 0) {
idx = i;
}
}
int x = idx / 3;
int y = idx % 3;
//printf("%d %d %d\n", idx, x, y);
//getchar();
for (int i = 0; i < 4; ++i) {
int nx = x + dx[i];
int ny = y + dy[i];
int nidx = nx * 3 + ny;
if (nx >= 0 && nx < 3 && ny >= 0 && ny < 3) {
memcpy(st[tail], st[head], sizeof(st[head]));
swap(st[tail][nidx], st[tail][idx]);
fa[tail] = head;
direct[tail] = mm[i];
if (try_insert(tail)) {
++tail;
}
}
}
++head;
}
return 0;
}

/*
* poj 1077 可以ac,但是hdu1043是tle
*/

int main(int argc, char *argv[])
{
fac[0] = 1;
for (int i = 1; i <= 15; ++i) {
fac[i] = fac[i - 1] * i;
}
while (fgets(strtmp, 100, stdin) != NULL) {
int len = strlen(strtmp);
orilen = 0;
for (int i = 0; i < len; ++i) {
if (isdigit(strtmp[i])) {
oritmp[orilen++] = strtmp[i] - '0';
} else if (strtmp[i] == 'x') {
oritmp[orilen++] = 0;
}
}
int p = bfs();
if (p == 0) {
printf("unsolvable\n");
continue;
}
string ans;
//printf("%d\n", p);
while (fa[p] != -1) {
ans.pb(direct[p]);
p = fa[p];
}
reverse(ans.begin(), ans.end());
cout << ans << endl;
}
return 0;
}

为了优化算法,可以预先处理得到所有状态到目标状态的路径,最后,对于某一个状态,直接输出路径就行了。搜索的时候,从目标状态开始搜索,同时记录路径。因为poj上数据比较少,用这种方法会超时。
代码如下:

/*************************************************************************
> File Name: 1077pre.cpp
> Author: gwq
> Mail: gwq5210@qq.com
> Created Time: 2015年08月12日 星期三 19时07分45秒
************************************************************************/


#include <cmath>
#include <ctime>
#include <cctype>
#include <climits>
#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <map>
#include <set>
#include <queue>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <algorithm>

#define INF (INT_MAX / 10)
#define clr(arr, val) memset(arr, val, sizeof(arr))
#define pb push_back
#define sz(a) ((int)(a).size())

using namespace std;
typedef set<int> si;
typedef vector<int> vi;
typedef map<int, int> mii;
typedef pair<int, int> pii;
typedef long long ll;

const double esp = 1e-5;

#define N 600000

int orilen = 9;
int goalen;
int ori[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
int goal[9];
int fac[20];
string path[N];
int vis[N];
int st[N][9], head, tail;
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
char mm[] = "urdl";
char buf[100];

int getcode(int s[])
{
int res = 0;
for (int i = 0; i < 9; ++i) {
int cnt = 0;
for (int j = 0; j < i; ++j) {
if (s[j] < s[i]) {
++cnt;
}
}
res += cnt * fac[s[i] - 1];
}
return res;
}

void bfs(void)
{
head = 1;
tail = 2;
clr(vis, 0);
memcpy(st[head], ori, sizeof(ori));
int code = getcode(st[head]);
vis[code] = 1;
path[code] = "";
while (head < tail) {
int idx = 0;
for (int i = 0; i < 9; ++i) {
if (st[head][i] == 9) {
idx = i;
}
}
int x = idx / 3;
int y = idx % 3;
for (int i = 0; i < 4; ++i) {
int nx = x + dx[i];
int ny = y + dy[i];
int nidx = nx * 3 + ny;
if (nx >= 0 && nx < 3 && ny >= 0 && ny < 3) {
memcpy(st[tail], st[head], sizeof(st[head]));
swap(st[tail][nidx], st[tail][idx]);
code = getcode(st[head]);
int ncode = getcode(st[tail]);
if (!vis[ncode]) {
path[ncode] = path[code]
+ mm[(i + 2) % 4];
vis[ncode] = 1;
++tail;
}
}
}
++head;
}
}

/*
1. 从终点扩展,记录所用的路径
2. hdu1043可以过
*/

int main(int argc, char *argv[])
{
fac[0] = 1;
for (int i = 1; i < 20; ++i) {
fac[i] = fac[i - 1] * i;
}
bfs();
while (fgets(buf, 100, stdin) != NULL) {
int len = strlen(buf);
goalen = 0;
for (int i = 0; i < len; ++i) {
if (isdigit(buf[i])) {
goal[goalen++] = buf[i] - '0';
} else if (buf[i] == 'x') {
goal[goalen++] = 9;
}
}
int code = getcode(goal);
if (vis[code]) {
int len = path[code].size();
for (int i = len - 1; i >= 0; --i) {
printf("%c", path[code][i]);
}
printf("\n");
} else {
printf("unsolvable\n");
}
}

return 0;
}

三,a*算法
先来介绍一下A算法,在BFS算法中,若对每个状态n都设定估价函数f(n)=g(n)+h(n),并且每次从Open表中选节点进行扩展 时,都选取f值最小的节点,则称该搜索算法为启发式搜索算法,又称A算法。
在估价函数f(n)中,g(n)是从起始状态到当前状态n的代价,h(n)是从当前状态n到目标状态的估计代价。

A算法中若对估价函数选取不当,则可能找不到解,或者找到的解不是最优解。因此,需要对估价函数做一些限制,使得算法确保找到最优解。A*算法即为对估价函数做了特定限制,且确保找到最优解的A算法。

f*(n) = g*(n) + h*(n),其中,f*(n)是从初始节点S0出发,经过节点n到达目标节点的最小步数(真实值),g*(n)是从S0出发,到达n的最小步数(真实值),h*(n)是从n出发,到达目标节点的最少步数(真实值),而估价函数f(n)是f*(n)的估计值。

f(n)=g(n)+h(n),且满足:g(n)是从S0到n的真实步数(未必是最优的),因此,g(n)>0且g(n)>=g*(n),h(n)是从n到目标的估计步数,估计总是过于乐观的,即h(n)<=h*(n),且h(n)相容,则A算法转变成A*算法。

h(n)相容是指,如果对任意s1到s2满足h(s1)<=h(s2)+c(s1,s2),其中c(s1,s2)是s1转移到s2的步数,则称h是相容的。h相容能确保随着一步步往前走,f递增,这样A*能更高效的找到最优解。一般来说,在满足h(n)<=h*(n)的前提下,h(n)的值越大越好。

一般用从当前节点到目标节点的直线距离或者曼哈顿距离作为估值函数h,但也要具体问题具体分析。

下面是伪代码(原文链接在这里):

OPEN = priority queue containing START
CLOSED = empty set
while lowest rank in OPEN is not the GOAL:
current = remove lowest rank item from OPEN
add current to CLOSED
for neighbors of current:
cost = g(current) + movementcost(current, neighbor)
if neighbor in OPEN and cost less than g(neighbor):
remove neighbor from OPEN, because new path is better
if neighbor in CLOSED and cost less than g(neighbor): **
remove neighbor from CLOSED
if neighbor not in OPEN and neighbor not in CLOSED:
set g(neighbor) to cost
add neighbor to OPEN
set priority queue rank to g(neighbor) + h(neighbor)
set neighbor.s parent to current

reconstruct reverse path from goal to start
by following parent pointers

但实际上我们平常写的A*并不是这个样子,而是和普通的bfs类似,讲fifo队列换成优先队列,其他的类似。

使用A*算法的代码如下,估值函数用的是曼哈顿距离:

/*************************************************************************
> File Name: 1043_astar.cpp
> Author: gwq
> Mail: gwq5210@qq.com
> Created Time: 2015年08月13日 星期四 16时44分33秒
************************************************************************/


#include <cmath>
#include <ctime>
#include <cctype>
#include <climits>
#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <map>
#include <set>
#include <queue>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <algorithm>

#define INF (INT_MAX / 10)
#define clr(arr, val) memset(arr, val, sizeof(arr))
#define pb push_back
#define sz(a) ((int)(a).size())

using namespace std;
typedef set<int> si;
typedef vector<int> vi;
typedef map<int, int> mii;
typedef pair<int, int> pii;
typedef long long ll;

const double esp = 1e-5;

#define N 400000

int ori[9];
int orilen;
int oripos;
int goal[9] = {1, 2, 3, 4, 5, 6, 7, 8, 0};
int goalpos = 8;
int goalcode;
int vis[N];
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
char mm[] = "urdl";
char str[100];
int fac[20];
int pre[N];
char direct[N];
int len;

int getcode(int s[])
{
int res = 0;
for (int i = 0; i < 9; ++i) {
int cnt = 0;
for (int j = i + 1; j < 9; ++j) {
if (s[i] > s[j]) {
++cnt;
}
}
res += fac[8 - i] * cnt;
}
return res;
}

struct Node {
int perm[9];
int h, g, x, y, st, pos, f;
Node(int s[], int hh, int gg, int xx, int yy, int sst, int ppos)
{
memcpy(perm, s, sizeof(perm));
h = hh;
g = gg;
f = g + h;
x = xx;
y = yy;
st = sst;
pos = ppos;
}
Node() {}
void output(void)
{
for (int i = 0; i < 9; ++i) {
if (i % 3 == 0) {
printf("\n");
}
printf("%d ", perm[i]);
}
}
bool check(void)
{
if (st == goalcode) {
return true;
} else {
return false;
}
}
};

int geth(int s[])
{
int h = 0;
for (int i = 0; i < 9; ++i) {
if (s[i] == 0) {
continue;
}
int x = (s[i] - 1) / 3;
int y = (s[i] - 1) % 3;
int nx = i / 3;
int ny = i % 3;
h += abs(x - nx) + abs(y - ny);
}
return h;
}

bool operator <(Node u, Node v)
{
return u.h != v.h ? u.h > v.h : u.g > v.g;
}

int check(int s[])
{
int num = 0;
for (int i = 0; i < 9; ++i) {
if (s[i] != 0) {
for (int j = i + 1; j < 9; ++j) {
if (s[i] > s[j] && s[j] != 0) {
++num;
}
}
}
}
return num % 2;
}

void bfs(void)
{
priority_queue<Node> q;
clr(vis, 0);
clr(pre, -1);
clr(direct, '*');
int code = getcode(ori);
Node u = Node(ori, geth(ori), 0, oripos / 3, oripos % 3, code, oripos);
vis[code] = 1;
q.push(u);
while (!q.empty()) {
u = q.top();
q.pop();
//u.output();
//getchar();
if (u.check()) {
string path;
int p = u.st;
while (pre[p] != -1) {
path += direct[p];
p = pre[p];
}
reverse(path.begin(), path.end());
printf("%s\n", path.c_str());
return;
}
for (int i = 0; i < 4; ++i) {
int nx = u.x + dx[i];
int ny = u.y + dy[i];
int npos = nx * 3 + ny;
Node v;
if (nx >= 0 && nx < 3 && ny >= 0 && ny < 3) {
memcpy(v.perm, u.perm, sizeof(u.perm));
swap(v.perm[npos], v.perm[u.pos]);
int nh = geth(v.perm);
int ng = u.g + 1;
int ncode = getcode(v.perm);
v.h = nh;
v.g = ng;
v.f = v.h + v.g;
v.x = nx;
v.y = ny;
v.pos = npos;
v.st = ncode;
if (!vis[ncode] && !check(v.perm)) {
pre[ncode] = u.st;
direct[ncode] = mm[i];
q.push(v);
vis[ncode] = 1;
}
}
}
}
cout << "unsolvable" << endl;
}

int main(int argc, char *argv[])
{
fac[0] = 1;
for (int i = 1; i < 20; ++i) {
fac[i] = fac[i - 1] * i;
}
goalcode = getcode(goal);
while (fgets(str, 100, stdin) != NULL) {
len = strlen(str);
orilen = 0;
for (int i = 0; i < len; ++i) {
if (isdigit(str[i])) {
ori[orilen++] = str[i] - '0';
} else if (str[i] == 'x') {
oripos = orilen;
ori[orilen++] = 0;
}
}
if (check(ori)) {
printf("unsolvable\n");
continue;
}
bfs();
}
return 0;
}

四,DBFS双向广度优先搜索算法(参考pdf
DBFS算法是对BFS算法的一种扩展。BFS算法以广度优先的顺序不断扩展直到遇到目标节点。DBFS算法从起始节点和目标节点两个方向以广度优先的顺序同时扩展,直到一个队列中已经出现了另一个队列中已经扩展了的节点,也就相当于两个扩展方向有了交点,那么可以认为找到了一条路径。

DBFS算法相对于BFS算法,因为采用了双向扩展的方法,搜索树的宽度得到了明显的减少,时间和空间复杂度都有了明显的提高。DBFS每次选择节点数比较少的那边进行扩展,并不是机械的进行扩展。

DBFS框架:

void dbfs()
{
1. 将起始节点放入队列q0 ,将目标节点放入队列q1;
2. 当两个队列都未空时,作如下循环:
1) 如果队列q0里的节点比q1中的少,则扩展队列q0;

2) 否则扩展队列q1
3. 如果队列q0未空,不断扩展q0直到为空;
4. 如果队列q1未空,不断扩展q1直到为空;
}

这道题目的代码如下:

/*************************************************************************
> File Name: 1077dbfs.cpp
> Author: gwq
> Mail: gwq5210@qq.com
> Created Time: 2015年08月12日 星期三 17时09分43秒
************************************************************************/


#include <cmath>
#include <ctime>
#include <cctype>
#include <climits>
#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <map>
#include <set>
#include <queue>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <algorithm>

#define INF (INT_MAX / 10)
#define clr(arr, val) memset(arr, val, sizeof(arr))
#define pb push_back
#define sz(a) ((int)(a).size())

using namespace std;
typedef set<int> si;
typedef vector<int> vi;
typedef map<int, int> mii;
typedef pair<int, int> pii;
typedef long long ll;

const double esp = 1e-5;

#define M 400000

int orilen;
int oritmp[9];
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
char mm[] = "urdl";
int goal[9] = {1, 2, 3, 4, 5, 6, 7, 8, 0};
char strtmp[20];
int fac[20], vis[2][M];
int st[2][M][9];
int head[2];
int tail[2];
int fa[2][M];
char direct[2][M];

int getcode(int s[])
{
int res = 0;
for (int i = 0; i < 9; ++i) {
int cnt = 0;
for (int j = i + 1; j < 9; ++j) {
if (s[i] > s[j]) {
++cnt;
}
}
res += cnt * fac[8 - i];
}
return res;
}

// 忽略0之后,不改变排列的奇偶性
int check(int s[])
{
int num = 0;
for (int i = 0; i < 9; ++i) {
if (s[i] == 0) {
continue;
}
for (int j = i + 1; j < 9; ++j) {
if (s[j] != 0 && s[i] > s[j]) {
++num;
}
}
}
return num % 2;
}

void dbfs(void)
{
head[0] = 1;
head[1] = 1;
tail[0] = 2;
tail[1] = 2;
clr(vis, 0);
memcpy(st[0][1], oritmp, sizeof(oritmp));
memcpy(st[1][1], goal, sizeof(goal));
fa[0][1] = -1;
fa[1][1] = -1;
direct[0][1] = '*';
direct[1][1] = '*';
int code0 = getcode(st[0][1]);
int code1 = getcode(st[1][1]);
vis[0][code0] = 1;
vis[1][code1] = 1;
while (head[0] < tail[0] && head[1] < tail[1]) {
int no = 0;
if (head[0] == tail[0]) {
no = 1;
} else if (head[1] == tail[1]) {
no = 0;
} else {
if (tail[0] - head[0] < tail[1] - head[1]) {
no = 0;
} else {
no = 1;
}
}
int ono = 1 - no;
int code = getcode(st[no][head[no]]);
//printf("\n%d..%d", code, no);
for (int i = 0; i < 9; ++i) {
if (i % 3 == 0) {
//printf("\n");
}
//printf("%d ", st[no][head[no]][i]);
}
if (vis[ono][code]) {
//printf("done\n");
string ans;
int pos = head[no];
if (no) {
for (int i = 0; i < tail[0]; ++i) {
int tmp = getcode(st[0][i]);
if (tmp == code) {
pos = i;
break;
}
}
} else {
pos = head[no];
}
int p = pos;
//printf("\n%d.....\n", pos);
while (fa[0][p] != -1) {
ans += direct[0][p];
p = fa[0][p];
}
reverse(ans.begin(), ans.end());
//cout << ans << endl;
if (no == 0) {
for (int i = 0; i < tail[1]; ++i) {
int tmp = getcode(st[1][i]);
if (tmp == code) {
pos = i;
break;
}
}
} else {
pos = head[no];
}
p = pos;
//printf("%d.....%d\n", pos, head[no]);
while (fa[1][p] != -1) {
ans += direct[1][p];
p = fa[1][p];
}
printf("%s\n", ans.c_str());
return;
}
int idx = 0;
for (int i = 0; i < 9; ++i) {
if (st[no][head[no]][i] == 0) {
idx = i;
break;
}
}
int x = idx / 3;
int y = idx % 3;
for (int i = 0; i < 4; ++i) {
int nx = x + dx[i];
int ny = y + dy[i];
int nidx = nx * 3 + ny;
if (nx >= 0 && nx < 3 && ny >= 0 && ny < 3) {
memcpy(st[no][tail[no]], st[no][head[no]], sizeof(st[no][head[no]]));
swap(st[no][tail[no]][idx], st[no][tail[no]][nidx]);
int ncode = getcode(st[no][tail[no]]);
if (!vis[no][ncode]) {
vis[no][ncode] = 1;
fa[no][tail[no]] = head[no];
direct[no][tail[no]] = mm[no ? (i + 2) % 4 : i];
++tail[no];
}
}
}
++head[no];
}
printf("unsolvable\n");
}

int main(int argc, char *argv[])
{
fac[0] = 1;
for (int i = 1; i < 20; ++i) {
fac[i] = fac[i - 1] * i;
}
while (fgets(strtmp, 20, stdin) != NULL) {
//printf("fgets\n");
int len = strlen(strtmp);
orilen = 0;
for (int i = 0; i < len; ++i) {
if (isdigit(strtmp[i])) {
oritmp[orilen++] = strtmp[i] - '0';
} else if (strtmp[i] == 'x') {
oritmp[orilen++] = 0;
}
}
for (int i = 0; i < 9; ++i) {
if (i % 3 == 0) {
//printf("\n");
}
//printf("%d ", oritmp[i]);
}
//printf("\n");
//printf("%d\n", getcode(oritmp));
if (check(oritmp)) {
printf("unsolvable\n");
continue;
}
dbfs();
}
return 0;
}