I am trying to lower simple mlir tensor example. Matmult with tiling only on i dimension where C[i][j] += A[i][k] * B[k][j]
.
simple.mlir
func.func @mm_tiled(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%ret = scf.for %i = %c0 to %c2 step %c1 iter_args(%c_out = %C) -> tensor<?x?xf32> {
%i_init = arith.shli %i, %c1 : index
%c_update = tensor.extract_slice %c_out[%i_init, 0] [%c2, %c4] [1, 1] :
tensor<?x?xf32> to tensor<?x?xf32>
%a_update = tensor.extract_slice %A[%i_init, 0] [%c2, %c4] [1, 1] :
tensor<?x?xf32> to tensor<?x?xf32>
%ret0 = linalg.matmul ins(%a_update, %B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%c_update: tensor<?x?xf32>) -> tensor<?x?xf32>
%ret = tensor.insert_slice %ret0 into %c_out[%i_init, 0] [%c2, %c4] [1, 1] :
tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %ret: tensor<?x?xf32>
}
return %ret : tensor<?x?xf32>
}
I am using command line args:
mlir-opt --mlir-disable-threading simple.mlir -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map" \
-drop-equivalent-buffer-results -buffer-deallocation -resolve-shaped-type-result-dims -convert-linalg-to-loops -resolve-ranked-shaped-type-result-dims -expand-strided-metadata -convert-tensor-to-linalg -normalize-memrefs -lower-affine -convert-scf-to-cf -memref-expand -finalize-memref-to-llvm \
-convert-scf-to-cf -convert-func-to-llvm -convert-cf-to-llvm \
-convert-arith-to-llvm -convert-index-to-llvm \
-reconcile-unrealized-casts
It causes assertion: Assertion other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"
failed.
detailed output
mlir-opt: /home/mvaidya/source/repos/MLIR_Workspace/llvm-project/mlir/lib/Analysis/FlatLinearValueConstraints.cpp:160: mlir::LogicalResult mlir::FlatLinearConstraints::composeMatchingMap(mlir::AffineMap): Assertion `other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0. Program arguments: mlir-opt --mlir-disable-threading simple.mlir "-one-shot-bufferize=allow-return-allocs bufferize-function-boundaries unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map" -drop-equivalent-buffer-results -buffer-deallocation -resolve-shaped-type-result-dims -convert-linalg-to-loops -resolve-ranked-shaped-type-result-dims -expand-strided-metadata -convert-tensor-to-linalg -normalize-memrefs -lower-affine -convert-scf-to-cf -memref-expand -finalize-memref-to-llvm -convert-scf-to-cf -convert-func-to-llvm -convert-cf-to-llvm -convert-arith-to-llvm -convert-index-to-llvm -reconcile-unrealized-casts
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0 mlir-opt 0x0000562935be9f5d llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 61
1 mlir-opt 0x0000562935bea3db
2 mlir-opt 0x0000562935be8676 llvm::sys::RunSignalHandlers() + 134
3 mlir-opt 0x0000562935beabf5
4 libc.so.6 0x00007f46bd242520
5 libc.so.6 0x00007f46bd296a7c pthread_kill + 300
6 libc.so.6 0x00007f46bd242476 raise + 22
7 libc.so.6 0x00007f46bd2287f3 abort + 211
8 libc.so.6 0x00007f46bd22871b
9 libc.so.6 0x00007f46bd239e96
10 mlir-opt 0x000056293a0591dc mlir::FlatLinearConstraints::composeMatchingMap(mlir::AffineMap) + 268
11 mlir-opt 0x0000562935dd3409 mlir::affine::normalizeMemRefType(mlir::MemRefType, unsigned int) + 617
12 mlir-opt 0x0000562937a00a39
13 mlir-opt 0x0000562937a0054e
14 mlir-opt 0x0000562937a0049d
15 mlir-opt 0x0000562935cdfd0c
16 mlir-opt 0x0000562935cdfcbe void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) + 462
17 mlir-opt 0x0000562935cdfc6d void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) + 381
18 mlir-opt 0x0000562935cdfc6d void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) + 381
19 mlir-opt 0x0000562937a00432
20 mlir-opt 0x0000562937a003cd
21 mlir-opt 0x00005629379ff0b0
22 mlir-opt 0x00005629379fd8c7
23 mlir-opt 0x00005629379fd24e
24 mlir-opt 0x000056293a0cd95b
25 mlir-opt 0x000056293a0cd8f5
26 mlir-opt 0x0000562935b095a9
27 mlir-opt 0x000056293a0d096d
28 mlir-opt 0x000056293a0c8e56 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 646
29 mlir-opt 0x000056293a0c9414 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 388
30 mlir-opt 0x000056293a0caeac mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) + 108
31 mlir-opt 0x000056293a0cadc2 mlir::PassManager::run(mlir::Operation*) + 1090
32 mlir-opt 0x000056293a0b1289
33 mlir-opt 0x000056293a0b0f29
34 mlir-opt 0x000056293a0b0d1b
35 mlir-opt 0x000056293a0b0c9d
36 mlir-opt 0x000056293a221979
37 mlir-opt 0x000056293a220f55 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) + 149
38 mlir-opt 0x000056293a0adba9 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) + 345
39 mlir-opt 0x000056293a0ae07e mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) + 1150
40 mlir-opt 0x0000562935a7ab26 main + 134
41 libc.so.6 0x00007f46bd229d90
42 libc.so.6 0x00007f46bd229e40 __libc_start_main + 128
43 mlir-opt 0x0000562935a7a7e5 _start + 37
[1] 59067 IOT instruction (core dumped) mlir-opt --mlir-disable-threading simple.mlir -drop-equivalent-buffer-result
The llvm commit I am using is 806dea46be4c49dc587b98dab5e4d9d242a6abdb.
I am wondering if there is a recommended DAG of passes which can be traced to lower tensor dialect to various low level dialects (in particular llvm dialect)? Thanks in advance!
cc: @matthias-springer (based on recommendation from @maheshravishankar).
+miheer