Approach to fuse specific loop iterations in LinAlg

I have a question on approaches to fuse specific loop iterations in LinAlg. Any pointer is appreciated. For example, in the code snippet below, I’d like to only fuse the outer most dimension i.e. 2x but not the rest of iterations. What would be the recommended way of doing this in LinAlg if supported? I see one possibility as tile-and-fuse the iteration with tile size 1. But driving fusion through tiling does not feel very natural so I want to understand if there is a better approach for it in LinAlg’s design. Thanks.

  • Before fusion
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>

func @test_before_fusion(%arg0: tensor<2x512xbf16>, %arg1: tensor<2x512xbf16>, %arg2: tensor<2x512x512xbf16>, %arg3: tensor<2x512x512xbf16>) -> tensor<2x512x512xbf16> {
  %0 = linalg.init_tensor [2, 512] : tensor<2x512xbf16>
  %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<2x512xbf16>, tensor<2x512xbf16>) outs(%0 :tensor<2x512xbf16>) {
  ^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
    %6 = arith.addf %arg4, %arg5 : bf16
    linalg.yield %6 : bf16
  } -> tensor<2x512xbf16>
  %2 = linalg.init_tensor [2, 512, 512] : tensor<2x512x512xbf16>
  %3 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2, %arg3 : tensor<2x512x512xbf16>,tensor<2x512x512xbf16>) outs(%2 : tensor<2x512x512xbf16>) {
  ^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
    %6 = arith.subf %arg4, %arg5 : bf16
    linalg.yield %6 : bf16
  } -> tensor<2x512x512xbf16>
  %4 = linalg.init_tensor [2, 512, 512] : tensor<2x512x512xbf16>
  %5 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %3 : tensor<2x512xbf16>, tensor<2x512x512xbf16>) outs(%4: tensor<2x512x512xbf16>) {
  ^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
    %6 = arith.mulf %arg4, %arg5 : bf16
    linalg.yield %6 : bf16
  } -> tensor<2x512x512xbf16>
  return %5 : tensor<2x512x512xbf16>
}
  • After fusion
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
func @test_after_fusion(%arg0: tensor<2x512xbf16>, %arg1: tensor<2x512xbf16>, %arg2: tensor<2x512x512xbf16>, %arg3: tensor<2x512x512xbf16>) -> tensor<2x512x512xbf16> {
  %0 = linalg.init_tensor [2, 512, 512] : tensor<2x512x512xbf16>
  %cst = arith.constant 0.000000e+00 : bf16
  %1 = linalg.fill(%cst, %0) : bf16, tensor<2x512x512xbf16> -> tensor<2x512x512xbf16> 
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %2 = scf.for %arg4 = %c0 to %c2 step %c1 iter_args(%arg5 = %1) -> (tensor<2x512x512xbf16>) {
    %3 = tensor.extract_slice %arg0[%arg4, 0] [1, 512] [1, 1] : tensor<2x512xbf16> to tensor<512xbf16>
    %4 = tensor.extract_slice %arg1[%arg4, 0] [1, 512] [1, 1] : tensor<2x512xbf16> to tensor<512xbf16>
    %5 = linalg.init_tensor [512] : tensor<512xbf16>
    %6 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} ins(%3, %4 : tensor<512xbf16>, tensor<512xbf16>) outs(%5 : tensor<512xbf16>) {
    ^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
      %14 = arith.addf %arg6, %arg7 : bf16
      linalg.yield %14 : bf16
    } -> tensor<512xbf16>
    %7 = tensor.extract_slice %arg2[%arg4, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<2x512x512xbf16> to tensor<512x512xbf16>
    %8 = tensor.extract_slice %arg3[%arg4, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<2x512x512xbf16> to tensor<512x512xbf16>
    %9 = linalg.init_tensor [512, 512] : tensor<512x512xbf16>
    %10 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%7, %8 : tensor<512x512xbf16>, tensor<512x512xbf16>) outs(%9 : tensor<512x512xbf16>) {
    ^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
      %14 = arith.subf %arg6, %arg7 : bf16
      linalg.yield %14 : bf16
    } -> tensor<512x512xbf16>
    %11 = linalg.init_tensor [512, 512] : tensor<512x512xbf16>
    %12 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%6, %10 : tensor<512xbf16>, tensor<512x512xbf16>) outs(%11 : tensor<512x512xbf16>) {
    ^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
      %14 = arith.mulf %arg6, %arg7 : bf16
      linalg.yield %14 : bf16
    } -> tensor<512x512xbf16>
    %13 = tensor.insert_slice %12 into %arg5[%arg4, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<512x512xbf16> into tensor<2x512x512xbf16>
    scf.yield %13 : tensor<2x512x512xbf16>
  }
  return %2 : tensor<2x512x512xbf16>
}

