Tuesday 18 June 2024

Sorting the nibbles of a u64

I was reminded (on mastodon) of this nibble-sorting technique (it could be adapted to other element sizes), which I apparently had only vaguely tweeted about in the past. It deserves a post, so here it is.

Binary LSD radix sort can be expressed as a sequence of stable-partitions, first stable-partitioning based on the least-significant bit, then on the second-to-least-significant bit and so on.

In modern x86, pext essentially implements half of a stable partition, only the half that moves a subset of the elements down towards lower indices. If we do that twice, the second time with an inverted mask, and shift the subset of elements where the mask is set left to put it at the top, we get a gadget that partitions a u64 based on a mask:

(pext(x, mask) << popcount(~mask)) | pext(x, ~mask)

This is sometimes called the sheep-and-goats operation.

For radix sort the masks that we need are, in order, the least significant bit of each element, each broadcasted to cover the whole corresponding element, then the same thing but with the second-to-least-signficant bit and so on. One way to express that is by shifting the whole thing right to put the bit that we want to broadcast in the the least significant position of the element, and then multiplying by 15 to broadcast that bit into every bit of the element. Different compilers handled that multiplication by 15 differently (there are alternative ways to express that).

static ulong sort_nibbles(ulong x)
{
    ulong m = 0x1111111111111111;
    ulong t = (x & m) *15;
    x = (Bmi2.X64.ParallelBitExtract(x, t) << BitOperations.PopCount(~t)) |
        Bmi2.X64.ParallelBitExtract(x, ~t);
    t = ((x >> 1) & m) * 15;
    x = (Bmi2.X64.ParallelBitExtract(x, t) << BitOperations.PopCount(~t)) |
        Bmi2.X64.ParallelBitExtract(x, ~t);
    t = ((x >> 2) & m) * 15;
    x = (Bmi2.X64.ParallelBitExtract(x, t) << BitOperations.PopCount(~t)) |
        Bmi2.X64.ParallelBitExtract(x, ~t);
    t = ((x >> 3) & m) * 15;
    x = (Bmi2.X64.ParallelBitExtract(x, t) << BitOperations.PopCount(~t)) |
        Bmi2.X64.ParallelBitExtract(x, ~t);
    return x;
}

It's easy to extend this to a key-value sort. Hypothetically you could use that key-value sort to invert a permutation (sorting the values 0..15 by the permutation), but you can do much better with AVX512.

Tuesday 4 June 2024

Sharpening a lower bound with KnownBits information

I have written about this before, but that was a long time ago, I've had a lot more practice with similar things since then. This topic came up on Mastodon, inspiring me to give it another try. Actually the title is a bit of a lie, I will be using Miné's bitfield domain in which we have bitvectors z indicating the bits that can be zero and o indicating the bits that can be one (as opposed to bitvectors which indicate the bits known to be zero and the bits known to be one respectively, or bitvectors 0b and 1b as used in the paper linked below where 1b has bits set that are known to be set and 0b has bits unset that are known to be unset). The exact representation doesn't really matter.

The problem, to be clear, is that suppose we have a lower bound on some variable, along with some knowledge about its bits (knowing that some bits have a fixed value, which others do not), for example we may know that a variable is even (its least significant bit is known to be zero) and at least 5. "Sharpening" the lower bound means increasing it, if possible, so that the lower bound "fits" the knowledge we have about the bits. If a value is even and at least 5, it is also at least 6, so we can increase the lower bound.

As a more recent reference for an algorithm that is better than my old one, you can read Sharpening Constraint Programming approaches for Bit-Vector Theory.

As that paper notes, we need to find the highest bit in the current lower bound that doesn't "fit" the KnownBits (or z, o pair from the bitfield domain) information, and then either:

  • If that bit was not set but should be, we need to set it, and reset any lower bits that are not required to be set (lower bound must go up, but only as little as possible).
  • If that bit was set but shouldn't be, we need to reset it, and in order to do that we need to set a higher bit that wasn't set yet, and also reset any lower bits that are not required to be set.

So far so good. What that paper doesn't tell you, is that these are essentially the same case, and we can do:

  • Starting from the highest "wrong" bit in the lower bound, find the lowest bit that is unset but could be set, set it, and clear any lower bits that are not required to be set.

That mostly sounds like the second case, what allows the original two cases to be unified is the fact that the bit we find is the same as the bit that needs to be set in the first case too.

As a reminder, x & -x is a common technique used to extract or isolate the lowest set bit aka blsi. It can also be written as x & (~x + 1), and if we change the 1 to some other constant, we can use this technique to find the lowest set bit but starting from some position that is not necessarily the least significant bit. So if we start from highestSetBit(~low & o), we find the bit we're looking for. Actually the annoying part is highestSetBit. Putting the rest together, we may get an implementation like this:

uint64_t sharpen_low(uint64_t low, uint64_t z, uint64_t o)
{
    uint64_t m = (~low & ~z) | (low & ~o);
    if (m) {
        uint64_t target = ~low & o;
        target &= ~target + highestSetBit(m);
        low = (low & -target) | target;
        low |= ~z;
    }
    return low;
}

The branch on m is a bit annoying, but on the plus side it means that the input of highestSetBit is always non-zero. Zero is otherwise a bit of an annoying case to handle in highestSetBit. In modern C++, you can use std::bit_floor for highestSetBit.

Sharpening the upper bound is symmetric, it can be implemented as ~sharpen_low(~high, o, z) or you could push the bitwise flips "inside" the algorithm and do some algebra to cancel them out.

Monday 3 June 2024

Multiplying 64x64 bit-matrices with GF2P8AFFINEQB

This is a relatively simple use of GF2P8AFFINEQB. By itself GF2P8AFFINEQB essentially multiplies two 8x8 bit-matrices (but with transpose-flip applied to the second operand, and an extra XOR by a byte that we can set to zero). A 64x64 matrix can be seen as a block-matrix where each block is an 8x8 matrix. You can also view this as taking the ring of 8x8 matrices over GF(2), and then working with 8x8 matrices with elements from that ring. All we really need to do is write an 8x8 matrix multiplication and let GF2P8AFFINEQB take care of the complicated part.

Using the "full" 512-bit version of VGF2P8AFFINEQB from AVX-512, one VGF2P8AFFINEQB instructions performs 8 of those products. A convenient way to use them is by broadcasting an element (which is really an 8x8 bit-matrix but let's put that aside for now) from the left matrix to all QWORD elements of a ZMM vector, and multiplying that by a row of the right matrix. That way we end up with a row of the result, which is nice to work with: no reshaping or horizontal-SIMD required. XOR-ing QWORDs together horizontally could be done relatively reasonably with another GF2P8AFFINEQB trick, which is neat but avoiding it is even better. All we need to do to compute a row of the result (still viewed as an 8x8 matrix) is 8 broadcasts, 8 VGF2P8AFFINEQB, and XOR-ing the 8 results together, which doesn't take 8 VPXORQ because VPTERNLOGQ can XOR three vectors together. Then just do this for each row of the result.

There are two things that I've skipped so far. First, the built-in transpose-flip of GF2P8AFFINEQB needs to be cancelled out with a flip-transpose (unless explicitly working with a matrix in a weird format is OK). Second, working with an 8x8 block-matrix is mathematically "free" by imagining some dotted lines running through the matrix, but in order to get the right data into GF2P8AFFINEQB we have to actually rearrange it (again: unless the weird format is OK).

One way to implement a flip-transpose (ie the inverse of the bit-permutation that GF2P8AFFINEQB applies to its second operand) is by reversing the bytes in each QWORD and then left-multiplying (in the second of GF2P8AFFINEQB-ing with constant as the first operand) by a flipped identity matrix, which as a QWORD looks like: 0x0102040810204080. Reversing the bytes in each QWORD could be done with a VPERMB, there are other ways, but we're about to have a VPERMB anyway.

Rearranging the data between a fully row-major layout and an 8x8 matrix in which each element is an 8x8 bit-matrix is easy, that's just an 8x8 transpose after all, so just VPERMB. That's needed both for the inputs and the output. The input that is the right-hand operand of the overall matrix multiplication also needs to have a byte-reverse applied to each QWORD, the same VPERMB that does that transpose can also do that byte-reverse.

Here's one way to put that all together:

array<uint64_t, 64> mmul_gf2_avx512(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    __m512i id = _mm512_set1_epi64(0x0102040810204080);
    __m512i tp = _mm512_setr_epi8(
        0, 8, 16, 24, 32, 40, 48, 56,
        1, 9, 17, 25, 33, 41, 49, 57,
        2, 10, 18, 26, 34, 42, 50, 58,
        3, 11, 19, 27, 35, 43, 51, 59,
        4, 12, 20, 28, 36, 44, 52, 60,
        5, 13, 21, 29, 37, 45, 53, 61,
        6, 14, 22, 30, 38, 46, 54, 62,
        7, 15, 23, 31, 39, 47, 55, 63);
    __m512i tpr = _mm512_setr_epi8(
        56, 48, 40, 32, 24, 16, 8, 0,
        57, 49, 41, 33, 25, 17, 9, 1,
        58, 50, 42, 34, 26, 18, 10, 2,
        59, 51, 43, 35, 27, 19, 11, 3,
        60, 52, 44, 36, 28, 20, 12, 4,
        61, 53, 45, 37, 29, 21, 13, 5,
        62, 54, 46, 38, 30, 22, 14, 6,
        63, 55, 47, 39, 31, 23, 15, 7);
    array<uint64_t, 64> res;

    __m512i b_0 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[0]));
    __m512i b_1 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[8]));
    __m512i b_2 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[16]));
    __m512i b_3 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[24]));
    __m512i b_4 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[32]));
    __m512i b_5 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[40]));
    __m512i b_6 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[48]));
    __m512i b_7 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[56]));
    
    b_0 = _mm512_gf2p8affine_epi64_epi8(id, b_0, 0);
    b_1 = _mm512_gf2p8affine_epi64_epi8(id, b_1, 0);
    b_2 = _mm512_gf2p8affine_epi64_epi8(id, b_2, 0);
    b_3 = _mm512_gf2p8affine_epi64_epi8(id, b_3, 0);
    b_4 = _mm512_gf2p8affine_epi64_epi8(id, b_4, 0);
    b_5 = _mm512_gf2p8affine_epi64_epi8(id, b_5, 0);
    b_6 = _mm512_gf2p8affine_epi64_epi8(id, b_6, 0);
    b_7 = _mm512_gf2p8affine_epi64_epi8(id, b_7, 0);

    for (size_t i = 0; i < 8; i++)
    {
        __m512i a_tiles = _mm512_loadu_epi64(&A[i * 8]);
        a_tiles = _mm512_permutexvar_epi8(tp, a_tiles);
        __m512i row = _mm512_ternarylogic_epi64(
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(0), a_tiles), b_0, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(1), a_tiles), b_1, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(2), a_tiles), b_2, 0), 0x96);
        row = _mm512_ternarylogic_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(3), a_tiles), b_3, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(4), a_tiles), b_4, 0), 0x96);
        row = _mm512_ternarylogic_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(5), a_tiles), b_5, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(6), a_tiles), b_6, 0), 0x96);
        row = _mm512_xor_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(7), a_tiles), b_7, 0));
        row = _mm512_permutexvar_epi8(tp, row);
        _mm512_storeu_epi64(&res[i * 8], row);
    }
    
    return res;
}

