[RFC] Type System for the Transform Dialect

Proposal

An increasing number of transform dialect operations expect certain properties of the payload operations they are processing, or even specific classes of operations. For example, most transform operations from the SCF dialect extension expect their operands to be mapped to loops; some structured transform operations only apply to structured payload operations with a certain rank or certain properties such as outer parallel iterators. These are currently treated as preconditions of the transformation functions and result in transform failures when they fail to apply. In the code predating the transform dialect, such unmet preconditions are hard errors (return failure()) with little debug information associated or sometimes even assertions. This leads to a double usability problem of the transform dialect: (1) a transformation may fail to apply due to a precondition failure in its implementation without any useful feedback to the user other than “transformation X failed to apply”; (2) the preconditions are not visible in the transform IR, they are at best specified in English in the documentation and at worst are not documented at all.

This document proposes to address this problem by introducing a type system for the transform dialect. Instead of uniformly using !pdl.operation, each transform dialect handle has a type that captures the relevant properties of the payload ops associated with the handle. With the sufficiently rich type system, both pre- and post-conditions become visible in transform IR. The verification is performed systematically when the association between the payload ops and the transform IR handles is established, i.e., in TransformState::set/updatePayloadOps and does not require any action from the transform op implementation. This provides a homogeneous mechanism for reporting pre/postcondition failures to the user without necessarily modifying the transformation logic. It enables an earlier verification of pre/postcondition mismatch between transformations feeding each other. A part of the verification can be performed statically by checking that the transform IR types are compatible in some sense (e.g., some transformations apply to loop-like payload ops but it can be deduced from the script that the handle does not contain such ops).

Improved verification comes at the cost of potentially duplicating the pre/postcondition checks between transform IR types and actual transformations, especially when transformations are implemented as utilities callable from outside the transform dialect infrastructure. The cost in terms of code duplication is acceptable because (1) the number of transform IR types is expected to be significantly smaller than the number of transform IR operations similarly to other dialects and (2) the actual check can be factored out from the utility transformation function into another utility function, called by both the original function and the transform IR type condition checker. This may have a beneficial side effect on the implementation of transformations by factoring out the important preconditions and checking them early thus avoiding expensive IR modification rollbacks and generally improving the design. The cost in terms of the compiler runtime can be mitigated by providing an option to disable the IR type condition checker when detailed error reporting is not required.

The transform dialect adds support for injecting types via dialect extensions, similarly to operations.

In practice, transform dialect types must implement the TransformTypeInterface by defining the DiagnosableSilenceableFailure checkPayload(Location, ArrayRef<Operation *>) method that checks whether a payload IR operation corresponds to the conditions of the type. This method is called from TransformState::set/updatePayloadOps for all payload operations. It is intentionally called “check” rather than “verify” to indicate that it is not a part of the normal IR verification process.

Transform operations specify type constraints on their handle operands to match the transformation preconditions. When a transform operation is applied to payload, the payload operations associated with the handles will have been checked for preconditions.

Transform operations are heavily encouraged to specify broad type constraints on their results, or even no constraints at all when reasonable. When a specific instance of such an operation is created, it may be created with a more or less tight transform IR type of the result depending on the needs of the consumers. For example, even if a loop tiling operation is known to produce an scf.for loop with less than N iterations, different instances of this transform IR operation can define handles that correspond to just scf.for loop, a “loop-like construct”, or even a general “operation”. This is a conscious decision to maintain the readability/usability of the transform IR that would otherwise require multiple cast-like operations to convert types: such operations would still rely on dynamic condition checks and would move the eventual diagnostics from the operation that failed to satisfy the postconditions expressed in the type to the cast operation.

Example:

transform.with_pdl_patterns {
^bb0(%root: !transform.any_op):
  transform.sequence %root failures(propagate) {
  ^bb1(%seq_root: !transform.any_op):
    // The ODS type constraint on this operation does not prescribe a concrete result
    // type, but this concrete usage indicates that the matched payload is always either
    // "foo" or "bar". Note that in the case of PDL patterns, we may check or infer
    // the matched payload op types, thanks to PDL itself having a declarative spec.
    %0 = transform.pdl_match @pattern in %seq_root
       : !transform.op<["dialect.foo", "dialect.bar"]>
    // The following results in a type error because the op expects a handle
    // to "dialect.bar" only.
    transform.do_something_with_bar %0

    %1 = transform.expand_foo_to_bar %0 : !transform.op<["dialect.bar"]>
    // This is fine because of the expansion above.
    transform.do_something_with_bar %1
  }

  pdl.pattern @pattern { /* … */ }
}