Hi Rui,

Linalg has two types of fusion currently (depending on how one counts :)).

Option one is element fusion that basically combines two dependent point-wise generic operations into one point-wise generic operation. Elementwise fusion does not involve loops and always fully fuses the iteration spaces. I believe this probably not what you are looking for.

Option two is indeed tile and fuse and the only way to fuse only the outermost loop dimension.

There is no way to fuse the outer loop dimension of two Linalg operations directly since that would result in an iteration space that cannot be expressed with Linalg operations. Only after tiling we have the loops that are needed to express fusion.

One alternative that may support more fusion scenarios is to transform Linalg operations to loops (for example affine loops) and then use affine loop fusion to get closer to what you are looking for.

Best,
Tobias

1 Like

The design is indeed specified as fusing tiles (or more generally N-D subsets depending on the data type) of data accessed by operations.

The traditional Allen/Cocke loop fusion is indeed the limit of an N-D subset with all sizes equal to 1.

If you really want the implementation detail of the fusion itself to occur on loops, you may want to lower to loops and try e.g. affine fusion (assuming affine evolves to support tensors and better interoperate with Linalg).

You will note however that “lowering Linalg to loops” is implemented… by tiling by 1 on the dimensions you want to reduce to loops :slight_smile: .

Speaking of the design, in the fullness of time, tiling, fusion, lowering to loops and other transformations are all instances/parameterizations of a single generic rewrite:

subset(linalg.generic(X)) -> linalg.generic(subset(X))

This rewrite introduces loops.

1 Like

Thank you for your response @gysit @nicolasvasilache! I have a more general question about LinAlg that your input can be greatly helpful. I’m also tagging @MaheshRavishankar as I see IREE has been using LinAlg extensively so hopefully Mahesh can share his experience as well.

What I’m trying to better understand is the pros and cons of driving loop transformations on LinAlg. The scenario I want to avoid is to duplicate the same set of analysis and transformations on LinAlg and then on a lower-level abstraction again. The workloads I focus on is AI/ML at this moment. For example, if I do tiling/fusion/interchange on LinAlg, I would like to avoid the possibility that LinAlg is not able to capture all tiling/fusion/interchange scenarios that I’d like to support such that I have to analyze and transform these loops again on Affine or Scf for tiling/fusion/interchange.

Though I like some benefits coming from LinAlg’s higher level of abstraction, its tradeoff on loop transformations is not very clear to me. It will be appreciated if you can share your experience with tiling/fusion/interchange on LinAlg. Have you hit cases that LinAlg is insufficient to drive desirable loop transformations in your cases? Thanks!

My personal experience is that a structured abstraction, as it is currently implemented (i.e. linalg.generic) is useful as it simplifies the design space and control but is not sufficient by itself.
It is completed by transformations that are better done at the loop, vector, async and other levels; as well as other (future) structured abstractions.

There are identified cases for which we know we need better abstractions and that regular loops can capture (e.g. fusing reshape, concat, gather and other things that look like reindexings); details are important to confidently stick the label “supported” on a particular solution.

Then there is sparse and other upcoming things that also rely on data representations, inspector-executor and other such techniques, but these are too early to discuss rn.

I’ll let @MaheshRavishankar comment on the general suitability for AI/ML workloads.

This is greatly helpful information! If possible, I would love to learn a bit more about limitations that you experienced on the quoted operations.

  • For reshapes, is propagating them across the graph with the aim of canonicalizing reshapes into either types or indexing maps something plausible? I don’t have a good visibility of what hidden rocks could be there if going down this path with LinAlg. For example, propagating the expand_shape operation in the following example into %arg0’s type can remove the expand_shape operation. The propagation could be bounded by certain operations blocking propagation.
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

