Thursday 4 April 2024

Enumerating all mathematical identities (in fixed-size bitvector arithmetic with a restricted set of operations) of a certain size

Once again a boring-but-specific title, I don't want to clickbait the audience after all. Even so, let's get some "disclaimers" out of the way.

  • "All" includes both boring and interesting identities. I tried to remove the most boring ones, so it's no longer truly all identities, but still a lot of boring ones remain. The way I see it, the biggest problem which the approach that I describe in this blog post has, is generating too much "true but boring" junk.
  • This approach is, as far as I know, absolutely limited to fixed-size bitvectors, but that's what I'm interested in anyway. To keep things reasonable, the size should be small, which does result in some differences with eg the arithmetic of 64-bit bitvectors. Most of the results either directly transfer to larger sizes, or generalize to larger sizes.
  • The set of operations is restricted to those that are cheap to implement in CNF SAT, that is not a hard limitation but a practical one.
  • "Of a certain size" means we have to pick in advance the number of operations on both sides of the mathematical identity, and then only identities with exactly that number of operations (counted in the underlying representation, which may represent a seemingly larger expression if the result of an operation is used more than once) are found. This can be repeated for any size we're interested in, but this approach is not very scalable and tends to run out of memory if there are too many identities of the requested size.

The results look something like this. Let's say we want "all" identities that involve 2 variables, 2 operations on the left, and 0 operations (ie only a variable) on the right. The result would be the following, keep in mind that a bunch of redundant and boring identities are filtered out.

(a - (a - b)) == b
(a + (b - a)) == b
((a + b) - a) == b
(b | (a & b)) == b
(a ^ (a ^ b)) == b
(b & (a | b)) == b
// Done in 0s. Used a set of 3 inputs.

Nothing too interesting so far, but then we didn't ask for much. Here are a couple of selected "more interesting" (but not by any means new or unknown) identities that this approach can also enumerate:

((a & b) + (a | b)) == (a + b)
((a | b) - (a & b)) == (a ^ b)
((a ^ b) | (a & b)) == (a | b)
((a & b) ^ (a | b)) == (a ^ b)
((~ a) - (~ b)) == (b - a)
(~ (b + (~ a))) == (a - b)

Now that the expectations have been set accurately (hopefully), let's get into what the approach is.

The Approach

The core mechanism I use is CounterExample-Guided Inductive Synthesis (CEGIS) based on a SAT solver. Glucose worked well, other solvers can be used. Rather than asking CEGIS to generate a snippet of code that performs some specific task however, I ask it to generate two snippets that are equivalent. That does not fundementally change how it operates, which is still a loop of:

  1. Synthesize code that does the right thing for each input in the set of inputs to check.
  2. Check whether the code matches the specification. If it does, we're done. If it doesn't, add the counter-example to the set of inputs to check.

Both synthesis and checking could be performed by a SAT solver, but I only use a SAT solver for synthesis. For checking, since 4-bit bitvectors have so few combinations, I just brute force every possible valuation of the variables.

When a pair of equivalent expressions has been found, I add its negation as a single clause to prevent the same thing from being synthesized again. This is what enables pulling out one identity after the other. In my imagination, that looks like Thor smashing his mug and asking for another.

Solving for programs may seem odd, the trick here is to represent a program as a sequence of instructions that are constructed out of boolean variables, the SAT solver is then invoked to solve for those variables.

The code is available on gitlab.

Original motivation

The examples I gave earlier only involve "normal" bitvector arithmetic. Originally what I set out to do is discover what sorts of mathematical identities are true in the context of trapping arithmetic (in which subtraction and addition trap on signed overflow), using the rule that two expressions are equivalent if and only if they have the same behaviour, in the following sense: for all valuations of the variables, the two expressions either yield the same value, or they both trap. That rule is also implemented in the linked source.

