Tuesday 28 May 2024

Implementing grevmul with GF2P8AFFINEQB

As a reminder, grev is generalized bit-reversal, and performs a bit-permutation that corresponds to XORing the indices of the bits of the left operand by the number given in the right operand. A possible (not very suitable for actual use, but illustrative) implementation of grev is:

def grev(bits, k):
    return bits[np.arange(len(bits)) ^ k]

grevmul (see also this older post where I list some of its properties) can be defined in terms of grev (hence the name[1]), with grev replacing the left shift in a carryless multiplication. But let's do something else first. An interesting way to look at multiplication (plain old multiplication between natural numbers) is as:

  1. Form the Cartesian product of the inputs, viewed as arrays of bits.
  2. For each pair in the Cartesian product, compute the AND of the two bits.
  3. Send the AND of each pair with index (i, j) to the bin i + j, to be accumulated with the function (+).
  4. Resolve the carries.

Carryless multiplication works mostly the same way except that accumulation is done with XOR, which makes step 4 unnecessary. grevmul also works mostly the same way, with accumulation also done with XOR, but now the pair with index (i, j) is sent to the bin i XOR j.

+ ^
+ imul clmul
^ ??? grevmul

That won't be important for the implementation, but it may help you think about what grevmul does, and this will be on the test. OK there is no test, but you can test yourself by explaining why (popcnt(a & b) & 1) == (grevmul(a, b) & 1), based on reasoning about the Cartesian-product-algorithm for grevmul.

Implementing grevmul with GF2P8AFFINEQB

Some of you may have already seen a version of the code that I am about to discuss, although I have made some changes since. Nothing serious, but while thinking about how the code worked, I encountered some opportunities to make polish it up.