func @reshape_before_propagation(%arg0: tensor<2x256x512xbf16>, %arg1: tensor<2x256x2x256xbf16>) -> tensor<2x256x2x256xbf16> {
  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] : tensor<2x256x512xbf16> into tensor<2x256x2x256xbf16>
  %1 = linalg.init_tensor [2, 256, 2, 256] : tensor<2x256x2x256xbf16>
  %2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0, %arg1 : tensor<2x256x2x256xbf16>, tensor<2x256x2x256xbf16>)outs(%1 : tensor<2x256x2x256xbf16>) {
  ^bb0(%arg2: bf16, %arg3: bf16, %arg4: bf16):
    %3 = arith.addf %arg2, %arg3 : bf16
    linalg.yield %3 : bf16
  } -> tensor<2x256x2x256xbf16>
  return %2 : tensor<2x256x2x256xbf16>
}

func @reshape_after_propagation(%arg0: tensor<2x256x2x256xbf16>, %arg1: tensor<2x256x2x256xbf16>) -> tensor<2x256x2x256xbf16> {
  %0 = linalg.init_tensor [2, 256, 2, 256] : tensor<2x256x2x256xbf16>
  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<2x256x2x256xbf16>,tensor<2x256x2x256xbf16>) outs(%0 : tensor<2x256x2x256xbf16>) {
  ^bb0(%arg2: bf16, %arg3: bf16, %arg4: bf16):
    %2 = arith.addf %arg2, %arg3 : bf16
    linalg.yield %2 : bf16
  } -> tensor<2x256x2x256xbf16>
  return %1 : tensor<2x256x2x256xbf16>
}

  • For split/concat, is it because some form of explicit control flow would be helpful to represent patterns that fuse into the concatenated iteration space which is not something LinAlg is designed for?

  • Gather is tough. Not sure what to do with it. :slight_smile:

This is possible in a limited case. A version of this is implemented here.

As @gysit mentioned there are two kinds of fusion in Linalg. One is the elementwise fusion and other is tile + fuse. The file I link to above does the elementwise fusion. It broadly has three different things

  1. Fuse linalg.generic with other linalg.generic to create a new linalg.generic operation. This is akin to fusion where the fusion results in perfectly nested loops
  2. Fuse tensor.collapse_shapelinalg.generic and linalg.generictensor.expand_shape. This is done by expanding the dimensionality of the linalg.generic op in both cases. THis is called “FuseByExpansion”.
  3. Fuse tensor.expand_shapelinalg.generic and linalg.generictensor.collapse_shape. This is done by collapsing the dimensionality of the linalg.generic op in both cases. THis is called “FuseByCollapsing”. Here need to be careful that the fusion does not result in indexing maps that are not “projected permutations” cause that affects subsequent analysis + optimizations.

All these three are implemented in that file. There are some other patterns in that file as well that are meant to be deprecated. See this post for more details and background. I havent pushed on this as much as I would have liked. I implemented the patterns needed, but havent deprecated the things that should be because I am waiting to decouple its uses from IREE.

This is precisely what FuseByExpansion implemented does. It is already used in IREE and works quite well. Improves ability to fuse. After some recent study in IREE, we found that we needed the “FuseByCollapsing” as well to propagate the reshapes further to the edges when they get “blocked” by certain operations (the discourse post I linked above has more details).

Gather for now is implemented as a linalg.generic when it is lowered from say MHLO dialect in TF, or TOSA dialect (or from Torch-MLIR). Since its a linalg.generic operation, it fuses with other operations the same way you would expect. We have seen cases in some BERT models were multiple gathers get fused because of the elementwise operation fusion.

I just provided some high level overview in my response. I can go into a lot of detail but that would take a really long time. I can give you more targeted info if you need, but from your questions I think a lot of the functionality you are looking for is already implemented in the ElementwiseFusion approach to fusion in Linalg. If there is something missing, would love to know more about your use case here. Im only asking because I consider the kind of questions asked, and approaches you mention to match well with what we’ve tried in Linalg and used in IREE, so seems like there might be some synergy here.

@MaheshRavishankar, thank you very much for your detailed response! In my scenario, both elementwise fusion and outer dimension fusion are important, targeting at different granularity of parallelism. Really appreciate your pointers to IREE’s approach on reshapes and very glad to know it has been working well for you. Your information on gather is also super helpful and I’ll dig more into it.

