Does the Linalg dialect support fusion-on-memrefs?

Hello everyone.

The Linalg dialect supports tensor semantics fusion, even mixed fusion, as demonstrated in the following example:(mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir)

// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
#map0 = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: @mixed_fusion
func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>)
{
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
  %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%2 : tensor<?x?xf32>) {
    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
      %4 = arith.addf %arg3, %arg4 : f32
      linalg.yield %4 : f32
  } -> tensor<?x?xf32>
  // CHECK: linalg.generic {
  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
  linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
      ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%arg8 : memref<?x?xf32>) {
    // CHECK: ^{{[a-zA-Z0-9_]*}}
    // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
    // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
    // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
      // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
      // CHECK-NOT: linalg.yield
      // CHECK: arith.mulf [[T1]], [[ARG2]]
      // CHECK: linalg.yield
      %5 = arith.mulf %arg5, %arg6 : f32
      linalg.yield %5 : f32
    }
  return
}

Does Linalg support fusion for memref only semantics? I saw a post mentioning that fusion for memref semantics was removed. I am implementing a DSL based on MLIR that requires a container to store elements of custom types. Since tensor does not support custom types, memref is used. The DSL’s IR is lowered to Linalg to leverage optimizations like fusion and vectorization, but I notice issues with memref-based fusion. What is the best practice in this scenario?

We removed it because to work properly this needs to be supported by a strong memory dependency analysis which was not a priority.

I think someone close to @TobiasGrosser ’s team mentioned they had such an analysis?

Thanks for your reply. Can I ask what the best practice would be in this scenario if fusion-on-memrefs is not available?

Use a patched fork of MLIR that supports custom element types in tensor. You can also consider proposing an RFC/patch to enable support of custom element types the same way they are in memref.

it is a good suggestion. What does everyone think? Extending the tensor semantics shouldn’t cause compatibility issues, right? Could it potentially break some assumptions of certain passes? I can give it a try.

Please start a new topic with an appropriate title and a specific proposal.

I tried simply making MemRefType and TensorType support all element types.

/// Return true if the specified element type is ok in a tensor.
bool TensorType::isValidElementType(Type type) {
  // Note: Non standard/builtin types are allowed to exist within tensor
  // types. Dialects are expected to verify that tensor types have a valid
  // element type within that dialect.
  // return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
  //                  IndexType>(type) ||
  //        !llvm::isa<BuiltinDialect>(type.getDialect());
  return true;
}

inline bool BaseMemRefType::isValidElementType(Type type) {
  // return type.isIntOrIndexOrFloat() ||
  //        llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
  //            type) ||
  //        llvm::isa<MemRefElementTypeInterface>(type);
  return true;
}

but it seems that the TensorBufferizePass doesn’t work.

Test mlir file: test.mlir

func.func @main() {
  %0 = tensor.empty() : tensor<3xtensor<3xf32>>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
                                        iterator_types = ["parallel"]}
  outs(%0 : tensor<3xtensor<3xf32>>) {
  ^bb0(%out0: tensor<3xf32>):
    %2 = tensor.empty() : tensor<3xf32>
    %fzero = arith.constant 0.0 : f32
    %3 = linalg.fill ins(%fzero: f32) outs(%2: tensor<3xf32>) -> tensor<3xf32>
    linalg.yield %3 : tensor<3xf32>
  } -> tensor<3xtensor<3xf32>>

  %zero = arith.constant 0 : index
  %one = arith.constant 1 : index
  %three = arith.constant 3 : index

  scf.for %i = %zero to %three step %one {
    %extract = tensor.extract %1[%zero] : tensor<3xtensor<3xf32>>
    %print = tensor.cast %extract : tensor<3xf32> to tensor<*xf32>
    func.call @printMemrefF32(%print) : (tensor<*xf32>) -> ()
  }
  return
}

Command: mlir-opt test.mlir -tensor-bufferize