Many of the identities found in that context involve a trapping operation that can never actually trap. For example the trapping subtraction (the t-suffix in -t indicates that it is the trapping version of subtraction) in (b & (~ a)) == (b -t (a & b)) cannot trap (boring to prove so I won't bother). But the "infamous" (among who, maybe just me) (-t (-t (-t a))) == (-t a) is also enumerated and the sole remaining negation can trap but does so in exactly the same case as the original three-negations-in-a-row (namely when a is the bitvector with only its sign-bit set). Here is a small selection of nice identities that hold in trapping arithmetic:

((a & b) +t (a | b)) == (a +t b)
((a | b) -t (a & b)) == (a ^ b)
((~ b) -t (~ a)) == (a -t b)
(~ (b +t (~ a))) == (a -t b)
(a -t (a -t b)) == (a - (a -t b))  // note: one of the subtractions is a non-trapping subtraction

Future directions

A large source of boring identities is the fact that if f(x) == g(x), then also f(x) + x == g(x) + x and f(x) & x == g(x) & x and so on, which causes "small" identities to show up again as part of larger ones, without introducing any new information, and multiplied in myriad ways. If there was a good way to prevent them from being enumerated (it would have to be sufficiently easy to state in terms of CNF SAT clauses, to prevent slowing down the solver too much), or to summarize the full output, that could make the output of the enumeration more human-digestible.

Saturday 9 March 2024

The solutions to 𝚙𝚘𝚙𝚌𝚗𝚝(𝚡) < 𝚝𝚣𝚌𝚗𝚝(𝚡) and why there are Fibonacci[n] of them below 2ⁿ

popcnt(x) < tzcnt(x) asks the question "does x have fewer set bits than it has trailing zeroes". It's a simple question with a simple answer, but cute enough to think about on a Sunday morning.[1]

Here are the solutions for 8 bits, in order: 0, 4, 8, 16, 24, 32, 40, 48, 64, 72, 80, 96, 112, 128, 136, 144, 160, 176, 192, 208, 224[2]

In case you find decimal hard to do read (as I do), here they are again in binary: 00000000, 00000100, 00001000, 00010000, 00011000, 00100000, 00101000, 00110000, 01000000, 01001000, 01010000, 01100000, 01110000, 10000000, 10001000, 10010000, 10100000, 10110000, 11000000, 11010000, 11100000

Simply staring at the values doesn't do much for me. To get a better handle on what's going on, let's recursively (de-)construct the set of n-bit solutions.

The most significant bit of an n-bit solution is either 0 or 1:

  1. If it is 0, then that bit affects neither the popcnt nor the tzcnt so removing it must yield an (n-1)-bit solution.
  2. If it is 1, then removing it along with the least significant bit (which must be zero, there are no odd solutions since their tzcnt would be zero) would decrease the both popcnt and the tzcnt by 1, yielding an (n-2)-bit solution.

This "deconstructive" recursion is slightly awkward. The constructive version would be: you can take the (n-1)-bit solutions and prepend a zero to them, and you can take the (n-2)-bit solutions and prepend a one and append a zero to them. However, it is less clear then (to me anyway) that those are the only n-bit solutions. The "deconstructive" version starts with all n-bit solutions and splits them into two obviously-disjoint groups, removing the possibility of solutions getting lost or being counted double.

The F(n) = F(n - 1) + F(n - 2) structure of the number of solutions is clear, but there are different sequences that follow that same recurrence that differ in their base cases. Here we have 1 solution for 1-bit integers (namely zero) and 1 solution for 2-bit integers (also zero), so the base cases are 1 and 1 as in the Fibonacci sequence.

This is probably all useless, and it's barely even bitmath.

[1] Or whenever, but it happens to be a Sunday morning for me right now.
[2] This sequence does not seem to be on the OEIS at the time of writing.

Wednesday 17 January 2024

Partial sums of popcount

The partial sums of popcount, aka A000788: Total number of 1's in binary expansions of 0, ..., n can be computed fairly efficiently with some mysterious code found through its OEIS entry (see the link Fast C++ function for computing a(n)): (reformatted slightly to reduce width)

unsigned A000788(unsigned n)
    unsigned v = 0;
    for (unsigned bit = 1; bit <= n; bit <<= 1)
        v += ((n>>1)&~(bit-1)) +
             ((n&bit) ? (n&((bit<<1)-1))-(bit-1) : 0);
    return v;

Knowing what we (or I, anyway) know from computing the partial sums of blsi and blsmsk, let's try to improve on that code. "Improve" is a vague goal, let's say we don't want to loop over the bits, but also not just unroll by 64x to do this for a 64-bit integer.

First let's split this thing into the sum of an easy problem and a harder problem, the easy problem being the sum of (n>>1)&~(bit-1) (reminder that ~(bit-1) == -bit, unsigned negation is safe, UB-free, and does exactly what we need, even on hypothetical non-two's-complement hardware). This is the same thing we saw in the partial sum of blsi, bit k of n occurs k times in the sum, which we can evaluate like this:

