Умножение троичных матриц для нейросетей

от автора

В статье «Как исследователи нарушают привычные подходы в ИИ, исключая матричное умножение» упоминалось, в частности, что перспективным кажется хранение в нейросетевых матрицах лишь троичных значений: (-1, 0, 1), иначе говоря — тритов. Такие матрицы умножать друг на друга проще. И в моей статье я расскажу, как собственно, матрицы из тритов хранить и умножать.

Как известно, при умножении матриц, мы строку левой матрицы умножаем на столбец правой, и результат записываем в соответствующую ячейку результирующей матрицы. Чтобы было быстрее, мы правую матрицу предварительно транспонируем: тогда строку левой будем умножать на строку правой. Иначе говоря, мы скалярно перемножаем два вектора из тритов, причём оба занимают непрерывную область памяти.

Предположим, что процессор ориентирован на 32-битную арифметику. Тогда разобьём строку матрицы на векторы по 32 трита. Каждый из этих векторов будем хранить в виде двух 32-битных целых чисел, назовём их «плюс-вектором» и «минус-вектором». Трит с номером N равен разности бита с номером N плюс-вектора, и бита с номером N минус-вектора. При этом нулевое значение трита кодируется двумя способами: когда оба бита равны 0, или когда они оба равны 1.

трит

-1

0

0

1

плюс-вектор

0

0

1

1

минус-вектор

1

0

1

0

Такой способ хранения позволяет быстро подсчитать сумму всех тритов вектора. Для этого мы из суммы битов плюс-вектора вычитаем сумму битов минус-вектора. Различные алгоритмы нахождения суммы битов описаны в статье «Обстоятельно о подсчёте единичных битов», к тому же процессор может поддерживать инструкцию POPCNT, которая эту сумму подсчитывает.
Например, можно использовать следующий алгоритм из вышеупомянутой статьи:

// Количество единичных битов unsigned __int32 popcnt(unsigned __int32 value) {     // Суммируем чётные и нечётные биты     //  00  01  10  11  value     //   0   0   1   1  (value >> 1) & 0x55555555     //  --  --  --  --  --     //  00  01  01  10  =     value -= (value >> 1) & 0x55555555;      // Повторяем уже для пар битов     value = ((value >> 2) & 0x33333333) + (value & 0x33333333);      // Умножение на 0x01010101 эквивалентно сумированию значений 4 байт числа,     // при условии, что в младших байтах не будет переполнения     // (а его в нашем случае не будет, так как там содержатся суммы битов).     // Результат сложения будет в старшем байте произведения.     return ((((value >> 4) + value) & 0x0F0F0F0F) * 0x01010101) >> 24; } 

Остаётся разобраться с потритовым умножением векторов. Для наглядности, будем считать трит структурой из двух битовых полей с именами p и m, причём p содержит значение из плюс-вектора, а m – из минус-вектора. Назовём «r» результат умножения тритов a и b. Тогда:
r.p = (a.p | b.m) & (a.m | b.p)
r.m = (a.p | b.p) & (a.m | b.m)
Итак, мы сначала по этим формулам потритово умножаем два троичных вектора, затем у результата этого умножения находим суммы битов плюс-вектора и минус-вектора, затем находим разность этих сумм. Это и будет результатом скалярного произведения двух троичных векторов друг на друга.

Проиллюстрируем этот алгоритм программой на языке C++. В этой программе мы создадим две троичные матрицы размерами 32×32, заполним их случайными значениями, и перемножим друг на друга двумя способами: классическим алгоритмом, и оптимизированным для троичных вычислений.

