How to use PassManager in Python binding

Hi friends,

I am trying the PassManager in Python binding with

from mlir.ir import Module
from mlir.passmanager import PassManager

module = Module.parse("mlir assembly ...")
PassManager.parse("builtin.module(lower-affine, convert-vector-to-llvm, convert-func-to-llvm, ...)").run(module)
print(module)

I found the printed module keeps unchanged, still the original mlir assembly with higher level diacets. Is there a way to get the lowered assembly with Python binding?

That looks about right to me. You can double check against the test here: llvm-project/pass_manager.py at 4e295cb1ce3b94446a36a5d903afb17e8e30ec68 · llvm/llvm-project · GitHub.

It depends what is inside "mlir assembly ...". If you share that, maybe we can help see what’s wrong. It’s possible those passes don’t change anything for your IR. Have you tested the same pass pipeline on the same input with mlir-opt?

Many thanks for the kind help!

Here is the full code:

from mlir.ir import Module, Context
from mlir.passmanager import PassManager

with Context():
    module = Module.parse("module {" +
        "func.func @add(%arg0: memref<?x5xf64>, %arg1: memref<?x5xf64>) {" +
        "%c10 = arith.constant 10 : index" +
        "%c0 = arith.constant 0 : index" +
        "%0 = vector.load %arg0[%c10, %c0] : memref<?x5xf64>, vector<5xf64>" +
        "%1 = vector.load %arg1[%c10, %c0] : memref<?x5xf64>, vector<5xf64>" +
        "%2 = arith.addf %0, %1 : vector<5xf64>" +
        "vector.print %2 : vector<5xf64>" +
        "return}}")
    PassManager.parse(
        "builtin.module(" +
        "lower-affine," +
        "convert-scf-to-cf," +
        "convert-func-to-llvm," +
        "convert-vector-to-llvm," +
        "convert-memref-to-llvm," +
        "convert-cf-to-llvm," +
        "convert-math-to-llvm," +
        "convert-arith-to-llvm," +
        "canonicalize," +
        "reconcile-unrealized-casts" +
        ")").run(module)

print(module)

Here is the printed output:

module {
  func.func @add(%arg0: memref<?x5xf64>, %arg1: memref<?x5xf64>) {
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %0 = vector.load %arg0[%c10, %c0] : memref<?x5xf64>, vector<5xf64>
    %1 = vector.load %arg1[%c10, %c0] : memref<?x5xf64>, vector<5xf64>
    %2 = arith.addf %0, %1 : vector<5xf64>
    vector.print %2 : vector<5xf64>
    return
  }
}

The output of mlir-opt whit the same input and pipline:

module attributes {llvm.data_layout = ""} {
  llvm.func @printNewline()
  llvm.func @printClose()
  llvm.func @printComma()
  llvm.func @printOpen()
  llvm.func @printF64(f64)
  llvm.func @add(%arg0: !llvm.ptr<f64>, %arg1: !llvm.ptr<f64>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<f64>, %arg8: !llvm.ptr<f64>, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64) {
    %0 = llvm.mlir.constant(4 : index) : i64
    %1 = llvm.mlir.constant(3 : index) : i64
    %2 = llvm.mlir.constant(2 : index) : i64
    %3 = llvm.mlir.constant(1 : index) : i64
    %4 = llvm.mlir.constant(5 : index) : i64
    %5 = llvm.mlir.constant(0 : index) : i64
    %6 = llvm.mlir.constant(10 : index) : i64
    %7 = llvm.mul %6, %4  : i64
    %8 = llvm.add %7, %5  : i64
    %9 = llvm.getelementptr %arg1[%8] : (!llvm.ptr<f64>, i64) -> !llvm.ptr<f64>
    %10 = llvm.bitcast %9 : !llvm.ptr<f64> to !llvm.ptr<vector<5xf64>>
    %11 = llvm.load %10 {alignment = 8 : i64} : !llvm.ptr<vector<5xf64>>
    %12 = llvm.mul %6, %4  : i64
    %13 = llvm.add %12, %5  : i64
    %14 = llvm.getelementptr %arg8[%13] : (!llvm.ptr<f64>, i64) -> !llvm.ptr<f64>
    %15 = llvm.bitcast %14 : !llvm.ptr<f64> to !llvm.ptr<vector<5xf64>>
    %16 = llvm.load %15 {alignment = 8 : i64} : !llvm.ptr<vector<5xf64>>
    %17 = llvm.fadd %11, %16  : vector<5xf64>
    llvm.call @printOpen() : () -> ()
    %18 = llvm.extractelement %17[%5 : i64] : vector<5xf64>
    llvm.call @printF64(%18) : (f64) -> ()
    llvm.call @printComma() : () -> ()
    %19 = llvm.extractelement %17[%3 : i64] : vector<5xf64>
    llvm.call @printF64(%19) : (f64) -> ()
    llvm.call @printComma() : () -> ()
    %20 = llvm.extractelement %17[%2 : i64] : vector<5xf64>
    llvm.call @printF64(%20) : (f64) -> ()
    llvm.call @printComma() : () -> ()
    %21 = llvm.extractelement %17[%1 : i64] : vector<5xf64>
    llvm.call @printF64(%21) : (f64) -> ()
    llvm.call @printComma() : () -> ()
    %22 = llvm.extractelement %17[%0 : i64] : vector<5xf64>
    llvm.call @printF64(%22) : (f64) -> ()
    llvm.call @printClose() : () -> ()
    llvm.call @printNewline() : () -> ()
    llvm.return
  }
}

You should remove the nesting on builtin.module in the pipeline string, i.e.

PassManager.parse("lower-affine, convert-vector-to-llvm, convert-func-to-llvm, ...").run(module)

The pipeline is already implicitly nested on builtin.module, so your string would only work if your IR looked like module { module { ... stuff ... } } (imo the implicit builtin.module nesting is confusing and has bit me before as well).

1 Like

It works :smiling_face: Thank you so much!