uint64_t v = 
    ((n & 0xAAAA'AAAA'AAAA'AAAA) >> 1) +
    ((n & 0xCCCC'CCCC'CCCC'CCCC) << 0) +
    ((n & 0xF0F0'F0F0'F0F0'F0F0) << 1) +
    ((n & 0xFF00'FF00'FF00'FF00) << 2) +
    ((n & 0xFFFF'0000'FFFF'0000) << 3) +
    ((n & 0xFFFF'FFFF'0000'0000) << 4);

The harder problem, the contribution from ((n&bit) ? (n&((bit<<1)-1))-(bit-1) : 0), has a similar pattern but more annoying in three ways. Here's an example of the pattern, starting with n in the first row and listing the values being added together below the horizontal line:

  1. Some anomalous thing happens for the contiguous group of rightmost set bits.
  2. The weights are based not on the column index, but sort of dynamic based on the number of set bits ...
  3. ... to the left of the bit we're looking at. That's significant, "to the right" would have been a lot nicer to deal with.

For problem 1, I'm just going to state without proof that we can add 1 to n and ignore the problem, as long as we add n & ~(n + 1) to the final sum. Problems 2 and 3 are more interesting. If we had problem 2 but counting the bits to the right of the bit we're looking at, that would have nice and easy, instead of (n & 0xAAAA'AAAA'AAAA'AAAA) we would have _pdep_u64(0xAAAA'AAAA'AAAA'AAAA, n), problem solved. If we had a "pdep but from left to right" (aka expand_left) named _pdepl_u64 we could have done this:

uint64_t u = 
    ((_pdepl_u64(0x5555'5555'5555'5555, m) >> shift) << 0) +
    ((_pdepl_u64(0x3333'3333'3333'3333, m) >> shift) << 1) +
    ((_pdepl_u64(0x0F0F'0F0F'0F0F'0F0F, m) >> shift) << 2) +
    ((_pdepl_u64(0x00FF'00FF'00FF'00FF, m) >> shift) << 3) +
    ((_pdepl_u64(0x0000'FFFF'0000'FFFF, m) >> shift) << 4) +
    ((_pdepl_u64(0x0000'0000'FFFF'FFFF, m) >> shift) << 5);

But as far as I know, that requires bit-reversing the inputs (see the update below) of a normal _pdep_u64 and bit-reversing the result, which is not so nice at least on current x64 hardware. Every ISA should have a Generalized Reverse operation like the grevi instruction which used to be in the drafts of the RISC-V Bitmanip Extension prior to version 1.


It turned out there is a reasonable way to implement _pdepl_u64(v, m) in plain scalar code after all, namely as _pdep_u64(v >> (std::popcount(~m) & 63), m). The & 63 isn't meaningful, it's just to prevent UB at the C++ level.

This approach turned out to be more efficient than the AVX512 approach, so that's obsolete now, but maybe still interesting to borrow ideas from. Here's the scalar implementation in full:

uint64_t _pdepl_u64(uint64_t v, uint64_t m)
    return _pdep_u64(v >> (std::popcount(~m) & 63), m);

uint64_t partialSumOfPopcnt(uint64_t n)
    uint64_t v =
        ((n & 0xAAAA'AAAA'AAAA'AAAA) >> 1) +
        ((n & 0xCCCC'CCCC'CCCC'CCCC) << 0) +
        ((n & 0xF0F0'F0F0'F0F0'F0F0) << 1) +
        ((n & 0xFF00'FF00'FF00'FF00) << 2) +
        ((n & 0xFFFF'0000'FFFF'0000) << 3) +
        ((n & 0xFFFF'FFFF'0000'0000) << 4);
    uint64_t m = n + 1;
    int shift = std::countl_zero(m);
    m = m << shift;
    uint64_t u =
        ((_pdepl_u64(0x5555'5555'5555'5555, m) >> shift) << 0) +
        ((_pdepl_u64(0x3333'3333'3333'3333, m) >> shift) << 1) +
        ((_pdepl_u64(0x0F0F'0F0F'0F0F'0F0F, m) >> shift) << 2) +
        ((_pdepl_u64(0x00FF'00FF'00FF'00FF, m) >> shift) << 3) +
        ((_pdepl_u64(0x0000'FFFF'0000'FFFF, m) >> shift) << 4) +
        ((_pdepl_u64(0x0000'0000'FFFF'FFFF, m) >> shift) << 5);
    return u + (n & ~(n + 1)) + v;

Repeatedly calling _pdepl_u64 with the same mask creates some common-subexpressions, they could be manually factored out but compilers do that anyway, even MSVC only uses one actual popcnt instruction (but MSVC, annoyingly, actually performs the meaningless & 63).

Enter AVX512

Using AVX512, we could more easily reverse the bits of a 64-bit integer, there are various ways to do that. But just using that and then going back to scalar pdep would be a waste of a good opportunity to implement the whole thing in AVX512, pdep and all. The trick to doing a pdep in AVX512, if you have several 64-bit integers that you want to pdep with the same mask, is to transpose 8x 64-bit integers into 64x 8-bit integers, use vpexpandb, then transpose back. In this case the first operand of the pdep is a constant, so the first transpose is not necessary. We still have to reverse the mask though. Since vpexpandb takes the mask input in a mask register and we only have one thing to reverse, this trick to bit-permute integers seems like a better fit than Wunk's whole-vector bit-reversal or some variant thereof.

I sort of glossed over the fact that we're supposed to be bit-reversing relative to the most significant set bit in the mask, but that's easy to do by shifting left by std::countl_zero(m) and then doing a normal bit-reverse, so in the end it still comes down to a normal bit-reverse. The result of the pdeps have to be shifted right by the same amount to compensate.

Here's the whole thing: (note that this is less efficient than the updated approach without AVX512)

uint64_t partialSumOfPopcnt(uint64_t n)
    uint64_t v = 
        ((n & 0xAAAA'AAAA'AAAA'AAAA) >> 1) +
        ((n & 0xCCCC'CCCC'CCCC'CCCC) << 0) +
        ((n & 0xF0F0'F0F0'F0F0'F0F0) << 1) +
        ((n & 0xFF00'FF00'FF00'FF00) << 2) +
        ((n & 0xFFFF'0000'FFFF'0000) << 3) +
        ((n & 0xFFFF'FFFF'0000'0000) << 4);
    // 0..63
    __m512i weights = _mm512_setr_epi8(
        0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19, 20, 21, 22, 23,
        24, 25, 26, 27, 28, 29, 30, 31,
        32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47,
        48, 49, 50, 51, 52, 53, 54, 55,
        56, 57, 58, 59, 60, 61, 62, 63);
    // 63..0
    __m512i rev = _mm512_set_epi8(
        0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19, 20, 21, 22, 23,
        24, 25, 26, 27, 28, 29, 30, 31,
        32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47,
        48, 49, 50, 51, 52, 53, 54, 55,
        56, 57, 58, 59, 60, 61, 62, 63);
    uint64_t m = n + 1;
    // bit-reverse the mask to implement expand-left
    int shift = std::countl_zero(m);
    m = (m << shift);
    __mmask64 revm = _mm512_bitshuffle_epi64_mask(_mm512_set1_epi64(m), rev);
    // the reversal of expand-right with reversed inputs is expand-left
    __m512i leftexpanded = _mm512_permutexvar_epi8(rev, 
        _mm512_mask_expand_epi8(_mm512_setzero_si512(), revm, weights));
    // transpose back to 8x 64-bit integers
    leftexpanded = Transpose64x8(leftexpanded);
    // compensate for having shifted m left
    __m512i masks = _mm512_srlv_epi64(leftexpanded, _mm512_set1_epi64(shift));
    // scale and sum results
    __m512i parts = _mm512_sllv_epi64(masks,
        _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0));
    __m256i parts2 = _mm256_add_epi64(
        _mm512_extracti64x4_epi64(parts, 1));
    __m128i parts3 = _mm_add_epi64(
        _mm256_extracti64x2_epi64(parts2, 1));
    uint64_t u = 
        _mm_cvtsi128_si64(parts3) + 
        _mm_extract_epi64(parts3, 1);
    return u + (n & ~(n + 1)) + v;

As for Transpose64x8, you can find out how to implement that in Permuting bits with GF2P8AFFINEQB.