GF2P8AFFINEQB computes, for each QWORD (so for now, let's concentrate on one QWORD of the result), the product of two bit-matrices, where the left matrix comes from the first input and the right matrix is the transpose-flip (or flip-transpose, however you prefer to think about it) of the second input. You can think of it as a mutant of the mxor operation found in Knuth's MMIX instruction set, and from here on I will use the name bmatxor for the version of this operation that simply multiplies two 8x8 bit-matrices, with none of that transpose-flip business[2]. There is also a free XOR by a constant byte thrown in at the end, which can be safely ignored by setting it to zero. The transpose-flip however need to be worked around or otherwise taken into account. You may also want to read (Ab)using gf2p8affineqb to turn indices into bits first for some extra familiarity with / alternate view of GF2P8AFFINEQB.

To implement grevmul(a, b), we need some linear combination (controlled by b) of permuted (by grev) versions of a (the roles of a and b can be swapped since grevmul is commutative, which the Cartesian product-based algorithm makes clear). GF2P8AFFINEQB is all about linear combinations, but it works with 8x8 matrices, not 64x64 (put that in AVX-4096) which would have made it really simple. Fortunately, we can slice a grevmul however we want.

Now to start in the middle, let's say we have 8 copies of a byte (a byte of b), with the i'th copy grev'ed by i, concatenated into a QWORD that I will call m. bmatxor(a, m) would (thinking of a matrix multiplication AB as forming linear combinations of rows of B), for each row of the result, form a linear combination (controlled by a) of grev'ed copies of the byte from b. That may seem like it would be wrong, since every byte of a is done separately and uses the same m, so it's "missing" a grev by 8 for the second byte, 16 for the third byte, etc. But if x is a byte (not if and only if, just regular "if"), then grev(x, 8 * i) is the same as x << (8 * i) and the second byte is indeed already in the second position anyway, so we get this for free. Thus, mxor(a, m) would allow us to grevmul a 64-bit number by an 8-bit number. If we could just do that 8 times (for each byte of b) and combine the results, we're done.

But we don't have bmatxor, we have GF2P8AFFINEQB with its built-in transpose-flip, and that presents a choice: either put another GF2P8AFFINEQB before, or after, the "main" GF2P8AFFINEQB to counter that built-in transpose. Not the whole transpose-flip, let's put the flip aside for now. There is a small reason to favour the "extra GF2P8AFFINEQB after the main one" order, namely that that results in GF2P8AFFINEQB(m, broadcast(a)) (as opposed to GF2P8AFFINEQB(broadcast(a), m)) and when a comes from memory it can be loaded and broadcasted directly with a {1to8} broadcasted memory operand. That option would not be available if a was the first operand of the GF2P8AFFINEQB. This is a small matter, but we have to make the choice somehow, and there seems to be no difference aside from this.

At this point there are two (and a half) pieces of the puzzle left: forming m, and horizontally combining 8 results.

Forming m

If we would first form 8 copies of a byte of b and then try to grev those by their respective indices, that would be hard. But doing it the other way around is easy, broadcast b into each QWORD of a 512-bit vector, grev each QWORD by its index, then transpose the whole vector as an 8x8 matrix of bytes. Actually for constant-reuse (loading fewer humongous vector constants is rarely bad) and because of the built-in transpose-flip it turns out slightly better to transpose-flip that 8x8 matrix of bytes as well, that doesn't cost any more than just transposing it.

grev-ing each QWORD by its index is easy, perhaps easier than it sounds. A grev by 0..7 only rearranges the bits within each byte, which is easy to do with a single GF2P8AFFINEQB with a constant as the second operand (a "P" step).

An 8x8 transpose-flip is just a VPERMB by some index vector.[3]

Combining the results

After the "middle" step, if we went with the "extra GF2P8AFFINEQB before the main GF2P8AFFINEQB"-route, we would have bmatxor(a, m) in each QWORD (with a different m per QWORD), which need to be combined. If we were implementing a plain old integer multiplication, the value in the QWORD with the index i would be shifted left by 8 * i and then all of them would be summed up. Since we're implementing a grevmul, that value is grev'ed by 8 * i (which is just some byte permutation) and the resulting QWORDs are XORed.

If we go with the "extra GF2P8AFFINEQB after the main GF2P8AFFINEQB"-route, which I chose, then there is a GF2P8AFFINEQB to do before we can start combining QWORDs. We really only need it to un-transpose the result of the "main" GF2P8AFFINEQB, the rest is just a byte-permutation and we're about to do a VPERMB anyway (if we choose the cool way of XORing the QWORDs together), but there is a neat opportunity here: if we re-use the same set of bit-matrices that was used to grev b by 0..7, in addition to the transpose that we wanted we also permute the bytes such that we grev each QWORD by 8 * i as a bonus.

Now we could just XOR-fold the vector 3 times to end up with the XOR of all eight QWORDs, and that would be a totally valid implementation, but it would also be boring. Alternatively, if we transpose the vector as an 8x8 matrix of bytes again (it can be a transpose-flip, so we get to reuse the same index vector as before), then the i'th byte of each QWORD would be gathered in the i'th QWORD of the transpose and bmatxor(0xFF, vector) would XOR together all bytes with the same index and give us a vector that has one byte of the final result per QWORD, easily extractable with VPMOVQB. We still don't have bmatxor though, we have bmatxor with an extra transpose-flip, which can be countered the usual way, with yet another GF2P8AFFINEQB.

As yet another alternative[4], after similar transpose trickery bmatxor(vector, 0xFF) would also XOR together bytes with the same index but leave the result in a form that can be extracted with VPMOVB2M, which puts the result in a mask register but it's not too bad to move it from there to a GPR.

The code

In case anyone makes it this far, here is one possible embodiment[5] of the algorithm described herein:

uint64_t grevmul_avx512(uint64_t a, uint64_t b)
    uint64_t id = 0x0102040810204080;
    __m512i grev_by_index = _mm512_setr_epi64(
        grev(id, 1),
        grev(id, 2),
        grev(id, 3),
        grev(id, 4),
        grev(id, 5),
        grev(id, 6),
        grev(id, 7));
    __m512i tp_flip = _mm512_setr_epi8(
        56, 48, 40, 32, 24, 16, 8, 0,
        57, 49, 41, 33, 25, 17, 9, 1,
        58, 50, 42, 34, 26, 18, 10, 2,
        59, 51, 43, 35, 27, 19, 11, 3,
        60, 52, 44, 36, 28, 20, 12, 4,
        61, 53, 45, 37, 29, 21, 13, 5,
        62, 54, 46, 38, 30, 22, 14, 6,
        63, 55, 47, 39, 31, 23, 15, 7);
    __m512i m = _mm512_set1_epi64(b);
    m = _mm512_gf2p8affine_epi64_epi8(m, grev_by_index, 0);
    m = _mm512_permutexvar_epi8(tp_flip, m);
    __m512i t512 = _mm512_gf2p8affine_epi64_epi8(m, _mm512_set1_epi64(a), 0);
    t512 = _mm512_gf2p8affine_epi64_epi8(grev_by_index, t512, 0);
    t512 = _mm512_permutexvar_epi8(tp_flip, t512);
    t512 = _mm512_gf2p8affine_epi64_epi8(_mm512_set1_epi64(0x8040201008040201), t512, 0);
    t512 = _mm512_gf2p8affine_epi64_epi8(t512, _mm512_set1_epi64(0xFF), 0);
    return _mm512_movepi8_mask(t512);

The end

As far as I know, no one really uses grevmul for anything, so being able to compute it somewhat efficiently (more efficiently than a naive scalar solution at least) is not immediately useful. On the other hand, if an operation is not known to be efficiently computable, that may preclude its use. But the point of this post is more to show something neat.

Originally I had found the sequence of _mm512_gf2p8affine_epi64_epi8(m, _mm512_set1_epi64(a), 0) and _mm512_gf2p8affine_epi64_epi8(grev_by_index, t512, 0) as a solution to grevmul-ing a QWORD by a (constant) byte, using a SAT solver (shout-out to togasat for being easy to add to a project - though admittedly I eventually bit the bullet and switched to MiniSAT). That formed the starting point of this investigation / puzzle-solving session. It may be possible to pull a more complete solution out of the void with a SAT/SMT based technique such as CEGIS, perhaps the bitwise-bilinear nature of grevmul can be exploited (I used the bitwise-linear nature of grevmul-by-a-constant in my SAT-experiments to represent the problem as a composition of matrices over GF(2)).

Almost half of the steps of this algorithm are some kind of transpose, which has also been the case with some other SIMD algorithms that I recently had a hand in. I used to think of a transpose as "not really doing anything", barely worth the notation when doing linear algebra, but maybe I was wrong.

[1] Maybe it's less "the name" and more "what I decided to call it". I'm not aware of any established name for this operation.
[2] This makes it sound more negative than it really is. The transpose-flip often needs to be worked around when we don't want it, but that's not that bad. Having no easy access to a transpose would be much worse to work around when we do need it. Separate bmatxor and bmattranspose instructions would have been nice.
[3] GF2P8AFFINEQB trickery is nice, but when I recently wrote some AVX2 code it was VPERMB that I missed the most.
[4] Can we stop with the alternatives and just pick something?
[5] No patents were read in the development of this algorithm, nor in the writing of this blog post.