[MLIR][RFC] Semantics of `linalg.reduce` with variadic operands

linalg.reduce currently has variadic input/output operands, like any DPS linalg op. The implementation of the verifier, as well as its documentation and tests, clearly show that the authors intended to support variadic operands in some form. However, there is confusion as to what a reduce with several operands should mean. In current linalg, this op also has the trait SameVariadicOperandSize, which adds to the confusion. This RFC is there to discuss

  1. Whether linalg.reduce should even allow variadic operands, and if yes,
  2. What the semantics of the op are when it has several operands.

@rengolin @MaheshRavishankar

Current behavior

Current linalg.reduce allows having several inputs and outputs. In effect it allows you to write eg

linalg.reduce
  ins(%input, %input2:memref<16x32x64xi32>, memref<16x32x64xi32>)
  outs(%init, %init2: memref<16x64xi32>, memref<16x64xi32>)
  dimensions = [1]
  (%in: i32, %in2: i32, %out: i32, %out2: i32) {
    %0 = arith.muli %in, %in2: i32
    %1 = arith.addi %out, %0: i32
    linalg.yield %1, %1: i32, i32
 }

The current semantics this code has can be seen more clearly when lowering to affine for instance:

affine.for %i = 0 to 16 {
  affine.for %k = 0 to 32 {
    affine.for %j = 0 to 64 {
      %4 = affine.load %input[%i, %k, %j] : memref<16x32x64xi32>
      %5 = affine.load %input2[%i, %k, %j] : memref<16x32x64xi32>
      %6 = affine.load %init[%i, %j] : memref<16x64xi32>
      %7 = affine.load %init2[%i, %j] : memref<16x64xi32>
      %8 = arith.muli %4, %5 : i32
      %9 = arith.addi %6, %8 : i32
      affine.store %9, %init[%i, %j] : memref<16x64xi32>
      affine.store %9, %init2[%i, %j] : memref<16x64xi32>
    }
  }
}

So the current semantics that the op has in this case is, you can reduce with one loop over same-size dimensions of two or more tensors. That seems to be well defined.

But there is a catch. The op requires to have the same numbers of inputs and outputs. That means if you want to contract two tensors like above, you can use linalg.reduce, but you have to have a bogus second output that you will ignore. This is IMO an accident and not the intended behavior.

Ways forward

In #107005 I argue that that restriction is accidental and removed it. This change was pushed back as there is no consensus over whether linalg.reduce should be allowed to have several operands in the first place.

In any case there seems to be something wrong with the semantics of this op. Either

  1. The accident is SameVariadicOperandSize, meaning the op should support variadic operands, but not have this restriction that num outputs == num inputs.
  2. Or the accident is that this op should not be supporting variadic operands at all, and should be unary.

Option 1 seems to fit what the original authors intended, to me at least. Option 2 makes reduce more restricted, but maybe it fits better in the new tree-like design of this RFC.

As far as I can tell, with option 1,reduce with 2 inputs and one output is similar to the proposed linalg.contract, but more general as it is not restricted to some combination of the products of the input elements. It could be eg a combination of the sums or whatever. It seems that such a powerful op would fit between generic and contract in the abstraction tree. However it is probably useful to keep an operator for simple reduction over one dimension, as it seems to be how linalg.reduce is used in the wild (from a search on Github). So IMHO, a good compromise would be:

  • restrict linalg.reduce to be unary;
  • in the future, if needed, introduce a new operator between generic and contract, that supports an arbitrary combine/reduce region.

Implementation

  • PR #107005 implemented option 1 above, but has been rejected for now.
  • PR #119871 sets the current behavior in stone by adding a test and fixing the verifier.
  • If we want to make linalg.reduce unary, it might be a more involved change. Maybe we need to provide a new builder with just Value for the input and output for instance.

Thank you for the very detailed RFC!

Agree. Even if someone is relying on this “feature”, this could be dangerous. I may be wrong, but this looks like an eager over-generalization without a lot of thought behind.

We discussed in the contract RFC that we do want to support that case, but not now, as there are no real users yet and it adds complexity. We’ll leave it as a TODO for now but definitely add as soon as the op is stable and the semantics is agreed upon.

Agree! It can even be contract itself that supports it. But we can discuss that when time comes.

Unless there’s a real use case (that cannot be solved by other ways), it should be mostly a builder issue and update the tests.