#include <stdlib.h> #include <stdio.h>  // Неупакованные троичные матрицы, // которые будем перемножать. // Каждый элемент принимает значения (-1, 0, 1) int A[32][32]; int B[32][32];  // Результат умножения неупакованных матриц int C[32][32];  typedef unsigned __int32 u32;  // Вектор из 32 тритов struct TritVector32 {     u32 p; // плюс-вектор     u32 m; // минус-вектор      TritVector32() {p = 0; m = 0;}      // Получить значение трита с указанным номером     int getTrit(int index)     {         int mask = 1 << index;         return ((p & mask) != 0) - ((m & mask) != 0);     }     void setTrit(int index, int trit)     {         int mask = 1 << index;         p |= mask;         p ^= mask & ((trit - 1) >> 1); // p ^= mask & -(trit <= 0);         m |= mask;         m ^= mask & ((~trit) >> 1);    // p ^= mask & -(trit >= 0);     } };  // Упакованные троичные матрицы, // содержащие те же значения, что и неупакованные A и B TritVector32 A3[32]; TritVector32 B3t[32]; // эта матрица транспонирована  // Результат умножения упакованных матриц int C3[32][32];  // Возвращает случайное значение (-1, 0, 1) int rand3() {     return (int)((u32) rand() * 3 / (RAND_MAX + 1)) - 1; }  // Количество единичных битов int popcnt(u32 value) {     // Для наглядности используется простой, но медленный алгоритм.     // Существуют гораздо более быстрые.     int result = 0;     while(value)     {         result += value & 1;         value >>= 1;     }     return result; }  int main() {     // Заполняем матрицы сомножителей случайными значениями     for(int i = 0; i < 32; i++)         for(int j = 0; j < 32; j++)         {             A[i][j] = rand3();             B[i][j] = rand3();         }     // Умножаем матрицы классическим способом     for(int i = 0; i < 32; i++)         for(int j = 0; j < 32; j++)         {             int c = 0;             for(int k = 0; k < 32; k++)                 c += A[i][k] * B[k][j];             C[i][j] = c;         }      // Заполняем упакованные матрицы     for(int i = 0; i < 32; i++)         for(int j = 0; j < 32; j++)         {             A3[i].setTrit(j, A[i][j]);             B3t[j].setTrit(i, B[i][j]);         }     // Умножаем оптимизированным способом     for(int i = 0; i < 32; i++)     {         TritVector32 *a = A3 + i;         for(int j = 0; j < 32; j++)         {             TritVector32 *b = B3t + j;             TritVector32 r;              // Потритовое умножение             r.p = (a->p | b->m) & (a->m | b->p);             r.m = (a->p | b->p) & (a->m | b->m);              C3[i][j] = popcnt(r.p) - popcnt(r.m);         }     }      // Выводим результаты     FILE *fp = fopen("classic.txt", "wt");     fputs("A =\n", fp);     for(int i = 0; i < 32; i++)     {         for(int j = 0; j < 32; j++)             fprintf(fp, "%2d ", A[i][j]);         fputs("\n", fp);     }     fputs("\nB =\n", fp);     for(int i = 0; i < 32; i++)     {         for(int j = 0; j < 32; j++)             fprintf(fp, "%2d ", B[i][j]);         fputs("\n", fp);     }     fputs("\nC =\n", fp);     for(int i = 0; i < 32; i++)     {         for(int j = 0; j < 32; j++)             fprintf(fp, "%3d ", C[i][j]);         fputs("\n", fp);     }     fclose(fp);      fp = fopen("trit.txt", "wt");     fputs("A =\n", fp);     for(int i = 0; i < 32; i++)     {         for(int j = 0; j < 32; j++)             fprintf(fp, "%2d ", A3[i].getTrit(j));         fputs("\n", fp);     }     fputs("\nB =\n", fp);     for(int i = 0; i < 32; i++)     {         for(int j = 0; j < 32; j++)             fprintf(fp, "%2d ", B3t[j].getTrit(i));         fputs("\n", fp);     }     fputs("\nC =\n", fp);     for(int i = 0; i < 32; i++)     {         for(int j = 0; j < 32; j++)             fprintf(fp, "%3d ", C3[i][j]);         fputs("\n", fp);     }     fclose(fp);     return 0; } 

Сравнение файлов результатов показывает, что они одинаковы. Всё работает.

Теперь возникает вопрос: а что, если у нас больше двух матриц? Ведь матрица произведения двух троичных матриц содержит числа, вообще говоря, не троичные. Как её умножить на троичную матрицу?
Но и здесь можно обойтись без умножения.
Пусть a – обычное целое число, которое может быть отрицательным. Пусть b – трит, имеющий два поля p и m, соответствующие плюс- и минус-векторам. Тогда произведение a на b можно записать так:
a•b = a•(b.p – b.m) = (a & -b.p) — (a & -b.m) =
= ((a — b.m) ^ -b.m) & -(b.m ^ b.p) =
= ((a — b.m) ^ -b.m) & ((-b.m) ^ -b.p) =
= ((a ^ -b.m) + b.m) & ((-b.m) ^ -b.p)
Можно выбрать любую из этих формул, или им аналогичных.


ссылка на оригинал статьи https://habr.com/ru/articles/857788/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *