Saturday 28 September 2013

Determining the cardinality of a set described by a mask and bounds

In other words, calculating this \begin{equation} | \left\{ x | (x \& m) = v \land x \ge a \land x \le b \right\} | \end{equation}

A BDD (which haroldbot uses to solve this) has no trouble with that at all, but the structure the problem is so nice that I thought it should be possible to do better. And it is, though the solution I'm about to present is probably far from optimal. I don't know how to significantly improve on it, but I just get a very non-scientific gut-feeling that there should be a fundamentally better way to do it.

The cardinality of the set without the bounds (only the mask) is obviously trivial to compute (as 1 << popcnt(~m)), and the bounds divide that set in three parts: items that are lower than the lower bound ("left part"), items that are higher than the upper bound ("right part")), and items that are actual set of interest ("middle part"). The main idea it is built on is that it's relatively easy to solve the problem if it didn't have a lower bound and the upper bound was a power of two, with a formula roughly similar to the one above. Using that, all the powers of two that fit before lower bound can be counted from high to low, giving the cardinality of the "left part". The same can be done for the "right part", actually in exactly the same way, by complementing the mask and the upper bound. Obviously the cardinality of the "middle part" can be computed from this.

And here's the code. It doesn't like it when the cardinality is 232, and watch out for weird corner cases such as when the lower bound is bigger than the upper bound (why would you even do that?). It usually works, that's about as much as I can say - I didn't prove it correct.

static uint Cardinality(uint m, uint v, uint a, uint b)
{
    // count x such that a <= x <= b && (x & m) == v
    // split in three parts:
    // left  = 0 <= x < a
    // right = b < x <= -1
    // the piece in the middle is (1 << popcnt(~m)) - left - right
    uint left = 0;
    uint x = 0;
    for (int i = 32 - 1; i >= 0; i--)
    {
        uint mask = 1u << i;
        if (x + mask <= a)
        {
            x += mask;
            uint mask2 = 0 - mask;
            if ((x - 1 & m & mask2) == (v & mask2))
            {                        
                uint amount = 1u << popcnt(~(m | mask2));
                left += amount;
            }
        }
    }
    uint right = 0;
    uint y = 0;
    for (int i = 32 - 1; i >= 0; i--)
    {
        uint mask = 1u << i;
        if (y + mask <= ~b)
        {
            y += mask;
            uint mask2 = 0 - mask;
            if ((y - 1 & m & mask2) == (~v & m & mask2))
            {
                uint amount = 1u << popcnt(~(m | mask2));
                right += amount;
            }
        }
    }
    uint res = (uint)((1UL << popcnt(~m)) - (left + right));
    return res;
}
The loops can be merged of course, but for clarity they're separate here.

If you have any improvements, please let me know.