When performing multiple matrix multiplications in a row, it may make sense to leave the intermediate results in the format of an 8x8 matrix of 8x8 bit-matrices. The B matrix needs to be permuted anyway, but the two _mm512_permutexvar_epi8 in the loop can be removed. And obviously, if the same matrix is used as the B matrix several times, it only needs to be permuted once. You may need to manually inline the code to convince your compiler to keep the matrix in registers.

Crude benchmarks

A very boring conventional implementation of this 64x64 matrix multiplication may look like this:

array<uint64_t, 64> mmul_gf2_scalar(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    array<uint64_t, 64> res;
    for (size_t i = 0; i < 64; i++) {
        uint64_t result_row = 0;
        for (size_t j = 0; j < 64; j++) {
            if (A[i] & (1ULL << j))
                result_row ^= B[j];
        }
        res[i] = result_row;
    }
    return res;
}

There are various ways to write this slightly differently, some of which may be a bit faster, that's not really the point.

On my PC, which has a 11600K (Rocket Lake) in it, mmul_gf2_scalar runs around 500 times as slow (in terms of the time taken to perform a chain of dependent multiplications) as the AVX-512 implementation. Really, it's that slow - but that is partly due to my choice of data: I mainly benchmarked this on random matrices where each bit has a 50% chance of being set. The AVX-512 implementation does not care about that at all, while the above scalar implementation has thousands (literally) of branch mispredictions. That can be fixed without using SIMD, for example:

array<uint64_t, 64> mmul_gf2_branchfree(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    array<uint64_t, 64> res;
    for (size_t i = 0; i < 64; i++) {
        uint64_t result_row = 0;
        for (size_t j = 0; j < 64; j++) {
            result_row ^= B[j] & -((A[i] >> j) & 1);
        }
        res[i] = result_row;
    }
    return res;
}

That was, in my benchmark, already about 8 times as fast as the branching version, if I don't let MSVC auto-vectorize. If I do let it auto-vectorize (with AVX-512), this implementation becomes 25 times as fast as the branching version. This was not supposed to be a post about branch (mis)prediction, but be careful out there, you might get snagged on a branch.

It would be interesting to compare the AVX-512 version against established finite-field (or GF(2)-specific) linear algebra packages, such as FFLAS-FFPACK and M4RI. Actually I tried to include M4RI in the benchmark, but it ended up being 10 times as slow as mmul_gf2_branchfree (when it is auto-vectorized). That's bizarrely bad so I probably did something wrong, but I've already sunk 4 times as much time into getting M4RI to work at all as it took to write the code which this blog post is really about, so I'll just accept that I did it wrong and leave it at that.