MAIN EN typewriter

older-tomato

Winograd — Strassen algorithm

Multithreading • Block matrices • Comparing algorithms 11.02.2022

Consider a modification of Strassen’s algorithm for square matrix multiplication with fewer number of summations between blocks than in the ordinary algorithm — 15 instead of 18 and the same number of multiplications as in the ordinary algorithm — 7. We will use Java Streams.

Recursive partitioning of matrices into blocks during multiplication makes sense up to a certain limit, and then it loses its sense, since the Strassen’s algorithm does not use cache of the execution environment. Therefore, for small blocks we will use a parallel version of nested loops, and for large blocks we will perform recursive partitioning in parallel.

We determine the boundary between the two algorithms experimentally — we adjust it to the cache of the execution environment. The benefit of Strassen’s algorithm becomes more evident on sizable matrices — the difference with the algorithm using nested loops becomes larger and depends on the execution environment. Let’s compare the operating time of two algorithms.

Algorithm using three nested loops: Optimizing matrix multiplication.

Algorithm description #

Matrices must be the same size. We partition each matrix into 4 equally sized blocks. The blocks must be square, therefore if this is not the case, then first we supplement the matrices with zero rows and columns, and after that partition them into blocks. We will remove the redundant rows and columns later from the resulting matrix.

{\displaystyle A={\begin{pmatrix}A_{11}&A_{12}\\A_{21}&A_{22}\end{pmatrix}},\quad B={\begin{pmatrix}B_{11}&B_{12}\\B_{21}&B_{22}\end{pmatrix}}.}

Summation of blocks.

{\displaystyle{\begin{aligned}S_{1}&=(A_{21}+A_{22});\\S_{2}&=(S_{1}-A_{11});\\S_{3}&=(A_{11}-A_{21});\\S_{4}&=(A_{12}-S_{2});\\S_{5}&=(B_{12}-B_{11});\\S_{6}&=(B_{22}-S_{5});\\S_{7}&=(B_{22}-B_{12});\\S_{8}&=(S_{6}-B_{21}).\end{aligned}}}

Multiplication of blocks.

{\displaystyle{\begin{aligned}P_{1}&=S_{2}S_{6};\\P_{2}&=A_{11}B_{11};\\P_{3}&=A_{12}B_{21};\\P_{4}&=S_{3}S_{7};\\P_{5}&=S_{1}S_{5};\\P_{6}&=S_{4}B_{22};\\P_{7}&=A_{22}S_{8}.\end{aligned}}}

Summation of blocks.

{\displaystyle{\begin{aligned}T_{1}&=P_{1}+P_{2};\\T_{2}&=T_{1}+P_{4}.\end{aligned}}}

Blocks of the resulting matrix.

{\displaystyle{\begin{pmatrix}C_{11}&C_{12}\\C_{21}&C_{22}\end{pmatrix}}={\begin{pmatrix}P_{2}+P_{3}&T_{1}+P_{5}+P_{6}\\T_{2}-P_{7}&T_{2}+P_{5}\end{pmatrix}}.}

Hybrid algorithm #

We partition each matrix A and B into 4 equally sized blocks and, if necessary, we supplement the missing parts with zeros. Perform 15 summations and 7 multiplications over the blocks — we get 4 blocks of the matrix C. Remove the redundant zeros, if added, and return the resulting matrix. We run recursive partitioning of large blocks in parallel mode, and for small blocks we call the algorithm with nested loops.

/**
 * @param n   matrix size
 * @param brd minimum matrix size
 * @param a   first matrix 'n×n'
 * @param b   second matrix 'n×n'
 * @return resulting matrix 'n×n'
 */
