Monday, 3 June 2024

Multiplying 64x64 bit-matrices with GF2P8AFFINEQB

This is a relatively simple use of GF2P8AFFINEQB. By itself GF2P8AFFINEQB essentially multiplies two 8x8 bit-matrices (but with transpose-flip applied to the second operand, and an extra XOR by a byte that we can set to zero). A 64x64 matrix can be seen as a block-matrix where each block is an 8x8 matrix. You can also view this as taking the ring of 8x8 matrices over GF(2), and then working with 8x8 matrices with elements from that ring. All we really need to do is write an 8x8 matrix multiplication and let GF2P8AFFINEQB take care of the complicated part.

Using the "full" 512-bit version of VGF2P8AFFINEQB from AVX-512, one VGF2P8AFFINEQB instructions performs 8 of those products. A convenient way to use them is by broadcasting an element (which is really an 8x8 bit-matrix but let's put that aside for now) from the left matrix to all QWORD elements of a ZMM vector, and multiplying that by a row of the right matrix. That way we end up with a row of the result, which is nice to work with: no reshaping or horizontal-SIMD required. XOR-ing QWORDs together horizontally could be done relatively reasonably with another GF2P8AFFINEQB trick, which is neat but avoiding it is even better. All we need to do to compute a row of the result (still viewed as an 8x8 matrix) is 8 broadcasts, 8 VGF2P8AFFINEQB, and XOR-ing the 8 results together, which doesn't take 8 VPXORQ because VPTERNLOGQ can XOR three vectors together. Then just do this for each row of the result.

There are two things that I've skipped so far. First, the built-in transpose-flip of GF2P8AFFINEQB needs to be cancelled out with a flip-transpose (unless explicitly working with a matrix in a weird format is OK). Second, working with an 8x8 block-matrix is mathematically "free" by imagining some dotted lines running through the matrix, but in order to get the right data into GF2P8AFFINEQB we have to actually rearrange it (again: unless the weird format is OK).

One way to implement a flip-transpose (ie the inverse of the bit-permutation that GF2P8AFFINEQB applies to its second operand) is by reversing the bytes in each QWORD and then left-multiplying (in the second of GF2P8AFFINEQB-ing with constant as the first operand) by a flipped identity matrix, which as a QWORD looks like: 0x0102040810204080. Reversing the bytes in each QWORD could be done with a VPERMB, there are other ways, but we're about to have a VPERMB anyway.

Rearranging the data between a fully row-major layout and an 8x8 matrix in which each element is an 8x8 bit-matrix is easy, that's just an 8x8 transpose after all, so just VPERMB. That's needed both for the inputs and the output. The input that is the right-hand operand of the overall matrix multiplication also needs to have a byte-reverse applied to each QWORD, the same VPERMB that does that transpose can also do that byte-reverse.

Here's one way to put that all together:

array<uint64_t, 64> mmul_gf2_avx512(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    __m512i id = _mm512_set1_epi64(0x0102040810204080);
    __m512i tp = _mm512_setr_epi8(
        0, 8, 16, 24, 32, 40, 48, 56,
        1, 9, 17, 25, 33, 41, 49, 57,
        2, 10, 18, 26, 34, 42, 50, 58,
        3, 11, 19, 27, 35, 43, 51, 59,
        4, 12, 20, 28, 36, 44, 52, 60,
        5, 13, 21, 29, 37, 45, 53, 61,
        6, 14, 22, 30, 38, 46, 54, 62,
        7, 15, 23, 31, 39, 47, 55, 63);
    __m512i tpr = _mm512_setr_epi8(
        56, 48, 40, 32, 24, 16, 8, 0,
        57, 49, 41, 33, 25, 17, 9, 1,
        58, 50, 42, 34, 26, 18, 10, 2,
        59, 51, 43, 35, 27, 19, 11, 3,
        60, 52, 44, 36, 28, 20, 12, 4,
        61, 53, 45, 37, 29, 21, 13, 5,
        62, 54, 46, 38, 30, 22, 14, 6,
        63, 55, 47, 39, 31, 23, 15, 7);
    array<uint64_t, 64> res;

    __m512i b_0 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[0]));
    __m512i b_1 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[8]));
    __m512i b_2 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[16]));
    __m512i b_3 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[24]));
    __m512i b_4 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[32]));
    __m512i b_5 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[40]));
    __m512i b_6 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[48]));
    __m512i b_7 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[56]));
    
    b_0 = _mm512_gf2p8affine_epi64_epi8(id, b_0, 0);
    b_1 = _mm512_gf2p8affine_epi64_epi8(id, b_1, 0);
    b_2 = _mm512_gf2p8affine_epi64_epi8(id, b_2, 0);
    b_3 = _mm512_gf2p8affine_epi64_epi8(id, b_3, 0);
    b_4 = _mm512_gf2p8affine_epi64_epi8(id, b_4, 0);
    b_5 = _mm512_gf2p8affine_epi64_epi8(id, b_5, 0);
    b_6 = _mm512_gf2p8affine_epi64_epi8(id, b_6, 0);
    b_7 = _mm512_gf2p8affine_epi64_epi8(id, b_7, 0);

    for (size_t i = 0; i < 8; i++)
    {
        __m512i a_tiles = _mm512_loadu_epi64(&A[i * 8]);
        a_tiles = _mm512_permutexvar_epi8(tp, a_tiles);
        __m512i row = _mm512_ternarylogic_epi64(
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(0), a_tiles), b_0, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(1), a_tiles), b_1, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(2), a_tiles), b_2, 0), 0x96);
        row = _mm512_ternarylogic_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(3), a_tiles), b_3, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(4), a_tiles), b_4, 0), 0x96);
        row = _mm512_ternarylogic_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(5), a_tiles), b_5, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(6), a_tiles), b_6, 0), 0x96);
        row = _mm512_xor_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(7), a_tiles), b_7, 0));
        row = _mm512_permutexvar_epi8(tp, row);
        _mm512_storeu_epi64(&res[i * 8], row);
    }
    
    return res;
}

