Thursday 22 August 2024

Enumerating identities, part 2

Part 2 of Enumerating all mathematical identities (in fixed-size bitvector arithmetic with a restricted set of operations) of a certain size.

To recap, the approach in part 1 was broadly:

  • Use a CEGIS-style loop to find a pair of expressions that are equal for all possible inputs.
  • After finding such a pair, block it from being found again.
  • Repeat until no more pairs can be found.

One weakness (in addition to the other weaknesses) of that approach is that the "block list" keeps growing as more identities are found. Here is an alternative approach that does not use an ever-growing block list:

  • Interpreting the raw bit-string that represents a pair of expressions as an integer, find the lowest pair of equivalent expressions.
  • After finding the lowest pair, interpreted as the number X, set the lower bound for the next pair to X + 1.
  • Repeat until no more pairs can be found.

Finding the lowest pair of expressions

SAT by itself does not try to find the lowest solution, nor does the CEGIS-loop built on top of it. But we can use the CEGIS-loop as an oracle to answer the question: is there any solution in (a restricted part of) the search space, and that lets us do a bitwise binary search - the "find one bit at the time" variant of binary search, typically discussed in a completely different context. Bitwise binary search maps well to SAT, not directly of course, but in the following way:

  1. Initialize the prefix to an empty list.
  2. If the prefix has the same size as a pair of expressions, directly return the result of the CEGIS-oracle.
  3. Extend the prefix with false.
  4. Ask the CEGIS-oracle if there is a solution that starts with the prefix, if there is, go back to step 2.
  5. Otherwise, turn the false at the end of the prefix into true, then go back to step 2.

Asking a SAT solver for a solution that starts with a given prefix is easy and tends to make the SAT instance easier to solve (especially when the prefix is long), this only involves forcing some variables (the ones covered by the prefix) to be true or false, using single-literal clauses.

A very useful optimization can be done in step 4: when the CEGIS-oracle says there is a solution with the current prefix, the solution may have some extra zeroes after the prefix which we can use directly to extend the prefix for free. That saves a lot of SAT solves when the solutions tend to have a lot of zeroes in them, which in my case they do, due to extensive use of one-hot encoding.

Encoding the lower bound

There is a neat way to encode that an integer must be greater-than-or-equal-to some given constant, which I haven't seen people talk about (perhaps it's part of the folklore?), using only popcnt(lower_bound) clauses. The idea here is that for every bit that's set in the bound, at least one of the following bits must be set in the solution: that bit itself, or any more-significant bit that is zero in the lower bound. That only takes one clause to encode, a clause containing the variable that corresponds to the set bit, and the variables corresponding to the zeroes to the left of that set bit.

Code

For concreteness, here's how I implemented constraining solutions to conform to the prefix, it's really simple:

// set the prefix (used by binary search)
for (size_t i = 0; i < prefix.size(); i++)
{
    if (prefix[i])
        s.addClause(allprogbits[i]);
    else
        s.addClause(~allprogbits[i]);
}

Here's how I set the lower bound:

// if there is a lower bound (used by binary search), enforce it
if (!lower_bound.empty())
{
    vector<Lit> cl;
    for (size_t i = 0; i < lower_bound.size(); i++) {
        if (lower_bound[i]) {
            cl.push_back(allprogbits[i]);
            s.addClause(cl);
            cl.pop_back();
        }
        else
        {
            cl.push_back(allprogbits[i]);
        }
    }
}

And binary search (with the optimization to keep the extra zeroes that the solver gives for free) looks like this:

optional<pair<vector<InstrB>, vector<InstrB>>> find_lowest(vector<bool>& prefix,
      vector<vector<int>>& inputs)
{
    do {
        if (prefix.size() == progbits) {
            return format_progbits(synthesize(prefix, inputs),
                inputcount, lhs_size, rhs_size);
        }
        else {
            prefix.push_back(false);
            auto f = synthesize(prefix, inputs);
            if (f.has_value()) {
                // if the next bits are already zero in the solution, keep them zero
                auto bits = *f;
                while (prefix.size() < progbits && !bits[prefix.size()])
                    prefix.push_back(false);
                continue;
            }

            prefix.pop_back();
            prefix.push_back(true);
        }
    } while (true);
}

The full code of my implementation of this idea is available on gitlab. It's a bit crap but at least it should show any detail that you may still be wondering about. In the code I also constrain expressions to not be "a funny way to write zero" and to not be "a complicated way to do nothing", otherwise a lot of less-interesting identities would be generated.

Tuesday 18 June 2024

Sorting the nibbles of a u64

I was reminded (on mastodon) of this nibble-sorting technique (it could be adapted to other element sizes), which I apparently had only vaguely tweeted about in the past. It deserves a post, so here it is.

Binary LSD radix sort can be expressed as a sequence of stable-partitions, first stable-partitioning based on the least-significant bit, then on the second-to-least-significant bit and so on.

In modern x86, pext essentially implements half of a stable partition, only the half that moves a subset of the elements down towards lower indices. If we do that twice, the second time with an inverted mask, and shift the subset of elements where the mask is set left to put it at the top, we get a gadget that partitions a u64 based on a mask:

(pext(x, mask) << popcount(~mask)) | pext(x, ~mask)

This is sometimes called the sheep-and-goats operation.

For radix sort the masks that we need are, in order, the least significant bit of each element, each broadcasted to cover the whole corresponding element, then the same thing but with the second-to-least-signficant bit and so on. One way to express that is by shifting the whole thing right to put the bit that we want to broadcast in the the least significant position of the element, and then multiplying by 15 to broadcast that bit into every bit of the element. Different compilers handled that multiplication by 15 differently (there are alternative ways to express that).

static ulong sort_nibbles(ulong x)
{
    ulong m = 0x1111111111111111;
    ulong t = (x & m) *15;
    x = (Bmi2.X64.ParallelBitExtract(x, t) << BitOperations.PopCount(~t)) |
        Bmi2.X64.ParallelBitExtract(x, ~t);
    t = ((x >> 1) & m) * 15;
    x = (Bmi2.X64.ParallelBitExtract(x, t) << BitOperations.PopCount(~t)) |
        Bmi2.X64.ParallelBitExtract(x, ~t);
    t = ((x >> 2) & m) * 15;
    x = (Bmi2.X64.ParallelBitExtract(x, t) << BitOperations.PopCount(~t)) |
        Bmi2.X64.ParallelBitExtract(x, ~t);
    t = ((x >> 3) & m) * 15;
    x = (Bmi2.X64.ParallelBitExtract(x, t) << BitOperations.PopCount(~t)) |
        Bmi2.X64.ParallelBitExtract(x, ~t);
    return x;
}

It's easy to extend this to a key-value sort. Hypothetically you could use that key-value sort to invert a permutation (sorting the values 0..15 by the permutation), but you can do much better with AVX512.

Tuesday 4 June 2024

Sharpening a lower bound with KnownBits information

I have written about this before, but that was a long time ago, I've had a lot more practice with similar things since then. This topic came up on Mastodon, inspiring me to give it another try. Actually the title is a bit of a lie, I will be using Miné's bitfield domain in which we have bitvectors z indicating the bits that can be zero and o indicating the bits that can be one (as opposed to bitvectors which indicate the bits known to be zero and the bits known to be one respectively, or bitvectors 0b and 1b as used in the paper linked below where 1b has bits set that are known to be set and 0b has bits unset that are known to be unset). The exact representation doesn't really matter.

The problem, to be clear, is that suppose we have a lower bound on some variable, along with some knowledge about its bits (knowing that some bits have a fixed value, which others do not), for example we may know that a variable is even (its least significant bit is known to be zero) and at least 5. "Sharpening" the lower bound means increasing it, if possible, so that the lower bound "fits" the knowledge we have about the bits. If a value is even and at least 5, it is also at least 6, so we can increase the lower bound.

As a more recent reference for an algorithm that is better than my old one, you can read Sharpening Constraint Programming approaches for Bit-Vector Theory.

As that paper notes, we need to find the highest bit in the current lower bound that doesn't "fit" the KnownBits (or z, o pair from the bitfield domain) information, and then either:

  • If that bit was not set but should be, we need to set it, and reset any lower bits that are not required to be set (lower bound must go up, but only as little as possible).
  • If that bit was set but shouldn't be, we need to reset it, and in order to do that we need to set a higher bit that wasn't set yet, and also reset any lower bits that are not required to be set.

So far so good. What that paper doesn't tell you, is that these are essentially the same case, and we can do:

  • Starting from the highest "wrong" bit in the lower bound, find the lowest bit that is unset but could be set, set it, and clear any lower bits that are not required to be set.

That mostly sounds like the second case, what allows the original two cases to be unified is the fact that the bit we find is the same as the bit that needs to be set in the first case too.

As a reminder, x & -x is a common technique used to extract or isolate the lowest set bit aka blsi. It can also be written as x & (~x + 1), and if we change the 1 to some other constant, we can use this technique to find the lowest set bit but starting from some position that is not necessarily the least significant bit. So if we start from highestSetBit(~low & o), we find the bit we're looking for. Actually the annoying part is highestSetBit. Putting the rest together, we may get an implementation like this:

uint64_t sharpen_low(uint64_t low, uint64_t z, uint64_t o)
{
    uint64_t m = (~low & ~z) | (low & ~o);
    if (m) {
        uint64_t target = ~low & o;
        target &= ~target + highestSetBit(m);
        low = (low & -target) | target;
        low |= ~z;
    }
    return low;
}

The branch on m is a bit annoying, but on the plus side it means that the input of highestSetBit is always non-zero. Zero is otherwise a bit of an annoying case to handle in highestSetBit. In modern C++, you can use std::bit_floor for highestSetBit.

Sharpening the upper bound is symmetric, it can be implemented as ~sharpen_low(~high, o, z) or you could push the bitwise flips "inside" the algorithm and do some algebra to cancel them out.

Monday 3 June 2024

Multiplying 64x64 bit-matrices with GF2P8AFFINEQB

This is a relatively simple use of GF2P8AFFINEQB. By itself GF2P8AFFINEQB essentially multiplies two 8x8 bit-matrices (but with transpose-flip applied to the second operand, and an extra XOR by a byte that we can set to zero). A 64x64 matrix can be seen as a block-matrix where each block is an 8x8 matrix. You can also view this as taking the ring of 8x8 matrices over GF(2), and then working with 8x8 matrices with elements from that ring. All we really need to do is write an 8x8 matrix multiplication and let GF2P8AFFINEQB take care of the complicated part.

