ГЛАВНАЯ RU typewriter

older-tomato

Оптимизация умножения матриц

Перестановки • Вложенные циклы • Сравнение алгоритмов 09.12.2021

Рассмотрим алгоритм перемножения матриц с использованием трёх вложенных циклов. Сложность такого алгоритма по определению должна составлять O(n³), но есть особенности, связанные со средой выполнения — скорость работы алгоритма зависит от последовательности, в которой выполняются циклы.

Сравним различные варианты перестановок вложенных циклов и время выполнения алгоритмов. Возьмём две матрицы: {L×M} и {M×N} → три цикла → шесть перестановок: LMN, LNM, MLN, MNL, NLM, NML.

Быстрее других отрабатывают те алгоритмы, которые пишут данные в результирующую матрицу построчно слоями: LMN и MLN, — разница в процентах к другим алгоритмам значительная и зависит от среды выполнения.

Дальнейшая оптимизация: Умножение матриц в параллельных потоках.

Построчный алгоритм #

Внешний цикл обходит строки первой матрицы L, далее идёт цикл по общей стороне двух матриц M и за ним цикл по колонкам второй матрицы N. Запись в результирующую матрицу происходит построчно, а каждая строка заполняется слоями.

/**
 * @param l строки матрицы 'a'
 * @param m колонки матрицы 'a'
 *          и строки матрицы 'b'
 * @param n колонки матрицы 'b'
 * @param a первая матрица 'l×m'
 * @param b вторая матрица 'm×n'
 * @return результирующая матрица 'l×n'
 */
