Hello,
I am trying to use the dialect conversion framework to apply some conversion pattern to the internal control flow inside a function. This is part of implementing “detensoring” for linalg-on-tensor
ops which basically means converting such op instances that take and produce only 0D tensors to their equivalent ops that directly work on the underlying tensor element types. This is being implemented in these 2 patches: ⚙ D96271 [MLIR][LinAlg] Start detensoring implementation. and ⚙ D97148 {WIP: PLZ DON'T REVIEW YET}[MLIR][LinAlg] Detensorize interal CF..
To that end, we would like to avoid detensoring across function boundaries. This means that we would like to apply type conversion for all the basic blocks signatures inside a function except for the entry one. For example for the following (contrived) example:
func @main(%arg0: tensor<i32>) -> tensor<i32> attributes {iree.module.export} {
br ^bb1(%arg0: tensor<i32>)
^bb1(%0: tensor<i32>):
return %0 : tensor<i32>
}
we would like to convert it to something similar to:
func @main(%arg0: tensor<i32>) -> tensor<i32> attributes {iree.module.export} {
%ex_arg0 = tensor.extract %arg0[] : tensor<i32>
br ^bb1(%ex_arg0: i32)
^bb1(%0: i32):
%1 = tensor.from_elements %0 : tensor<1xi32>
%2 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
return %2 : tensor<i32>
}
I am trying to use the dialect conversion framework to properly handle control flow conversion in this patch (⚙ D97148 {WIP: PLZ DON'T REVIEW YET}[MLIR][LinAlg] Detensorize interal CF.) (Note that the conversion within the BB boundary is implemented here: ⚙ D96271 [MLIR][LinAlg] Start detensoring implementation.).
The difficulty I am currently facing is in properly handling the function’s entry block. In particular, this block should be left as is without converting its signature. I do this by passing an “identity” TypeConverter::SignatureConversion
instance to ConversionPatternRewriter::convertRegionTypes(...)
as you can see in Detensorize.cpp:198
(copied here):
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
...
TypeConverter::SignatureConversion result(type.getNumInputs());
result.addInputs(type.getInputs());
SmallVector<Type, 1> newResults;
if (failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op),
*typeConverter, &result))) {
rewriter.cancelRootUpdate(op);
return failure();
}
...
}
The above code properly achieves the desired goal of not converting the entry BB’s arguments. Also, a target materialization is added in order to extract %arg0
's element and pass the extracted value to the converted br
op. However, during the OperationConverter::finalize(...) -> OperationConverter::legalizeConvertedArgumentTypes(...) -> ArgConverter::materializeLiveConversions(...)
, the framework tries to create a source materialization for %arg0
in the entry BB. Looking at the implementation, the last method in that sequence (i.e. ArgConverter::materializeLiveConversions
) invokes the source materialization hook with a non-empty value only if the argReplacementValue
is different from the original value (llvm-project/DialectConversion.cpp at main · llvm/llvm-project · GitHub). This makes me suspect that I didn’t setup the framework properly in someway or another because I believe a source materialization shouldn’t have been needed in this situation.
Below is the debug output:
//===-------------------------------------------===//
Legalizing operation : 'func'(0x7fd297104510) {
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'func -> ()' {
** Insert : 'tensor.from_elements'(0x7fd2971168d8)
** Insert : 'linalg.tensor_reshape'(0x7fd297116c18)
//===-------------------------------------------===//
Legalizing operation : 'func'(0x7fd297104510) {
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0x7fd2971168d8) {
%1 = "tensor.from_elements"(%0) : (i32) -> tensor<1xi32>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'linalg.tensor_reshape'(0x7fd297116c18) {
%2 = "linalg.tensor_reshape"(%1) {reassociation = []} : (tensor<1xi32>) -> tensor<i32>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'std.br'(0x7fd29710b2e0) {
"std.br"(<<UNKNOWN SSA VALUE>>)[^bb1] : (tensor<i32>) -> ()
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'std.br -> ()' {
** Insert : 'tensor.extract'(0x7fd297204088)
//===-------------------------------------------===//
Legalizing operation : 'std.br'(0x7fd29710b2e0) {
"std.br"(%0)[^bb1] : (i32) -> ()
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.extract'(0x7fd297204088) {
%0 = "tensor.extract"(<<UNKNOWN SSA VALUE>>) : (tensor<i32>) -> i32
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'std.return'(0x7fd29710b4b0) {
"std.return"(<<UNKNOWN SSA VALUE>>) : (tensor<i32>) -> ()
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
Assertion failed: (inputs.size() == 1), function operator(), file /Users/ergawy/work/llvm-project/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp, line 91.
PLEASE submit a bug report to https://bugs.llvm.org/ and include the crash backtrace.
Stack dump:
0. Program arguments: /Users/ergawy/work/llvm-project/build/bin/mlir-opt /Users/ergawy/work/llvm-project/mlir/test/Dialect/Linalg/detensorized_while.mlir -linalg-detensorize -func-detensorize -debug -print-ir-after-all
1. Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0 mlir-opt 0x00000001089cfa1b llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 43
1 mlir-opt 0x00000001089ce728 llvm::sys::RunSignalHandlers() + 248
2 mlir-opt 0x00000001089d0077 SignalHandler(int) + 295
3 libsystem_platform.dylib 0x00007fff6d7495fd _sigtramp + 29
4 libsystem_platform.dylib 0x0000000000000b40 _sigtramp + 18446603338679809376
5 libsystem_c.dylib 0x00007fff6d61f808 abort + 120
6 libsystem_c.dylib 0x00007fff6d61eac6 err + 0
7 mlir-opt 0x000000010a4e1b01 std::__1::__function::__func<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Ty
pe, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lamb
2 mlir-opt 0x00000001089d0077 SignalHandler(int) + 295
3 libsystem_platform.dylib 0x00007fff6d7495fd _sigtramp + 29
4 libsystem_platform.dylib 0x0000000000000b40 _sigtramp + 18446603338679809376
5 libsystem_c.dylib 0x00007fff6d61f808 abort + 120
6 libsystem_c.dylib 0x00007fff6d61eac6 err + 0
7 mlir-opt 0x000000010a4e1b01 std::__1::__function::__func<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Ty
pe, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lamb
da0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)&&)::'lambda'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location), std::__1::allocator<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir:
:Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Type, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Lo
cation)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)&&)::'lambda'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>,
llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>::operator()(mlir::OpBuilder&, mlir::Type&&, mlir::ValueRange&&, mlir::Location&&) (.cold.3) + 33
8 mlir-opt 0x0000000108ba6058 std::__1::__function::__func<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Ty
pe, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lamb
da0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)&&)::'lambda'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location), std::__1::allocator<std::__1::function<llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir:
:Type, mlir::ValueRange, mlir::Location)> mlir::TypeConverter::wrapMaterialization<mlir::Type, (anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Lo
cation)>((anonymous namespace)::DetensorizeTypeConverter::DetensorizeTypeConverter()::'lambda0'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)&&)::'lambda'(mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>,
llvm::Optional<mlir::Value> (mlir::OpBuilder&, mlir::Type, mlir::ValueRange, mlir::Location)>::operator()(mlir::OpBuilder&, mlir::Type&&, mlir::ValueRange&&, mlir::Location&&) + 232
9 mlir-opt 0x00000001092857ab (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) + 3499
10 mlir-opt 0x0000000109286e09 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget&, mlir::FrozenRewritePatternList const&, llvm::DenseSet<mlir::Operation*, llvm::DenseMapInfo<mlir::Operation*> >*) + 73
11 mlir-opt 0x0000000108ba7c76 (anonymous namespace)::FuncDetensorize::runOnFunction() + 1590
12 mlir-opt 0x000000010921cad0 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 512
13 mlir-opt 0x000000010921cf65 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::__1::unique_ptr<mlir::Pass, std::__1::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operati
on*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 133
14 mlir-opt 0x0000000109222bc4 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8::operator()(llvm::MutableArrayRef<mlir::OpPassManager>) const + 452
15 mlir-opt 0x000000010921dc21 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) + 1761
16 mlir-opt 0x000000010921cc67 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 919
17 mlir-opt 0x000000010921f76a mlir::PassManager::run(mlir::Operation*) + 762
18 mlir-opt 0x00000001091faf9d performActions(llvm::raw_ostream&, bool, bool, llvm::SourceMgr&, mlir::MLIRContext*, mlir::PassPipelineCLParser const&) + 397
19 mlir-opt 0x00000001091f90f0 processBuffer(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer> >, bool, bool, bool, bool, mlir::PassPipelineCLParser const&, mlir::Dia
lectRegistry&) + 304
20 mlir-opt 0x00000001091f9c84 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&, bool) + 2788
21 mlir-opt 0x00000001087bcfbc main + 140
22 libdyld.dylib 0x00007fff6d550cc9 start + 1
Any pointers on where I might have gone wrong?
Kareem