Thursday 22 August 2024

Enumerating identities, part 2

Part 2 of Enumerating all mathematical identities (in fixed-size bitvector arithmetic with a restricted set of operations) of a certain size.

To recap, the approach in part 1 was broadly:

  • Use a CEGIS-style loop to find a pair of expressions that are equal for all possible inputs.
  • After finding such a pair, block it from being found again.
  • Repeat until no more pairs can be found.

One weakness (in addition to the other weaknesses) of that approach is that the "block list" keeps growing as more identities are found. Here is an alternative approach that does not use an ever-growing block list:

  • Interpreting the raw bit-string that represents a pair of expressions as an integer, find the lowest pair of equivalent expressions.
  • After finding the lowest pair, interpreted as the number X, set the lower bound for the next pair to X + 1.
  • Repeat until no more pairs can be found.

Finding the lowest pair of expressions

SAT by itself does not try to find the lowest solution, nor does the CEGIS-loop built on top of it. But we can use the CEGIS-loop as an oracle to answer the question: is there any solution in (a restricted part of) the search space, and that lets us do a bitwise binary search - the "find one bit at the time" variant of binary search, typically discussed in a completely different context. Bitwise binary search maps well to SAT, not directly of course, but in the following way:

  1. Initialize the prefix to an empty list.
  2. If the prefix has the same size as a pair of expressions, directly return the result of the CEGIS-oracle.
  3. Extend the prefix with false.
  4. Ask the CEGIS-oracle if there is a solution that starts with the prefix, if there is, go back to step 2.
  5. Otherwise, turn the false at the end of the prefix into true, then go back to step 2.

Asking a SAT solver for a solution that starts with a given prefix is easy and tends to make the SAT instance easier to solve (especially when the prefix is long), this only involves forcing some variables (the ones covered by the prefix) to be true or false, using single-literal clauses.

A very useful optimization can be done in step 4: when the CEGIS-oracle says there is a solution with the current prefix, the solution may have some extra zeroes after the prefix which we can use directly to extend the prefix for free. That saves a lot of SAT solves when the solutions tend to have a lot of zeroes in them, which in my case they do, due to extensive use of one-hot encoding.

Encoding the lower bound

There is a neat way to encode that an integer must be greater-than-or-equal-to some given constant, which I haven't seen people talk about (perhaps it's part of the folklore?), using only popcnt(lower_bound) clauses. The idea here is that for every bit that's set in the bound, at least one of the following bits must be set in the solution: that bit itself, or any more-significant bit that is zero in the lower bound. That only takes one clause to encode, a clause containing the variable that corresponds to the set bit, and the variables corresponding to the zeroes to the left of that set bit.

Code

For concreteness, here's how I implemented constraining solutions to conform to the prefix, it's really simple:

// set the prefix (used by binary search)
for (size_t i = 0; i < prefix.size(); i++)
{
    if (prefix[i])
        s.addClause(allprogbits[i]);
    else
        s.addClause(~allprogbits[i]);
}

Here's how I set the lower bound:

// if there is a lower bound (used by binary search), enforce it
if (!lower_bound.empty())
{
    vector<Lit> cl;
    for (size_t i = 0; i < lower_bound.size(); i++) {
        if (lower_bound[i]) {
            cl.push_back(allprogbits[i]);
            s.addClause(cl);
            cl.pop_back();
        }
        else
        {
            cl.push_back(allprogbits[i]);
        }
    }
}

And binary search (with the optimization to keep the extra zeroes that the solver gives for free) looks like this:

optional<pair<vector<InstrB>, vector<InstrB>>> find_lowest(vector<bool>& prefix,
      vector<vector<int>>& inputs)
{
    do {
        if (prefix.size() == progbits) {
            return format_progbits(synthesize(prefix, inputs),
                inputcount, lhs_size, rhs_size);
        }
        else {
            prefix.push_back(false);
            auto f = synthesize(prefix, inputs);
            if (f.has_value()) {
                // if the next bits are already zero in the solution, keep them zero
                auto bits = *f;
                while (prefix.size() < progbits && !bits[prefix.size()])
                    prefix.push_back(false);
                continue;
            }

            prefix.pop_back();
            prefix.push_back(true);
        }
    } while (true);
}

The full code of my implementation of this idea is available on gitlab. It's a bit crap but at least it should show any detail that you may still be wondering about. In the code I also constrain expressions to not be "a funny way to write zero" and to not be "a complicated way to do nothing", otherwise a lot of less-interesting identities would be generated.