Using the "full" 512-bit version of VGF2P8AFFINEQB from AVX-512, one VGF2P8AFFINEQB instructions performs 8 of those products. A convenient way to use them is by broadcasting an element (which is really an 8x8 bit-matrix but let's put that aside for now) from the left matrix to all QWORD elements of a ZMM vector, and multiplying that by a row of the right matrix. That way we end up with a row of the result, which is nice to work with: no reshaping or horizontal-SIMD required. XOR-ing QWORDs together horizontally could be done relatively reasonably with another GF2P8AFFINEQB trick, which is neat but avoiding it is even better. All we need to do to compute a row of the result (still viewed as an 8x8 matrix) is 8 broadcasts, 8 VGF2P8AFFINEQB, and XOR-ing the 8 results together, which doesn't take 8 VPXORQ because VPTERNLOGQ can XOR three vectors together. Then just do this for each row of the result.

There are two things that I've skipped so far. First, the built-in transpose-flip of GF2P8AFFINEQB needs to be cancelled out with a flip-transpose (unless explicitly working with a matrix in a weird format is OK). Second, working with an 8x8 block-matrix is mathematically "free" by imagining some dotted lines running through the matrix, but in order to get the right data into GF2P8AFFINEQB we have to actually rearrange it (again: unless the weird format is OK).

One way to implement a flip-transpose (ie the inverse of the bit-permutation that GF2P8AFFINEQB applies to its second operand) is by reversing the bytes in each QWORD and then left-multiplying (in the second of GF2P8AFFINEQB-ing with constant as the first operand) by a flipped identity matrix, which as a QWORD looks like: 0x0102040810204080. Reversing the bytes in each QWORD could be done with a VPERMB, there are other ways, but we're about to have a VPERMB anyway.

Rearranging the data between a fully row-major layout and an 8x8 matrix in which each element is an 8x8 bit-matrix is easy, that's just an 8x8 transpose after all, so just VPERMB. That's needed both for the inputs and the output. The input that is the right-hand operand of the overall matrix multiplication also needs to have a byte-reverse applied to each QWORD, the same VPERMB that does that transpose can also do that byte-reverse.

Here's one way to put that all together:

array<uint64_t, 64> mmul_gf2_avx512(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    __m512i id = _mm512_set1_epi64(0x0102040810204080);
    __m512i tp = _mm512_setr_epi8(
        0, 8, 16, 24, 32, 40, 48, 56,
        1, 9, 17, 25, 33, 41, 49, 57,
        2, 10, 18, 26, 34, 42, 50, 58,
        3, 11, 19, 27, 35, 43, 51, 59,
        4, 12, 20, 28, 36, 44, 52, 60,
        5, 13, 21, 29, 37, 45, 53, 61,
        6, 14, 22, 30, 38, 46, 54, 62,
        7, 15, 23, 31, 39, 47, 55, 63);
    __m512i tpr = _mm512_setr_epi8(
        56, 48, 40, 32, 24, 16, 8, 0,
        57, 49, 41, 33, 25, 17, 9, 1,
        58, 50, 42, 34, 26, 18, 10, 2,
        59, 51, 43, 35, 27, 19, 11, 3,
        60, 52, 44, 36, 28, 20, 12, 4,
        61, 53, 45, 37, 29, 21, 13, 5,
        62, 54, 46, 38, 30, 22, 14, 6,
        63, 55, 47, 39, 31, 23, 15, 7);
    array<uint64_t, 64> res;

    __m512i b_0 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[0]));
    __m512i b_1 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[8]));
    __m512i b_2 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[16]));
    __m512i b_3 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[24]));
    __m512i b_4 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[32]));
    __m512i b_5 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[40]));
    __m512i b_6 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[48]));
    __m512i b_7 = _mm512_permutexvar_epi8(tpr, _mm512_loadu_epi64(&B[56]));
    
    b_0 = _mm512_gf2p8affine_epi64_epi8(id, b_0, 0);
    b_1 = _mm512_gf2p8affine_epi64_epi8(id, b_1, 0);
    b_2 = _mm512_gf2p8affine_epi64_epi8(id, b_2, 0);
    b_3 = _mm512_gf2p8affine_epi64_epi8(id, b_3, 0);
    b_4 = _mm512_gf2p8affine_epi64_epi8(id, b_4, 0);
    b_5 = _mm512_gf2p8affine_epi64_epi8(id, b_5, 0);
    b_6 = _mm512_gf2p8affine_epi64_epi8(id, b_6, 0);
    b_7 = _mm512_gf2p8affine_epi64_epi8(id, b_7, 0);

    for (size_t i = 0; i < 8; i++)
    {
        __m512i a_tiles = _mm512_loadu_epi64(&A[i * 8]);
        a_tiles = _mm512_permutexvar_epi8(tp, a_tiles);
        __m512i row = _mm512_ternarylogic_epi64(
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(0), a_tiles), b_0, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(1), a_tiles), b_1, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(2), a_tiles), b_2, 0), 0x96);
        row = _mm512_ternarylogic_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(3), a_tiles), b_3, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(4), a_tiles), b_4, 0), 0x96);
        row = _mm512_ternarylogic_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(5), a_tiles), b_5, 0),
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(6), a_tiles), b_6, 0), 0x96);
        row = _mm512_xor_epi64(row,
            _mm512_gf2p8affine_epi64_epi8(_mm512_permutexvar_epi64(_mm512_set1_epi64(7), a_tiles), b_7, 0));
        row = _mm512_permutexvar_epi8(tp, row);
        _mm512_storeu_epi64(&res[i * 8], row);
    }
    
    return res;
}

When performing multiple matrix multiplications in a row, it may make sense to leave the intermediate results in the format of an 8x8 matrix of 8x8 bit-matrices. The B matrix needs to be permuted anyway, but the two _mm512_permutexvar_epi8 in the loop can be removed. And obviously, if the same matrix is used as the B matrix several times, it only needs to be permuted once. You may need to manually inline the code to convince your compiler to keep the matrix in registers.

Crude benchmarks

A very boring conventional implementation of this 64x64 matrix multiplication may look like this:

array<uint64_t, 64> mmul_gf2_scalar(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    array<uint64_t, 64> res;
    for (size_t i = 0; i < 64; i++) {
        uint64_t result_row = 0;
        for (size_t j = 0; j < 64; j++) {
            if (A[i] & (1ULL << j))
                result_row ^= B[j];
        }
        res[i] = result_row;
    }
    return res;
}

There are various ways to write this slightly differently, some of which may be a bit faster, that's not really the point.

On my PC, which has a 11600K (Rocket Lake) in it, mmul_gf2_scalar runs around 500 times as slow (in terms of the time taken to perform a chain of dependent multiplications) as the AVX-512 implementation. Really, it's that slow - but that is partly due to my choice of data: I mainly benchmarked this on random matrices where each bit has a 50% chance of being set. The AVX-512 implementation does not care about that at all, while the above scalar implementation has thousands (literally) of branch mispredictions. That can be fixed without using SIMD, for example:

array<uint64_t, 64> mmul_gf2_branchfree(const array<uint64_t, 64>& A, const array<uint64_t, 64>& B)
{
    array<uint64_t, 64> res;
    for (size_t i = 0; i < 64; i++) {
        uint64_t result_row = 0;
        for (size_t j = 0; j < 64; j++) {
            result_row ^= B[j] & -((A[i] >> j) & 1);
        }
        res[i] = result_row;
    }
    return res;
}

That was, in my benchmark, already about 8 times as fast as the branching version, if I don't let MSVC auto-vectorize. If I do let it auto-vectorize (with AVX-512), this implementation becomes 25 times as fast as the branching version. This was not supposed to be a post about branch (mis)prediction, but be careful out there, you might get snagged on a branch.

It would be interesting to compare the AVX-512 version against established finite-field (or GF(2)-specific) linear algebra packages, such as FFLAS-FFPACK and M4RI. Actually I tried to include M4RI in the benchmark, but it ended up being 10 times as slow as mmul_gf2_branchfree (when it is auto-vectorized). That's bizarrely bad so I probably did something wrong, but I've already sunk 4 times as much time into getting M4RI to work at all as it took to write the code which this blog post is really about, so I'll just accept that I did it wrong and leave it at that.

Tuesday 28 May 2024

Implementing grevmul with GF2P8AFFINEQB

As a reminder, grev is generalized bit-reversal, and performs a bit-permutation that corresponds to XORing the indices of the bits of the left operand by the number given in the right operand. A possible (not very suitable for actual use, but illustrative) implementation of grev is:

def grev(bits, k):
    return bits[np.arange(len(bits)) ^ k]

grevmul (see also this older post where I list some of its properties) can be defined in terms of grev (hence the name[1]), with grev replacing the left shift in a carryless multiplication. But let's do something else first. An interesting way to look at multiplication (plain old multiplication between natural numbers) is as:

  1. Form the Cartesian product of the inputs, viewed as arrays of bits.
  2. For each pair in the Cartesian product, compute the AND of the two bits.
  3. Send the AND of each pair with index (i, j) to the bin i + j, to be accumulated with the function (+).
  4. Resolve the carries.

Carryless multiplication works mostly the same way except that accumulation is done with XOR, which makes step 4 unnecessary. grevmul also works mostly the same way, with accumulation also done with XOR, but now the pair with index (i, j) is sent to the bin i XOR j.

+ ^
+ imul clmul
^ ??? grevmul

That won't be important for the implementation, but it may help you think about what grevmul does, and this will be on the test. OK there is no test, but you can test yourself by explaining why (popcnt(a & b) & 1) == (grevmul(a, b) & 1), based on reasoning about the Cartesian-product-algorithm for grevmul.

Implementing grevmul with GF2P8AFFINEQB

Some of you may have already seen a version of the code that I am about to discuss, although I have made some changes since. Nothing serious, but while thinking about how the code worked, I encountered some opportunities to make polish it up.

