Hi all,
I’m trying to use the convert-scf-to-spirv
pass in MLIR to lower the (scf.for
) in the following toy example to spirv.mlir.loop
.
func.func @forward() {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
scf.for %arg2 = %c0 to %c32 step %c1 {
%1 = index.add %arg2, %arg2
}
return
}
However, I run into the following error
error: failed to materialize conversion for block argument #0 that remained live after conversion, type was 'index', with target type 'i32' origArg: <block argument> of type 'index' at index: 0 newArg: <<NULL VALUE>>see existing live user here: %8 = "index.add"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) : (index, index) -> index
scf.for %arg2 = %c0 to %c32 step %c1 {
^
/nn-mlir/test-scf-to-spirv.mlir:6:14: note: see existing live user here: %8 = "index.add"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) : (index, index) -> index
%1 = index.add %arg2, %arg2
which I believe has to do with the fact that index.add
expects the operands to be of type index
and during the conversion of scf.for
to spirv.mlir.loop
, the index
gets converted to an i32
and so, there is no way to pass an index
to index.add
.
However, even if I add
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return std::optional<Value>(cast.getResult(0));
};
typeConverter.addSourceMaterialization(addUnrealizedCast);
typeConverter.addTargetMaterialization(addUnrealizedCast);
target->addLegalOp<UnrealizedConversionCastOp>();
to SCFToSPIRVPass::runOnOperation()
, I would expect the resulting IR to look something like
module {
spirv.func @forward() "None" {
%cst0_i32 = spirv.Constant 0 : i32
%cst32_i32 = spirv.Constant 32 : i32
%cst1_i32 = spirv.Constant 1 : i32
spirv.mlir.loop {
spirv.Branch ^bb1(%cst0_i32 : i32)
^bb1(%0: i32): // 2 preds: ^bb0, ^bb2
%1 = spirv.SLessThan %0, %cst32_i32 : i32
spirv.BranchConditional %1, ^bb2, ^bb3
^bb2: // pred: ^bb1
%cast_to_index = builtin.unrealized_conversion_cast %0 : i32 to index
%index_add = index.add %cast_to_index, %cast_to_index
%2 = spirv.IAdd %0, %cst1_i32 : i32
spirv.Branch ^bb1(%2 : i32)
^bb3: // pred: ^bb1
spirv.mlir.merge
}
spirv.Return
}
}
but it still fails with the same error.
Seems like I’m clearly missing something. I’d appreciate any help! Thanks!