Consider a modification of Strassen’s algorithm for square matrix multiplication with fewer number of summations between blocks than in the ordinary algorithm — 15 instead of 18 and the same number of multiplications as in the ordinary algorithm — 7. We will use Java Streams.
Recursive partitioning of matrices into blocks during multiplication makes sense up to a certain limit, and then it loses its sense, since the Strassen’s algorithm does not use cache of the execution environment. Therefore, for small blocks we will use a parallel version of nested loops, and for large blocks we will perform recursive partitioning in parallel.
We determine the boundary between the two algorithms experimentally — we adjust it to the cache of the execution environment. The benefit of Strassen’s algorithm becomes more evident on sizable matrices — the difference with the algorithm using nested loops becomes larger and depends on the execution environment. Let’s compare the operating time of two algorithms.
Algorithm using three nested loops: Optimizing matrix multiplication.
Matrices must be the same size. We partition each matrix into 4 equally sized blocks. The blocks must be square, therefore if this is not the case, then first we supplement the matrices with zero rows and columns, and after that partition them into blocks. We will remove the redundant rows and columns later from the resulting matrix.
Summation of blocks.
Multiplication of blocks.
Summation of blocks.
Blocks of the resulting matrix.
We partition each matrix A
and B
into 4 equally sized blocks and, if necessary, we supplement
the missing parts with zeros. Perform 15 summations and 7 multiplications over the blocks — we get
4 blocks of the matrix C
. Remove the redundant zeros, if added, and return the resulting matrix.
We run recursive partitioning of large blocks in parallel mode, and for small blocks we call the
algorithm with nested loops.
/**
* @param n matrix size
* @param brd minimum matrix size
* @param a first matrix 'n×n'
* @param b second matrix 'n×n'
* @return resulting matrix 'n×n'
*/
public static int[][] multiplyMatrices(int n, int brd, int[][] a, int[][] b) {
// multiply small blocks using algorithm with nested loops
if (n < brd) return simpleMultiplication(n, a, b);
// midpoint of the matrix, round up — blocks should
// be square, if necessary add zero rows and columns
int m = n - n / 2;
// blocks of the first matrix
int[][] a11 = getQuadrant(m, n, a, true, true);
int[][] a12 = getQuadrant(m, n, a, true, false);
int[][] a21 = getQuadrant(m, n, a, false, true);
int[][] a22 = getQuadrant(m, n, a, false, false);
// blocks of the second matrix
int[][] b11 = getQuadrant(m, n, b, true, true);
int[][] b12 = getQuadrant(m, n, b, true, false);
int[][] b21 = getQuadrant(m, n, b, false, true);
int[][] b22 = getQuadrant(m, n, b, false, false);
// summation of blocks
int[][] s1 = sumMatrices(m, a21, a22, true);
int[][] s2 = sumMatrices(m, s1, a11, false);
int[][] s3 = sumMatrices(m, a11, a21, false);
int[][] s4 = sumMatrices(m, a12, s2, false);
int[][] s5 = sumMatrices(m, b12, b11, false);
int[][] s6 = sumMatrices(m, b22, s5, false);
int[][] s7 = sumMatrices(m, b22, b12, false);
int[][] s8 = sumMatrices(m, s6, b21, false);
int[][][] p = new int[7][][];
// multiplication of blocks in parallel streams
IntStream.range(0, 7).parallel().forEach(i -> {
switch (i) { // recursive calls
case 0 -> p[i] = multiplyMatrices(m, brd, s2, s6);
case 1 -> p[i] = multiplyMatrices(m, brd, a11, b11);
case 2 -> p[i] = multiplyMatrices(m, brd, a12, b21);
case 3 -> p[i] = multiplyMatrices(m, brd, s3, s7);
case 4 -> p[i] = multiplyMatrices(m, brd, s1, s5);
case 5 -> p[i] = multiplyMatrices(m, brd, s4, b22);
case 6 -> p[i] = multiplyMatrices(m, brd, a22, s8);
}
});
// summation of blocks
int[][] t1 = sumMatrices(m, p[0], p[1], true);
int[][] t2 = sumMatrices(m, t1, p[3], true);
// blocks of the resulting matrix
int[][] c11 = sumMatrices(m, p[1], p[2], true);
int[][] c12 = sumMatrices(m, t1, sumMatrices(m, p[4], p[5], true), true);
int[][] c21 = sumMatrices(m, t2, p[6], false);
int[][] c22 = sumMatrices(m, t2, p[4], true);
// assemble a matrix from blocks,
// remove zero rows and columns, if added
return putQuadrants(m, n, c11, c12, c21, c22);
}
// helper method for matrix summation
private static int[][] sumMatrices(int n, int[][] a, int[][] b, boolean sign) {
int[][] c = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
c[i][j] = sign ? a[i][j] + b[i][j] : a[i][j] - b[i][j];
return c;
}
// helper method, gets a block of a matrix
private static int[][] getQuadrant(int m, int n, int[][] x,
boolean first, boolean second) {
int[][] q = new int[m][m];
if (first) for (int i = 0; i < m; i++)
if (second) System.arraycopy(x[i], 0, q[i], 0, m); // x11
else System.arraycopy(x[i], m, q[i], 0, n - m); // x12
else for (int i = m; i < n; i++)
if (second) System.arraycopy(x[i], 0, q[i - m], 0, m); // x21
else System.arraycopy(x[i], m, q[i - m], 0, n - m); // x22
return q;
}
// helper method, assembles a matrix from blocks
private static int[][] putQuadrants(int m, int n,
int[][] x11, int[][] x12,
int[][] x21, int[][] x22) {
int[][] x = new int[n][n];
for (int i = 0; i < n; i++)
if (i < m) {
System.arraycopy(x11[i], 0, x[i], 0, m);
System.arraycopy(x12[i], 0, x[i], m, n - m);
} else {
System.arraycopy(x21[i - m], 0, x[i], 0, m);
System.arraycopy(x22[i - m], 0, x[i], m, n - m);
}
return x;
}
To supplement the previous algorithm and to compare with it, we take the optimized variant of nested loops, that uses cache of the execution environment better than others — processing of the rows of the resulting matrix occurs independently of each other in parallel streams. For small matrices, we use this algorithm — large matrices we partition into small blocks and use the same algorithm.
/**
* @param n matrix size
* @param a first matrix 'n×n'
* @param b second matrix 'n×n'
* @return resulting matrix 'n×n'
*/
public static int[][] simpleMultiplication(int n, int[][] a, int[][] b) {
// the resulting matrix
int[][] c = new int[n][n];
// bypass the rows of matrix 'a' in parallel mode
IntStream.range(0, n).parallel().forEach(i -> {
// bypass the indexes of the common side of two matrices:
// the columns of matrix 'a' and the rows of matrix 'b'
for (int k = 0; k < n; k++)
// bypass the indexes of the columns of matrix 'b'
for (int j = 0; j < n; j++)
// the sum of the products of the elements of the i-th
// row of matrix 'a' and the j-th column of matrix 'b'
c[i][j] += a[i][k] * b[k][j];
});
return c;
}
To check, we take two square matrices A=[1000×1000]
and B=[1000×1000]
, filled with random numbers.
Take the minimum block size [200×200]
elements. First, we compare the correctness of the implementation
of the two algorithms — matrix products must match. Then we execute each method 10 times and calculate
the average execution time.
// start the program and output the result
public static void main(String[] args) {
// incoming data
int n = 1000, brd = 200, steps = 10;
int[][] a = randomMatrix(n, n), b = randomMatrix(n, n);
// matrix products
int[][] c1 = multiplyMatrices(n, brd, a, b);
int[][] c2 = simpleMultiplication(n, a, b);
// check the correctness of the results
System.out.println("The results match: " + Arrays.deepEquals(c1, c2));
// measure the execution time of two methods
benchmark("Hybrid algorithm", steps, () -> {
int[][] c = multiplyMatrices(n, brd, a, b);
if (!Arrays.deepEquals(c, c1)) System.out.print("error");
});
benchmark("Nested loops ", steps, () -> {
int[][] c = simpleMultiplication(n, a, b);
if (!Arrays.deepEquals(c, c2)) System.out.print("error");
});
}
// helper method, returns a matrix of the specified size
private static int[][] randomMatrix(int row, int col) {
int[][] matrix = new int[row][col];
for (int i = 0; i < row; i++)
for (int j = 0; j < col; j++)
matrix[i][j] = (int) (Math.random() * row * col);
return matrix;
}
// helper method for measuring the execution time of the passed code
private static void benchmark(String title, int steps, Runnable runnable) {
long time, avg = 0;
System.out.print(title);
for (int i = 0; i < steps; i++) {
time = System.currentTimeMillis();
runnable.run();
time = System.currentTimeMillis() - time;
// execution time of one step
System.out.print(" | " + time);
avg += time;
}
// average execution time
System.out.println(" || " + (avg / steps));
}
Output depends on the execution environment, time in milliseconds:
The results match: true
Hybrid algorithm | 196 | 177 | 156 | 205 | 154 | 165 | 133 | 118 | 132 | 134 || 157
Nested loops | 165 | 164 | 168 | 167 | 168 | 168 | 170 | 179 | 173 | 168 || 169
On an eight-core Linux x64 computer, execute the above test 100 times instead of 10. Take the minimum
block size [brd=200]
elements. Change only n
— sizes of both matrices A=[n×n]
and B=[n×n]
. Get
a summary table of results. Time in milliseconds.
n | 900 | 1000 | 1100 | 1200 | 1300 | 1400 | 1500 | 1600 | 1700 |
-----------------|-----|------|------|------|------|------|------|------|------|
Hybrid algorithm | 96 | 125 | 169 | 204 | 260 | 313 | 384 | 482 | 581 |
Nested loops | 119 | 162 | 235 | 281 | 361 | 497 | 651 | 793 | 971 |
Results: the benefit of the Strassen algorithm becomes more evident on large matrices, when the size of the matrix itself is several times larger than the size of the minimal block, and depends on the execution environment.
All the methods described above, including collapsed blocks, can be placed in one class.
import java.util.Arrays;
import java.util.stream.IntStream;
© Golovin G.G., Code with comments, translation from Russian, 2022