When performing multiple matrix multiplications in a row, it may make sense to leave the intermediate results in the format of an 8x8 matrix of 8x8 bit-matrices. The B matrix needs to be permuted anyway, but the two _mm512_permutexvar_epi8 in the loop can be removed. And obviously, if the same matrix is used as the B matrix several times, it only needs to be permuted once. You may need to manually inline the code to convince your compiler to keep the matrix in registers.

Crude benchmarks

A very boring conventional implementation of this 64x64 matrix multiplication may look like this:

array<uint64_t, 64> mmul_gf2_scalar(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    array<uint64_t, 64> res;
    for (size_t i = 0; i < 64; i++) {
        uint64_t result_row = 0;
        for (size_t j = 0; j < 64; j++) {
            if (A[i] & (1ULL << j))
                result_row ^= B[j];
        }
        res[i] = result_row;
    }
    return res;
}

There are various ways to write this slightly differently, some of which may be a bit faster, that's not really the point.

On my PC, which has a 11600K (Rocket Lake) in it, mmul_gf2_scalar runs around 500 times as slow (in terms of the time taken to perform a chain of dependent multiplications) as the AVX-512 implementation. Really, it's that slow - but that is partly due to my choice of data: I mainly benchmarked this on random matrices where each bit has a 50% chance of being set. The AVX-512 implementation does not care about that at all, while the above scalar implementation has thousands (literally) of branch mispredictions. That can be fixed without using SIMD, for example:

array<uint64_t, 64> mmul_gf2_branchfree(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    array<uint64_t, 64> res;
    for (size_t i = 0; i < 64; i++) {
        uint64_t result_row = 0;
        for (size_t j = 0; j < 64; j++) {
            result_row ^= B[j] & -((A[i] >> j) & 1);
        }
        res[i] = result_row;
    }
    return res;
}

That was, in my benchmark, already about 8 times as fast as the branching version, if I don't let MSVC auto-vectorize. If I do let it auto-vectorize (with AVX-512), this implementation becomes 25 times as fast as the branching version. This was not supposed to be a post about branch (mis)prediction, but be careful out there, you might get snagged on a branch.

It would be interesting to compare the AVX-512 version against established finite-field (or GF(2)-specific) linear algebra packages, such as FFLAS-FFPACK and M4RI. Actually I tried to include M4RI in the benchmark, but it ended up being 10 times as slow as mmul_gf2_branchfree (when it is auto-vectorized). That's bizarrely bad so I probably did something wrong, but I've already sunk 4 times as much time into getting M4RI to work at all as it took to write the code which this blog post is really about, so I'll just accept that I did it wrong and leave it at that.