Error: mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp:712: void mlir::bufferization::replaceOpWithBufferizedValues(mlir::RewriterBase&, mlir::Operation*, mlir::ValueRange): Assertion `(llvm::isa(replacement.getType()) || llvm::isa(replacement.getType())) && “tensor op result should be replaced with a memref value”’ failed.

If only the outermost layer uses tensors, then other complex types such as memref can be used internally. This program can get correct output:

func.func private @printMemrefF32(memref<*xf32>)
func.func @main() {
  %0 = tensor.empty() : tensor<3xmemref<3xf32>>
  %alloc = memref.alloc() : memref<3xf32>
  %fzero = arith.constant 0.0 : f32
  linalg.fill ins(%fzero: f32) outs(%alloc: memref<3xf32>)

  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
                                        iterator_types = ["parallel"]}
  outs(%0 : tensor<3xmemref<3xf32>>) {
  ^bb0(%o: memref<3xf32>):
    linalg.yield %alloc : memref<3xf32>
  } -> tensor<3xmemref<3xf32>>

  %zero = arith.constant 0 : index
  %one = arith.constant 1 : index
  %three = arith.constant 3 : index

  scf.for %i = %zero to %three step %one {
    %extract = tensor.extract %1[%zero] : tensor<3xmemref<3xf32>>
    %print = memref.cast %extract : memref<3xf32> to memref<*xf32>
    
    func.call @printMemrefF32(%print) : (memref<*xf32>) -> ()
  }

  memref.dealloc %alloc : memref<3xf32>
  return
}

Command: mlir-opt test.mlir -linalg-bufferize -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops -convert-scf-to-cf -convert-arith-to-llvm -finalize-memref-to-llvm -convert-func-to-llvm -convert-math-to-llvm --canonicalize --cse -reconcile-unrealized-casts|mlir-cpu-runner -e main -entry-point-result=void -shared-libs=$mlir_runner_utils

Unsure how many other passes will be affected.

Tensor of tensors is slightly more complex than “custom types on tensors”. It’s also slightly orthogonal to fusion on memrefs. If you look at the PR, it’s mostly about memory dependency checks, as @nicolasvasilache mentioned.

Though your examples don’t seem to be in memrefs, so I’m unsure of what you really want. If you want a tensor of memrefs than that’s crossing too many barriers (SSA vs pointer, two different dialects), and I don’t see how this could work.

A simpler take could be trying tensor of vectors, since vectors are closer to element types than memrefs.

Nicolas tried this a while back, it’s not a trivial change, nor is a meaningful one. We discussed tensor of tensors upstream before many times and we never quite reach consensus on what we can do or even what we should do.

I have started writing something last year but never quite got to finish:

Don’t expect those docs to make sense, they raise more questions and don’t answer any. This is why I’m skeptical and asking for more work done. It’s not a trivial matter so we should only consider proposals that advance beyond what has been done already.

I believe tensor of vector will be simpler to implement and allowing vector types inside linalg is a more reasonable expectation than memrefs or tensors. We just need to be careful with the affine maps (or use scalar evolution composition).

As @ftynse said, if you want to add this, work on a proposal and start a new RFC. Search the forum for the related material, create a PoC with relevant examples and show the impact of that change on the existing dialects (tensor, memref, linalg, etc).

1 Like

Thank you all for your helpful responses.

Initially, my question was about using custom types in linalg to reuse passes such as linalg fusion and vectorization.

However, I found that tensors don’t support arbitrary element types, and linalg fusion doesn’t support memref semantics.

One solution would be to enable linalg to support memref semantics for fusion. Another would be to make tensors support arbitrary types.

I then wondered if, ideally, we could make both tensors and memrefs support any element type, on IR expressiveness and pass compatibility, including custom and nested tensor types.

However, I realize now that I may have underestimated the complexity of this way, and the goal might be too ambitious. I’ll take more time to reorganize the problem and conduct further research.

Indeed, the current discussion has deviated from the original topic. I’ll create a new thread to continue this discussion if necessary.

Another solution is to do fusion at the loop level. There you can fuse whatever you want, as long as you take care of the correct semantics of your ops/types. This is what most people do.

The reason why I was writing about tensor of tensors and linalg nesting is for the same reason you are: fusion. However, linalg generics are a cumbersome op due to its limitations and verbosity.

We are now looking into loop tiling and fusion by annotating arbitrary loops, instead of trying to force special constructs to have special fusion rules.

This also raises too many questions. Are the sub-memrefs contiguous in memory? Can I assume pre-fetches work? If I can’t answer questions like that, then doing tiling becomes impossible. At some point I must understand where all this data is.

My reasoning to use vectors instead is that a virtual vector type can be made to be on the same register bank by a register allocator, while memrefs can’t be made to be allocated in contiguous memory unless we add some arena allocation.

See:

We’ve all done that. :smiley: I wrote those pages after I realised how deep it goes and wanted to brain dump in hopes one day it would help the next person.

So the biggest issue with memref-based fusion is dependency analysis. I am not sure even with strong memory dependency analysis not sure the fusion would be any easier. Basically what the strong memory dependency analysis would give you will always be at par or worse than the dependency analysis you get by just using tensor semantics. So I would suggest just using tensor semantics and then use bufferization to get the final memref version.

Thank you all for your answers, I’ve learned a lot. I’ll make some more attempts, such as doing fusion on scf, supporting custom types on tensors, and so on.