Declarative folding

I have a proposal for an alternative way in which we could express operational semantics with regards to folding.

While implementing my candidate for poison semantics in MLIR, I quickly ran into limitations with the way operational semantics are encoded in C++. In particular, I created a maintainence disaster just so I could implement these identities in my bit dialect.

I came up with a C++ template solution that may be interesting to others as well. Here’s what it looks like:

static const StaticFolder andFolder(
    [](Value x, Value y) -> OpFoldResult { 
        if (x == y) return x; else return {};
    },
    [](Any, Zero z) -> OpFoldResult { return z; },
    ...);

// andFolder(
//    Operation*, 
//    ArrayRef<Attributes>, 
//    SmallVectorImpl<OpFoldResult> &) -> LogicalResult;

IR Concepts

I’ve already adopted custom IR concepts basically everywhere, i.e., deriving from Type, Attribute and Value (though now we have TypedValue, although on my SHA it doesn’t participate in llvm-style casting yet?). That means I can very effectively eliminate a lot of C++ code, improve maintainability, ensure consistent canonicalization, and sometimes even catch a few bugs. In general, I prefer the C++ type system checked approach to the plain fancy pointers.

Basically, if interfaces and other concepts are rigorously used everywhere (i.e., TableGen would actually emit them to C++ instead of anonymous trait checking methods), folding is already much cleaner. Still, the seemingly least-code-duplicating way to implement it is to build overload sets from OpFoldResult down to the specific attribute concepts. (The former is needed so that things lke and %x, 0 are foldable.)

Matchers

So, what if we generalize this to arbitrary matchers that, during the fold operation, bind to something in the state. Multiple matchers form a pattern, which is rejected if any of the matchers are rejected. Otherwise, some internal code is allowed to attempt folding on the bound values. Of course, we want these patterns to be lambdas, and we don’t want to have to annotate them with any metadata.

I designed a mechanism that allows implementing custom matchers, or adding this functionality to existing types. This includes all the basic concepts.

StaticFolder

The whole thing is put together by a StaticFolder that is supposed to exist with static lifetime. It calls the patterns in the order they were created in.

I am quite pleased with the current implementation, although it lacks much testing. Also, it remains to be seen how it can be improved. I don’t think it provides a good solution for container folding yet, you still have to do that “on foot”. Let me hear what you think and if you have ideas for improvements. Maybe we can find a way to connect this to some real declarative folding coming from TableGen or something similar in the future.

I’ve ironed out the bugs in my implementation and made it a lot more flexible. I’ve also completed the migration of my bit dialect to the declarative folders. So, for a more concrete example, you can look at this file.

The implementation fulfills the same identities as LLVM (unless I screwed up). For example, here is the folder for the select operation:

static const StaticFolder folder(
        [](Any, Any trueValue, Any falseValue) -> OpFoldResult {
            if (trueValue == falseValue) return trueValue;
            return {};
        },
        [](AnyPoison, Type type, Value) {
            return ConstOrPoison::get(llvm::cast<BitSequenceLikeType>(type));
        },
        [](Zeros, Any, Any falseValue) { return falseValue; },
        [](Ones, Any trueValue, Any) { return trueValue; },
        [](ConstOrPoison cond,
           ConstOrPoison trueValue,
           ConstOrPoison falseValue) {
            return cond.zip(
                [](const auto &c,
                   const auto &t,
                   const auto &f) -> std::optional<BitSequence> {
                    if (!c) return poison;
                    return c->isOnes() ? t : f;
                },
                trueValue,
                falseValue,
                trueValue.getType().getElementType());
        });