Wednesday 9 June 2021

Partial sums of blsi and blsmsk

blsi is an x86 operation which extracts the rightmost set bit from a number, it can be implemented efficiently in terms of more familiar operations as i&-i. blsmsk is a closely related operation which extracts the rightmost set bit and also "smears" it right, setting the bits to the right of that bit as well. blsmsk can be implemented as i^(i-1). The smearing makes the result of blsmsk almost (but not quite) twice as high as the result of blsi for the same input: blsmsk(i) = floor(blsi(i) * (2 - ฮต)).

The partial sums of blsi and blsmsk can be defined as b(n) = sum(i=1, n, blsi(i)) and a(n) = sum(i=1, n, blsmsk(i)) respectively. These sequences are on OEIS as A006520 and A080277. Direct evaluation of those definitions would be inefficient, is there a better way?

Unrecursing the recursive definition

Let's look at the partial sums of blsmsk first. Its OEIS entry suggests the recursion below, which is already significantly better than the naive summation:

a(1) = 1, a(2*n) = 2*a(n) + 2*n, a(2*n+1) = 2*a(n) + 2*n + 1

A "more code-ish"/"less mathy" way to implement that could be like this:

int PartialSumBLSMSK(int i) {
    if (i == 1)
        return 1;
    return (PartialSumBLSMSK(i >> 1) << 1) + i;
}

Let's understand what this actually computes, and then find an other way to do it. The overal big picture of the recursion is that n is being shifted right on the way "down", and the results of the recursive calls are being shifted left on the way "up", in a way that cancels each other. So in total, what happens is that a bunch of "copies" of n is added up, except that at the kth step of the recursion, the kth bit of n is reset.

This non-tail recursion can be turned into tail-recursion using a standard trick: turning the "sum on the way"-logic into "sum on the way down", by passing two accumulators in extra arguments, one to keep track of the sum, and an other to keep track of how much to multiply i by:

int PartialSumBLSMSKTail(int i, int accum_a, int accum_m) {
    if (i <= 1)
        return accum_a + i * accum_m;
    return PartialSumBLSMSKTail(i >> 1, accum_a + i * accum_m, accum_m * 2);
}

Such a tail-recursive function is then simple to transform into a loop. Shockingly, GCC (even quite old versions) manages to compile the original non-tail recursive function into a loop as well without much help (just changing the left-shift into the equivalent multiplication), although some details differ.

Anyway, let's put in a value and see what happens. For example, if the input was 11101001, then the following numbers would be added up:

11101001 (reset bit 0)
11101000 (reset bit 1, it's already 0)
11101000 (reset bit 2)
11101000 (reset bit 3)
11100000 (reset bit 4)
11100000 (reset bit 5)
11000000 (reset bit 6)
10000000 (base case)

Look at the columns of the matrix of numbers above, the column for bit 3 has four ones in it, the column for bit 5 has six ones in it. If bit k is set bit in n, there is a pattern that that bit is set in k+1 rows.

Using the pattern

Essentially what that pattern means, is that a(n) can be expressed as the dot-product between n viewed as a vector of bits (weighted according to their position) (๐“ท), and a constant vector (๐“ฌ) with entries 1, 2, 3, 4, etc, up to the size of the integer. For example for n=5, ๐“ท would be (1, 0, 4), and the dot-product with ๐“ฌ would be 13. A dot-product like that can be implemented with some bitwise trickery, by using bit-slicing. The trick there is that instead of multiplying the entries of ๐“ท by the entries of ๐“ฌ directly, we multiply the entries of ๐“ท by the least-significant bits of the entries of ๐“ฌ, and then seperately multiply it by all the second bits of the entries of ๐“ฌ, and so on. Multiplying every entry of ๐“ท at once by a bit of an entry of ๐“ฌ can be implemented using just a bitwise-AND operation.

Although this trick lends itself well to any vector ๐“ฌ, I will use 0,1,2,3.. and add an extra n separately (this corresponds to factoring out the +1 that appears at the end of the recursive definition), because that way part of the code can be reused directly by the solution of the partial sums of blsi (and also because it looks nicer). The masks that correspond to the chosen vector ๐“ฌ are easy to compute: each column across the masks is an entry of that vector. In this case, for 32bit integers:

c0 10101010101010101010101010101010
c1 11001100110011001100110011001100
c2 11110000111100001111000011110000
c3 11111111000000001111111100000000
c4 11111111111111110000000000000000

The whole function could look like this:

int PartialSumBLSMSK(int n)
{
    int sum = n;
    sum += (n & 0xAAAAAAAA);
    sum += (n & 0xCCCCCCCC) << 1;
    sum += (n & 0xF0F0F0F0) << 2;
    sum += (n & 0xFF00FF00) << 3;
    sum += (n & 0xFFFF0000) << 4;
    return sum;
}

PartialSumBLSI works almost the same way, with its recursive formula being b(1) = 1, b(2n) = 2b(n) + n, b(2n+1) = 2b(n) + n + 1. The +1 can be factored out as before, and the other part (n instead of 2*n) is exactly half of what it was before. Dividing ๐“ฌ by half seems like a problem, but it can be done implicitly by shifting the bit-slices of the product to the right by 1 bit. There are no problems with bits being lost that way, because the least significant bit is always zero in this case (๐“ฌ has zero as its first element).

int PartialSumBLSI(int n)
{
    int sum = n;
    sum += (n & 0xAAAAAAAA) >> 1;
    sum += (n & 0xCCCCCCCC);
    sum += (n & 0xF0F0F0F0) << 1;
    sum += (n & 0xFF00FF00) << 2;
    sum += (n & 0xFFFF0000) << 3;
    return sum;
}

Wrapping up

The particular set of constants I used is very useful and appears in more tricks, such as collecting indexes of set bits. They are the bitwise complements of a set of masks that Knuth (in The Art of Computer Programming volume 4, section 7.1.3) calls "magic masks", labeled ยตk.

This post was inspired by this question on Stack Overflow and partially based on my own answer and the answer of Eric Postpischil, without which I probably would not have come up with any of this, although I used a used a different derivation and explanation for this post.