Future evolution of typed transform dialect may include declarative specification of the constraints on the payload ops using interpretable IR, for example PDL or other matching or navigation transform operations. This can be used to check if the pre/post-condition contracts can be broken using abstract interpretation and find when exactly without applying the transformation. Furthermore, the effects of the transformation could be described in a similar way to check whether the postcondition specified by a transform IR type is valid. This is left out of scope of the current proposal.

Specific Types

The following transform IR types are a natural fit for inclusion into the core part of the transform dialect:

  • !transform.any_op - handle to absolutely any payload operation, drop-in replacement for the current use of !pdl.operation;
  • !transform.op<["dialect.foo", "dialect.bar"]> - handle to payload operations of the specified class indicated as a string containing the canonical OperationName; this scales automatically to all possible operations;
  • !transform.interface<interface_name> - handle to payload operations implementing the given “core” (living in lib/Interfaces or lib/Dialect/Transform) interface; this uses an explicit enum in the type and doesn’t scale automatically since we don’t have a generic naming mechanism for interfaces anyway.

Transform dialect extensions can introduce new transform IR types around which the transformations in these extensions are designed. For example, the “structured” subset of transform IR operations can introduce a !transform.structured.op transform IR type to materialize the concept of a structured (payload) operation in the transform IR. For example, it can include the transformation-relevant information about the computation such as rank of operands, properties of iteration space dimensions (parallelism structure, size relations and divisibility), etc.

Several open questions remain in the transform IR type system:

  • Should it be used to indicate that a handle is guaranteed to point to exactly one payload operation (block arguments of sequence, foreach)?
    This is currently a minority use case and most known ops are not relying on there being a single op, so the value of capturing that in the type system is small. However, this is technically possible to implement given the current interface.
  • Will we need logic on payload operation constraints expressed by transform IR types? For example, a transformation may require that the payload operation implements interface A and interface B, which may benefit from being expressed as transform IR type. While it can be a separate type, we can also consider introducing constraint-conjunction and eventually disjunction types to combine the existing transform IR types into one: !transform.and<interface<A>, interface<B>>. It is unclear if simple logic would be sufficient or if this will need to be expanded into some additional meta-IR defining types. For now, both sound like avoidable premature generalization.
  • Introducing types may result in tighter coupling between transform dialect extensions that are currently separated. Specifically, an extension A may be using a type defined by another extension B. This is technically solvable by always loading A alongside B, but may bring some cross-dialect dependencies that wouldn’t be necessary otherwise (e.g., vector transform extension may use the structured transform IR type, but the dependency of the vector dialect on Linalg may not be desirable). For now, we can assume that such cross-extension types indicate that the concept is ripe from being promoted to the “core” of transform dialect or, preferably, as interface in lib/Interfaces.

Cast Operation

The core part of the transform dialect provides a single transform.cast operation that assigns result handles from operand handles and has no effect on the payload IR. This will result in postcondition checks on the payload IR for the new handle type automatically. The operation needs no knowledge about extension-specific types and delegates to them through the type interface.

Design Implications

Introducing the type system makes the transform dialect more verbose, but arguably the transform IR becomes more easily understandable and maintainable. The dialect, as any IR, is not really intended to have language-like expressive power; clients seeking brevity should use Python bindings for the dialect or build a DSL.

Transform IR types render explicit static information about payload operations, potentially some information that isn’t explicit in the payload IR itself. This information may come from post-conditions of some transformations, e.g., tiling with sizes imperfectly dividing the original iteration space or when the size of the space is unknown always produces iteration space tiles with size less than or equal to the tiling size. This information may be relevant for, e.g., vectorization that can completely avoid vector splitting for sufficiently small sizes.

This also opens the way for more abstract reasoning about MLIR transformations that are more complex than simple replacements on groups of ops currently supported by PDL(L) and enables research on transformation analysis, abstract interpretation and proofs provided the pre/post-condition of transformations are sufficiently expressive.

Example

Example taken from IREE softmax transformation (partial) with types:

transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%variant_op: !pdl.operation):
  // First level of tiling + fusion parallelizes to blocks.
  // The mapping  to block ids can only happen after bufferization atm
  %root = transform.structured.match interface{LinalgOp}
    attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
    : !transform.structured.op<rank=3>
  %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
    : !transform.structured.op<any>
  %red = transform.structured.match interface{LinalgOp}
    attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
    : !transform.structured.op<any>  // less strict type on purpose
  %not_root = merge_handles %fill, %red
    : !transform.structured.op<any>
  %foreach_thread, %tiled_generic =
    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 4]
    : !transform.concrete_op<"foreach_thread">, !transform.concrete_op<"linalg.generic">
  transform.structured.fuse_into_containing_op %not_root into %foreach_thread

4 Likes