挑战程序竞赛系列(30):3.4矩阵的幂

2019-05-26 09:26:57 浏览数 (1)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434651

挑战程序竞赛系列(30):3.4矩阵的幂

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

  • POJ 3734: Blocks
  • POJ 3420: Quad Tiling
  • POJ 3735: Training Little cats

POJ 3734: Blocks

矩阵的幂入门题,写出递推式即可,题解:需要记录红色和绿色的状态,分成三个状态:

  • a:红色和绿色均为偶数时
  • b:红色和绿色恰为一个奇数(注意互斥)
  • c:红色和绿色均为奇数

这样当加入下一个木块时,就可以写出状态转移方程了,有点像HMM中的状态转移啊。。。

状态转移方程:

代码语言:javascript复制
a = 2a   b;
b = 2a   2b   2c;
c = 2c   b;

矩阵幂技术在于把上述转移状态写成矩阵的形式,因为每个状态只和前几个状态相关而不是所有状态,这点很关键,于是有:

⎛⎝⎜aibici⎞⎠⎟=⎛⎝⎜220121022⎞⎠⎟i⎛⎝⎜a0b0c0⎞⎠⎟

begin{pmatrix} a_i b_i c_i end{pmatrix} = begin{pmatrix} 2 & 1 & 0 2 & 2 & 2 0 & 1 & 2 end{pmatrix}^i begin{pmatrix} a_0 b_0 c_0 end{pmatrix}

当然可以思考下为什么矩阵的幂的时间复杂度为O(logn)O(log n),关键在于求解AnA^n的过程加快了速度,传统的乘法需要循环n次,但我们可以利用二进制转十进制的性质,用快速幂来计算A的n次。

代码如下:

代码语言:javascript复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3734.txt";

    static final int MOD = 10007;
    void solve() {
        int T = ni();
        for (int t = 0; t < T;   t){
            int n = ni();
            int[][] a = {{2, 1, 0},{2, 2, 2},{0, 1, 2}};
            Mat A = new Mat(a);
            A = A.pow(A, n, MOD);
            out.println(A.mat[0][0]);
        }
    }

    class Mat{
        int[][] mat;
        int n;
        int m;

        public Mat(int[][] arra){
            this.mat = arra;
            this.n = arra.length;
            this.m = arra[0].length;
        }

        public Mat mul(Mat A, Mat B, int MOD){
            int[][] a = A.mat;
            int[][] b = B.mat;
            int[][] res = new int[A.n][B.m];
            for (int i = 0; i < A.n;   i){
                for (int j = 0; j < B.m;   j){
                    for (int ll = 0; ll < A.m;   ll){
                        res[i][j] = (res[i][j]   a[i][ll] * b[ll][j]) % MOD;
                    }
                }
            }
            return new Mat(res);
        }

        public Mat pow(Mat A, int n, int MOD){
            int[][] one = new int[A.n][A.m];
            for (int i = 0; i < A.n;   i) one[i][i] = 1;
            Mat res = new Mat(one);

            while (n > 0){
                if (n % 2 != 0){
                    res = mul(res, A, MOD);
                }
                n >>= 1;
                A = mul(A, A, MOD);
            }
            return res;
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s   "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf  ];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p  ] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i  )
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i  )
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10   (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10   (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3420: Quad Tiling

参考博文:http://blog.sina.com.cn/s/blog_69c3f0410100vnhj.html

思路:关键看怎么找递推式了,起初找递推的方式比较幼稚,出现大量子问题重复情况,而这种再做进一步递推式不知道如何干净去重,有点蛋疼。

它的思路是根据2*1的木块在4行中可能出现的轮廓来构建,进行完美贴合,呵呵哒,所以说不一定要以“正确的完美的递推式”来递推出答案,(递推就一定要保证每个n正确的情况下才能完成么?它只要是其中几种情况的一个解即可),思维很重要啊!

所以如上可以构成6种合法轮廓,如下图:

接着根据这六种情况就可以写出递推式了:

an 1=an bn cn dxn dyn

a_{n 1} = a_n b_n c_n dx_n dy_n

bn 1=an

b_{n 1} = a_n

cn 1=an e

c_{n 1} = a_n e

dxn 1=an dyn

dx_{n 1} = a_n dy_n

dyn 1=an dxn

dy_{n 1} = a_n dx_n

en 1=cn

e_{n 1} = c_n

当然令 d = dx dy,可得

dn 1=2an dn

d_{n 1} = 2a_n d_n

于是我们得到了A矩阵为:

A=⎛⎝⎜⎜⎜⎜⎜⎜1112010000100011001000100⎞⎠⎟⎟⎟⎟⎟⎟

A = begin{pmatrix} 1 & 1 & 1 & 1 & 0 1 & 0 & 0 & 0 & 0 1 & 0 & 0 & 0 & 1 2 & 0 & 0 & 1 & 0 0 & 0 & 1 & 0 & 0 end{pmatrix}

代码如下:

代码语言:javascript复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3420.txt";

    void solve() {
        while (true){
            int N = ni();
            int M = ni();
            if (N   M == 0) break;
            int[][] a = {{1,1,1,1,0},{1,0,0,0,0},{1,0,0,0,1},{2,0,0,1,0},{0,0,1,0,0}};
            Mat A = new Mat(a);
            A = A.pow(A, N, M);
            out.println(A.mat[0][0]);
        }

    }

    class Mat{
        int[][] mat;
        int n;
        int m;

        public Mat(int[][] mat){
            this.mat = mat;
            this.n = mat.length;
            this.m = mat[0].length;
        }

        public Mat mul(Mat A, Mat B, int MOD){
            int[][] a = A.mat;
            int[][] b = B.mat;
            int[][] res = new int[A.n][B.m];
            for (int i = 0; i < A.n;   i){
                for (int j = 0; j < B.m;   j){
                    for (int ll = 0; ll < A.m;   ll){
                        res[i][j] = (res[i][j]   a[i][ll] * b[ll][j]) % MOD;
                    }
                }
            }
            return new Mat(res);
        }

        public Mat pow(Mat A, int n, int MOD){
            int[][] one = new int[A.n][A.n];
            for (int i = 0; i < A.n;   i) one[i][i] = 1;
            Mat res = new Mat(one);
            while (n > 0){
                if ((n & 1) != 0){
                    res = mul(res, A, MOD);
                }
                n >>= 1;
                A = mul(A, A, MOD);
            }
            return res;
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s   "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf  ];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p  ] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i  )
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i  )
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10   (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10   (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3735: Training Little cats

如果能够想到矩阵幂来做,就不难了。无非就是如何根据这些操作来构造一个矩阵,就拿case为例:

代码语言:javascript复制
3 1 6
g 1
g 2
g 2
s 1 2
g 3
e 2
0 0 0

有三只猫,可以当作变量a,b,c
g 1 : a = a   1
如果看成矩阵
a   1 0 0 1   0
b = 0 1 0 0 * 0
c   0 0 1 0   0
1   0 0 0 1   1

得a = a   1
同理,s 1 2 无非就是把元素i和j对应的位置交换下:
a   0 1 0 1   0
b = 1 0 0 0 * 0
c   0 0 1 0   0
1   0 0 0 1   1

e 2
令矩阵[1][1] = 0即可
a   0 1 0 1   0
b = 1 0 0 0 * 0
c   0 0 0 0   0
1   0 0 0 1   1
得 c = 0

每个操作可以单独和初始向量相乘,保证矩阵相乘的正确性,最后构造的最先乘,最后再幂乘m次。

注意两点:long防止溢出wa,稀疏矩阵加个判断,否则TLE。

代码如下:

代码语言:javascript复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.Stack;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3735.txt";

    int N;
    void solve() {
        while (true){
            N = ni();
            int M = ni();
            int K = ni();
            if (N   M   K == 0) break;
            Stack<Mat> stack = new Stack<Mat>();
            for (int i = 0; i < K;   i){
                char c = nc();
                if (c == 'g'){
                    stack.push(createMat(c, ni() - 1, 0));
                }
                else if (c == 's'){
                    stack.push(createMat(c, ni() - 1, ni() - 1));
                }
                else{
                    stack.push(createMat(c, ni() - 1, 0));
                }
            }

            long[][] one = new long[N   1][N   1];
            for (int i = 0; i < N   1;   i) one[i][i] = 1;
            Mat A = new Mat(one);
            while (!stack.isEmpty()){
                A = mul(A, stack.pop());
            }
            A = pow(A, M);
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < N;   i){
                sb.append(" "   A.mat[i][N]);
            }
            out.println(sb.deleteCharAt(0).toString());
        }
    }

    public Mat createMat(char command, int i, int j){
        long[][] one = new long[N   1][N   1];
        for (int l = 0; l < one.length;   l) one[l][l] = 1;
        switch (command) {
        case 'g':
            one[i][N] = 1;
            break;
        case 's':
            one[i][i] = 0;
            one[j][j] = 0;
            one[i][j] = 1;
            one[j][i] = 1;
            break;
        case 'e':
            one[i][i] = 0;
            break;
        default:
            break;
        }
        return new Mat(one);
    }

    class Mat{
        long[][] mat;
        int n;
        int m;
        public Mat(long[][] mat){
            this.mat = mat;
            this.n = mat.length;
            this.m = mat[0].length;
        }
    }

    public Mat mul(Mat A, Mat B){
        long[][] a = A.mat;
        long[][] b = B.mat;
        long[][] res = new long[A.n][B.m];
        for (int i = 0; i < A.n;   i){
            for (int ll = 0; ll < A.m;   ll){
                if (a[i][ll] != 0){
                    for (int j = 0; j < B.m;   j){
                        res[i][j]  = a[i][ll] * b[ll][j];
                    }
                }
            }
        }
        return new Mat(res);
    }

    public Mat pow(Mat A, int n){
        long[][] one = new long[A.n][A.n];
        for (int i = 0; i < A.n;   i) one[i][i] = 1;
        Mat res = new Mat(one);
        while (n > 0){
            if ((n & 1) != 0){
                res = mul(res, A);
            }
            n >>= 1;
            A = mul(A, A);
        }
        return res;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s   "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf  ];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p  ] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i  )
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i  )
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10   (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10   (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

当然你也可以在生成矩阵时,直接对原始矩阵进行操作,不过这是代码量的优化,无关乎算法,具体代码参考博文:http://www.hankcs.com/program/algorithm/poj-3735-training-little-cats-time.html

0 人点赞