I’ve been searching the forums and test files, but haven’t been able to find a concrete example of how to do this yet!
I’m trying to lower this simple parallel for loop into LLVM:
func.func @legateMLIRKernel2(%arg0: memref<?xf64>, %arg1: memref<?xf64>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%dim = memref.dim %arg0, %c0 : memref<?xf64>
%c0_0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%arg2) = (%c0_0) to (%dim) step (%c1) {
%0 = memref.load %arg0[%arg2] : memref<?xf64>
%1 = arith.addf %0, %0 : f64
%2 = arith.addf %1, %0 : f64
memref.store %2, %arg1[%arg2] : memref<?xf64>
scf.yield
}
return
}
When I run this through ./bin/mlir-opt testing.mlir --pass-pipeline="builtin.module(convert-scf-to-openmp, convert-openmp-to-llvm)"
, I get some reasonable looking IR:
module {
llvm.func @legateMLIRKernel2(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {llvm.emit_c_interface} {
%0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%12 = llvm.mlir.constant(0 : index) : i64
%13 = llvm.extractvalue %5[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%14 = llvm.mlir.constant(0 : index) : i64
%15 = llvm.mlir.constant(1 : index) : i64
%16 = llvm.mlir.constant(1 : i64) : i64
omp.parallel {
omp.wsloop for (%arg10) : i64 = (%14) to (%13) step (%15) {
%17 = llvm.intr.stacksave : !llvm.ptr
llvm.br ^bb1
^bb1: // pred: ^bb0
%18 = llvm.extractvalue %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%19 = llvm.getelementptr %18[%arg10] : (!llvm.ptr, i64) -> !llvm.ptr, f64
%20 = llvm.load %19 : !llvm.ptr -> f64
%21 = llvm.fadd %20, %20 : f64
%22 = llvm.fadd %21, %20 : f64
%23 = llvm.extractvalue %11[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%24 = llvm.getelementptr %23[%arg10] : (!llvm.ptr, i64) -> !llvm.ptr, f64
llvm.store %22, %24 : f64, !llvm.ptr
llvm.intr.stackrestore %17 : !llvm.ptr
llvm.br ^bb2
^bb2: // pred: ^bb1
omp.yield
}
omp.terminator
}
llvm.return
}
llvm.func @_mlir_ciface_legateMLIRKernel2(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {llvm.emit_c_interface} {
%0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
llvm.call @legateMLIRKernel2(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> ()
llvm.return
}
}
In the command line, I can then pass this through to mlir-translate --mlir-to-llvmir
, which appears to give me back valid LLVM.
I can’t seem to get the same behavior within my C++ application though!
I’m applying at least those same passes to my MLIR fragment:
mlir::PassManager pm(ctx, this->module_.get()->getName().getStringRef(), mlir::PassManager::Nesting::Implicit);
{
pm.addNestedPass<mlir::func::FuncOp>(mlir::createLowerAffinePass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::arith::createArithExpandOpsPass());
pm.addPass(mlir::createConvertSCFToOpenMPPass());
pm.addPass(mlir::createConvertOpenMPToLLVMPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::memref::createExpandStridedMetadataPass());
pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
pm.addPass(mlir::createConvertMathToLLVMPass());
pm.addPass(mlir::createConvertMathToLibmPass());
// TODO (rohany): Add in complex to libm passes?
pm.addPass(mlir::createConvertFuncToLLVMPass());
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
}
if (mlir::failed(pm.run(this->module_.get()))) {
assert(false);
}
std::unique_ptr<llvm::LLVMContext> llvmContext = std::make_unique<llvm::LLVMContext>();
auto llvmModule = mlir::translateModuleToLLVMIR(this->module_.get(), *llvmContext, this->kernelName_);
assert(llvmModule);
However, the final assertion (assert(llvmModule);
) fails with:
loc("binary_op"): error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: omp.parallel
What’s missing between my C++ code and mlir-translate
? I tried looking at the mlir-translate source but couldn’t figure out too much what was going on.