I also want to take this chance to seek your input on dynamic dimensions. Tagging @nicolasvasilache and @gysit as well. I’m exploring LinAlg’s current support of dynamic dimensions and notice something I don’t fully understand. For example, in the following example, I’m expecting the two LinAlg operations fused into one, but it looks like LinalgElementwiseOpFusion thinks dynamic dimensions are different so repeats the first operation twice. I’m not sure if I have misunderstanding of the semantics or this is simply a missed opportunity. As this example is very basic, it makes me wonder how your experience has been working with dynamic dimensions on LinAlg so far? Thanks in advance.

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  func @test_fuse_dynamic_dim(%arg0: tensor<?x512x1024xbf16>, %arg1: tensor<?x512x1024xbf16>, %arg2: tensor<?x512x1024xbf16>) -> tensor<?x512x1024xbf16> {
    %c0 = arith.constant 0 : index
    %0 = tensor.dim %arg0, %c0 : tensor<?x512x1024xbf16>
    %1 = linalg.init_tensor [%0, 512, 1024] : tensor<?x512x1024xbf16>
    %2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x512x1024xbf16>, tensor<?x512x1024xbf16>) outs(%1 : tensor<?x512x1024xbf16>) {
    ^bb0(%arg3: bf16, %arg4: bf16, %arg5: bf16):
      %6 = arith.addf %arg3, %arg4 : bf16
      linalg.yield %6 : bf16
    } -> tensor<?x512x1024xbf16>
    %c0_0 = arith.constant 0 : index
    %3 = tensor.dim %2, %c0_0 : tensor<?x512x1024xbf16>
    %4 = linalg.init_tensor [%3, 512, 1024] : tensor<?x512x1024xbf16>
    %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2, %arg2 : tensor<?x512x1024xbf16>, tensor<?x512x1024xbf16>) outs(%4 : tensor<?x512x1024xbf16>) {
    ^bb0(%arg3: bf16, %arg4: bf16, %arg5: bf16):
      %6 = arith.subf %arg3, %arg4 : bf16
      linalg.yield %6 : bf16
    } -> tensor<?x512x1024xbf16>
    return %5 : tensor<?x512x1024xbf16>
  }
}

// -----// IR Dump After LinalgElementwiseOpFusion //----- //
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  func @test_fuse_dynamic_dim(%arg0: tensor<?x512x1024xbf16>, %arg1: tensor<?x512x1024xbf16>, %arg2: tensor<?x512x1024xbf16>) -> tensor<?x512x1024xbf16> {
    %c0 = arith.constant 0 : index
    %0 = tensor.dim %arg0, %c0 : tensor<?x512x1024xbf16>
    %1 = linalg.init_tensor [%0, 512, 1024] : tensor<?x512x1024xbf16>
    %2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x512x1024xbf16>, tensor<?x512x1024xbf16>) outs(%1 : tensor<?x512x1024xbf16>) {
    ^bb0(%arg3: bf16, %arg4: bf16, %arg5: bf16):
      %6 = arith.addf %arg3, %arg4 : bf16
      linalg.yield %6 : bf16
    } -> tensor<?x512x1024xbf16>
    %3 = tensor.dim %2, %c0 : tensor<?x512x1024xbf16>
    %4 = linalg.init_tensor [%3, 512, 1024] : tensor<?x512x1024xbf16>
    %5 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor<?x512x1024xbf16>, tensor<?x512x1024xbf16>, tensor<?x512x1024xbf16>) outs(%4 : tensor<?x512x1024xbf16>) {
    ^bb0(%arg3: bf16, %arg4: bf16, %arg5: bf16, %arg6: bf16):
      %6 = arith.addf %arg3, %arg4 : bf16
      %7 = arith.subf %6, %arg5 : bf16
      linalg.yield %7 : bf16
    } -> tensor<?x512x1024xbf16>
    return %5 : tensor<?x512x1024xbf16>
  }
}

If you look at the second linalg.generic you see that the op is fused. The original op is not removed due to its use in tensor.dim operation. If you run -resolve-shaped-type-result-dims, then that dim will get resolved in terms of the operands of the first linalg.generic and it will be DCE-ed.

