Need help in better understanding SCF to SPIR-V lowering

Hi all,

I’m trying to use the convert-scf-to-spirv pass in MLIR to lower the (scf.for) in the following toy example to spirv.mlir.loop.

func.func @forward() {
    %c0 = arith.constant 0 : index
    %c32 = arith.constant 32 : index
    %c1 = arith.constant 1 : index
    scf.for %arg2 = %c0 to %c32 step %c1 {
        %1 = index.add %arg2, %arg2

However, I run into the following error

error: failed to materialize conversion for block argument #0 that remained live after conversion, type was 'index', with target type 'i32' origArg: <block argument> of type 'index' at index: 0 newArg: <<NULL VALUE>>see existing live user here: %8 = "index.add"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) : (index, index) -> index
    scf.for %arg2 = %c0 to %c32 step %c1 {
/nn-mlir/test-scf-to-spirv.mlir:6:14: note: see existing live user here: %8 = "index.add"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) : (index, index) -> index
        %1 = index.add %arg2, %arg2

which I believe has to do with the fact that index.add expects the operands to be of type index and during the conversion of scf.for to spirv.mlir.loop, the index gets converted to an i32 and so, there is no way to pass an index to index.add.

However, even if I add

auto addUnrealizedCast = [](OpBuilder &builder, Type type,
                            ValueRange inputs, Location loc) {
  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
  return std::optional<Value>(cast.getResult(0));

to SCFToSPIRVPass::runOnOperation(), I would expect the resulting IR to look something like

module {
  spirv.func @forward() "None" {
    %cst0_i32 = spirv.Constant 0 : i32
    %cst32_i32 = spirv.Constant 32 : i32
    %cst1_i32 = spirv.Constant 1 : i32
    spirv.mlir.loop {
      spirv.Branch ^bb1(%cst0_i32 : i32)
    ^bb1(%0: i32):  // 2 preds: ^bb0, ^bb2
      %1 = spirv.SLessThan %0, %cst32_i32 : i32
      spirv.BranchConditional %1, ^bb2, ^bb3
    ^bb2:  // pred: ^bb1
      %cast_to_index = builtin.unrealized_conversion_cast %0 : i32 to index
      %index_add = index.add %cast_to_index, %cast_to_index
      %2 = spirv.IAdd %0, %cst1_i32 : i32
      spirv.Branch ^bb1(%2 : i32)
    ^bb3:  // pred: ^bb1

but it still fails with the same error.

Seems like I’m clearly missing something. I’d appreciate any help! Thanks!

Hi @hsnbrg,

Sorry I missed this thread. It doesn’t look like we have an index to spirv conversion pass – I can only see IndexToLLVM in the tree. index is a newer addition to mlir and we may not have seen inputs that use in the wild yet. We should add one.

Just to check if there are some other issue, does the code work if you change index.add to arith.addi?

cc: @antiagainst

I opened a tracking issue: [mlir][spirv] Support index to spir-v dialect conversion · Issue #63713 · llvm/llvm-project · GitHub

Yes, what @kuhar said in the above. We’d need to add patterns to convert the new index ops. They should be fairly straightforward. @hsnbrg do you want to give it a try?