Assertion failed when lowering simple tensor dilect example

I am trying to lower simple mlir tensor example. Matmult with tiling only on i dimension where C[i][j] += A[i][k] * B[k][j].

simple.mlir
func.func @mm_tiled(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %c4 = arith.constant 4 : index
  %ret = scf.for %i = %c0 to %c2 step %c1 iter_args(%c_out = %C) -> tensor<?x?xf32> {
    %i_init = arith.shli %i, %c1 : index
    %c_update = tensor.extract_slice %c_out[%i_init, 0] [%c2, %c4] [1, 1] :
      tensor<?x?xf32> to tensor<?x?xf32>
    %a_update = tensor.extract_slice %A[%i_init, 0] [%c2, %c4] [1, 1] :
      tensor<?x?xf32> to tensor<?x?xf32>
    %ret0 = linalg.matmul ins(%a_update, %B : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%c_update: tensor<?x?xf32>) -> tensor<?x?xf32>
    %ret = tensor.insert_slice %ret0 into %c_out[%i_init, 0] [%c2, %c4] [1, 1] :
      tensor<?x?xf32> into tensor<?x?xf32>
    scf.yield %ret: tensor<?x?xf32>
  }
  return %ret : tensor<?x?xf32>
}

I am using command line args:

mlir-opt --mlir-disable-threading simple.mlir -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map" \
   -drop-equivalent-buffer-results -buffer-deallocation -resolve-shaped-type-result-dims -convert-linalg-to-loops -resolve-ranked-shaped-type-result-dims -expand-strided-metadata -convert-tensor-to-linalg -normalize-memrefs -lower-affine -convert-scf-to-cf -memref-expand -finalize-memref-to-llvm \
    -convert-scf-to-cf -convert-func-to-llvm -convert-cf-to-llvm \
   -convert-arith-to-llvm -convert-index-to-llvm  \
   -reconcile-unrealized-casts

It causes assertion: Assertion other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch" failed.

detailed output
mlir-opt: /home/mvaidya/source/repos/MLIR_Workspace/llvm-project/mlir/lib/Analysis/FlatLinearValueConstraints.cpp:160: mlir::LogicalResult mlir::FlatLinearConstraints::composeMatchingMap(mlir::AffineMap): Assertion `other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: mlir-opt --mlir-disable-threading simple.mlir "-one-shot-bufferize=allow-return-allocs bufferize-function-boundaries unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map" -drop-equivalent-buffer-results -buffer-deallocation -resolve-shaped-type-result-dims -convert-linalg-to-loops -resolve-ranked-shaped-type-result-dims -expand-strided-metadata -convert-tensor-to-linalg -normalize-memrefs -lower-affine -convert-scf-to-cf -memref-expand -finalize-memref-to-llvm -convert-scf-to-cf -convert-func-to-llvm -convert-cf-to-llvm -convert-arith-to-llvm -convert-index-to-llvm -reconcile-unrealized-casts
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  mlir-opt  0x0000562935be9f5d llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 61
1  mlir-opt  0x0000562935bea3db
2  mlir-opt  0x0000562935be8676 llvm::sys::RunSignalHandlers() + 134
3  mlir-opt  0x0000562935beabf5
4  libc.so.6 0x00007f46bd242520
5  libc.so.6 0x00007f46bd296a7c pthread_kill + 300
6  libc.so.6 0x00007f46bd242476 raise + 22
7  libc.so.6 0x00007f46bd2287f3 abort + 211
8  libc.so.6 0x00007f46bd22871b
9  libc.so.6 0x00007f46bd239e96
10 mlir-opt  0x000056293a0591dc mlir::FlatLinearConstraints::composeMatchingMap(mlir::AffineMap) + 268
11 mlir-opt  0x0000562935dd3409 mlir::affine::normalizeMemRefType(mlir::MemRefType, unsigned int) + 617
12 mlir-opt  0x0000562937a00a39
13 mlir-opt  0x0000562937a0054e
14 mlir-opt  0x0000562937a0049d
15 mlir-opt  0x0000562935cdfd0c
16 mlir-opt  0x0000562935cdfcbe void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) + 462
17 mlir-opt  0x0000562935cdfc6d void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) + 381
18 mlir-opt  0x0000562935cdfc6d void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) + 381
19 mlir-opt  0x0000562937a00432
20 mlir-opt  0x0000562937a003cd
21 mlir-opt  0x00005629379ff0b0
22 mlir-opt  0x00005629379fd8c7
23 mlir-opt  0x00005629379fd24e
24 mlir-opt  0x000056293a0cd95b
25 mlir-opt  0x000056293a0cd8f5
26 mlir-opt  0x0000562935b095a9
27 mlir-opt  0x000056293a0d096d
28 mlir-opt  0x000056293a0c8e56 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 646
29 mlir-opt  0x000056293a0c9414 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 388
30 mlir-opt  0x000056293a0caeac mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) + 108
31 mlir-opt  0x000056293a0cadc2 mlir::PassManager::run(mlir::Operation*) + 1090
32 mlir-opt  0x000056293a0b1289
33 mlir-opt  0x000056293a0b0f29
34 mlir-opt  0x000056293a0b0d1b
35 mlir-opt  0x000056293a0b0c9d
36 mlir-opt  0x000056293a221979
37 mlir-opt  0x000056293a220f55 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) + 149
38 mlir-opt  0x000056293a0adba9 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) + 345
39 mlir-opt  0x000056293a0ae07e mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) + 1150
40 mlir-opt  0x0000562935a7ab26 main + 134
41 libc.so.6 0x00007f46bd229d90
42 libc.so.6 0x00007f46bd229e40 __libc_start_main + 128
43 mlir-opt  0x0000562935a7a7e5 _start + 37
[1]    59067 IOT instruction (core dumped)  mlir-opt --mlir-disable-threading simple.mlir  -drop-equivalent-buffer-result

The llvm commit I am using is 806dea46be4c49dc587b98dab5e4d9d242a6abdb.

I am wondering if there is a recommended DAG of passes which can be traced to lower tensor dialect to various low level dialects (in particular llvm dialect)? Thanks in advance!

cc: @matthias-springer (based on recommendation from @maheshravishankar).
+miheer

Actually @matthias-springer . Any pointers?

1 Like

Please ignore last comment. I was using old mlir-opt masking issue.

This crash happens during -normalize-memrefs. Do you need that pass? If you drop it, your example compiles.

I’m not familiar with memref normalization, but here is a potential fix for the crash: ⚙ D150250 [mlir][memref] Make result normalization aware of the number symbols I’m not familiar with this part of the code base, so I’d appreciate if someone else could review it.

The general flow is -one-shot-bufferize to convert tensor IR to memref IR. Then there are various passes that convert to LLVM. Here is an example: llvm-project/mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir

1 Like

Thank you @matthias-springer, I was wondering if it would be useful to more newcomers to have the set of passes documented? Maybe as a small DAG at the beginning of this page?