Help lowering OpenMP dialect to LLVM

I’ve been searching the forums and test files, but haven’t been able to find a concrete example of how to do this yet!

I’m trying to lower this simple parallel for loop into LLVM:

  func.func @legateMLIRKernel2(%arg0: memref<?xf64>, %arg1: memref<?xf64>) attributes {llvm.emit_c_interface} {
    %c0 = arith.constant 0 : index
    %dim = memref.dim %arg0, %c0 : memref<?xf64>
    %c0_0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    scf.parallel (%arg2) = (%c0_0) to (%dim) step (%c1) {
      %0 = memref.load %arg0[%arg2] : memref<?xf64>
      %1 = arith.addf %0, %0 : f64
      %2 = arith.addf %1, %0 : f64
      memref.store %2, %arg1[%arg2] : memref<?xf64>
      scf.yield
    }
    return
  }

When I run this through ./bin/mlir-opt testing.mlir --pass-pipeline="builtin.module(convert-scf-to-openmp, convert-openmp-to-llvm)", I get some reasonable looking IR:

module {
  llvm.func @legateMLIRKernel2(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %12 = llvm.mlir.constant(0 : index) : i64
    %13 = llvm.extractvalue %5[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %14 = llvm.mlir.constant(0 : index) : i64
    %15 = llvm.mlir.constant(1 : index) : i64
    %16 = llvm.mlir.constant(1 : i64) : i64
    omp.parallel   {
      omp.wsloop   for  (%arg10) : i64 = (%14) to (%13) step (%15) {
        %17 = llvm.intr.stacksave : !llvm.ptr
        llvm.br ^bb1
      ^bb1:  // pred: ^bb0
        %18 = llvm.extractvalue %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
        %19 = llvm.getelementptr %18[%arg10] : (!llvm.ptr, i64) -> !llvm.ptr, f64
        %20 = llvm.load %19 : !llvm.ptr -> f64
        %21 = llvm.fadd %20, %20  : f64
        %22 = llvm.fadd %21, %20  : f64
        %23 = llvm.extractvalue %11[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
        %24 = llvm.getelementptr %23[%arg10] : (!llvm.ptr, i64) -> !llvm.ptr, f64
        llvm.store %22, %24 : f64, !llvm.ptr
        llvm.intr.stackrestore %17 : !llvm.ptr
        llvm.br ^bb2
      ^bb2:  // pred: ^bb1
        omp.yield
      }
      omp.terminator
    }
    llvm.return
  }
  llvm.func @_mlir_ciface_legateMLIRKernel2(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {llvm.emit_c_interface} {
    %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %6 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    llvm.call @legateMLIRKernel2(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> ()
    llvm.return
  }
}

In the command line, I can then pass this through to mlir-translate --mlir-to-llvmir, which appears to give me back valid LLVM.

I can’t seem to get the same behavior within my C++ application though!

I’m applying at least those same passes to my MLIR fragment:

  mlir::PassManager pm(ctx, this->module_.get()->getName().getStringRef(), mlir::PassManager::Nesting::Implicit);
  {
    pm.addNestedPass<mlir::func::FuncOp>(mlir::createLowerAffinePass());
    pm.addNestedPass<mlir::func::FuncOp>(mlir::arith::createArithExpandOpsPass());
    pm.addPass(mlir::createConvertSCFToOpenMPPass());
    pm.addPass(mlir::createConvertOpenMPToLLVMPass());
    pm.addNestedPass<mlir::func::FuncOp>(mlir::memref::createExpandStridedMetadataPass());
    pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
    pm.addPass(mlir::createConvertMathToLLVMPass());
    pm.addPass(mlir::createConvertMathToLibmPass());
    // TODO (rohany): Add in complex to libm passes?
    pm.addPass(mlir::createConvertFuncToLLVMPass());
    pm.addPass(mlir::createReconcileUnrealizedCastsPass());
  }
  if (mlir::failed(pm.run(this->module_.get()))) {
    assert(false);
  }
  std::unique_ptr<llvm::LLVMContext> llvmContext = std::make_unique<llvm::LLVMContext>();
  auto llvmModule = mlir::translateModuleToLLVMIR(this->module_.get(), *llvmContext, this->kernelName_);
  assert(llvmModule);

However, the final assertion (assert(llvmModule);) fails with:

loc("binary_op"): error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: omp.parallel

What’s missing between my C++ code and mlir-translate? I tried looking at the mlir-translate source but couldn’t figure out too much what was going on.

You can try adding a call to one of the functions below.

void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry);
void mlir::registerOpenMPDialectTranslation(MLIRContext &context);

Thanks, that worked!

Hi @kiranchandramohan, I’m pushing on this further and trying to JIT some of the code that I’ve lowered down to the OpenMP dialect using your help.

The LLVM I’m generating looks like it’s been lowered correctly (contains calls to OpenMP runtime functions like __kmpc_global_thread_num and __kmpc_fork_call. I’m then passing this to the ORCJIT to link and run. However, I’m not convinced that the right thing is going on here, as the generated code does not seem sensitive to things like OMP_NUM_THREADS. What are some things that I can look at to see what might be going wrong?

In the end, I have a slightly exotic goal – I’m jitting into a system that has a custom implementation of the openmp runtime, so in the end, I need to get the jitted code to link against that implementation, rather than the standard LLVM implementation. Is this something that is possible to achieve?

Never mind, i think I got it to work!

1 Like