MAIN EN typewriter

older-tomato

Optimizing matrix multiplication

Permutations • Nested loops • Comparing algorithms 10.12.2021

Consider an algorithm for multiplying matrices using three nested loops. The complexity of such an algorithm by definition should be O(n³), but there are particularities related to the execution environment — the speed of the algorithm depends on the sequence in which the loops are executed.

Let’s compare different permutations of nested loops and the execution time of the algorithms. Let’s take two matrices: {L×M} and {M×N} → three loops → six permutations: LMN, LNM, MLN, MNL, NLM, NML.

The algorithms that work faster than others are those that write data to the resulting matrix row-wise in layers: LMN and MLN, — the percentage difference to other algorithms is substantial and depends on the execution environment.

Further optimization: Matrix multiplication in parallel streams.

Row-wise algorithm #

The outer loop bypasses the rows of the first matrix L, then there is a loop across the common side of the two matrices M and it is followed by a loop across the columns of the second matrix N. Writing to the resulting matrix occurs row-wise, and each row is filled in layers.

/**
 * @param l rows of matrix 'a'
 * @param m columns of matrix 'a'
 *          and rows of matrix 'b'
 * @param n columns of matrix 'b'
 * @param a first matrix 'l×m'
 * @param b second matrix 'm×n'
 * @return resulting matrix 'l×n'
 */
public static int[][] matrixMultiplicationLMN(int l, int m, int n, int[][] a, int[][] b) {
    // resulting matrix
    int[][] c = new int[l][n];
    // bypass the indexes of the rows of matrix 'a'
    for (int i = 0; i < l; i++)
        // bypass the indexes of the common side of two matrices:
        // the columns of matrix 'a' and the rows of matrix 'b'
        for (int k = 0; k < m; k++)
            // bypass the indexes of the columns of matrix 'b'
            for (int j = 0; j < n; j++)
                // the sum of the products of the elements of the i-th
                // row of matrix 'a' and the j-th column of matrix 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

Layer-wise algorithm #

The outer loop bypasses the common side of the two matrices M, then there is a loop across the rows of the first matrix L, and it is followed by a loop across the columns of the second matrix N. Writing to the resulting matrix occurs layer-wise, and each layer is filled row-wise.

/**
 * @param l rows of matrix 'a'
 * @param m columns of matrix 'a'
 *          and rows of matrix 'b'
 * @param n columns of matrix 'b'
 * @param a first matrix 'l×m'
 * @param b second matrix 'm×n'
 * @return resulting matrix 'l×n'
 */
public static int[][] matrixMultiplicationMLN(int l, int m, int n, int[][] a, int[][] b) {
    // resulting matrix
    int[][] c = new int[l][n];
    // bypass the indexes of the common side of two matrices:
    // the columns of matrix 'a' and the rows of matrix 'b'
    for (int k = 0; k < m; k++)
        // bypass the indexes of the rows of matrix 'a'
        for (int i = 0; i < l; i++)
            // bypass the indexes of the columns of matrix 'b'
            for (int j = 0; j < n; j++)
                // the sum of the products of the elements of the i-th
                // row of matrix 'a' and the j-th column of matrix 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

Other algorithms #

The bypass of the columns of the second matrix N occurs before the bypass of the common side of the two matrices M and/or before the bypass of the rows of the first matrix L.

Code without comments
public static int[][] matrixMultiplicationLNM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int i = 0; i < l; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationNLM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int i = 0; i < l; i++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationMNL(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int k = 0; k < m; k++)
        for (int j = 0; j < n; j++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationNML(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int k = 0; k < m; k++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

Comparing algorithms #

To check, we take two matrices A=[500×700] and B=[700×450], filled with random numbers. First, we compare the correctness of the implementation of the algorithms — all results obtained must match. Then we execute each method 10 times and calculate the average execution time.

// start the program and output the result
public static void main(String[] args) throws Exception {
    // incoming data
    int l = 500, m = 700, n = 450, steps = 10;
    int[][] a = randomMatrix(l, m), b = randomMatrix(m, n);
    // map of methods for comparison
    var methods = new TreeMap<String, Callable<int[][]>>(Map.of(
            "LMN", () -> matrixMultiplicationLMN(l, m, n, a, b),
            "LNM", () -> matrixMultiplicationLNM(l, m, n, a, b),
            "MLN", () -> matrixMultiplicationMLN(l, m, n, a, b),
            "MNL", () -> matrixMultiplicationMNL(l, m, n, a, b),
            "NLM", () -> matrixMultiplicationNLM(l, m, n, a, b),
            "NML", () -> matrixMultiplicationNML(l, m, n, a, b)));
    int[][] last = null;
    // bypass the methods map, check the correctness of the returned
    // results, all results obtained must be equal to each other
    for (var method : methods.entrySet()) {
        // next method for comparison
        var next = methods.higherEntry(method.getKey());
        // if the current method is not the last — compare the results of two methods
        if (next != null) System.out.println(method.getKey() + "=" + next.getKey() + ": "
                // compare the result of executing the current method and the next one
                + Arrays.deepEquals(method.getValue().call(), next.getValue().call()));
            // the result of the last method
        else last = method.getValue().call();
    }
    int[][] test = last;
    // bypass the methods map, measure the execution time of each method
    for (var method : methods.entrySet())
        // parameters: title, number of steps, runnable code
        benchmark(method.getKey(), steps, () -> {
            try { // execute the method, get the result
                int[][] result = method.getValue().call();
                // check the correctness of the results at each step
                if (!Arrays.deepEquals(result, test)) System.out.print("error");
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
}
Helper methods
// helper method, returns a matrix of the specified size
private static int[][] randomMatrix(int row, int col) {
    int[][] matrix = new int[row][col];
    for (int i = 0; i < row; i++)
        for (int j = 0; j < col; j++)
            matrix[i][j] = (int) (Math.random() * row * col);
    return matrix;
}
// helper method for measuring the execution time of the passed code
private static void benchmark(String title, int steps, Runnable runnable) {
    long time, avg = 0;
    System.out.print(title);
    for (int i = 0; i < steps; i++) {
        time = System.currentTimeMillis();
        runnable.run();
        time = System.currentTimeMillis() - time;
        // execution time of one step
        System.out.print(" | " + time);
        avg += time;
    }
    // average execution time
    System.out.println(" || " + (avg / steps));
}

Output depends on the execution environment, time in milliseconds:

LMN=LNM: true
LNM=MLN: true
MLN=MNL: true
MNL=NLM: true
NLM=NML: true
LMN | 191 | 109 | 105 | 106 | 105 | 106 | 106 | 105 | 123 | 109 || 116
LNM | 417 | 418 | 419 | 416 | 416 | 417 | 418 | 417 | 416 | 417 || 417
MLN | 113 | 115 | 113 | 115 | 114 | 114 | 114 | 115 | 114 | 113 || 114
MNL | 857 | 864 | 857 | 859 | 860 | 863 | 862 | 860 | 858 | 860 || 860
NLM | 404 | 404 | 407 | 404 | 406 | 405 | 405 | 404 | 403 | 404 || 404
NML | 866 | 872 | 867 | 868 | 867 | 868 | 867 | 873 | 869 | 863 || 868

All the methods described above, including collapsed blocks, can be placed in one class.

Required imports
import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;

© Golovin G.G., Code with comments, translation from Russian, 2021