Hi,
it seems like I found a really horrible bug today. Let me first describe how to reproduce it in a concrete test case, before I get into my analysis.
Reproducing the issue
Consider the following (quite stupid) implementation of matrix multiplication using linalg.map
by computing dot products and reducing them with linalg.reduce
.
func.func @matmul(%A: tensor<3x4xf64>, %B: tensor<4x3xf64>) -> tensor<3x3xf64> {
%init = tensor.empty () : tensor<3x3xf64>
%mm = linalg.map outs(%init: tensor<3x3xf64>) (){
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%init2 = tensor.empty() : tensor<3xf64>
%mv = linalg.map outs(%init2: tensor<3xf64>) (){
%k = linalg.index 0 : index
%a = tensor.extract %A[%i, %k] : tensor<3x4xf64>
%b = tensor.extract %B[%k, %j] : tensor<4x3xf64>
%c = arith.mulf %a, %b : f64
linalg.yield %c : f64
}
%init3 = arith.constant dense<0.0> : tensor<f64>
%reduce = linalg.reduce { arith.addf }
ins(%mv: tensor<3xf64>)
outs(%init3: tensor<f64>)
dimensions = [0]
%out = tensor.extract %reduce[] : tensor<f64>
linalg.yield %out : f64
}
return %mm : tensor<3x3xf64>
}
If you run -cse
on this, you get the following wrong program:
module {
func.func @matmul(%arg0: tensor<3x4xf64>, %arg1: tensor<4x3xf64>) -> tensor<3x3xf64> {
%0 = tensor.empty() : tensor<3x3xf64>
%mapped = linalg.map outs(%0 : tensor<3x3xf64>)
() {
%1 = linalg.index 0 : index
%2 = linalg.index 1 : index
%3 = tensor.empty() : tensor<3xf64>
%mapped_0 = linalg.map outs(%3 : tensor<3xf64>)
() {
%extracted_1 = tensor.extract %arg0[%1, %1] : tensor<3x4xf64>
%extracted_2 = tensor.extract %arg1[%1, %2] : tensor<4x3xf64>
%4 = arith.mulf %extracted_1, %extracted_2 : f64
linalg.yield %4 : f64
}
%cst = arith.constant dense<0.000000e+00> : tensor<f64>
%reduced = linalg.reduce { arith.addf } ins(%mapped_0 : tensor<3xf64>) outs(%cst : tensor<f64>) dimensions = [0]
%extracted = tensor.extract %reduced[] : tensor<f64>
linalg.yield %extracted : f64
}
return %mapped : tensor<3x3xf64>
}
}
This program result is now different, because cse
has substituted %k
with %i
, which is not semantically preserving. This is not the expected behavior, as the documentation for linalg.index
states:
The
linalg.index
operation returns the iteration index of the immediately enclosing linalg structured operation for the iteration dimensiondim
. Thedim
attribute specifies the position of the accessed dimension in the indexing map domain.
Tracing the issue
In my opinion, the root cause of this issue is found in the way that OperationEquivalence
decides that the two linalg.index
operations are equivalent, which are OpState
equal (see OperationEquivalence::isEquivalentTo
). Since one dominates the other and is in scope, it is replaced by the CSEDriver
with the other.
OpState
equivalence does not imply semantic equivalence, as operation semantics in MLIR may depend on structural properties, such as parent-child relationships (as is the case here) or sibling-sibling relationships (think of a struct
definition where the fields
are not control flow but still ordered).
Fixing the issue
I hope that you all agree that this is a bug and not intended behavior. I can see how the way forward to fix it might be controversial. I will give my own point of view first.
Let me begin with saying that, under the current status quo, I don’t believe linalg
is in the wrong here. Declaring the operation as [Pure]
is semantically sound in my opinion. Introducing some new SSA dependency on the parent is clearly against the spirit of the abstraction, too.
I have two points I’d like to address:
-
OperationEquivalence
should respect structural semantics.To me, MLIR not having mandatory traits that indicate operation “well-behaved-ness” is a huge issue. I do want to have safe “metadata-like”/“non-control-flow-like” operations, and others that have semantics that depend on their surrounding structure. If such a trait, e.g.
OpTrait::Structural
, was introduced,OperationEquivalence
could use it.For the particular case of
linalg.index
, however, we do want CSE to collapse repeated uses in the same scope. In this special case, the actual semantics are along the lines of “this op is structurally dependent on its parent”. A trait that reflects this relationship, e.g.OpTrait::Scoped<Parent>
, could be introduced, which would requireOperationEquivalence
to check that both ops have the same first ancestor of the given op constraint.This means that
linalg.index
and other ops with these semantics will need to explicitly declare them going forward, but I think this is desirable. -
CSEDriver
should be more conservative.Looking at the current implementation (see
CSEDriver::simplifyOperation
), theCSEDriver
only makes a special exception for ops with memory side effects. I think this is way too optimistic.The
CSEDriver
should probably bail on any op that has side-effects it does not recognize. In addition, CSE-ing out non-speculatable ops may change the observed program behavior, and there seems to be no guard against that either.