All of Linalg transformations work with dynamic shapes including the elementwise fusion. It is used in IREE under the same conditions. Transformations working with dynamic shapes is pretty much a pre-requirement in Linalg (and in IREE as well).

To clarify, AFAICS, these are already implemented transformations in elementwise op fusion. They might not be accessible easily. So happy to help with getting the details and interface worked out

2 Likes

@MaheshRavishankar, many thanks for the tip and sharing your experience with dynamic shapes! I’m able to get the redundant operation in my prior example removed with -resolve-shaped-type-result-dims that you kindly shared.

Here is another example that I have issue to fuse its outer most dimension to a desired form with LinAlg. In test_before_fusion all LinAlg operations share the same outer most dimension, so I’d like to fuse all outer most dimensions together. But it appears LinAlg does not support it with the presence of extract_slice and insert_slice. I can only get its outermost iteration fused to something like test_after_fusion where desired_test_after_fusion would be potential form that I’d like to get out of fusion. Is this type of fusion supported by IREE on LinAlg?

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

func @test_before_fusion(%arg0: tensor<?x512x1024xbf16>, %arg1: tensor<?x512x512xbf16>, %arg2: tensor<?x512x512xbf16>, %arg3: tensor<?x512x1024xbf16>) -> tensor<?x512x1024xbf16> {
  %cst = arith.constant 0.000000e+00 : bf16
  %c0 = arith.constant 0 : index
  %0 = tensor.dim %arg0, %c0 : tensor<?x512x1024xbf16>
  %1 = tensor.extract_slice %arg0[0, 0, 0] [%0, 512, 512] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<?x512x512xbf16>
  %2 = linalg.init_tensor [%0, 512, 512] : tensor<?x512x512xbf16>
  %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg1 : tensor<?x512x512xbf16>, tensor<?x512x512xbf16>) outs(%2 : tensor<?x512x512xbf16>) {
  ^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
    %11 = arith.addf %arg4, %arg5 : bf16
    linalg.yield %11 : bf16
  } -> tensor<?x512x512xbf16>
  %4 = tensor.extract_slice %arg0[0, 0, 512] [%0, 512, 512] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<?x512x512xbf16>
  %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %arg2 : tensor<?x512x512xbf16>, tensor<?x512x512xbf16>) outs(%2 : tensor<?x512x512xbf16>) {
  ^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
    %11 = arith.addf %arg4, %arg5 : bf16
    linalg.yield %11 : bf16
  } -> tensor<?x512x512xbf16>
  %6 = linalg.init_tensor [%0, 512, 1024] : tensor<?x512x1024xbf16>
  %7 = linalg.fill(%cst, %6) : bf16, tensor<?x512x1024xbf16> -> tensor<?x512x1024xbf16> 
  %8 = tensor.insert_slice %3 into %7[0, 0, 0] [%0, 512, 512] [1, 1, 1] : tensor<?x512x512xbf16> into tensor<?x512x1024xbf16>
  %9 = tensor.insert_slice %5 into %8[0, 0, 512] [%0, 512, 512] [1, 1, 1] : tensor<?x512x512xbf16> into tensor<?x512x1024xbf16>
  %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%9, %arg3 : tensor<?x512x1024xbf16>, tensor<?x512x1024xbf16>) outs(%6 : tensor<?x512x1024xbf16>) {
  ^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
    %11 = arith.mulf %arg4, %arg5 : bf16
    linalg.yield %11 : bf16
  } -> tensor<?x512x1024xbf16>
  return %10 : tensor<?x512x1024xbf16>
}