public static int[][] multiplyMatrices(int n, int brd, int[][] a, int[][] b) {
    // multiply small blocks using algorithm with nested loops
    if (n < brd) return simpleMultiplication(n, a, b);
    // midpoint of the matrix, round up — blocks should
    // be square, if necessary add zero rows and columns
    int m = n - n / 2;
    // blocks of the first matrix
    int[][] a11 = getQuadrant(m, n, a, true, true);
    int[][] a12 = getQuadrant(m, n, a, true, false);
    int[][] a21 = getQuadrant(m, n, a, false, true);
    int[][] a22 = getQuadrant(m, n, a, false, false);
    // blocks of the second matrix
    int[][] b11 = getQuadrant(m, n, b, true, true);
    int[][] b12 = getQuadrant(m, n, b, true, false);
    int[][] b21 = getQuadrant(m, n, b, false, true);
    int[][] b22 = getQuadrant(m, n, b, false, false);
    // summation of blocks
    int[][] s1 = sumMatrices(m, a21, a22, true);
    int[][] s2 = sumMatrices(m, s1, a11, false);
    int[][] s3 = sumMatrices(m, a11, a21, false);
    int[][] s4 = sumMatrices(m, a12, s2, false);
    int[][] s5 = sumMatrices(m, b12, b11, false);
    int[][] s6 = sumMatrices(m, b22, s5, false);
    int[][] s7 = sumMatrices(m, b22, b12, false);
    int[][] s8 = sumMatrices(m, s6, b21, false);
    int[][][] p = new int[7][][];
    // multiplication of blocks in parallel streams
    IntStream.range(0, 7).parallel().forEach(i -> {
        switch (i) { // recursive calls
            case 0 -> p[i] = multiplyMatrices(m, brd, s2, s6);
            case 1 -> p[i] = multiplyMatrices(m, brd, a11, b11);
            case 2 -> p[i] = multiplyMatrices(m, brd, a12, b21);
            case 3 -> p[i] = multiplyMatrices(m, brd, s3, s7);
            case 4 -> p[i] = multiplyMatrices(m, brd, s1, s5);
            case 5 -> p[i] = multiplyMatrices(m, brd, s4, b22);
            case 6 -> p[i] = multiplyMatrices(m, brd, a22, s8);
        }
    });
    // summation of blocks
    int[][] t1 = sumMatrices(m, p[0], p[1], true);
    int[][] t2 = sumMatrices(m, t1, p[3], true);
    // blocks of the resulting matrix
    int[][] c11 = sumMatrices(m, p[1], p[2], true);
    int[][] c12 = sumMatrices(m, t1, sumMatrices(m, p[4], p[5], true), true);
    int[][] c21 = sumMatrices(m, t2, p[6], false);
    int[][] c22 = sumMatrices(m, t2, p[4], true);
    // assemble a matrix from blocks,
    // remove zero rows and columns, if added
    return putQuadrants(m, n, c11, c12, c21, c22);
}
Helper methods
// helper method for matrix summation
private static int[][] sumMatrices(int n, int[][] a, int[][] b, boolean sign) {
    int[][] c = new int[n][n];
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            c[i][j] = sign ? a[i][j] + b[i][j] : a[i][j] - b[i][j];
    return c;
}
// helper method, gets a block of a matrix
private static int[][] getQuadrant(int m, int n, int[][] x,
                                   boolean first, boolean second) {
    int[][] q = new int[m][m];
    if (first) for (int i = 0; i < m; i++)
        if (second) System.arraycopy(x[i], 0, q[i], 0, m); // x11
        else System.arraycopy(x[i], m, q[i], 0, n - m); // x12
    else for (int i = m; i < n; i++)
        if (second) System.arraycopy(x[i], 0, q[i - m], 0, m); // x21
        else System.arraycopy(x[i], m, q[i - m], 0, n - m); // x22
    return q;
}
// helper method, assembles a matrix from blocks
private static int[][] putQuadrants(int m, int n,
                                    int[][] x11, int[][] x12,
                                    int[][] x21, int[][] x22) {
    int[][] x = new int[n][n];
    for (int i = 0; i < n; i++)
        if (i < m) {
            System.arraycopy(x11[i], 0, x[i], 0, m);
            System.arraycopy(x12[i], 0, x[i], m, n - m);
        } else {
            System.arraycopy(x21[i - m], 0, x[i], 0, m);
            System.arraycopy(x22[i - m], 0, x[i], m, n - m);
        }
    return x;
}

Nested loops #

To supplement the previous algorithm and to compare with it, we take the optimized variant of nested loops, that uses cache of the execution environment better than others — processing of the rows of the resulting matrix occurs independently of each other in parallel streams. For small matrices, we use this algorithm — large matrices we partition into small blocks and use the same algorithm.

