挑战程序竞赛系列(32):4.5 A*与IDA*

2019-05-26 09:29:07 浏览数 (1)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434666

挑战程序竞赛系列(32):4.5 A*与IDA*

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

  • POJ 3523: The morning after Halloween
  • POJ 2032: Square Carpets

POJ 3523: The morning after Halloween

又是一道抽象状态和BFS搜索的好题。思路:状态是什么?显然是abc在board中的位置,且由于board的长和宽至多为16,所以用4位就可以表示坐标x或y,这就意味着一个ghost可以用8位表示,三个ghost只需要24位,而所可能出现的位置总共有256 * 256 *256个状态。

既然可以表达所有的状态,我们就能从初始状态一步步通过BFS来接近终止状态,中间搜索时,无非需要判断哪些状态是非法的,以及如何根据当前状态生成下一状态。

基本步骤:

  • 定义状态(3个ghost所处的位置)
  • 生成下一状态
  • 判断是否合法
  • 合法加入队列
  • 直到搜到最终状态

此题的难点在于状态时还是较多,如果单纯的用BFS暴力搜索非常耗时,怎么加快速度?采用A*在原来的距离上增加一个下界函数,不同状态的下界函数不同,此时我们有dstart -> next hnext -> end,表示从初始状态到当前状态的距离和当前状态到最终状态的距离之和。这样我们就可以结合Dijkstra算法,对这些距离进行排序,始终选择较短的优先构造,这样必然会比传统BFS最先抵达目标状态。

此处的下界函数,简单认为是每个独立的ghost到终点的最短距离中的最大值(至少需要这些步骤,不可能比这还短了)。

代码如下:

代码语言:javascript复制
    static final int INF = 1 << 29;
    static final int[][] direction = {{1, 0},{-1, 0},{0, 1},{0, -1}, {0, 0}};
    int W;
    int H;
    int[] dist = new int[256 * 256 * 256];
    int[][] d = new int[3][256];
    void solve() {
        while (true){
            int w = ni();
            int h = ni();
            int n = ni();

            W = w;
            H = h;

            if (w   h   n == 0) break;
            char[][] board = new char[h][w];
            for (int i = 0; i < h;   i){
                for (int j = 0; j < w;   j){
                    board[i][j] = nc();
                }
            }

            //初始状态
            int init = 0;
            int goal = 0;

            for (int i = 0; i < h;   i){
                for (int j = 0; j < w;   j){
                    if (Character.isLowerCase(board[i][j])){
                        int pos = board[i][j] - 'a';
                        init |= ((i << 4) | j) << (pos * 8);
                        board[i][j] = ' ';
                    }
                    if (Character.isUpperCase(board[i][j])){
                        int pos = board[i][j] - 'A';
                        goal |= ((i << 4) | j) << (pos * 8);
                        board[i][j] = ' ';
                    }
                }
            }

            d = new int[3][256];
            for (int i = 0; i < 3;   i) Arrays.fill(d[i], INF);
            for (int i = 0; i < n;   i){
                int v = stateToVerties(goal, i);
                Node start = new Node(v, 0);
                d[i][v] = 0;
                Queue<Node> queue = new PriorityQueue<Node>();
                queue.offer(start);
                while (!queue.isEmpty()){
                    Node now = queue.poll();
                    int row = now.v / 16;
                    int col = now.v % 16;
                    for (int[] dir : direction){
                        int nrow = row   dir[0];
                        int ncol = col   dir[1];
                        int nv = nrow * 16   ncol;
                        if (nrow >= 0 && nrow < h && ncol >= 0 && ncol < w && board[nrow][ncol] != '#'
                                && d[i][now.v]   1 < d[i][nv]) {
                            d[i][nv] = d[i][now.v]   1;
                            queue.offer(new Node(nv, d[i][nv]));
                        }
                    }
                }
            }

            State start = new State(init, 0);
            State end = new State(goal, 0);
            dist = new int[256 * 256 * 256];
            Arrays.fill(dist, INF);
            dist[start.s] = 0;
            Queue<Node> queue = new PriorityQueue<Node>();
            queue.offer(new Node(start.s, 0));
            int ans = -1;
            while (!queue.isEmpty()){
                Node now = queue.poll();
                if (now.v == end.s){
                    ans = dist[now.v];
                    break;
                }

                int cost = now.d - hstar(now.v, d, n);
                if (cost > dist[now.v]) continue;

                List<State> nextStates = moveNext(new State(now.v, dist[now.v]), board, n);
                for (State state : nextStates){
                    int next = state.s;
                    if (dist[next] > dist[now.v]   1){
                        dist[next] = dist[now.v]   1;
                        queue.offer(new Node(next, dist[next]   hstar(next, d, n)));
                    }
                }
            }

            out.println(ans);
        }
    }

    public int hstar(int state, int[][] d, int n){ //计算当前状态到终态的最短距离
        int dist = 0;
        for (int i = 0; i < n;   i){
            int sv = stateToVerties(state, i);
            dist = Math.max(dist, d[i][sv]);
        }
        return dist;
    }

    class State{
        int s;
        int turn;

        public State(int s, int turn){
            this.s = s;
            this.turn = turn;
        }
    }

    public List<State> moveNext(State now, char[][] board, int n){
        List<State>[] states = new ArrayList[n];
        for (int i = 0; i < n;   i) states[i] = new ArrayList<State>();

        for (int i = 0; i < n;   i){
            int v = stateToVerties(now.s, i);
            int row = v / 16;
            int col = v % 16;
            for (int[] dir : direction){
                int nrow = row   dir[0];
                int ncol = col   dir[1];
                if (nrow >= 0 && nrow < H && ncol >= 0 && ncol < W && board[nrow][ncol] != '#'){
                    State state = new State(ghostToState(nrow, ncol, i),now.turn   1);
                    states[i].add(state);
                }
            }
        }

        //check 三种非法状态, 1. 撞墙 (在生成状态时已经避免) 2. 重合  3. 交换
        List<State> validState = new ArrayList<State>();
        for (int i = 0; 0 < n && i < states[0].size();   i){
            int state = states[0].get(i).s;
            int turn = states[0].get(i).turn;
            for (int j = 0; 1 < n && j < states[1].size();   j){
                state = states[0].get(i).s;
                state |= states[1].get(j).s;
                for (int l = 0; 2 < n && l < states[2].size();   l){
                    state = states[0].get(i).s;
                    state |= states[1].get(j).s;
                    state |= states[2].get(l).s;
                    if (3 == n && valid(state, now.s, n)) validState.add(new State(state, turn));
                }
                if (2 == n && valid(state, now.s, n)) validState.add(new State(state, turn));
            }
            if (1 == n) validState.add(new State(state, turn));
        }
        return validState;
    }


    public boolean valid(int state, int prev, int n){
        int i = state >> 16 & (0xff);
        int j = state >>  8 & (0xff);
        int k = state >>  0 & (0xff);
        if (n == 2 && (j == k)) return false;
        if (n == 3 && (i == j || j == k || k == i)) return false; //重合
        if (n == 2){
            int pj = prev >> 8 & (0xff);
            int pk = prev >> 0 & (0xff);
            if (pj == k && pk == j) return false;
        }
        if (n == 3){
            int pi = prev >> 16 & (0xff);
            int pj = prev >>  8 & (0xff);
            int pk = prev >>  0 & (0xff);
            if (swap(i, j, pi, pj) || swap(i, k, pi, pk) || swap(j, k, pj, pk)) return false; //交换
        }
        return true;
    }

    public boolean swap(int i, int j, int pi, int pj){
        return (i == pj) && (j == pi);
    }

    public int ghostToState(int i, int j, int id){
        int state = 0;
        state = (i << 4 | j) << (id * 8);
        return state;
    }

    public int stateToVerties(int state, int n){
        int ss = state >> (8 * n) & (0xff);
        int si = ss >> 4;
        int sj = ss & (0x0f);
        return si * 16   sj;
    }

    class Node implements Comparable<Node>{
        int v;
        int d;
        public Node(int v, int d){
            this.v = v;
            this.d = d;
        }
        @Override
        public int compareTo(Node that) {
            return this.d - that.d;
        }
    }