func @test_after_fusion(%arg0: tensor<?x512x1024xbf16>, %arg1: tensor<?x512x512xbf16>, %arg2: tensor<?x512x512xbf16>, %arg3: tensor<?x512x1024xbf16>) -> tensor<?x512x1024xbf16> {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : bf16
  %c1 = arith.constant 1 : index
  %0 = tensor.dim %arg0, %c0 : tensor<?x512x1024xbf16>
  %1 = tensor.extract_slice %arg0[0, 0, 0] [%0, 512, 512] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<?x512x512xbf16>
  %2 = linalg.init_tensor [%0, 512, 512] : tensor<?x512x512xbf16>
  %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg1 : tensor<?x512x512xbf16>, tensor<?x512x512xbf16>) outs(%2 : tensor<?x512x512xbf16>) {
  ^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
    %12 = arith.addf %arg4, %arg5 : bf16
    linalg.yield %12 : bf16
  } -> tensor<?x512x512xbf16>
  %4 = tensor.extract_slice %arg0[0, 0, 512] [%0, 512, 512] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<?x512x512xbf16>
  %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %arg2 : tensor<?x512x512xbf16>, tensor<?x512x512xbf16>) outs(%2 : tensor<?x512x512xbf16>) {
  ^bb0(%arg4: bf16, %arg5: bf16, %arg6: bf16):
    %12 = arith.addf %arg4, %arg5 : bf16
    linalg.yield %12 : bf16
  } -> tensor<?x512x512xbf16>
  %6 = linalg.init_tensor [%0, 512, 1024] : tensor<?x512x1024xbf16>
  %7 = linalg.fill(%cst, %6) : bf16, tensor<?x512x1024xbf16> -> tensor<?x512x1024xbf16> 
  %8 = tensor.insert_slice %3 into %7[0, 0, 0] [%0, 512, 512] [1, 1, 1] : tensor<?x512x512xbf16> into tensor<?x512x1024xbf16>
  %9 = tensor.insert_slice %5 into %8[0, 0, 512] [%0, 512, 512] [1, 1, 1] : tensor<?x512x512xbf16> into tensor<?x512x1024xbf16>
  %10 = tensor.dim %9, %c0 : tensor<?x512x1024xbf16>
  %11 = scf.for %arg4 = %c0 to %10 step %c1 iter_args(%arg5 = %6) -> (tensor<?x512x1024xbf16>) {
    %12 = tensor.extract_slice %9[%arg4, 0, 0] [1, 512, 1024] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<1x512x1024xbf16>
    %13 = tensor.extract_slice %arg3[%arg4, 0, 0] [1, 512, 1024] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<1x512x1024xbf16>
    %14 = tensor.extract_slice %arg5[%arg4, 0, 0] [1, 512, 1024] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<1x512x1024xbf16>
    %15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12, %13 : tensor<1x512x1024xbf16>, tensor<1x512x1024xbf16>) outs(%14 : tensor<1x512x1024xbf16>) {
    ^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
      %17 = arith.mulf %arg6, %arg7 : bf16
      linalg.yield %17 : bf16
    } -> tensor<1x512x1024xbf16>
    %16 = tensor.insert_slice %15 into %arg5[%arg4, 0, 0] [1, 512, 1024] [1, 1, 1] : tensor<1x512x1024xbf16> into tensor<?x512x1024xbf16>
    scf.yield %16 : tensor<?x512x1024xbf16>
  }
  return %11 : tensor<?x512x1024xbf16>
}