/**
 * @param n matrix size
 * @param a first matrix 'n×n'
 * @param b second matrix 'n×n'
 * @return resulting matrix 'n×n'
 */
public static int[][] simpleMultiplication(int n, int[][] a, int[][] b) {
    // the resulting matrix
    int[][] c = new int[n][n];
    // bypass the rows of matrix 'a' in parallel mode
    IntStream.range(0, n).parallel().forEach(i -> {
        // bypass the indexes of the common side of two matrices:
        // the columns of matrix 'a' and the rows of matrix 'b'
        for (int k = 0; k < n; k++)
            // bypass the indexes of the columns of matrix 'b'
            for (int j = 0; j < n; j++)
                // the sum of the products of the elements of the i-th
                // row of matrix 'a' and the j-th column of matrix 'b'
                c[i][j] += a[i][k] * b[k][j];
    });
    return c;
}

Testing #

To check, we take two square matrices A=[1000×1000] and B=[1000×1000], filled with random numbers. Take the minimum block size [200×200] elements. First, we compare the correctness of the implementation of the two algorithms — matrix products must match. Then we execute each method 10 times and calculate the average execution time.

// start the program and output the result
public static void main(String[] args) {
    // incoming data
    int n = 1000, brd = 200, steps = 10;
    int[][] a = randomMatrix(n, n), b = randomMatrix(n, n);
    // matrix products
    int[][] c1 = multiplyMatrices(n, brd, a, b);
    int[][] c2 = simpleMultiplication(n, a, b);
    // check the correctness of the results
    System.out.println("The results match: " + Arrays.deepEquals(c1, c2));
    // measure the execution time of two methods
    benchmark("Hybrid algorithm", steps, () -> {
        int[][] c = multiplyMatrices(n, brd, a, b);
        if (!Arrays.deepEquals(c, c1)) System.out.print("error");
    });
    benchmark("Nested loops    ", steps, () -> {
        int[][] c = simpleMultiplication(n, a, b);
        if (!Arrays.deepEquals(c, c2)) System.out.print("error");
    });
}
Helper methods
// helper method, returns a matrix of the specified size
private static int[][] randomMatrix(int row, int col) {
    int[][] matrix = new int[row][col];
    for (int i = 0; i < row; i++)
        for (int j = 0; j < col; j++)
            matrix[i][j] = (int) (Math.random() * row * col);
    return matrix;
}
// helper method for measuring the execution time of the passed code
private static void benchmark(String title, int steps, Runnable runnable) {
    long time, avg = 0;
    System.out.print(title);
    for (int i = 0; i < steps; i++) {
        time = System.currentTimeMillis();
        runnable.run();
        time = System.currentTimeMillis() - time;
        // execution time of one step
        System.out.print(" | " + time);
        avg += time;
    }
    // average execution time
    System.out.println(" || " + (avg / steps));
}

Output depends on the execution environment, time in milliseconds:

The results match: true
Hybrid algorithm | 196 | 177 | 156 | 205 | 154 | 165 | 133 | 118 | 132 | 134 || 157
Nested loops     | 165 | 164 | 168 | 167 | 168 | 168 | 170 | 179 | 173 | 168 || 169

Comparing algorithms #

On an eight-core Linux x64 computer, execute the above test 100 times instead of 10. Take the minimum block size [brd=200] elements. Change only n — sizes of both matrices A=[n×n] and B=[n×n]. Get a summary table of results. Time in milliseconds.

               n | 900 | 1000 | 1100 | 1200 | 1300 | 1400 | 1500 | 1600 | 1700 |
-----------------|-----|------|------|------|------|------|------|------|------|
Hybrid algorithm |  96 |  125 |  169 |  204 |  260 |  313 |  384 |  482 |  581 |
Nested loops     | 119 |  162 |  235 |  281 |  361 |  497 |  651 |  793 |  971 |

Results: the benefit of the Strassen algorithm becomes more evident on large matrices, when the size of the matrix itself is several times larger than the size of the minimal block, and depends on the execution environment.

All the methods described above, including collapsed blocks, can be placed in one class.

Required imports
import java.util.Arrays;
import java.util.stream.IntStream;

© Golovin G.G., Code with comments, translation from Russian, 2022