写了一上午,结果上述代码MLE了,心累,我要弃java了!正确AC代码可以参考博文:http://www.hankcs.com/program/algorithm/poj-3523-the-morning-after-halloween.html

POJ 2032: Square Carpets

神题,对IDA*的特性运用的淋漓尽致。还是简单说个思路,跟着思路做,代码就出来了。

此题给我最大的感受是,原来人一眼能看出来的最优策略,在计算机的世界里居然如此费劲,但IDA*很好的模拟了这种策略的选择。何以见得?

首先,对于每个顶点,我们尽可能的选择最大的矩形,吼吼,这点毋庸置疑。那么矩形该如何选择呢?策略很简单,那些唯一的点,也就是说只有单个矩形覆盖的点,对应的这个矩形一定是要选择的,那么剩下的是什么矩形呢?

很简单,那些没有确定的矩形,这些矩形有很明显的特征,要么是两个组合拼成一个完整的长条,要么选中一个矩形和周边的矩形组合构成一个完整的长条。我们可以肉眼分辨,显然后者最优,但程序不知道啊!!!

所以关键来了,在搜索时,为什么要利用下界控制搜索深度的原因在于,每次填矩形时,尽可能的压低深度(最小距离),此时dfs,就好像是在同一层搜索最优解一样,当深度过深时,直接就返回了,没有搜的必要,因为最优解不可能出现在深度更深的地方,而只可能是在每次慢慢提高的下界中。这样一来,两者组合的情况就被这下界给过滤掉了,因为它们组合的深度为2,而一个矩形和周边已选矩形的拼成的长条深度只为1,因为下界一定比1小,所以2的组合自然不会被选择。这就好比BFS,状态一步一步生成,只不过此时我们采用DFS模拟BFS,刚好利用了选择or不选的递归树来做更方便而已。

总而言之,IDA*用dfs的方法做bfs,一来节省了搜索空间,二来模拟了一种最优策略的选择,这点还需慢慢体会。

下界函数的选择,尽可能能的让坏点填满所有矩形,具体步骤可以参考博文:http://www.hankcs.com/program/algorithm/poj-2032-square-carpets.html

此版JAVA总算过了,欣慰。

