Tuesday, 10 December 2024

Bit-permuting 16 u32s at once with AVX-512

The basic trick to apply the same bit-permutation to each of the u32s is to view them as matrix of 16 rows by 32 columns, transpose it into a 32 u16s, permute those u16s in the same way that we wanted to permute the bits of the u32s [1], then transpose back to 16 u32s. Easy:

__m512i permbits_16x32(__m512i data, __m512i indices)
{
    __m512i x = data;
    x = transpose_16_dwords_to_32_words(x);
    x = _mm512_permutexvar_epi16(indices, x);
    x = transpose_32_words_to_16_dwords(x);
    return x;
}

transpose_16_dwords_to_32_words and transpose_32_words_to_16_dwords may be tricky to implement by hand, but haroldbot.nl/avx512bpc.html can solve them. In order to transpose 16 dwords into 32 words, the 4 most significant bits of every bit-index (the 9-bit bit-index of every bit in a 512-bit __m512i) should become the 4 least significant bits, and the 5 least significant bits of the index should become the most significant bits. In other words, the bit order of the bit-indices should become 5,6,7,8,0,1,2,3,4. For the inverse, 4,5,6,7,8,0,1,2,3.

Taking the solutions from that AVX512 BPC permute solver and making them the bread of the transpose-shuffle-transpose sandwich is a valid solution, but there is an opportunity for improvement: that sandwich ends up with 3 back-to-back permutes in the middle, perhaps they can be merged somehow.

One solution I've found, is to use the bit-index-bit-order 5,6,7,0,1,2,3,4,8. Keeping the most-significant bit in place simplifies the first "transpose" to no longer need a shuffle at the end, but now instead of shuffling each pair of adjacent bytes the same way (which a permutation of 16-bit elements accomplishes for free) we need to shuffle the low and high half of a 64-byte vector the same way, which requires some pre-processing of the index-vector: duplicate a 32-byte vector into the top and bottom of a 64-byte vector, and add 32 to every byte in the top half. The second not-quite-transpose in the new sandwich, corresponding to a bit-order of 3,4,5,6,7,0,1,2,8, starts with a simple permute: byte-reverse every u64. This permute can be absorbed into the pre-processing of the index-vector. In general that would not help, but in cases in which the same index-vector is reused multiple times (probably a common usage pattern, I've already used it that way myself) moving that permute into the pre-processing step can move it out of the loop.

The code

Putting it all together, permbits_16x32 can be implemented like this:

__m512i permbits_16x32_weirdindex(__m512i x, __m512i p) {
    
    __m512i mID = _mm512_set1_epi64(0x8040201008040201);
    __m512i s1 = _mm512_set_epi8(
        35, 39, 43, 47, 51, 55, 59, 63,
        34, 38, 42, 46, 50, 54, 58, 62,
        33, 37, 41, 45, 49, 53, 57, 61,
        32, 36, 40, 44, 48, 52, 56, 60,
        3, 7, 11, 15, 19, 23, 27, 31,
        2, 6, 10, 14, 18, 22, 26, 30,
        1, 5, 9, 13, 17, 21, 25, 29,
        0, 4, 8, 12, 16, 20, 24, 28);
    __m512i s2 = _mm512_set_epi8(
        63, 55, 47, 39, 62, 54, 46, 38,
        61, 53, 45, 37, 60, 52, 44, 36,
        59, 51, 43, 35, 58, 50, 42, 34,
        57, 49, 41, 33, 56, 48, 40, 32,
        31, 23, 15, 7, 30, 22, 14, 6,
        29, 21, 13, 5, 28, 20, 12, 4,
        27, 19, 11, 3, 26, 18, 10, 2,
        25, 17, 9, 1, 24, 16, 8, 0);
    x = _mm512_permutexvar_epi8(s1, x);
    x = _mm512_gf2p8affine_epi64_epi8(mID, x, 0);
    x = _mm512_permutexvar_epi8(p, x);
    x = _mm512_gf2p8affine_epi64_epi8(mID, x, 0);
    x = _mm512_permutexvar_epi8(s2, x);
    return x;
}

__m512i permbits_16x32(__m512i x, __m256i indices)
{
    __m512i s1 = _mm512_set_epi8(
        24, 25, 26, 27, 28, 29, 30, 31,
        16, 17, 18, 19, 20, 21, 22, 23,
        8, 9, 10, 11, 12, 13, 14, 15,
        0, 1, 2, 3, 4, 5, 6, 7,
        24, 25, 26, 27, 28, 29, 30, 31,
        16, 17, 18, 19, 20, 21, 22, 23,
        8, 9, 10, 11, 12, 13, 14, 15,
        0, 1, 2, 3, 4, 5, 6, 7);
    __m512i p = _mm512_permutexvar_epi8(s1, _mm512_castsi256_si512(indices));
    uint64_t m = 0x2020202020202020;
    p = _mm512_add_epi8(p, _mm512_set_epi64(m, m, m, m, 0, 0, 0, 0));
    return permbits_16x32_weirdindex(x, p);
}

[1]: As a bonus, you can put other permutation-like operations between the transposes, such as vpcompressd to perform a pext on every u32 (but with only one shared mask used for all 16 u32s), or even operations that are not permutation-like.