GF2P8AFFINEQB computes, for each QWORD (so for now, let's concentrate on one QWORD of the result), the product of two bit-matrices, where the left matrix comes from the first input and the right matrix is the transpose-flip (or flip-transpose, however you prefer to think about it) of the second input. You can think of it as a mutant of the mxor operation found in Knuth's MMIX instruction set, and from here on I will use the name bmatxor for the version of this operation that simply multiplies two 8x8 bit-matrices, with none of that transpose-flip business[2]. There is also a free XOR by a constant byte thrown in at the end, which can be safely ignored by setting it to zero. The transpose-flip however need to be worked around or otherwise taken into account. You may also want to read (Ab)using gf2p8affineqb to turn indices into bits first for some extra familiarity with / alternate view of GF2P8AFFINEQB.

To implement grevmul(a, b), we need some linear combination (controlled by b) of permuted (by grev) versions of a (the roles of a and b can be swapped since grevmul is commutative, which the Cartesian product-based algorithm makes clear). GF2P8AFFINEQB is all about linear combinations, but it works with 8x8 matrices, not 64x64 (put that in AVX-4096) which would have made it really simple. Fortunately, we can slice a grevmul however we want.

Now to start in the middle, let's say we have 8 copies of a byte (a byte of b), with the i'th copy grev'ed by i, concatenated into a QWORD that I will call m. bmatxor(a, m) would (thinking of a matrix multiplication AB as forming linear combinations of rows of B), for each row of the result, form a linear combination (controlled by a) of grev'ed copies of the byte from b. That may seem like it would be wrong, since every byte of a is done separately and uses the same m, so it's "missing" a grev by 8 for the second byte, 16 for the third byte, etc. But if x is a byte (not if and only if, just regular "if"), then grev(x, 8 * i) is the same as x << (8 * i) and the second byte is indeed already in the second position anyway, so we get this for free. Thus, mxor(a, m) would allow us to grevmul a 64-bit number by an 8-bit number. If we could just do that 8 times (for each byte of b) and combine the results, we're done.

But we don't have bmatxor, we have GF2P8AFFINEQB with its built-in transpose-flip, and that presents a choice: either put another GF2P8AFFINEQB before, or after, the "main" GF2P8AFFINEQB to counter that built-in transpose. Not the whole transpose-flip, let's put the flip aside for now. There is a small reason to favour the "extra GF2P8AFFINEQB after the main one" order, namely that that results in GF2P8AFFINEQB(m, broadcast(a)) (as opposed to GF2P8AFFINEQB(broadcast(a), m)) and when a comes from memory it can be loaded and broadcasted directly with a {1to8} broadcasted memory operand. That option would not be available if a was the first operand of the GF2P8AFFINEQB. This is a small matter, but we have to make the choice somehow, and there seems to be no difference aside from this.

At this point there are two (and a half) pieces of the puzzle left: forming m, and horizontally combining 8 results.

Forming m

If we would first form 8 copies of a byte of b and then try to grev those by their respective indices, that would be hard. But doing it the other way around is easy, broadcast b into each QWORD of a 512-bit vector, grev each QWORD by its index, then transpose the whole vector as an 8x8 matrix of bytes. Actually for constant-reuse (loading fewer humongous vector constants is rarely bad) and because of the built-in transpose-flip it turns out slightly better to transpose-flip that 8x8 matrix of bytes as well, that doesn't cost any more than just transposing it.

grev-ing each QWORD by its index is easy, perhaps easier than it sounds. A grev by 0..7 only rearranges the bits within each byte, which is easy to do with a single GF2P8AFFINEQB with a constant as the second operand (a "P" step).

An 8x8 transpose-flip is just a VPERMB by some index vector.[3]

Combining the results

After the "middle" step, if we went with the "extra GF2P8AFFINEQB before the main GF2P8AFFINEQB"-route, we would have bmatxor(a, m) in each QWORD (with a different m per QWORD), which need to be combined. If we were implementing a plain old integer multiplication, the value in the QWORD with the index i would be shifted left by 8 * i and then all of them would be summed up. Since we're implementing a grevmul, that value is grev'ed by 8 * i (which is just some byte permutation) and the resulting QWORDs are XORed.

If we go with the "extra GF2P8AFFINEQB after the main GF2P8AFFINEQB"-route, which I chose, then there is a GF2P8AFFINEQB to do before we can start combining QWORDs. We really only need it to un-transpose the result of the "main" GF2P8AFFINEQB, the rest is just a byte-permutation and we're about to do a VPERMB anyway (if we choose the cool way of XORing the QWORDs together), but there is a neat opportunity here: if we re-use the same set of bit-matrices that was used to grev b by 0..7, in addition to the transpose that we wanted we also permute the bytes such that we grev each QWORD by 8 * i as a bonus.

Now we could just XOR-fold the vector 3 times to end up with the XOR of all eight QWORDs, and that would be a totally valid implementation, but it would also be boring. Alternatively, if we transpose the vector as an 8x8 matrix of bytes again (it can be a transpose-flip, so we get to reuse the same index vector as before), then the i'th byte of each QWORD would be gathered in the i'th QWORD of the transpose and bmatxor(0xFF, vector) would XOR together all bytes with the same index and give us a vector that has one byte of the final result per QWORD, easily extractable with VPMOVQB. We still don't have bmatxor though, we have bmatxor with an extra transpose-flip, which can be countered the usual way, with yet another GF2P8AFFINEQB.

As yet another alternative[4], after similar transpose trickery bmatxor(vector, 0xFF) would also XOR together bytes with the same index but leave the result in a form that can be extracted with VPMOVB2M, which puts the result in a mask register but it's not too bad to move it from there to a GPR.

The code

In case anyone makes it this far, here is one possible embodiment[5] of the algorithm described herein:

uint64_t grevmul_avx512(uint64_t a, uint64_t b)
{
    uint64_t id = 0x0102040810204080;
    __m512i grev_by_index = _mm512_setr_epi64(
        id,
        grev(id, 1),
        grev(id, 2),
        grev(id, 3),
        grev(id, 4),
        grev(id, 5),
        grev(id, 6),
        grev(id, 7));
    __m512i tp_flip = _mm512_setr_epi8(
        56, 48, 40, 32, 24, 16, 8, 0,
        57, 49, 41, 33, 25, 17, 9, 1,
        58, 50, 42, 34, 26, 18, 10, 2,
        59, 51, 43, 35, 27, 19, 11, 3,
        60, 52, 44, 36, 28, 20, 12, 4,
        61, 53, 45, 37, 29, 21, 13, 5,
        62, 54, 46, 38, 30, 22, 14, 6,
        63, 55, 47, 39, 31, 23, 15, 7);
    __m512i m = _mm512_set1_epi64(b);
    m = _mm512_gf2p8affine_epi64_epi8(m, grev_by_index, 0);
    m = _mm512_permutexvar_epi8(tp_flip, m);
    __m512i t512 = _mm512_gf2p8affine_epi64_epi8(m, _mm512_set1_epi64(a), 0);
    t512 = _mm512_gf2p8affine_epi64_epi8(grev_by_index, t512, 0);
    t512 = _mm512_permutexvar_epi8(tp_flip, t512);
    t512 = _mm512_gf2p8affine_epi64_epi8(_mm512_set1_epi64(0x8040201008040201), t512, 0);
    t512 = _mm512_gf2p8affine_epi64_epi8(t512, _mm512_set1_epi64(0xFF), 0);
    return _mm512_movepi8_mask(t512);
}

The end

As far as I know, no one really uses grevmul for anything, so being able to compute it somewhat efficiently (more efficiently than a naive scalar solution at least) is not immediately useful. On the other hand, if an operation is not known to be efficiently computable, that may preclude its use. But the point of this post is more to show something neat.

Originally I had found the sequence of _mm512_gf2p8affine_epi64_epi8(m, _mm512_set1_epi64(a), 0) and _mm512_gf2p8affine_epi64_epi8(grev_by_index, t512, 0) as a solution to grevmul-ing a QWORD by a (constant) byte, using a SAT solver (shout-out to togasat for being easy to add to a project - though admittedly I eventually bit the bullet and switched to MiniSAT). That formed the starting point of this investigation / puzzle-solving session. It may be possible to pull a more complete solution out of the void with a SAT/SMT based technique such as CEGIS, perhaps the bitwise-bilinear nature of grevmul can be exploited (I used the bitwise-linear nature of grevmul-by-a-constant in my SAT-experiments to represent the problem as a composition of matrices over GF(2)).

Almost half of the steps of this algorithm are some kind of transpose, which has also been the case with some other SIMD algorithms that I recently had a hand in. I used to think of a transpose as "not really doing anything", barely worth the notation when doing linear algebra, but maybe I was wrong.


[1] Maybe it's less "the name" and more "what I decided to call it". I'm not aware of any established name for this operation.
[2] This makes it sound more negative than it really is. The transpose-flip often needs to be worked around when we don't want it, but that's not that bad. Having no easy access to a transpose would be much worse to work around when we do need it. Separate bmatxor and bmattranspose instructions would have been nice.
[3] GF2P8AFFINEQB trickery is nice, but when I recently wrote some AVX2 code it was VPERMB that I missed the most.
[4] Can we stop with the alternatives and just pick something?
[5] No patents were read in the development of this algorithm, nor in the writing of this blog post.

Thursday 4 April 2024

Enumerating all mathematical identities (in fixed-size bitvector arithmetic with a restricted set of operations) of a certain size

Once again a boring-but-specific title, I don't want to clickbait the audience after all. Even so, let's get some "disclaimers" out of the way.

  • "All" includes both boring and interesting identities. I tried to remove the most boring ones, so it's no longer truly all identities, but still a lot of boring ones remain. The way I see it, the biggest problem which the approach that I describe in this blog post has, is generating too much "true but boring" junk.
  • This approach is, as far as I know, absolutely limited to fixed-size bitvectors, but that's what I'm interested in anyway. To keep things reasonable, the size should be small, which does result in some differences with eg the arithmetic of 64-bit bitvectors. Most of the results either directly transfer to larger sizes, or generalize to larger sizes.
  • The set of operations is restricted to those that are cheap to implement in CNF SAT, that is not a hard limitation but a practical one.
  • "Of a certain size" means we have to pick in advance the number of operations on both sides of the mathematical identity, and then only identities with exactly that number of operations (counted in the underlying representation, which may represent a seemingly larger expression if the result of an operation is used more than once) are found. This can be repeated for any size we're interested in, but this approach is not very scalable and tends to run out of memory if there are too many identities of the requested size.

The results look something like this. Let's say we want "all" identities that involve 2 variables, 2 operations on the left, and 0 operations (ie only a variable) on the right. The result would be the following, keep in mind that a bunch of redundant and boring identities are filtered out.

(a - (a - b)) == b
(a + (b - a)) == b
((a + b) - a) == b
(b | (a & b)) == b
(a ^ (a ^ b)) == b
(b & (a | b)) == b
// Done in 0s. Used a set of 3 inputs.

Nothing too interesting so far, but then we didn't ask for much. Here are a couple of selected "more interesting" (but not by any means new or unknown) identities that this approach can also enumerate:

((a & b) + (a | b)) == (a + b)
((a | b) - (a & b)) == (a ^ b)
((a ^ b) | (a & b)) == (a | b)
((a & b) ^ (a | b)) == (a ^ b)
((~ a) - (~ b)) == (b - a)
(~ (b + (~ a))) == (a - b)
  

Now that the expectations have been set accurately (hopefully), let's get into what the approach is.

The Approach

The core mechanism I use is CounterExample-Guided Inductive Synthesis (CEGIS) based on a SAT solver. Glucose worked well, other solvers can be used. Rather than asking CEGIS to generate a snippet of code that performs some specific task however, I ask it to generate two snippets that are equivalent. That does not fundementally change how it operates, which is still a loop of:

  1. Synthesize code that does the right thing for each input in the set of inputs to check.
  2. Check whether the code matches the specification. If it does, we're done. If it doesn't, add the counter-example to the set of inputs to check.

Both synthesis and checking could be performed by a SAT solver, but I only use a SAT solver for synthesis. For checking, since 4-bit bitvectors have so few combinations, I just brute force every possible valuation of the variables.

When a pair of equivalent expressions has been found, I add its negation as a single clause to prevent the same thing from being synthesized again. This is what enables pulling out one identity after the other. In my imagination, that looks like Thor smashing his mug and asking for another.

Solving for programs may seem odd, the trick here is to represent a program as a sequence of instructions that are constructed out of boolean variables, the SAT solver is then invoked to solve for those variables.

The code is available on gitlab.

Original motivation

The examples I gave earlier only involve "normal" bitvector arithmetic. Originally what I set out to do is discover what sorts of mathematical identities are true in the context of trapping arithmetic (in which subtraction and addition trap on signed overflow), using the rule that two expressions are equivalent if and only if they have the same behaviour, in the following sense: for all valuations of the variables, the two expressions either yield the same value, or they both trap. That rule is also implemented in the linked source.

Many of the identities found in that context involve a trapping operation that can never actually trap. For example the trapping subtraction (the t-suffix in -t indicates that it is the trapping version of subtraction) in (b & (~ a)) == (b -t (a & b)) cannot trap (boring to prove so I won't bother). But the "infamous" (among who, maybe just me) (-t (-t (-t a))) == (-t a) is also enumerated and the sole remaining negation can trap but does so in exactly the same case as the original three-negations-in-a-row (namely when a is the bitvector with only its sign-bit set). Here is a small selection of nice identities that hold in trapping arithmetic:

((a & b) +t (a | b)) == (a +t b)
((a | b) -t (a & b)) == (a ^ b)
((~ b) -t (~ a)) == (a -t b)
(~ (b +t (~ a))) == (a -t b)
(a -t (a -t b)) == (a - (a -t b))  // note: one of the subtractions is a non-trapping subtraction

Future directions

A large source of boring identities is the fact that if f(x) == g(x), then also f(x) + x == g(x) + x and f(x) & x == g(x) & x and so on, which causes "small" identities to show up again as part of larger ones, without introducing any new information, and multiplied in myriad ways. If there was a good way to prevent them from being enumerated (it would have to be sufficiently easy to state in terms of CNF SAT clauses, to prevent slowing down the solver too much), or to summarize the full output, that could make the output of the enumeration more human-digestible.

There is a part 2 for this post.

Saturday 9 March 2024

The solutions to 𝚙𝚘𝚙𝚌𝚗𝚝(𝚡) < 𝚝𝚣𝚌𝚗𝚝(𝚡) and why there are Fibonacci[n] of them below 2ⁿ

popcnt(x) < tzcnt(x) asks the question "does x have fewer set bits than it has trailing zeroes". It's a simple question with a simple answer, but cute enough to think about on a Sunday morning.[1]

Here are the solutions for 8 bits, in order: 0, 4, 8, 16, 24, 32, 40, 48, 64, 72, 80, 96, 112, 128, 136, 144, 160, 176, 192, 208, 224[2]

In case you find decimal hard to do read (as I do), here they are again in binary: 00000000, 00000100, 00001000, 00010000, 00011000, 00100000, 00101000, 00110000, 01000000, 01001000, 01010000, 01100000, 01110000, 10000000, 10001000, 10010000, 10100000, 10110000, 11000000, 11010000, 11100000

Simply staring at the values doesn't do much for me. To get a better handle on what's going on, let's recursively (de-)construct the set of n-bit solutions.

The most significant bit of an n-bit solution is either 0 or 1:

  1. If it is 0, then that bit affects neither the popcnt nor the tzcnt so removing it must yield an (n-1)-bit solution.
  2. If it is 1, then removing it along with the least significant bit (which must be zero, there are no odd solutions since their tzcnt would be zero) would decrease the both popcnt and the tzcnt by 1, yielding an (n-2)-bit solution.

This "deconstructive" recursion is slightly awkward. The constructive version would be: you can take the (n-1)-bit solutions and prepend a zero to them, and you can take the (n-2)-bit solutions and prepend a one and append a zero to them. However, it is less clear then (to me anyway) that those are the only n-bit solutions. The "deconstructive" version starts with all n-bit solutions and splits them into two obviously-disjoint groups, removing the possibility of solutions getting lost or being counted double.

The F(n) = F(n - 1) + F(n - 2) structure of the number of solutions is clear, but there are different sequences that follow that same recurrence that differ in their base cases. Here we have 1 solution for 1-bit integers (namely zero) and 1 solution for 2-bit integers (also zero), so the base cases are 1 and 1 as in the Fibonacci sequence.

This is probably all useless, and it's barely even bitmath.


[1] Or whenever, but it happens to be a Sunday morning for me right now.
[2] This sequence does not seem to be on the OEIS at the time of writing.

Wednesday 17 January 2024

Partial sums of popcount

The partial sums of popcount, aka A000788: Total number of 1's in binary expansions of 0, ..., n can be computed fairly efficiently with some mysterious code found through its OEIS entry (see the link Fast C++ function for computing a(n)): (reformatted slightly to reduce width)

unsigned A000788(unsigned n)
{
    unsigned v = 0;
    for (unsigned bit = 1; bit <= n; bit <<= 1)
        v += ((n>>1)&~(bit-1)) +
             ((n&bit) ? (n&((bit<<1)-1))-(bit-1) : 0);
    return v;
}

Knowing what we (or I, anyway) know from computing the partial sums of blsi and blsmsk, let's try to improve on that code. "Improve" is a vague goal, let's say we don't want to loop over the bits, but also not just unroll by 64x to do this for a 64-bit integer.

First let's split this thing into the sum of an easy problem and a harder problem, the easy problem being the sum of (n>>1)&~(bit-1) (reminder that ~(bit-1) == -bit, unsigned negation is safe, UB-free, and does exactly what we need, even on hypothetical non-two's-complement hardware). This is the same thing we saw in the partial sum of blsi, bit k of n occurs k times in the sum, which we can evaluate like this:

uint64_t v = 
    ((n & 0xAAAA'AAAA'AAAA'AAAA) >> 1) +
    ((n & 0xCCCC'CCCC'CCCC'CCCC) << 0) +
    ((n & 0xF0F0'F0F0'F0F0'F0F0) << 1) +
    ((n & 0xFF00'FF00'FF00'FF00) << 2) +
    ((n & 0xFFFF'0000'FFFF'0000) << 3) +
    ((n & 0xFFFF'FFFF'0000'0000) << 4);

The harder problem, the contribution from ((n&bit) ? (n&((bit<<1)-1))-(bit-1) : 0), has a similar pattern but more annoying in three ways. Here's an example of the pattern, starting with n in the first row and listing the values being added together below the horizontal line:

00100011000111111100001010101111
--------------------------------
00000000000000000000000000000001
00000000000000000000000000000010
00000000000000000000000000000100
00000000000000000000000000001000
00000000000000000000000000010000
00000000000000000000000000110000
00000000000000000000000010110000
00000000000000000000001010110000
00000000000000000100001010110000
00000000000000001100001010110000
00000000000000011100001010110000
00000000000000111100001010110000
00000000000001111100001010110000
00000000000011111100001010110000
00000000000111111100001010110000
00000001000111111100001010110000
00000011000111111100001010110000
  1. Some anomalous thing happens for the contiguous group of rightmost set bits.
  2. The weights are based not on the column index, but sort of dynamic based on the number of set bits ...
  3. ... to the left of the bit we're looking at. That's significant, "to the right" would have been a lot nicer to deal with.

For problem 1, I'm just going to state without proof that we can add 1 to n and ignore the problem, as long as we add n & ~(n + 1) to the final sum. Problems 2 and 3 are more interesting. If we had problem 2 but counting the bits to the right of the bit we're looking at, that would have nice and easy, instead of (n & 0xAAAA'AAAA'AAAA'AAAA) we would have _pdep_u64(0xAAAA'AAAA'AAAA'AAAA, n), problem solved. If we had a "pdep but from left to right" (aka expand_left) named _pdepl_u64 we could have done this:

uint64_t u = 
    ((_pdepl_u64(0x5555'5555'5555'5555, m) >> shift) << 0) +
    ((_pdepl_u64(0x3333'3333'3333'3333, m) >> shift) << 1) +
    ((_pdepl_u64(0x0F0F'0F0F'0F0F'0F0F, m) >> shift) << 2) +
    ((_pdepl_u64(0x00FF'00FF'00FF'00FF, m) >> shift) << 3) +
    ((_pdepl_u64(0x0000'FFFF'0000'FFFF, m) >> shift) << 4) +
    ((_pdepl_u64(0x0000'0000'FFFF'FFFF, m) >> shift) << 5);

But as far as I know, that requires bit-reversing the inputs (see the update below) of a normal _pdep_u64 and bit-reversing the result, which is not so nice at least on current x64 hardware. Every ISA should have a Generalized Reverse operation like the grevi instruction which used to be in the drafts of the RISC-V Bitmanip Extension prior to version 1.

Update:

It turned out there is a reasonable way to implement _pdepl_u64(v, m) in plain scalar code after all, namely as _pdep_u64(v >> (std::popcount(~m) & 63), m). The & 63 isn't meaningful, it's just to prevent UB at the C++ level.

This approach turned out to be more efficient than the AVX512 approach, so that's obsolete now, but maybe still interesting to borrow ideas from. Here's the scalar implementation in full:

uint64_t _pdepl_u64(uint64_t v, uint64_t m)
{
    return _pdep_u64(v >> (std::popcount(~m) & 63), m);
}

uint64_t partialSumOfPopcnt(uint64_t n)
{
    uint64_t v =
        ((n & 0xAAAA'AAAA'AAAA'AAAA) >> 1) +
        ((n & 0xCCCC'CCCC'CCCC'CCCC) << 0) +
        ((n & 0xF0F0'F0F0'F0F0'F0F0) << 1) +
        ((n & 0xFF00'FF00'FF00'FF00) << 2) +
        ((n & 0xFFFF'0000'FFFF'0000) << 3) +
        ((n & 0xFFFF'FFFF'0000'0000) << 4);
    uint64_t m = n + 1;
    int shift = std::countl_zero(m);
    m = m << shift;
    uint64_t u =
        ((_pdepl_u64(0x5555'5555'5555'5555, m) >> shift) << 0) +
        ((_pdepl_u64(0x3333'3333'3333'3333, m) >> shift) << 1) +
        ((_pdepl_u64(0x0F0F'0F0F'0F0F'0F0F, m) >> shift) << 2) +
        ((_pdepl_u64(0x00FF'00FF'00FF'00FF, m) >> shift) << 3) +
        ((_pdepl_u64(0x0000'FFFF'0000'FFFF, m) >> shift) << 4) +
        ((_pdepl_u64(0x0000'0000'FFFF'FFFF, m) >> shift) << 5);
    return u + (n & ~(n + 1)) + v;
}

Repeatedly calling _pdepl_u64 with the same mask creates some common-subexpressions, they could be manually factored out but compilers do that anyway, even MSVC only uses one actual popcnt instruction (but MSVC, annoyingly, actually performs the meaningless & 63).

Enter AVX512

Using AVX512, we could more easily reverse the bits of a 64-bit integer, there are various ways to do that. But just using that and then going back to scalar pdep would be a waste of a good opportunity to implement the whole thing in AVX512, pdep and all. The trick to doing a pdep in AVX512, if you have several 64-bit integers that you want to pdep with the same mask, is to transpose 8x 64-bit integers into 64x 8-bit integers, use vpexpandb, then transpose back. In this case the first operand of the pdep is a constant, so the first transpose is not necessary. We still have to reverse the mask though. Since vpexpandb takes the mask input in a mask register and we only have one thing to reverse, this trick to bit-permute integers seems like a better fit than Wunk's whole-vector bit-reversal or some variant thereof.

I sort of glossed over the fact that we're supposed to be bit-reversing relative to the most significant set bit in the mask, but that's easy to do by shifting left by std::countl_zero(m) and then doing a normal bit-reverse, so in the end it still comes down to a normal bit-reverse. The result of the pdeps have to be shifted right by the same amount to compensate.

Here's the whole thing: (note that this is less efficient than the updated approach without AVX512)

uint64_t partialSumOfPopcnt(uint64_t n)
{    
    uint64_t v = 
        ((n & 0xAAAA'AAAA'AAAA'AAAA) >> 1) +
        ((n & 0xCCCC'CCCC'CCCC'CCCC) << 0) +
        ((n & 0xF0F0'F0F0'F0F0'F0F0) << 1) +
        ((n & 0xFF00'FF00'FF00'FF00) << 2) +
        ((n & 0xFFFF'0000'FFFF'0000) << 3) +
        ((n & 0xFFFF'FFFF'0000'0000) << 4);
    // 0..63
    __m512i weights = _mm512_setr_epi8(
        0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19, 20, 21, 22, 23,
        24, 25, 26, 27, 28, 29, 30, 31,
        32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47,
        48, 49, 50, 51, 52, 53, 54, 55,
        56, 57, 58, 59, 60, 61, 62, 63);
    // 63..0
    __m512i rev = _mm512_set_epi8(
        0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19, 20, 21, 22, 23,
        24, 25, 26, 27, 28, 29, 30, 31,
        32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47,
        48, 49, 50, 51, 52, 53, 54, 55,
        56, 57, 58, 59, 60, 61, 62, 63);
    uint64_t m = n + 1;
    // bit-reverse the mask to implement expand-left
    int shift = std::countl_zero(m);
    m = (m << shift);
    __mmask64 revm = _mm512_bitshuffle_epi64_mask(_mm512_set1_epi64(m), rev);
    // the reversal of expand-right with reversed inputs is expand-left
    __m512i leftexpanded = _mm512_permutexvar_epi8(rev, 
        _mm512_mask_expand_epi8(_mm512_setzero_si512(), revm, weights));
    // transpose back to 8x 64-bit integers
    leftexpanded = Transpose64x8(leftexpanded);
    // compensate for having shifted m left
    __m512i masks = _mm512_srlv_epi64(leftexpanded, _mm512_set1_epi64(shift));
    // scale and sum results
    __m512i parts = _mm512_sllv_epi64(masks,
        _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0));
    __m256i parts2 = _mm256_add_epi64(
        _mm512_castsi512_si256(parts), 
        _mm512_extracti64x4_epi64(parts, 1));
    __m128i parts3 = _mm_add_epi64(
        _mm256_castsi256_si128(parts2), 
        _mm256_extracti64x2_epi64(parts2, 1));
    uint64_t u = 
        _mm_cvtsi128_si64(parts3) + 
        _mm_extract_epi64(parts3, 1);
    
    return u + (n & ~(n + 1)) + v;
}

As for Transpose64x8, you can find out how to implement that in Permuting bits with GF2P8AFFINEQB.

Saturday 23 September 2023

Permuting bits with GF2P8AFFINEQB

It's no secret that GF2P8AFFINEQB can be tricky to think about, even in the restricted context of bit-permutations. Thinking about more than one step (such as more than one GF2P8AFFINEQB back-to-back, or GF2P8AFFINEQB flanked by byte-wise shuffles) is just too much. Or perhaps you can do it, tell me your secret.

A good way for mere mortals to reason about these kinds of permutations, I think, is to think in terms of the bits of the indices of the bits that are really being permuted. So we're 4 levels deep:

  1. The value whose bits are being permuted.
  2. The bits that are being permuted.
  3. The indices of those bits.
  4. The bits of those indices.

This can get a little confusing because a lot of the time the operation that will be performed on the bits of those indices is a permutation again, but they don't have to be, another classic example is that a rotation corresponds to add/subtracting a constant to the indices. Just keep in mind that we're 4 levels deep the entire time.

Actually we don't need to go deeper.

The building blocks

Assuming we have 512 bits to work with, the indices of those bits are 0..511: 9-bit numbers. We will split that into 3 groups of 3 bits, denoted a,b,c where a locates a QWORD in the 512-bit register, b locates a byte within that QWORD, and c locates a bit within that byte.

Here are some nice building blocks (given fairly arbitrary names):

  • Pf(a,b,c) = a,b,f(c) aka "right GF2P8AFFINEQB", where f is any mapping from a 3-bit integer to a 3-bit integer. This building block can be implemented with _mm512_gf2p8affine_epi64_epi8(input, _mm512_set1_epi64(f_as_a_reversed_matrix), 0)
  • Qf(a,b,c) = a,f(c),~b aka "left GF2P8AFFINEQB", where ~b is a 3-bit inversion, equivalent to 7 - b. f can often be the identity mapping, swapping the second and third groups of bits is useful on its own (the "bonus" inversion can be annoying to deal with). This building block can be implemented with _mm512_gf2p8affine_epi64_epi8(_mm512_set1_epi64(f_as_a_matrix), input, 0)
  • Sg(a,b,c) = g(a,b),c aka Shuffle, where g is any mapping from a 6-bit integer to a 6-bit integer. This building block can be implemented with _mm512_permutexvar_epi8(g_as_an_array, input), but in some cases also with another instruction that you may prefer, depending on the mapping.

S, though it doesn't touch c, is quite powerful. As a couple of special cases that may be of interest, it can be used to swap a and b, invert a or b, or do a combined swap-and-invert.

We could further distinguish:

  • S64f(a,b,c) = f(a),b,c aka VPERMQ. This building block can be implemented with, you guessed it, VPERMQ.
  • S8f(a,b,c) = a,f(b),c aka PSHUFB. This building block can be implemented with, you guessed it, PSHUFB. PSHUFB allows a bit more freedom than is used here, the mapping could be from 4-bit integers to 4-bit integers, but that's not nice to think about in this framework of 3 groups of 3 bits.

Building something with the blocks

Let's say that we want to take a vector of 8 64-bit integers, and transpose it into a vector of 64 8-bit integers such that the k'th bit of the n'th uint64 ends up in the n'th bit of the k'th uint8. In terms of the bits of the indices of the bits (I swear it's not as confusing as it sounds) that means we want to build something that maps a,b,c to b,c,a. It's immediately clear that we need a Q operation at some point, since it's the only way to swap some other groups of bits into the 3rd position. But if we start with a Q, we get ~b in the 3rd position while we need a. We can solve that by starting with an S that swaps a and b while also inverting a (I'm not going to bother defining what that looks like in terms of an index mapping function, just imagine that those functions are whatever they need to be in order to make it work):

Sf(a,b,c) = b,~a,c
Qid(b,~a,c) = b,c,a

Which translates into code like this:

__m512i Transpose8x64(__m512i x)
{
    x = _mm512_permutexvar_epi8(_mm512_setr_epi8(
        56, 48, 40, 32, 24, 16, 8, 0,
        57, 49, 41, 33, 25, 17, 9, 1,
        58, 50, 42, 34, 26, 18, 10, 2,
        59, 51, 43, 35, 27, 19, 11, 3,
        60, 52, 44, 36, 28, 20, 12, 4,
        61, 53, 45, 37, 29, 21, 13, 5,
        62, 54, 46, 38, 30, 22, 14, 6,
        63, 55, 47, 39, 31, 23, 15, 7), x);
    __m512i idmatrix = _mm512_set1_epi64(0x8040201008040201);
    x = _mm512_gf2p8affine_epi64_epi8(idmatrix, x, 0);
    return x;
}

Now let's say that we want to do the inverse of that, going back from b,c,a to a,b,c. Again it's clear that we need a Q, but we have some choice now. We could start by inverting the c in the middle first:

S8f1(b,c,a) = b,~c,a
Qid(b,~c,a) = b,a,c
Sf2(b,a,c) = a,b,c

Which translates into code like this:

__m512i Transpose64x8(__m512i x)
{
    x = _mm512_shuffle_epi8(x, _mm512_setr_epi8(
        7, 6, 5, 4, 3, 2, 1, 0,
        15, 14, 13, 12, 11, 10, 9, 8,
        23, 22, 21, 20, 19, 18, 17, 16,
        31, 30, 29, 28, 27, 26, 25, 24,
        39, 38, 37, 36, 35, 34, 33, 32,
        47, 46, 45, 44, 43, 42, 41, 40,
        55, 54, 53, 52, 51, 50, 49, 48,
        63, 62, 61, 60, 59, 58, 57, 56));
    __m512i idmatrix = _mm512_set1_epi64(0x8040201008040201);
    x = _mm512_gf2p8affine_epi64_epi8(idmatrix, x, 0);
    x = _mm512_permutexvar_epi8(_mm512_setr_epi8(
        0, 8, 16, 24, 32, 40, 48, 56,
        1, 9, 17, 25, 33, 41, 49, 57,
        2, 10, 18, 26, 34, 42, 50, 58,
        3, 11, 19, 27, 35, 43, 51, 59,
        4, 12, 20, 28, 36, 44, 52, 60,
        5, 13, 21, 29, 37, 45, 53, 61,
        6, 14, 22, 30, 38, 46, 54, 62,
        7, 15, 23, 31, 39, 47, 55, 63), x);
    return x;
}

Or we could start with a Q to get the a out of the third position, then use an S to swap the first and second positions and a P to invert c (in any order).

Qid(b,c,a) = b,a,~c
Sf1(b,a,~c) = a,b,~c
Pf2(a,b,~c) = a,b,c

Which translates into code like this:

__m512i Transpose64x8(__m512i x)
{
    __m512i idmatrix = _mm512_set1_epi64(0x8040201008040201);
    x = _mm512_gf2p8affine_epi64_epi8(idmatrix, x, 0);
    x = _mm512_permutexvar_epi8(_mm512_setr_epi8(
        0, 8, 16, 24, 32, 40, 48, 56,
        1, 9, 17, 25, 33, 41, 49, 57,
        2, 10, 18, 26, 34, 42, 50, 58,
        3, 11, 19, 27, 35, 43, 51, 59,
        4, 12, 20, 28, 36, 44, 52, 60,
        5, 13, 21, 29, 37, 45, 53, 61,
        6, 14, 22, 30, 38, 46, 54, 62,
        7, 15, 23, 31, 39, 47, 55, 63), x);
    x = _mm512_gf2p8affine_epi64_epi8(x, idmatrix, 0);
    return x;
}

I will probably keep using a SAT solver to solve the masks (using the same techniques as in (Not) transposing a 16x16 bitmatrix), but now at least I have a proper way to think about the shape of the solution, which makes it a lot easier to ask a SAT solver to fill in the specifics.

This framework could be extended with other bit-permutation operatations such as QWORD rotates, but that quickly becomes tricky to think about.

Sunday 2 July 2023

Propagating bounds through bitwise operations

This post is meant as a replacement/recap of some work that I did over a decade ago on propagating bounds through bitwise operations, which was intended as an improvement over the implementations given in Hacker's Delight chapter 4, Arithmetic Bounds.

The goal is, given two variables x and y, with known bounds a ≤ x ≤ b, c ≤ y ≤ d, compute the bounds of x | y and of x & y. Thanks to De Morgan, we have the equations (most also listed in Hacker's Delight, except the last one)

  • minAND(a, b, c, d) = ~maxOR(~b, ~a, ~d, ~c)
  • maxAND(a, b, c, d) = ~minOR(~b, ~a, ~d, ~c)
  • minXOR(a, b, c, d) = minAND(a, b, ~d, ~c) | minAND(~b, ~a, c, d)
  • maxXOR(a, b, c, d) = maxOR(a, b, c, d) & ~minAND(a, b, c, d)

Everything can be written in terms of only minOR and maxOR and some basic operations.

maxOR

To compute the upper bound of the OR of x and y, what we need to do is find is the leftmost bit (henceforth the "target bit") such that it is both:

  1. set in both b and d (the upper bounds of x and y) and,
  2. changing an upper bound (either one of them, doesn't matter, but never both) by resetting the target bit and setting the bits that are less significant, keeps it greater-or-equal than the corresponding lower bound.

The explanation of why that works can be found in Hacker's Delight, along with a more of less direct transcription into code, but we can do better than a direct transcription.

Finding the leftmost bit that passes only the first condition would be easy, its the highest set bit in b & d. The second condition is a bit more complex to handle, but still surprisingly easy thanks to one simple observation: the bits that can pass it, are precisely those bits that are at (or to the right of) the leftmost bit where the upper and lower bound differ. Imagine two numbers in binary, one being the lower bound and the other the upper bound. The number have some equal prefix (possibly zero bits long, up to all bits) and then if they differ, they must differ by a bit in the upper bound being 1 while the corresponding bit in the lower bound is 0. Lowering the upper bound by resetting that bit while setting all bits the right of it, cannot make it lower than the lower bound.

For one of the inputs, say x, the position at which that second condition start being false (looking at that bit and to the left of it) can be computed directly with 64 - lzcnt(a ^ b). We actually need the maximum of that across both pairs of bounds, but there's no need to compute that for both bounds and then take the maximum, we can use this to let the lzcnt find the maximum automatically: 64 - lzcnt((a ^ b) | (c ^ d)).

bzhi(m, k) is an operation that resets the bits in m starting at index k. It can be emulated by shifting or masking, but an advantage of bzhi is that it is well defined for any relevant k, including when k is equal to the size of the integer in bits. bzhi is not strictly required here, but it is more convenient than "classic" bitwise operations, and available on most x64 processors today[1]. Using bzhi, it's simple to take the position calculated in the previous paragraph and reset all the bits in b & d that do not pass the second condition: bzhi(b & d, 64 - lzcnt((a ^ b) | (c ^ d))).

With that bitmask in hand, all we need to do is apply it to one of the upper bounds. We can skip the "reset the target bit" part, since that bit will be set in the other upper bound and therefore also in the result. It also does not matter which upper bound is changed, regardless of which bound we were conceptually changing. Let's pick b for no particular reason. Then in total, the implementation could be:

uint64_t maxOR(uint64_t a, uint64_t b, uint64_t c, uint64_t d)
{
    uint64_t index = 64 - _lzcnt_u64((a ^ b) | (c ^ d));
    uint64_t candidates = _bzhi_u64(b & d, index);
    if (candidates) {
        uint64_t target = highestSetBit(candidates);
        b |= target - 1;
    }
    return b | d;
}

For the highestSetBit function you can choose any way you like to isolate the highest set bit in an integer.

minOR

Computing the lower bound of x | y surprisingly seems to be more complex. The basic principles are similar, but this time bits are being reset in one of the lower bounds, and it does matter in which lower bound that happens. The computation of the mask of candidate bits also "splits" into separate candidates for each lower bound, unless there's some trick that I've missed. This whole "splitting" thing cannot be avoided by defining minOR in terms of maxAND either, because the same things happen there. But it's not too bad, a little bit of extra arithmetic. Anyway, let's see some code.

uint64_t minOR(uint64_t a, uint64_t b, uint64_t c, uint64_t d)
{
    uint64_t candidatesa = _bzhi_u64(~a & c, 64 - _lzcnt_u64(a ^ b));
    uint64_t candidatesc = _bzhi_u64(a & ~c, 64 - _lzcnt_u64(c ^ d));
    uint64_t target = highestSetBit(candidatesa | candidatesc);
    if (a & target) {
        c &= -target;
    }
    if (c & target) {
        a &= -target;
    }
    return a | c;
}

A Fun Fact here is that the target bit cannot be set in both bounds, opposite to what happens in maxOR where the target bit is always set in both bounds. You may be tempted to turn the second if into else if, but in my tests it was quite important that the ifs are compiled into conditional moves rather than branches (which of the lower bounds the target bit is found in was essentially random), and using else if here apparently discourages compilers (MSVC at least) from using conditional moves.

candidatesa | candidatesc can be zero, although that is very rare, at least in my usage of the function. As written, the code assumes that highestSetBit deals with that gracefully by returning zero if its input is zero. Branching here is (unlike in the two ifs at the end of minOR) not a big deal since this case is so rare (and therefore predictable).

Conclusion

In casually benchmarking these functions, I found them to be a bit faster than the ones that I came up with over a decade ago, and significantly faster than the ones from Hacker's Delight. That basic conclusion probably translates to different scenarios, but the exact ratios will vary a lot based on how predictable the branches are in that case, on your CPU, and on arbitrary codegen decisions made by your compiler.

In any case these new versions look nicer to me.

There are probably much simpler solutions if the bounds were stored in bit-reversed form, but that doesn't seem convenient.

Someone on a certain link aggregation site asked about signed integers. As Hacker's Delight explains via a table, things can go wrong if one (or both) bounds cross the negative/positive boundary - but the solution in those cases is still easy to compute. The way I see it, the basic problem is that a signed bound that crosses the negative/positive boundary effectively encodes two different unsigned intervals, one starting at zero and one ending at the greatest unsigned integer, and the basic unsigned minOR and so on cannot (by themselves) handle those "split" intervals.


[1] Sadly not all, some low-end Intel processors have AVX disabled, which apparently is done by disabling the entire VEX encoding and it takes out BMI2 as collateral damage.

Monday 12 June 2023

Some ways to check whether an unsigned sum wraps

When computing x + y, does the sum wrap? There are various ways to find out, some of them well known, some less. Some of these are probably totally unknown, in some cases deservedly so.

This is not meant to be an exhaustive list.

  • ~x < y

    A cute trick in case you don't want to compute the sum, for whatever reason.

    Basically a variation of how MISRA C recommends checking for wrapping if you use the precondition test since ~x = UINT_MAX - x.

  • (x + y) < x

    (x + y) < y

    The Classic™. Useful for being cheap to compute: x + y is often needed anyway, in which case this effectively only costs a comparison. Also recommended by MISRA C, in case you want to express the "has the addition wrapped"-test as a postcondition test.

  • avg_down(x, y) <s 0

    avg_down(x, y) & 0x80000000 // adjust to integer size

    <s is signed-less-than. Performed on unsigned integers here, so be it.

    avg_down is the unsigned average rounded down. avg_up is the unsigned average rounded up.

    Since avg_down is the sum (the full sum, without wrapping) shifted right by 1, what would have been the carry out of the top of the sum becomes the top bit of the average. So, checking the top bit of avg_down(x, y) is equivalent to checking the carry out of x + y.

    Can be converted into avg_up(~x, ~y) >=s 0 through the equivalence avg_down(x, y) = ~avg_up(~x, ~y).

  • (x + y) < min(x, y)

    (x + y) < max(x, y)

    (x + y) < avg(x, y)

    (x + y) < (x | y)

    (x + y) < (x & y)

    ~(x | y) < (x & y)

    Variants of The Classic™. They all work for essentially the same reason: addition is commutative, including at the bit-level. So if we have (x + y) < x, then we also have (x + y) < y, and together they imply that instead of putting x or y on the right hand side of the comparison, we could arbitrarily select one of them, or anything between them too. Bit-level commutativity takes care of the bottom three variants in a similar way.

    Wait, is that a signed or unsigned min? Does avg round up or down or either way depending on the phase of the moon? It doesn't matter, all of those variants work and more.

  • (x + y) != addus(x, y)

    (x + y) < addus(x, y)

    addus is addition with unsigned saturation, meaning that instead of wrapping the result would be UINT_MAX.

    When are normal addition and addition with unsigned saturation different? Precisely when one wraps and the other saturates. Wrapping addition cannot "wrap all the way back" to UINT_MAX, the highest result when the addition wraps is UINT_MAX + UINT_MAX = UINT_MAX - 1.

    When the normal sum and saturating sum are different, the normal sum must be the smaller of the two (it certainly couldn't be greater than UINT_MAX), hence the second variant.

  • subus(y, ~x) != 0

    subus(x, ~y) != 0

    addus(~x, ~y) != UINT_MAX

    subus is subtraction with unsigned saturation.

    Strange variants of ~x < y. Since subus(a, b) will be zero when a <= b, it will be non-zero when b < a, therefore subus(y, ~x) != 0 is equivalent to ~x < y.

    subus(a, b) = ~addus(~a, b) lets us turn the subus variant into the addus variant.

  • (x + y) < subus(y, ~x)

    Looks like a cursed hybrid of (x + y) < avg(x, y) and subus(y, ~x) != 0, but the mechanism is (at least the way I see it) different from both of them.

    subus(y, ~x) will be zero when ~x >= y, which is exactly when the sum x + y would not wrap. x + y certainly cannot be unsigned-less-than zero, so overall the condition (x + y) < subus(y, ~x) must be false (which is good, it's supposed to be false when x + y would not wrap).

    In the other case, when ~x < y, we know that x + y will wrap and subus(y, ~x) won't be zero (and therefore cannot saturate). Perhaps there is a nicer way to show what happens, but at least under those conditions (predictable wrapping and no saturation) it is easy to do algebra:

    • (x + y) < subus(y, ~x)
    • x + y - 2k < y - (2k - 1 - x)
    • x + y - 2k < y - 2k + 1 + x
    • x + y < y + 1 + x
    • 0 < 1

    So the overall condition (x + y) < subus(y, ~x) is true IFF x + y wraps.

  • ~x < avg_up(~x, y)

    Similar to ~x < y, but stranger. Averaging y with ~x cannot take a low y to above ~x, nor a high y to below ~x. The direction of rounding is important: avg_down(~x, y) could take an y that's just one higher than ~x down to ~x itself, making it no longer higher than ~x. avg_up(~x, y) cannot do that thanks to rounding up.

Monday 22 May 2023

grevmul

grev (generalized bit-reverse) is an operation that implements bit-permutations corresponding to XOR-ing the indices by some value. It has been proposed to be part of the Zbp extension of RISC-V, with this reference implementation (source: release v0.93)

uint32_t grev32(uint32_t rs1, uint32_t rs2)
{
    uint32_t x = rs1;
    int shamt = rs2 & 31;
    if (shamt &  1) x = ((x & 0x55555555) <<  1) | ((x & 0xAAAAAAAA) >>  1);
    if (shamt &  2) x = ((x & 0x33333333) <<  2) | ((x & 0xCCCCCCCC) >>  2);
    if (shamt &  4) x = ((x & 0x0F0F0F0F) <<  4) | ((x & 0xF0F0F0F0) >>  4);
    if (shamt &  8) x = ((x & 0x00FF00FF) <<  8) | ((x & 0xFF00FF00) >>  8);
    if (shamt & 16) x = ((x & 0x0000FFFF) << 16) | ((x & 0xFFFF0000) >> 16);
    return x;
}

grev looks in some ways similar to bit-shifts and rotates: the left and right operands have distinct roles with the right operand being a mask of k bits if the left operand has 2k bits[1].

Carry-less multiplication normally has a left-shift in it, grevmul is what you get when that left-shift is replaced with grev.

uint32_t grevmul32(uint32_t x, uint32_t y)
{
    uint32_t r = 0;
    for (int k = 0; k < 32; k++) {
        if (y & (1 << k))
            r ^= grev32(x, k);
    }
    return x;
}

grevmul is, at its core, very similar to clmul: take single-bit products (logical AND) of every bit of the left operand with every bit of the right operand, then do some XOR-reduction. The difference is in which partial products are grouped together. For clmul, the partial products that contribute to bit k of the result are pairs with indices i,j such that i + j = k. For grevmul, it's the pairs with indices such that i ^ j = k. This goes back to grev permuting the bits by XOR-ing their indices by some value, and that value is k here.

Now that grevmul has been defined, let's look at some of its properties, comparing it to clmul and plain old imul.

grevmul clmul imul
zero[2] 0 0 0
identity 1 1 1
commutative yes yes yes
associative yes yes yes
distributes over xor xor addition
op(x, 1 << k) is grev(x, k) x << k x << k
x has inverse if
popcnt(x) & 1 x & 1 x & 1
op(x, x) is popcnt(x) & 1 pdep(x, 0x55555555)

What is the "grevmul inverse" of x?

Time for some algebra. Looking just at the table above, and forgetting the actual definition of grevmul, can we say something about the solutions of grevmul(x, y) == 1? Surprisingly, yes.

Assuming we have some x with odd hamming weight (numbers with even hamming weight do not have inverses, so let's ignore them for now), we know that grevmul(x, x) == 1. The inverse in a monoid is unique so x is not just some inverse of x, it is the (unique) inverse of x.

Since the "addition operator" is XOR (for which negation is the identity function), this is a non-trivial example of a ring in which x = -x = x-1, when x-1 exists. Strange, isn't it?

We also have that f(x) = grevmul(x, c) (for appropriate choices of c) is a (non-trivial) involution, so it may be a contenter for the "middle operation" of an involutary bit finalizer, but probably useless without an efficient implementation.

I was going to write about implementing grevmul by an 8-bit constant with two GF2P8AFFINEQBs but I've had enough for now, maybe later. E: see Implementing grevmul with GF2P8AFFINEQB where I went ahead and implemented the whole thing, not only the "multiply by 8-bit constant" case.


[1] The right operand of a shift is often called the shift count, but it can also be interpreted as a mask indicating some subset of shift-by-2i operations to perform. That interpretation is useful for example when implementing a shift-by-variable operation on a machine that only has a shift-by-constant instruction, following the same pattern as the reference implementation of grev32.

[2] This looks like a joke, but I mean that the numeric value 0 acts as the zero element of the corresponding semigroup.

Wednesday 12 April 2023

(Not) transposing a 16x16 bitmatrix

Inverting a 16-element permutation may done like this:

for (int i = 0; i < 16; i++)
    inv[perm[i]] = i;

Computing a histogram of 16 nibbles may done like this:

for (int i = 0; i < 16; i++)
    hist[data[i]] += 1;

These different-sounding but already similar-looking tasks have something in common: they can be both be built around a 16x16 bitmatrix transpose. That sounds silly, why would anyone want to first construct a 16x16 bitmatrix, transpose it, and then do yet more processing to turn the resulting bitmatrix back into an array of numbers?

Because it turns out to be an efficiently-implementable operation, on some modern processors anyway.

If you know anything about the off-label application of GF2P8AFFINEQB, you may already suspect that it will be involved somehow (merely left-GF2P8AFFINEQB-ing by the identity matrix already results in some sort of 8x8 transpose, just horizontally mirrored), and it will be, but that's not the whole story.

First I will show not only how to do it with GF2P8AFFINEQB, but also how to find that solution programmatically using a SAT solver. There is nothing that fundamentally prevents a human from finding a solution by hand, but it seems difficult. Using a SAT solver to find a solution ex nihilo (requiring it to find both a sequence of instructions and their operands) is not that easy either (though that technique also exists). Thankfully, Geoff Langdale suggested a promising sequence of instructions:

The problem we have now (which the SAT solver will solve) is, under the constraint that for all X, f(X) = PERMB(GF2P8AFFINE(B, PERMB(X, A)), C) computes the transpose of X, what is a possible valuation of the variables A, B, C. Note that the variables in the SAT problem correspond to constants in the resulting code, and the variable in the resulting code (X) is quantified out of the problem.

If you know a bit about SAT solving, that "for all X" sounds like trouble, requiring either creating a set of constraints for every possible value of X (henceforth, concrete values of X will be known as "examples"), or some advanced technique such as CEGIS to dynamically discover a smaller set of examples to base the constraints on. Luckily, since we are dealing with a bit-permutation, there are simple and small sets of examples that together sufficiently constrain the problem. For a 16-bit permutation, this set of values could be used:

  • 1010101010101010
  • 1100110011001100
  • 1111000011110000
  • 1111111100000000

For a 256-bit permutation, a similar pattern can be used, where each of the examples has 256 bits and there would be 8 of them. Note that if you read the columns of the values, they list out the indices of the corresponding columns, which is no coincidence. Using that set of examples to constrain the problem with, essentially means that we assert that f when applied to the sequence 0..n-1 must result in the desired permutation. The way that I actually implemented this puts a column into one "abstract bit", so that it represents the index of the bit all in one place instead of spread out.

Implementing a "left GF2P8AFFINEQB" (multiplying a constant matrix on the left by a variable matrix on the right) in CNF, operating on "abstract bits" (8 variables each), is relatively straight forward. Every (abstract) bit of the result is the XOR of the AND of some (abstract) bits, writing that down is mostly a chore, but there is one interesting aspect: the XOR can be turned into an OR, since we know that we're multiplying by a permutation matrix. In CNF, OR is simpler than XOR, and easier for the solver to reason through.

VPERMB is more difficult to implement, given that the permutation operand is a variable (if it was a constant, we could just permute the abstract bits without generating any new constraints). To make it easier, I represent the permutation operand as a 32x32 permutation matrix, letting me create a bunch of simple ternary constraints of the form (¬P(i, j) ∨ ¬A(j) ∨ R(i)) ∧ (¬P(i, j) ∨ A(j) ∨ ¬R(i)) (read: if P(i, j), then A(j) must be equal to R(i)). The same thing can be used to implement VPSHUFB, with additional constraints on the permutation matrix (to prevent cross-slice movement).

Running that code, at least on my PC at this time[1], results in (with some whitespace manually added):

__m256i t0 = _mm256_permutexvar_epi8(_mm256_setr_epi8(
    14, 12, 10, 8, 6, 4, 2, 0,
    30, 28, 26, 24, 22, 20, 18, 16,
    15, 13, 11, 9, 7, 5, 3, 1,
    31, 29, 27, 25, 23, 21, 19, 17), input);
__m256i t1 = _mm256_gf2p8affine_epi64_epi8(_mm256_set1_epi64x(0x1080084004200201), t0, 0);
__m256i t2 = _mm256_shuffle_epi8(t1, _mm256_setr_epi8(
    0, 8, 1, 9, 3, 11, 5, 13,
    7, 15, 2, 10, 4, 12, 6, 14,
    0, 8, 1, 9, 3, 11, 5, 13,
    7, 15, 2, 10, 4, 12, 6, 14));

So that's it. That's the answer[2]. If you want to transpose a 16x16 bitmatrix, on a modern PC (this code requires AVX512_VBMI and AVX512_GFNI[3]), it's fairly easy and cheap, it's just not so easy to find this solution to begin with.

Using this transpose to invert a 16-element permutation is pretty easy, for example using _mm256_sllv_epi16 to construct the matrix and _mm256_popcnt_epi16(_mm256_sub_epi16(t2, _mm256_set1_epi16(1))) (sadly there is no SIMD version of TZCNT .. yet) to convert the bit-masks back into indices. It may be tempting to try to use a mirrored matrix and leading-zero count, which AVX512 does offer, but it only offers the DWORD and QWORD versions VPLZCNTD/Q.

Making a histogram is even simpler, using only _mm256_popcnt_epi16(t2) to convert the matrix into counts.

And for my next trick, I will now not transpose the matrix

What if we didn't transpose that matrix. Does that even make sense? Well, at least for the two applications that I focused on, what we really need is not so much the transpose of the matrix, but any matrix such that:

  1. Every bit of the original matrix occurs exactly once in the result.
  2. Each row of the result contains all bits from a particular column.
  3. The permutation within each row is "regular" enough that we can work with it. We don't need this when making a histogram (as Geoff already noted in one of his tweets).

There is no particular requirement on the order of the rows, any row-permutation we end up with is easy to undo.

The first two constraints leave plenty of options open, but the last constraint is quite vague. Too vague for me to do something such as searching for the best not-quite-transpose, so I don't promise to have found it. But here is a solution: rotate every row by its index, then rotate every column by its index.

At least, that's the starting point. Rotating the columns requires 3 rounds of blending a vector with cross-slice-permuted copy of that vector, and a VPERMQ sandwiched by two VPSHUFBs to rotate the last 8 columns by 8. That's a lot of cross-slice permuting, most of it can be avoided by modifying the overall permutation slightly:

  1. Exchange the off-diagonal quadrants.
  2. Rotate each row by its index.
  3. For each quadrant individually, rotate each column by its index.

Here is some attempt at illustrating that process, feel free to skip past it

These three steps are implementable in AVX2:

  1. Exchanging the off-diagonal quadrants can be done by gathering the quadrants into QWORDs, permuting them, and shuffling the QWORDs back into quadrants.
  2. Rotating the rows can be done with VPMULLW (used as a variable shift-left), VPMULHUW (used as a variable shift-right), and VPOR.
  3. Rotating the columns can be done by conditionally rotating the columns with odd indices by 1, conditionally rotating the columns that have the second bit of their index set by 2, and conditionally rotating the columns that have the third bit of their index set by 4. The rotations can be done using VPALIGNR[4], the conditionality can be implemented with blending, but since this needs to be bit-granular blend, it cannot be performed using VPBLENDVB.

In total, here is how I don't transpose a 16x16 matrix with AVX2, hopefully there is a better way:

__m256i nottranspose16x16(__m256i x)
{
    // exchange off-diagonal quadrants
    x = _mm256_shuffle_epi8(x, _mm256_setr_epi8(
        0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15,
        0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15));
    x = _mm256_permute4x64_epi64(x, _MM_SHUFFLE(3, 1, 2, 0));
    x = _mm256_shuffle_epi8(x, _mm256_setr_epi8(
        0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15,
        0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15));
    // rotate every row by its y coordinate
    __m256i shifts = _mm256_setr_epi16(
        1 << 0, 1 << 1, 1 << 2, 1 << 3,
        1 << 4, 1 << 5, 1 << 6, 1 << 7,
        1 << 8, 1 << 9, 1 << 10, 1 << 11,
        1 << 12, 1 << 13, 1 << 14, 1 << 15);
    __m256i sll = _mm256_mullo_epi16(x, shifts);
    __m256i srl = _mm256_mulhi_epu16(x, shifts);
    x = _mm256_or_si256(sll, srl);
    // within each quadrant independently, 
    // rotate every column by its x coordinate
    __m256i x0, x1, m;
    // rotate by 4
    m = _mm256_set1_epi8(0x0F);
    x0 = _mm256_and_si256(x, m);
    x1 = _mm256_andnot_si256(m, _mm256_alignr_epi8(x, x, 8));
    x = _mm256_or_si256(x0, x1);
    // rotate by 2
    m = _mm256_set1_epi8(0x33);
    x0 = _mm256_and_si256(x, m);
    x1 = _mm256_andnot_si256(m, _mm256_alignr_epi8(x, x, 4));
    x = _mm256_or_si256(x0, x1);
    // rotate by 1
    m = _mm256_set1_epi8(0x55);
    x0 = _mm256_and_si256(x, m);
    x1 = _mm256_andnot_si256(m, _mm256_alignr_epi8(x, x, 2));
    x = _mm256_or_si256(x0, x1);
    return x;
}

Using that not-transpose to invert a 16-element permutation takes some extra steps that, without AVX512, are about as annoying as not-transposing the matrix was.

  • Constructing the matrix is more difficult. AVX2 has shift-by-variable, but not for 16-bit element.[5] There are various work-arounds, such as using DWORDs and then narrowing, of course (boring). Another (funnier) option is to duplicate every byte, add 0xF878 to every word, then use VPSHUFB in lookup-table-mode to index into a table of powers of two. Having added 0x78 to every low byte of every word, that byte will mapped to zero if it was 8 or higher, or otherwise two to the power of that byte. The high byte, having 0xF8 added to it, will be mapped to 0 if it was below 8, or otherwise to two to the power of that byte minus 8. As wild as that sounds, it is pretty fast, costing only 5 cheap instructions (whereas widening to DWORDs, shifting, and narrowing, would be worse than it sounds). Perhaps there is a better way.
  • Converting masks back into indices is more difficult due to the lack of trailing zero count, leading zero count, or even popcount. What AVX2 does have, is .. VPSHUFB again. We can multiply by an order-4 de Bruijn sequence and use VPSHUFB to map the results to the indices of the set bits.
  • Then we have indices, but since the rows and columns were somewhat arbitrarily permuted, they must still be mapped back into something that makes sense. Fortunately that's no big deal, a modular subtraction (or addition, same thing really) cancels out the row-rotations, and yet another VPSHUFB cancels out the strange order that the rows are in. Fun detail: the constants that are subtracted and the permutation are both 0, 7, 6, 5, 4, 3, 2, 1, 8, 15, 14, 13, 12, 11, 10, 9.

All put together:

void invert_permutation_avx2(uint8_t *p, uint8_t *inv)
{
    __m256i v = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)p));
    // indexes to masks
    v = _mm256_or_si256(v, _mm256_slli_epi64(v, 8));
    v = _mm256_add_epi8(v, _mm256_set1_epi16(0xF878));
    __m256i m = _mm256_shuffle_epi8(_mm256_setr_epi8(
        1, 2, 4, 8, 16, 32, 64, 128,
        1, 2, 4, 8, 16, 32, 64, 128,
        1, 2, 4, 8, 16, 32, 64, 128,
        1, 2, 4, 8, 16, 32, 64, 128), v);
    // ???
    m = nottranspose16x16(m);
    // masks to indexes
    __m256i deBruijn = _mm256_and_si256(_mm256_mulhi_epu16(m, _mm256_set1_epi16(0x9AF0)), _mm256_set1_epi16(0x000F));
    __m128i i = _mm_packs_epi16(_mm256_castsi256_si128(deBruijn), _mm256_extracti128_si256(deBruijn, 1));
    i = _mm_shuffle_epi8(_mm_setr_epi8(
        0, 1, 2, 5, 3, 9, 6, 11, 15, 4, 8, 10, 14, 7, 13, 12), i);
    // un-mess-up the indexes
    i = _mm_sub_epi8(i, _mm_setr_epi8(0, 7, 6, 5, 4, 3, 2, 1, 8, 15, 14, 13, 12, 11, 10, 9));
    i = _mm_and_si128(i, _mm_set1_epi8(0x0F));
    i = _mm_shuffle_epi8(i, _mm_setr_epi8(0, 7, 6, 5, 4, 3, 2, 1, 8, 15, 14, 13, 12, 11, 10, 9));
    _mm_storeu_si128((__m128i*)inv, i);
}

To make a histogram, emulate VPOPCNTW using, you guessed it, PSHUFB.

The end

This post is, I think, one of the many examples of how AVX512 can be an enormous improvement compared to AVX2 even when not using 512-bit vectors. Every step of every problem had a simple solution in AVX512 (even if it was not always easy to find it). With AVX2, everything felt "only barely possible".

"As complicated as it is, is this actually faster than scalar code?" Yes actually, but feel free to benchmark it yourself. The AVX2 version being somewhat more efficient than scalar code is not really the point of this post anyway. The AVX512 version is nice and efficient, I'm showing an AVX2 version mostly to show how hard it is to create it.[6]

Transposing larger matrices with AVX512 can be done by first doing some quadrant-swapping (also used at the start of the not-transpose) until the bits that need to end up together in one 512-bit block are all in there, and then a VPERMB, VGF2P8AFFINEQB, VPERMB sequence with the right constants (which can be found using the techniques that I described) can put the bits in their final positions. But well, I already did that, so there you go.

A proper transpose can be done in AVX2 of course, for example using 4 rounds of quadrant-swapping. Implementations of that already exist so I thought that would be boring to talk about, but there is an interesting aspect to that technique that is often not mentioned: every round of quadrant-swapping can be seen as exchanging two bits of the indices. Swapping the big 8x8 quadrants swaps bits 3 and 7 of the indices, transposing the 2x2 submatrices swaps bits 0 and 4 of the indices. From that point of view, it's easy to see that the order in which the four steps are performed does not matter - no matter the order, the lower nibble of the index is swapped with the higher nibble of the index.


[1] While MiniSAT (which this program uses as its SAT solver) is a "deterministic solver" in the sense of definitely finding a satifying valuation if there is one, it is not deterministic in the sense of guaranteeing that the same satisfying valuation is found every time the solver is run on the same input.

[2] Not the unique answer, there are multiple solutions.

[3] But not 512-bit vectors.

[4] Nice! It's not common to see a 256-bit VPALIGNR being useful, due to it not being the natural widening of 128-bit PALIGNR, but acting more like two PALIGNRs side-by-side (with the same shifting distance).

[5] Intel, why do you keep doing this.

[6] Also as an excuse to use PSHUFB for everything.