代码语言:javascript复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.List;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P2032.txt";

    static final int MAX_N = 10;
    int ans, H, W, limit, max_W;
    int[][] T; //以(x,y)为右下角的正方形最大边长
    int[][] F; //输入
    int[][] X; // 每个点被几个正方形覆盖
    List<int[]> candicates;

    void solve() {
        while (true){
            int w = ni();
            int h = ni();
            if (w   h == 0) break;

            W = w;
            H = h;
            F = new int[MAX_N][MAX_N]; //输入
            for (int i = 0; i < h;   i){
                for (int j = 0; j < w;   j){
                    F[i][j] = ni();
                }
            }

            //初始化
            init();
            out.println(idaStar()   ans);
        }
    }

    void init(){
        ans = 0;
        max_W = 0;
        T = new int[MAX_N][MAX_N]; //以(x,y)为右下角的正方形最大边长
        X = new int[MAX_N][MAX_N]; // 每个点被几个正方形覆盖
        candicates = new ArrayList<int[]>();

        //统计每个顶点可覆盖的最大宽度
        for (int i = 0; i < W;   i){
            T[0][i] = F[0][i];
        }
        for (int j = 0; j < H;   j){
            T[j][0] = F[j][0];
        }

        for (int i = 1; i < H;   i){
            for (int j = 1; j < W;   j){
                if (F[i][j] != 0){
                    T[i][j] = Math.min(Math.min(T[i - 1][j], T[i][j - 1]), T[i - 1][j - 1])   1;
                }
            }
        }
        //删除完全覆盖的正方形
        for (int i = 0; i < H;   i){
            for (int j = 0; j < W;   j){
                if (T[i][j] != 0){
                    int len = T[i][j];
                    for (int k = 0; k < len;   k){
                        for (int l = 0; l < len;   l){
                            if (k == 0 && l == 0) continue;
                            if (i - k >= 0 && j - l >= 0 && T[i - k][j - l]   Math.max(l, k) <= T[i][j]){ //完全包含,则排除
                                T[i - k][j - l] = 0;
                            }
                        }
                    }
                }
            }
        }

        int[][] K = new int[MAX_N][MAX_N]; //每个点被几个正方形覆盖
        for (int i = 0; i < H;   i){
            for (int j = 0; j < W;   j){
                if (T[i][j] != 0){
                    int t = T[i][j];
                    for (int l = 0; l < t;   l){
                        for (int k = 0; k < t;   k){
                            K[i - l][j - k]   ;
                        }
                    }
                }
            }
        }

        for (int i = 0; i < H;   i){
            for (int j = 0; j < W;   j){
                outer:
                if (T[i][j] != 0){
                    int t = T[i][j];
                    for (int l = 0; l < t;   l){
                        for (int k = 0; k < t;   k){
                            if (K[i - l][j - k] == 1){
                                for (int x = 0; x < t;   x){
                                    for (int y = 0; y < t;   y){
                                        X[i - x][j - y]   ;
                                    }
                                }
                                ans   ;
                                T[i][j] = 0;
                                break outer;
                            }
                        }
                    }
                }
            }
        }

        for (int i = H - 1; i >= 0; --i){
            for (int j = W - 1; j >= 0; --j){
                if (T[i][j] != 0){
                    max_W = Math.max(max_W, T[i][j]);
                    candicates.add(new int[]{i, j});
                }
            }
        }
    }

    boolean done(){
        for (int i = 0; i < H;   i){
            for (int j = 0; j < W;   j){
                if (F[i][j] != 0 && X[i][j] == 0) return false;
            }
        }
        return true;
    }

    int h(){
        int sum = 0;
        for (int i = 0; i < H;   i){
            for (int j = 0; j < W;   j){
                if (F[i][j] != 0 && X[i][j] == 0) sum   ;
            }
        }
        return sum / (max_W * max_W);
    }

    int idaStar(){
        if (max_W == 0){
            return 0;
        }
        for (limit = h(); limit < 100;   limit){
            if (dfs(0, 0)) return limit;
        }
        return -1;
    }

    boolean dfs(int s, int cost){
        if (done()) return true;
        if (s >= candicates.size()) return false;
        if (cost   h() >= limit) return false;

        for (int i = candicates.get(s)[0]   1; i < H; i  ){
            for (int j = 0; j < W;   j){
                if (F[i][j] != 0 && X[i][j] == 0)
                    return false;
            }
        }

        if (dfs(s   1, cost)) return true;
        int[][] x_backup = copy(X);
        int pi = candicates.get(s)[0];
        int pj = candicates.get(s)[1];
        int w = T[pi][pj];
        for (int l = 0; l < w;   l){
            for (int k = 0; k < w;   k){
                X[pi - l][pj - k]  ;
            }
        }
        if (dfs(s   1, cost   1)) return true;
        X = copy(x_backup);
        return false;
    }

    int[][] copy(int[][] X){
        int n = X.length;
        int m = X[0].length;
        int[][] clone = new int[n][m];
        for (int i = 0; i < n;   i){
            for (int j = 0; j < n;   j){
                clone[i][j] = X[i][j];
            }
        }
        return clone;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s   "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf  ];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p  ] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i  )
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i  )
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10   (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10   (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

0 人点赞