算法应用场景
当数据巨大时、A*算法将会更加优化的进行搜索。
但是如果存在没有解的情况下、A*算法会被卡成bfs(甚至更慢、因为A*使用优先队列、导致每次查询时间复杂度$O(logn)$,bfs查询时间复杂度$O(1)$)
并且当数据量小的时候、A*算法无法突出其优势。
算法流程
while(!q.empty())
{
t = q.top(); //获得小根堆堆顶
if (t == end) //当终点第一次出队时退出
{
break;
}
for (t的所有邻边)
{
将 邻边 进队。
}
}
当终点第一出队时,即为答案。
需要特殊保证:
估计距离 <= 实际距离。
即估计的距离一定要比实际的距离短。
证明A*算法正确性
反证法:
设第一次出队时、不是最小值。设这个点为t。通过条件、可以推出,$dis[t] > dis[ans]$
因为实际距离 >= 估计距离。
由于实际距离 >= 估计距离,所以dis[ans] > f(ans) + g(ans);
所以,堆头一定存在一个比这个值小的值。但是因为已经是从堆头拿出来的、造成了矛盾。所以A*算法是正确的。
A*算法如何计算估值函数
依靠不同题目的要求,没有固定的方法,但是尽量估计值接近实际值。
例题
八数码问题
我们可以通过经典的估值函数计算估计值。
剩下的就可以直接套A*模板。
Code:
#include <iostream>
#include <cstdio>
#include <unordered_map>
#include <cstring>
#include <algorithm>
#include <string>
#include <queue>
#include <cmath>
using namespace std;
const int fx[4] = {0, 0, -1, 1};
const int fy[4] = {-1, 1, 0, 0};
string eds = "12345678x";
char cz[5] = "lrud";
bool check_where(int x, int y)
{
return x < 0 ? false : x >= 3 ? false : y < 0 ? false : y < 3;
}
int f(string str)
{
int res = 0;
int len = str.size();
for (int i = 0; i < len; i ++ )
{
if (str[i] != 'x')
{
int nx = str[i] - '1';
res += abs(i / 3 - nx / 3) + abs(i % 3 - nx % 3);
}
}
return res;
}
string astar(string start)
{
unordered_map<string, int> dis;
unordered_map<string, pair<char, string> > p;
priority_queue<pair<int, string>, vector<pair<int, string> >, greater<pair<int, string> > >q;
dis[start] = 0;
q.push(make_pair(f(start), start));
while(!q.empty())
{
pair<int, string> t = q.top();
q.pop();
string ne = t.second;
if (ne == eds)
{
break;
}
int x, y;
for (int i = 0; i < 9; i ++ )
{
if (ne[i] == 'x')
{
x = i / 3, y = i % 3;
break;
}
}
string src = ne;
for (int i = 0; i < 4; i ++ )
{
int nx = x + fx[i];
int ny = y + fy[i];
if (!check_where(nx, ny))
{
continue;
}
ne = src;
swap(ne[x * 3 + y], ne[nx * 3 + ny]);
if (dis.count(ne) == 0 || dis[ne] > dis[src] + 1)
{
dis[ne] = dis[src] + 1;
p[ne] = make_pair(cz[i], src);
q.push(make_pair(dis[ne] + f(ne), ne));
}
}
}
string res;
int cfs = 1;
while (start != eds)
{
cfs ++;
res += p[eds].first;
eds = p[eds].second;
}
reverse(res.begin(), res.end());
return res;
}
int main()
{
char c;
string s, sa;
while (cin >> c)
{
s += c;
if (c != 'x')
{
sa += c;
}
}
int cnt = 0;
for (int i = 0; i < 8; i ++ )
{
for (int j = i; j < 8; j ++ )
{
if (sa[i] > sa[j])
{
cnt ++;
}
}
}
if (cnt % 2 == 1)
{
printf("unsolvable");
}
else
{
cout << astar(s);
}
return 0;
}