[RFC] Making the constructor of the TransformState class protected?

Problem
Transform Dialect is a significant part of our production compiler infrastructure.
Recently, we are working on a pass that has transform op execution embedded within it. That is, inside this pass, we would like to first initialize some state variables. These state variables are then passed to the transform ops to be further processed. Afterwards, these state variables are post-processed by the same pass again. We encountered a problem with this interaction between the pass and the transform ops.

Currently, we call applyTransforms to run the transform ops programmatically. Since we also perform analysis on the IR, there are information that we wish to pass into the transform ops as part of the transform state. There seems to be no good way to pass information into the transform ops unless we hack the upstream code.

Current Methodology
We currently create a CustomTransformState class inheriting from the TransformState class (a simplified example is shown below) that includes extra data members to store the additional information.

We then create our customized customApplyTransforms function that will create a CustomTransformState state object with the additional information. As an CustomTransformState object is being constructed, the constructor of TransformState is being invoked. Since the TransformState’s constructor is made private, we had to change it to protected in order to make this to work.

Inside customApplyTransforms, we would then call state.applyTransform which will eventually call the apply function for each transform op. Each transform op is required to downcast the state input argument from TransformState to `CustomTransformState before accessing the additional information.


// schedules and constants are the additional information fields.
class CustomTransformState : public TransformState {
 
  friend LogicalResult customApplyTransforms(
    Operation *, TransformOpInterface, const RaggedArray<MappedValue> &,
    const TransformOptions &,
    DenseMap<Operation *, SmallVector<int>> *schedules,
    DenseMap<Operation*, int> *constants);

private:
  CustomTransformState(Region *region, Operation *payloadRoot,
                 const RaggedArray<MappedValue> &extraMappings = {},
                 const TransformOptions &options = TransformOptions()) :
    TransformState(region, payloadRoot, extraMappings, options) {}

public:
  DenseMap<Operation *, SmallVector<int>> *schedules;
  DenseMap<Operation*, int> *constants;
};

// customApplyTransforms creates a CustomTransformState object.
// schedules and constants are initialized by the pass that invokes customApplyTransforms. 
LogicalResult transform::customApplyTransforms(
    Operation *payloadRoot, TransformOpInterface transform,
    const RaggedArray<MappedValue> &extraMapping,
    const TransformOptions &options,
    DenseMap<Operation *, SmallVector<int>> *schedules,
    DenseMap<Operation *, int> *constants) {

  CustomTransformState state(transform->getParentRegion(), payloadRoot,
                                extraMapping, options);
  state.schedules = schedules;
  state.constants = constants;
  return state.applyTransform(transform).checkAndReport();
}

void CustomPass::runOnOperation() {
    // declare variables
    DenseMap<Operation *, SmallVector<int>> schedules;
    DenseMap<Operation *, int> constants;

    // initialize schedules and constants
    // ...

    // call the transform ops
    transform::customApplyTransforms(module, op, extraMappings, options, schedules, constants);

    // post process schedules and constants
    // ...
}

Other Approaches
We also tried to use the TransformState Extension which seems to be most relevant. However, the extension mechanism allows states to be communicated amongst the transform ops only, while we also want the ability to pass information in and out from the pass level.

The change proposed in this RFC is very simple (i.e. a single word change from private->protected :slight_smile: ) but we’re not sure if there are some design points that we missed. Would be really grateful to hear guidance and feedbacks from upstream!

Thank you!

I don’t see any problem with making the TransformState constructor protected. But I think it can also be done with an extension.

  template <typename Ty, typename... Args>
  Ty &TransformState::addExtension(Args &&...args);

  template <typename Ty>
  Ty *TransformState::getExtension();

As you said, you can store all necessary information in a custom extension class. You can call getExtension after all transforms have been applied (to get information about the transforms), or before applying the first transform (to pass information into the transforms). Does that not work for some reason? Generally speaking, I would prefer a composition-based approach (i.e., extensions) over an inheritance-based approach.

One question I had: why not an Analysis for querying the state? Now, I don’t know what state you need nor do I recall if transform interpreter allows for querying Analysis (I think it does, but memory fuzzy/I haven’t done it self).

If I were to guess, this may be querying some target attributes which others have also used an Attribute on some associated op for (now, the attribute could be for a Resource too).

(Not to say I don’t like Matthias’s idea or yours, just wondering about usage and feasibility as alternative)

I’d consider adding a mechanism to pre-initialize extensions somehow. It’s a more complex change than just making the constructor protected, but it will compose better (you can have your extension simultaneously with mine, which is much harder with inheritance) as Matthias suggested above.

Thank you for the suggestion!

The entry to the transformOps is via the applyTransforms method which is a friend of the TransformState class and thus capable of creating a state object (i.e. accessing the private constructor) and utilizes state extension per your suggestion. If my pass wishes to pass extra information into applyTransforms, then I would need to modify the applyTransforms method to accept extra arguments or create a customApplyTransforms method who is made a friend of the TransformState class. Either way would force me to modify the upstream code. As we build our own MLIR directory out-of-tree, modifying upstream MLIR is undesirable thus leading to this RFC.

I do believe extension can be utilized, just that we cannot get to it! :slight_smile:

I can see the problem now… Making the TransformState constructor public sounds reasonable to me.

A pull request with a very small change: [MLIR][Transform] Make TransformState constructor public by kaitingwang · Pull Request #101186 · llvm/llvm-project · GitHub

And since you already have to modify the code, you might as well do it the most future-proof way.

It is not a one-line change, but barely larger:

diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 842e244dcde5..0c8f6437bb43 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -135,7 +135,8 @@ LogicalResult
 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
                 const RaggedArray<MappedValue> &extraMapping = {},
                 const TransformOptions &options = TransformOptions(),
-                bool enforceToplevelTransformOp = true);
+                bool enforceToplevelTransformOp = true,
+                function_ref<void (TransformState &)> stateInitializer = nullptr);
 
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index f8f85e4615c5..b3978b797275 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1999,7 +1999,8 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 LogicalResult transform::applyTransforms(
     Operation *payloadRoot, TransformOpInterface transform,
     const RaggedArray<MappedValue> &extraMapping,
-    const TransformOptions &options, bool enforceToplevelTransformOp) {
+    const TransformOptions &options, bool enforceToplevelTransformOp
+    function_ref<void (TransformState &)> stateInitializer) {
   if (enforceToplevelTransformOp) {
     if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
         transform->getNumOperands() != 0) {
@@ -2013,6 +2014,8 @@ LogicalResult transform::applyTransforms(
 
   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
                        options);
+  if (stateInitializer)
+    stateInitializer(state);
   return state.applyTransform(transform).checkAndReport();
 }

plus documentation changes and tests.

Thank you for the suggestion! @ftynse

Would you be ok to also have a stateExporter argument as below to obtain the updated state back to the pass for further post processing:

applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
                const RaggedArray<MappedValue> &extraMapping = {},
                const TransformOptions &options = TransformOptions(),
                bool enforceToplevelTransformOp = true,
                function_ref<void (TransformState &)> stateInitializer,
                function_ref<LogicalResult (TransformState &)> stateExporter) {

   ...
   if (stateInitializer)
     stateInitializer(state);
   if (state.applyTransform(transform).checkAndReport().failed())
     return failure();
   if (stateExporter)
     return stateExporter(state);
   return success();
}  

With both a stateInitializer and a stateExporter, it would definitely work nicely for our use case. I’m more than happy to contribute a testcase for this (which should include a small test pass, a test transform op that updates a test extension, and a stateInitializer and stateExporter).

Sounds good to me!

Would you be ok to take a quick review: [MLIR][Transform] Allow stateInitializer and stateExporter for applyTransforms by kaitingwang · Pull Request #101186 · llvm/llvm-project · GitHub
Thank you!