【HDU-4965】Fast Matrix Calculation

【HDU-4965】Fast Matrix Calculation

矩阵乘法结合律的应用

问题链接:https://vjudge.net/problem/HDU-4965

Solution:

计算\((A B) ^{n * n}\),由于AB是n * n矩阵,而BA是k * k矩阵,其中n最大值1000,k最大值6,显然需要将上式改成\(A (BA)^{n * n – 1}B\),中间的部分用矩阵快速幂来算。

如果开1000 * 1000的long long矩阵,在全局作用域下声明多个这样的对象,程序最后可能产生与数组开太同样的问题。

#include <iostream>
#include <cstring>

using namespace std;

const int N = 1000 + 10;
const int K = 6 + 2;
const int MOD = 6;
struct Matrix {
    int v[K][K];
    int size;
    Matrix operator * (const Matrix & m) const {
        Matrix ret = {{0}};
        ret.size = size;
        for (int i = 0; i < size; i++) {
            for (int k = 0; k < size; k++) {
                if (v[i][k]) {
                    for (int j = 0; j < size; j++) {
                        ret.v[i][j] = (ret.v[i][j] + v[i][k] * m.v[k][j]) % MOD;
                    }
                }
            }
        }
        return ret;
    }
} c;
int a[N][K], b[K][N];
long long tmp[N][K], ans[N][N];
int n, k;

Matrix build() {
    Matrix ret = {{0}};
    ret.size = k;
    for (int i = 0; i < k; i++) {
        for (int j = 0; j < k; j++) {
            for (int w = 0; w < n; w++) {
                ret.v[i][j] = (ret.v[i][j] + b[i][w] * a[w][j]) % MOD;
            }
        }
    }
    return ret;
}

Matrix powMod(int p, Matrix m) {
    Matrix ret = {{0}};
    ret.size = m.size;
    for (int i = 0; i < ret.size; i++) ret.v[i][i] = 1;
    while (p) {
        if (p & 1) ret = m * ret;
        if (p) m = m * m;
        p >>= 1;
    }
    return ret;
}

int main(void) {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    while (cin >> n >> k && n && k) {
        memset(tmp, 0, sizeof tmp);
        memset(ans, 0, sizeof ans);
        for (int i = 0; i < n; i++)
            for (int j = 0; j < k; j++)
                cin >> a[i][j];
        for (int i = 0; i < k; i++)
            for (int j = 0; j < n; j++)
                cin >> b[i][j];
        Matrix c = build();
        c = powMod(n * n - 1, c);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < k; j++) {
                for (int w = 0; w < k; w++) {
                    tmp[i][j] = (tmp[i][j] + a[i][w] * c.v[w][j]) % MOD;
                }
            }
        }
        long long sum = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                for (int w = 0; w < k; w++) {
                    ans[i][j] = (ans[i][j] + tmp[i][w] * b[w][j]) % MOD;
                }
                sum += ans[i][j];
            }
        }
        cout << sum << endl;
    }
    return 0;
}