public static int[][] matrixMultiplicationLMN(int l, int m, int n, int[][] a, int[][] b) {
    // результирующая матрица
    int[][] c = new int[l][n];
    // обходим индексы строк матрицы 'a'
    for (int i = 0; i < l; i++)
        // обходим индексы общей стороны двух матриц:
        // колонок матрицы 'a' и строк матрицы 'b'
        for (int k = 0; k < m; k++)
            // обходим индексы колонок матрицы 'b'
            for (int j = 0; j < n; j++)
                // сумма произведений элементов i-ой строки
                // матрицы 'a' и j-ой колонки матрицы 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

Послойный алгоритм #

Внешний цикл обходит общую сторону двух матриц M, далее идёт цикл по строкам первой матрицы L и за ним цикл по колонкам второй матрицы N. Запись в результирующую матрицу происходит слоями, а каждый слой заполняется построчно.

/**
 * @param l строки матрицы 'a'
 * @param m колонки матрицы 'a'
 *          и строки матрицы 'b'
 * @param n колонки матрицы 'b'
 * @param a первая матрица 'l×m'
 * @param b вторая матрица 'm×n'
 * @return результирующая матрица 'l×n'
 */
public static int[][] matrixMultiplicationMLN(int l, int m, int n, int[][] a, int[][] b) {
    // результирующая матрица
    int[][] c = new int[l][n];
    // обходим индексы общей стороны двух матриц:
    // колонок матрицы 'a' и строк матрицы 'b'
    for (int k = 0; k < m; k++)
        // обходим индексы строк матрицы 'a'
        for (int i = 0; i < l; i++)
            // обходим индексы колонок матрицы 'b'
            for (int j = 0; j < n; j++)
                // сумма произведений элементов i-ой строки
                // матрицы 'a' и j-ой колонки матрицы 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

Прочие алгоритмы #

Обход колонок второй матрицы N происходит перед обходом общей стороны двух матриц M и/или перед обходом строк первой матрицы L.

Код без комментариев
public static int[][] matrixMultiplicationLNM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int i = 0; i < l; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationNLM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int i = 0; i < l; i++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationMNL(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int k = 0; k < m; k++)
        for (int j = 0; j < n; j++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationNML(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int k = 0; k < m; k++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

Сравнение алгоритмов #

Для проверки возьмём две матрицы A=[500×700] и B=[700×450], заполненные случайными числами. Сначала сравниваем между собой корректность реализации алгоритмов — все полученные результаты должны совпадать. Затем выполняем каждый метод по 10 раз и подсчитываем среднее время выполнения.

// запускаем программу и выводим результат
public static void main(String[] args) throws Exception {
    // входящие данные
    int l = 500, m = 700, n = 450, steps = 10;
    int[][] a = randomMatrix(l, m), b = randomMatrix(m, n);
    // карта методов для сравнения
    var methods = new TreeMap<String, Callable<int[][]>>(Map.of(
            "LMN", () -> matrixMultiplicationLMN(l, m, n, a, b),
            "LNM", () -> matrixMultiplicationLNM(l, m, n, a, b),
            "MLN", () -> matrixMultiplicationMLN(l, m, n, a, b),
            "MNL", () -> matrixMultiplicationMNL(l, m, n, a, b),
            "NLM", () -> matrixMultiplicationNLM(l, m, n, a, b),
            "NML", () -> matrixMultiplicationNML(l, m, n, a, b)));
    int[][] last = null;
    // обходим карту методов, проверяем корректность результатов,
    // все полученные результаты должны быть равны друг другу
    for (var method : methods.entrySet()) {
        // следующий метод для сравнения
        var next = methods.higherEntry(method.getKey());
        // если текущий метод не последний — сравниваем результаты двух методов
        if (next != null) System.out.println(method.getKey() + "=" + next.getKey() + ": "
                // сравниваем результат выполнения текущего метода и следующего за ним
                + Arrays.deepEquals(method.getValue().call(), next.getValue().call()));
            // результат выполнения последнего метода
        else last = method.getValue().call();
    }
    int[][] test = last;
    // обходим карту методов, замеряем время работы каждого метода
    for (var method : methods.entrySet())
        // параметры: заголовок, количество шагов, исполняемый код 
        benchmark(method.getKey(), steps, () -> {
            try { // выполняем метод, получаем результат
                int[][] result = method.getValue().call();
                // проверяем корректность результатов на каждом шаге
                if (!Arrays.deepEquals(result, test)) System.out.print("error");
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
}
Вспомогательные методы
// вспомогательный метод, возвращает матрицу указанного размера
private static int[][] randomMatrix(int row, int col) {
    int[][] matrix = new int[row][col];
    for (int i = 0; i < row; i++)
        for (int j = 0; j < col; j++)
            matrix[i][j] = (int) (Math.random() * row * col);
    return matrix;
}
// вспомогательный метод для замера времени работы переданного кода
private static void benchmark(String title, int steps, Runnable runnable) {
    long time, avg = 0;
    System.out.print(title);
    for (int i = 0; i < steps; i++) {
        time = System.currentTimeMillis();
        runnable.run();
        time = System.currentTimeMillis() - time;
        // время выполнения одного шага
        System.out.print(" | " + time);
        avg += time;
    }
    // среднее время выполнения
    System.out.println(" || " + (avg / steps));
}

Вывод зависит от среды выполнения, время в миллисекундах:

LMN=LNM: true
LNM=MLN: true
MLN=MNL: true
MNL=NLM: true
NLM=NML: true
LMN | 191 | 109 | 105 | 106 | 105 | 106 | 106 | 105 | 123 | 109 || 116
LNM | 417 | 418 | 419 | 416 | 416 | 417 | 418 | 417 | 416 | 417 || 417
MLN | 113 | 115 | 113 | 115 | 114 | 114 | 114 | 115 | 114 | 113 || 114
MNL | 857 | 864 | 857 | 859 | 860 | 863 | 862 | 860 | 858 | 860 || 860
NLM | 404 | 404 | 407 | 404 | 406 | 405 | 405 | 404 | 403 | 404 || 404
NML | 866 | 872 | 867 | 868 | 867 | 868 | 867 | 873 | 869 | 863 || 868

Все описанные выше методы, включая свёрнутые блоки, можно поместить в одном классе.

Необходимые импорты
import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;

© Головин Г.Г., Код с комментариями, 2021