func @desired_test_after_fusion(%arg0: tensor<?x512x1024xbf16>, %arg1: tensor<?x512x512xbf16>, %arg2: tensor<?x512x512xbf16>, %arg3: tensor<?x512x1024xbf16>) -> tensor<?x512x1024xbf16> {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : bf16
  %c1 = arith.constant 1 : index
  %dim = tensor.dim %arg0, %c0 : tensor<?x512x1024xbf16>
  %output = linalg.init_tensor [%dim, 512, 1024] : tensor<?x512x1024xbf16>  
  %11 = scf.for %arg4 = %c0 to %dim step %c1 iter_args(%arg5 = %output) -> (tensor<?x512x1024xbf16>) {      
    %arg0_slice = tensor.extract_slice %arg0[%arg4, 0, 0] [1, 512, 1024] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<1x512x1024xbf16>
    %arg1_slice = tensor.extract_slice %arg1[%arg4, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<?x512x512xbf16> to tensor<1x512x512xbf16>
    %arg2_slice = tensor.extract_slice %arg2[%arg4, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<?x512x512xbf16> to tensor<1x512x512xbf16>
    %arg3_slice = tensor.extract_slice %arg3[%arg4, 0, 0] [1, 512, 1024] [1, 1, 1] : tensor<?x512x1024xbf16> to tensor<1x512x1024xbf16>

    %1 = tensor.extract_slice %arg0_slice[0, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<1x512x1024xbf16> to tensor<1x512x512xbf16>
    %2 = linalg.init_tensor [1, 512, 512] : tensor<1x512x512xbf16>
    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg1_slice : tensor<1x512x512xbf16>, tensor<1x512x512xbf16>) outs(%2 : tensor<1x512x512xbf16>) {
    ^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
        %12 = arith.addf %arg6, %arg7 : bf16
        linalg.yield %12 : bf16
    } -> tensor<1x512x512xbf16>

    %4 = tensor.extract_slice %arg0_slice[0, 0, 512] [1, 512, 512] [1, 1, 1] : tensor<1x512x1024xbf16> to tensor<1x512x512xbf16>
    %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %arg2_slice : tensor<1x512x512xbf16>, tensor<1x512x512xbf16>) outs(%2 : tensor<1x512x512xbf16>) {
    ^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
        %12 = arith.addf %arg6, %arg7 : bf16
        linalg.yield %12 : bf16
    } -> tensor<1x512x512xbf16>

    %6 = linalg.init_tensor [1, 512, 1024] : tensor<1x512x1024xbf16>
    %7 = linalg.fill(%cst, %6) : bf16, tensor<1x512x1024xbf16> -> tensor<1x512x1024xbf16> 
    %8 = tensor.insert_slice %3 into %7[0, 0, 0] [1, 512, 512] [1, 1, 1] : tensor<1x512x512xbf16> into tensor<1x512x1024xbf16>
    %9 = tensor.insert_slice %5 into %8[0, 0, 512] [1, 512, 512] [1, 1, 1] : tensor<1x512x512xbf16> into tensor<1x512x1024xbf16>

    %15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%9, %arg3_slice : tensor<1x512x1024xbf16>, tensor<1x512x1024xbf16>) outs(%6 : tensor<1x512x1024xbf16>) {
    ^bb0(%arg6: bf16, %arg7: bf16, %arg8: bf16):
      %17 = arith.mulf %arg6, %arg7 : bf16
      linalg.yield %17 : bf16
    } -> tensor<1x512x1024xbf16>
    %16 = tensor.insert_slice %15 into %arg5[%arg4, 0, 0] [1, 512, 1024] [1, 1, 1] : tensor<1x512x1024xbf16> into tensor<?x512x1024xbf16>
    scf.yield %16 : tensor<?x512x1024xbf16>
  }
  return %11 : tensor<?x512x1024xbf16>
}

As Tobias mentioned there are two kinds of fusion in Linalg

  1. Elementwise fusion
  2. TIle + fuse.

Your “after” snippet has scf.for . That is not generated by elementwise fusion, but tile and fuse. I have been refering to elementwise fusion.

But looking at the snippet that kind of fusion with slices isn’t supported today. Contribution would be most welcome. ⚙ D122437 [mlir] Bubble up tensor.extract_slice above linalg operation might help a bit, but it does not account for the tensor.insert_slice. With that we should be able to generate the code you desire.

I only wanted to fuse the outer most dimension so used tile+fuse for this particular example.

Thank you very much for pointing me to this change. It looks promising to extend it to cover the case.

I can confirm that the tensor.insert_slice operations in the example are not supported by tile and fuse. If I understand correctly, they seem to concatenate the result tensors of the two producer ops. This pattern does not work since fusion at the moment needs a one-by-one mapping of consumer and producer loops. In particular, we will currently fail to analyze this mapping.

In the example, one could probably work around this limitation by expanding the last iteration dimension of the linalg.generic that consumes the concatenated tensor to be 2x512 instead of 1024. We could then remove the tensor.insert_slice operations and pass %3 and %5 directly to the linalg.generic, which should enable fusion. Inside the linalg.generic a select operation could then read the first or second tensor depending on the value of the leading expanded dimension (2x), which can accessed using linalg.index operation. I do not claim this is a clean way to address the problem. It is just hiding the problem (control flow :)) in the body of a linalg.generic.

Does this sound like it could solve the problem or did I miss something here?

1 Like

Thanks very much for sharing this tip! Hiding control flow inside linalg.generic should be able to work around the tensor.insert_slice issue in my example and similar cases in my experiments. :slight_smile: Thanks again!