版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434651
挑战程序竞赛系列(30):3.4矩阵的幂
详细代码可以fork下Github上leetcode项目,不定期更新。
练习题如下:
- POJ 3734: Blocks
- POJ 3420: Quad Tiling
- POJ 3735: Training Little cats
POJ 3734: Blocks
矩阵的幂入门题,写出递推式即可,题解:需要记录红色和绿色的状态,分成三个状态:
- a:红色和绿色均为偶数时
- b:红色和绿色恰为一个奇数(注意互斥)
- c:红色和绿色均为奇数
这样当加入下一个木块时,就可以写出状态转移方程了,有点像HMM中的状态转移啊。。。
状态转移方程:
代码语言:javascript复制a = 2a b;
b = 2a 2b 2c;
c = 2c b;
矩阵幂技术在于把上述转移状态写成矩阵的形式,因为每个状态只和前几个状态相关而不是所有状态,这点很关键,于是有:
⎛⎝⎜aibici⎞⎠⎟=⎛⎝⎜220121022⎞⎠⎟i⎛⎝⎜a0b0c0⎞⎠⎟
begin{pmatrix} a_i b_i c_i end{pmatrix} = begin{pmatrix} 2 & 1 & 0 2 & 2 & 2 0 & 1 & 2 end{pmatrix}^i begin{pmatrix} a_0 b_0 c_0 end{pmatrix}
当然可以思考下为什么矩阵的幂的时间复杂度为O(logn)O(log n),关键在于求解AnA^n的过程加快了速度,传统的乘法需要循环n次,但我们可以利用二进制转十进制的性质,用快速幂来计算A的n次。
代码如下:
代码语言:javascript复制import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201707/3734.txt";
static final int MOD = 10007;
void solve() {
int T = ni();
for (int t = 0; t < T; t){
int n = ni();
int[][] a = {{2, 1, 0},{2, 2, 2},{0, 1, 2}};
Mat A = new Mat(a);
A = A.pow(A, n, MOD);
out.println(A.mat[0][0]);
}
}
class Mat{
int[][] mat;
int n;
int m;
public Mat(int[][] arra){
this.mat = arra;
this.n = arra.length;
this.m = arra[0].length;
}
public Mat mul(Mat A, Mat B, int MOD){
int[][] a = A.mat;
int[][] b = B.mat;
int[][] res = new int[A.n][B.m];
for (int i = 0; i < A.n; i){
for (int j = 0; j < B.m; j){
for (int ll = 0; ll < A.m; ll){
res[i][j] = (res[i][j] a[i][ll] * b[ll][j]) % MOD;
}
}
}
return new Mat(res);
}
public Mat pow(Mat A, int n, int MOD){
int[][] one = new int[A.n][A.m];
for (int i = 0; i < A.n; i) one[i][i] = 1;
Mat res = new Mat(one);
while (n > 0){
if (n % 2 != 0){
res = mul(res, A, MOD);
}
n >>= 1;
A = mul(A, A, MOD);
}
return res;
}
}
void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);
long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s "ms");
}
public static void main(String[] args) throws Exception {
new Main().run();
}
private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;
private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf ];
}
private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}
private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}
private double nd() {
return Double.parseDouble(ns());
}
private char nc() {
return (char) skip();
}
private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private char[] ns(int n) {
char[] buf = new char[n];
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p ] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private char[][] nm(int n, int m) {
char[][] map = new char[n][];
for (int i = 0; i < n; i )
map[i] = ns(m);
return map;
}
private int[] na(int n) {
int[] a = new int[n];
for (int i = 0; i < n; i )
a[i] = ni();
return a;
}
private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private boolean oj = System.getProperty("ONLINE_JUDGE") != null;
private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}
POJ 3420: Quad Tiling
参考博文:http://blog.sina.com.cn/s/blog_69c3f0410100vnhj.html
思路:关键看怎么找递推式了,起初找递推的方式比较幼稚,出现大量子问题重复情况,而这种再做进一步递推式不知道如何干净去重,有点蛋疼。
它的思路是根据2*1的木块在4行中可能出现的轮廓来构建,进行完美贴合,呵呵哒,所以说不一定要以“正确的完美的递推式”来递推出答案,(递推就一定要保证每个n正确的情况下才能完成么?它只要是其中几种情况的一个解即可),思维很重要啊!
所以如上可以构成6种合法轮廓,如下图:
接着根据这六种情况就可以写出递推式了:
an 1=an bn cn dxn dyn
a_{n 1} = a_n b_n c_n dx_n dy_n
bn 1=an
b_{n 1} = a_n
cn 1=an e
c_{n 1} = a_n e
dxn 1=an dyn
dx_{n 1} = a_n dy_n
dyn 1=an dxn
dy_{n 1} = a_n dx_n
en 1=cn
e_{n 1} = c_n
当然令 d = dx dy,可得
dn 1=2an dn
d_{n 1} = 2a_n d_n
于是我们得到了A矩阵为:
A=⎛⎝⎜⎜⎜⎜⎜⎜1112010000100011001000100⎞⎠⎟⎟⎟⎟⎟⎟
A = begin{pmatrix} 1 & 1 & 1 & 1 & 0 1 & 0 & 0 & 0 & 0 1 & 0 & 0 & 0 & 1 2 & 0 & 0 & 1 & 0 0 & 0 & 1 & 0 & 0 end{pmatrix}
代码如下:
代码语言:javascript复制import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201707/3420.txt";
void solve() {
while (true){
int N = ni();
int M = ni();
if (N M == 0) break;
int[][] a = {{1,1,1,1,0},{1,0,0,0,0},{1,0,0,0,1},{2,0,0,1,0},{0,0,1,0,0}};
Mat A = new Mat(a);
A = A.pow(A, N, M);
out.println(A.mat[0][0]);
}
}
class Mat{
int[][] mat;
int n;
int m;
public Mat(int[][] mat){
this.mat = mat;
this.n = mat.length;
this.m = mat[0].length;
}
public Mat mul(Mat A, Mat B, int MOD){
int[][] a = A.mat;
int[][] b = B.mat;
int[][] res = new int[A.n][B.m];
for (int i = 0; i < A.n; i){
for (int j = 0; j < B.m; j){
for (int ll = 0; ll < A.m; ll){
res[i][j] = (res[i][j] a[i][ll] * b[ll][j]) % MOD;
}
}
}
return new Mat(res);
}
public Mat pow(Mat A, int n, int MOD){
int[][] one = new int[A.n][A.n];
for (int i = 0; i < A.n; i) one[i][i] = 1;
Mat res = new Mat(one);
while (n > 0){
if ((n & 1) != 0){
res = mul(res, A, MOD);
}
n >>= 1;
A = mul(A, A, MOD);
}
return res;
}
}
void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);
long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s "ms");
}
public static void main(String[] args) throws Exception {
new Main().run();
}
private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;
private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf ];
}
private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}
private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}
private double nd() {
return Double.parseDouble(ns());
}
private char nc() {
return (char) skip();
}
private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private char[] ns(int n) {
char[] buf = new char[n];
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p ] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private char[][] nm(int n, int m) {
char[][] map = new char[n][];
for (int i = 0; i < n; i )
map[i] = ns(m);
return map;
}
private int[] na(int n) {
int[] a = new int[n];
for (int i = 0; i < n; i )
a[i] = ni();
return a;
}
private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private boolean oj = System.getProperty("ONLINE_JUDGE") != null;
private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}
POJ 3735: Training Little cats
如果能够想到矩阵幂来做,就不难了。无非就是如何根据这些操作来构造一个矩阵,就拿case为例:
代码语言:javascript复制3 1 6
g 1
g 2
g 2
s 1 2
g 3
e 2
0 0 0
有三只猫,可以当作变量a,b,c
g 1 : a = a 1
如果看成矩阵
a 1 0 0 1 0
b = 0 1 0 0 * 0
c 0 0 1 0 0
1 0 0 0 1 1
得a = a 1
同理,s 1 2 无非就是把元素i和j对应的位置交换下:
a 0 1 0 1 0
b = 1 0 0 0 * 0
c 0 0 1 0 0
1 0 0 0 1 1
e 2
令矩阵[1][1] = 0即可
a 0 1 0 1 0
b = 1 0 0 0 * 0
c 0 0 0 0 0
1 0 0 0 1 1
得 c = 0
每个操作可以单独和初始向量相乘,保证矩阵相乘的正确性,最后构造的最先乘,最后再幂乘m次。
注意两点:long防止溢出wa,稀疏矩阵加个判断,否则TLE。
代码如下:
代码语言:javascript复制import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.Stack;
public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201707/3735.txt";
int N;
void solve() {
while (true){
N = ni();
int M = ni();
int K = ni();
if (N M K == 0) break;
Stack<Mat> stack = new Stack<Mat>();
for (int i = 0; i < K; i){
char c = nc();
if (c == 'g'){
stack.push(createMat(c, ni() - 1, 0));
}
else if (c == 's'){
stack.push(createMat(c, ni() - 1, ni() - 1));
}
else{
stack.push(createMat(c, ni() - 1, 0));
}
}
long[][] one = new long[N 1][N 1];
for (int i = 0; i < N 1; i) one[i][i] = 1;
Mat A = new Mat(one);
while (!stack.isEmpty()){
A = mul(A, stack.pop());
}
A = pow(A, M);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < N; i){
sb.append(" " A.mat[i][N]);
}
out.println(sb.deleteCharAt(0).toString());
}
}
public Mat createMat(char command, int i, int j){
long[][] one = new long[N 1][N 1];
for (int l = 0; l < one.length; l) one[l][l] = 1;
switch (command) {
case 'g':
one[i][N] = 1;
break;
case 's':
one[i][i] = 0;
one[j][j] = 0;
one[i][j] = 1;
one[j][i] = 1;
break;
case 'e':
one[i][i] = 0;
break;
default:
break;
}
return new Mat(one);
}
class Mat{
long[][] mat;
int n;
int m;
public Mat(long[][] mat){
this.mat = mat;
this.n = mat.length;
this.m = mat[0].length;
}
}
public Mat mul(Mat A, Mat B){
long[][] a = A.mat;
long[][] b = B.mat;
long[][] res = new long[A.n][B.m];
for (int i = 0; i < A.n; i){
for (int ll = 0; ll < A.m; ll){
if (a[i][ll] != 0){
for (int j = 0; j < B.m; j){
res[i][j] = a[i][ll] * b[ll][j];
}
}
}
}
return new Mat(res);
}
public Mat pow(Mat A, int n){
long[][] one = new long[A.n][A.n];
for (int i = 0; i < A.n; i) one[i][i] = 1;
Mat res = new Mat(one);
while (n > 0){
if ((n & 1) != 0){
res = mul(res, A);
}
n >>= 1;
A = mul(A, A);
}
return res;
}
void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);
long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s "ms");
}
public static void main(String[] args) throws Exception {
new Main().run();
}
private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;
private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf ];
}
private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}
private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}
private double nd() {
return Double.parseDouble(ns());
}
private char nc() {
return (char) skip();
}
private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private char[] ns(int n) {
char[] buf = new char[n];
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p ] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private char[][] nm(int n, int m) {
char[][] map = new char[n][];
for (int i = 0; i < n; i )
map[i] = ns(m);
return map;
}
private int[] na(int n) {
int[] a = new int[n];
for (int i = 0; i < n; i )
a[i] = ni();
return a;
}
private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private boolean oj = System.getProperty("ONLINE_JUDGE") != null;
private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}
当然你也可以在生成矩阵时,直接对原始矩阵进行操作,不过这是代码量的优化,无关乎算法,具体代码参考博文:http://www.hankcs.com/program/algorithm/poj-3735-training-little-cats-time.html