[RFC] `MPI` Dialect

A lot of interest for more information on the MPI dialect design was expressed at the last ODM. This is an RFC to provide a bigger picture of how we intend to model MPI as an MLIR dialect.

The mpi dialect is meant to serve as an interface dialect that models the MPI standard interface. The MPI dialect itself can be lowered to multiple MPI implementations and hide differences in ABI. The dialect models the functions of the MPI specification as close to 1:1 as possible while preserving SSA value semantics where it makes sense, and uses memref types instead of bare pointers.

For an in-depth documentation of the MPI library interface, please refer to official documentation such as the OpenMPI online documentation. Relevant parts of the documentation are linked throughout this RFC.

This RFC does not cover all of the MPI specification, it will instead focus on the following feature sets:

Feature State Comment
Blocking send/recv PR Ready Presented at ODM
Nonblocking send/recv Example IR Validated internally
Communicators Example IR
Collectives Example IR
Lowering Example IR, POC
MPI Error codes Example IR
Handling MPI Status Example IR

According to A large-scale study of MPI usage in open-source HPC applications, a small subset of all MPI calls make up the majority of MPI uses. The subset presented in this RFC provides good coverage of large parts of real-world HPC MPI usecases. This does not mean however, that features absent from this RFC are excluded from the MPI dialect. Additionally, features outlined in this RFC are not necessarily planned to be added to the dialect in the near future. It is instead intended to explore and show the decisions made while modelling MPI as an MLIR dialect and to verify that they make sense and are able to represent real HPC programs.

A collection of open questions is posed at the bottom of this RFC.

Blocking Communication

These are the simplest building blocks of MPI, our initial PR contains a simple
synchronous send/receive, init, finalise, and an operation to obtain
the processes rank:

func.func @mpi_test(%ref : memref<100xf32>) -> () {
    mpi.init

    %rank = mpi.comm_rank : i32

    mpi.send(%ref, %rank, %tag) : memref<100xf32>, i32, i32

    mpi.recv(%ref, %rank, %tag) : memref<100xf32>, i32, i32

    mpi.finalize

    func.return
}

For a more detailed look at this initial set of operations see
the PR which provides the output of mlir-tblgen -gen-dialect-doc.

The decision to model MPIs pointer+size+type as MLIR memrefs was made because we felt that the dialect would fit better into the existing ecosystem of MLIR dialects.

Non-blocking Communication

For non-blocking communication, a new datatype !mpi.request is introduced. This is directly equivalent to the MPI_Request type defined by MPI.

Since MPI_Requests are mutable objects that are always passed by reference, we decide to model them in memrefs and pass them as memref+index. This is consistent with how they are most often used in actual HPC programs (i.e. a stack-allocated array of MPI_Request objects).

With this, the nonblocking version of the blocking example above looks like this:

func.func @mpi_test(%ref : memref<100xf32>) -> () {
    mpi.init

    %rank = mpi.comm_rank : i32

    %requests = memref.alloca() : memref<2x!mpi.request>

    mpi.isend (%ref, %rank, %rank) as %requests[0] : memref<100xf32>, i32, i32, memref<2x!mpi.request>

    mpi.irecv (%ref, %rank, %rank) as %requests[1] : memref<100xf32>, i32, i32, memref<2x!mpi.request>

    // either waiting on a single one:
    %status = mpi.wait %requests[0] : memref<2x!mpi.request> -> !mpi.status

    // issue a waitall for all requests
    mpi.waitall %requests : memref<2x!mpi.request>

    mpi.finalize

    func.return
}

Implementing MPI_Wait, MPI_Waitany, MPI_Test, or MPI_Testany would be straightforward when modelled this way.

MPI_REQUEST_NULL:

Modelling MPI_REQUEST_NULL would be done similar to the way nullptrs are handled in the llvm dialect. Since this is an immutable constant value, we are okay with it existing outside of a memref.

%requests = memref.alloca() : memref<2x!mpi.request>
%null_req = mpi.request_null : -> !mpi.request
memref.store %null_req %request[%c0] : memref<2x!mpi.request>

Communicators

MPI communicators are at the heart of many HPC programs. They give rise to interesting structures and allow to abstract away complexity in selecting communication partners as well as providing guaranteed separation for library code. We introduce the !mpi.comm type to model communicators. As an example, here is how we imagine MPI_Comm_split and MPI_Comm_dup to work:

%comm_world = mpi.comm_world : !mpi.comm

%split = mpi.comm_split %comm_world by %color, %key : (!mpi.comm, i32, i32) -> !mpi.comm

%dup = mpi.comm_dup %split : !mpi.comm -> !mpi.comm

// other communicator constants can be modelled like this:
%comm_null = mpi.comm_null : !mpi.comm
%comm_self = mpi.comm_self : !mpi.comm

The patch that introduces communicators would add an !mpi.comm argument to every operation that requires a communicator.

Case Study: Cartesian Topology

We also want to look at how we would model Cartesian communicators:

%comm_world = mpi.comm_world : !mpi.comm
%nodes = mpi.comm_size %comm_world : !mpi.comm -> i32

%dims = memref.alloca : memref<3xi32>
// initialize to [0,0,2]
memref.store %c0, %dims[0] : memref<3xi32>
memref.store %c0, %dims[1] : memref<3xi32>
memref.store %c2, %dims[2] : memref<3xi32>

// int MPI_Dims_create(int nnodes, int ndims, int dims[])
// ndims will be inferred from the memref size.
// results will be written back into %dims
mpi.dims_create %nodes, %dims : %i32, memref<3x132>

// periods = [true, true, false]
%periods = memref.alloca : memref<3xi32>
// memref initialization left out for brevity

%reorder = arith.constant true : i1

%comm_cart = mpi.cart_create %comm_world, %dims, %periods, %reorder : (!mpi.comm, memref<3xi32>, memref<3xi32>, i1) -> !mpi.comm

Here are the documentation pages of OpenMPI for reference: MPI_Comm_size, MPI_Dims_create and MPI_Cart_create. Using the created Cartesian communicator would look like this:

// get number of dims
%dims = mpi.cartdim_get %comm_cart : !mpi.comm -> i32

// allocate a memref to hold cartesian coordinates:
%coords = memref.alloca(%dims) : memref<?xi32>

// get rank in communicator
%rank = mpi.comm_rank %comm_cart : !mpi.comm -> i32

// translate rank to cartesian coordinates:
mpi.cart_coords %comm_cart, %coords : !mpi.comm, memref<?xi32>

// update rank
mock.calc_dest_coords %coords : memref<?xi32>

// translate back into dest rank:
%rank = mpi.cart_rank %comm_cart, %coords : !mpi.comm, memref<?xi32> -> i32

This uses MPI_Cartdim_get, MPI_Comm_rank, MPI_Cart_coords and MPI_Cart_rank.

Notes:

  • MPI_Cart_rank expects the array to have exactly ndims elements, which we can’t universally verify at compile time.

We hope that this illustrates that the concept of MPI Communicators can be broadly mapped to MLIR in a consistent fashion.

One can see that mapping MPI_Group operations can be done in an analogous fashion to topologies.

Collectives / Operations

The easiest case of an MPI_Allreduce using MPI_SUM can be modelled like this:

%sum = mpi.op sum : !mpi.op
%outref = memref.alloc() : !memref<100xf32>

mpi.allreduce %ref with %sum into %outref on %my_comm : memref<100xf32>

// with MPI_IN_PLACE, replace `into` $dest with `in_place`
mpi.allreduce %ref with %sum in_place on %my_comm : memref<100xf32>

A simple MPI_Reduce poses an additional challenge, as the result buffer is only written to on rank 0, meaning we would
not want to allocate a full memref on each rank. Our idea is to allow unsized memref arguments on the destination.

%rank = mpi.comm_rank %my_comm : i32
%root = arith.constant 0 : i32
%is_root = arith.cmpi eq, %rank, %root : i32

// allocate memref only on root rank
%dest = scf.if %is_root -> (memref<?xf32>) {
    %ref = memref.alloc() : memref<100xf32>
    %unsized = memref.cast %ref : memref<100xf32> to memref<?xf32>
    scf.yield %unsized : memref<?xf32>
} else {
    %ref_empty = memref.alloc() : memref<0xf32>
    %unsized_empty = memref.cast %ref_empty : memref<0xf32> to memref<?xf32>
    scf.yield %unsized_empty : memref<?xf32>
}

mpi.reduce %data with %sum into %dest rank %rank on %my_comm : memref<100xf32>, !mpi.op, memref<?xf32>, i32, !mpi.comm

// in-place
mpi.reduce %data with %sum in_place rank %rank on %my_comm : memref<100xf32>, !mpi.op, i32, !mpi.comm

scf.if %is_root {
    %sized = memref.cast %dest : memref<?xf32> to memref<100xf32>
    // use data
}

MPI_Scatter and MPI_Gather can be modelled in similar ways to these operations.

The conditional allocation could be provided in a helper operation:

%dest_ref = mpi.allocate_on_rank %my_rank, %rank, memref<100xf32> -> memref<?xf32>

Defining custom MPI_Ops using MPI_Op_create:

// generates an operator with validity for a single datatype:
mpi.operator @mpi_custom_op (%in: memref<?xf32>, %inout: memref<?xf32>) {
    // runtime assert could be inserted into this function
    // compute operator
}

%commute = arith.constant 1: i32

%custom_op = mpi.op_create @mpi_custom_op, %commute : i32 -> !mpi.op

MPI requires the following format for user supplied functions:

typedef void MPI_User_function(
    void *invec, 
    void *inoutvec,
    int *len,
    MPI_Datatype *datatype
);

Modelling and inspecting MPI_Datatype at runtime as part of a custom op is currently not part of this RFC, but could be added if it is actually needed.

edit: We introduced mpi.operator instead of re-using func.func for the operator body since the operator needs to be lowered by MPI.

Handling MPI_Status

In order to handle MPI Status, we would introduce an optional result value of type !mpi.status. The MPI_Status is defined to be a struct with at least three fields (MPI_SOURCE, MPI_TAG and MPI_ERROR). Additionally, one can get the number of elements sent the from a status object using the MPI_Get_count function. We provide an accessor operation for these fields and additional operations for MPI_Get_count.

%status = mpi.send (%ref, %rank, %tag) : (memref<100xf32>, i32, i32) -> !mpi.status

// access struct members:
%source = mpi.status_get_field %status[MPI_SOURCE] : !mpi.status -> i32
%tag = mpi.status_get_field %status[MPI_TAG] : !mpi.status -> i32
%err = mpi.status_get_field %status[MPI_ERROR] : !mpi.status -> !mpi.retval

// using the MPI_Get_count function to access get the element count:
%count = mpi.get_count %status : !mpi.status -> i32

Lowering and Differences in ABI

This part gets into the ABI differences between implementation. We highly recommend the paper on MPI Application Binary Interface Standardization as a primer for this section.

We have implemented an example showing off how we lower our initial patch to both MPICH and OpenMPI style ABIs (using xDSL for quick prototyping). We target the llvm dialect directly because we need access to low-level concepts like pointers, structs, etc. We hope that the messy output below is enough argument in favour of introducing the MPI dialect abstraction:

// RUN: xdsl-opt %s | xdsl-opt -p "lower-mpi{vendor=mpich}"| filecheck %s --check-prefix=MPICH
// RUN: xdsl-opt %s | xdsl-opt -p "lower-mpi{vendor=ompi}" | filecheck %s --check-prefix=OMPI

"builtin.module"() ({
    func.func @mpi_example(%ref : memref<100xf32>, %dest : i32, %tag : i32) {
        mpi.init

        %rank = mpi.comm.rank : i32

        "mpi.send"(%ref, %dest, %tag) : (memref<100xf32>, i32, i32) -> ()

        "mpi.recv"(%ref, %dest, %tag) : (memref<100xf32>, i32, i32) -> ()

        mpi.finalize

        func.return
    }
}) : () -> ()


// Lowering to OpenMPI's opaque struct pointers:

// OMPI:      builtin.module {
// OMPI-NEXT:   func.func @mpi_example(%ref : memref<100xf32>, %dest : i32, %tag : i32) {
// OMPI-NEXT:     %0 = "llvm.mlir.null"() : () -> !llvm.ptr
// OMPI-NEXT:     %1 = "llvm.call"(%0, %0) {"callee" = @MPI_Init, "fastmathFlags" = #llvm.fastmath<none>} : (!llvm.ptr, !llvm.ptr) -> i32
// OMPI-NEXT:     %rank = "llvm.mlir.addressof"() {"global_name" = @ompi_mpi_comm_world} : () -> !llvm.ptr
// OMPI-NEXT:     %rank_1 = arith.constant 1 : i64
// OMPI-NEXT:     %rank_2 = "llvm.alloca"(%rank_1) {"alignment" = 32 : i64, "elem_type" = i32} : (i64) -> !llvm.ptr
// OMPI-NEXT:     %rank_3 = "llvm.call"(%rank, %rank_2) {"callee" = @MPI_Comm_rank, "fastmathFlags" = #llvm.fastmath<none>} : (!llvm.ptr, !llvm.ptr) -> i32
// OMPI-NEXT:     %rank_4 = "llvm.load"(%rank_2) : (!llvm.ptr) -> i32
// OMPI-NEXT:     %2 = "memref.extract_aligned_pointer_as_index"(%ref) : (memref<100xf32>) -> index
// OMPI-NEXT:     %3 = "arith.index_cast"(%2) : (index) -> i64
// OMPI-NEXT:     %4 = "llvm.inttoptr"(%3) : (i64) -> !llvm.ptr
// OMPI-NEXT:     %5 = arith.constant 100 : i32
// OMPI-NEXT:     %6 = "llvm.mlir.addressof"() {"global_name" = @ompi_mpi_float} : () -> !llvm.ptr
// OMPI-NEXT:     %7 = "llvm.mlir.addressof"() {"global_name" = @ompi_mpi_comm_world} : () -> !llvm.ptr
// OMPI-NEXT:     %8 = "llvm.call"(%4, %5, %6, %dest, %tag, %7) {"callee" = @MPI_Send, "fastmathFlags" = #llvm.fastmath<none>} : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// OMPI-NEXT:     %9 = "memref.extract_aligned_pointer_as_index"(%ref) : (memref<100xf32>) -> index
// OMPI-NEXT:     %10 = "arith.index_cast"(%9) : (index) -> i64
// OMPI-NEXT:     %11 = "llvm.inttoptr"(%10) : (i64) -> !llvm.ptr
// OMPI-NEXT:     %12 = arith.constant 100 : i32
// OMPI-NEXT:     %13 = "llvm.mlir.addressof"() {"global_name" = @ompi_mpi_float} : () -> !llvm.ptr
// OMPI-NEXT:     %14 = "llvm.mlir.addressof"() {"global_name" = @ompi_mpi_comm_world} : () -> !llvm.ptr
// OMPI-NEXT:     %15 = "llvm.mlir.null"() : () -> !llvm.ptr
// OMPI-NEXT:     %16 = "llvm.call"(%11, %12, %13, %dest, %tag, %14, %15) {"callee" = @MPI_Recv, "fastmathFlags" = #llvm.fastmath<none>} : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// OMPI-NEXT:     %17 = "llvm.call"() {"callee" = @MPI_Finalize, "fastmathFlags" = #llvm.fastmath<none>} : () -> i32
// OMPI-NEXT:     func.return
// OMPI-NEXT:   }
// OMPI-NEXT:   "llvm.mlir.global"() ({
// OMPI-NEXT:   }) {"global_type" = i32, "sym_name" = "ompi_mpi_comm_world", "linkage" = #llvm.linkage<"external">, "addr_space" = 0 : i32} : () -> ()
// OMPI-NEXT:   "llvm.mlir.global"() ({
// OMPI-NEXT:   }) {"global_type" = i32, "sym_name" = "ompi_mpi_float", "linkage" = #llvm.linkage<"external">, "addr_space" = 0 : i32} : () -> ()
// OMPI-NEXT:   "llvm.func"() ({
// OMPI-NEXT:   }) {"sym_name" = "MPI_Init", "function_type" = !llvm.func<i32 (!llvm.ptr, !llvm.ptr)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// OMPI-NEXT:   "llvm.func"() ({
// OMPI-NEXT:   }) {"sym_name" = "MPI_Comm_rank", "function_type" = !llvm.func<i32 (!llvm.ptr, !llvm.ptr)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// OMPI-NEXT:   "llvm.func"() ({
// OMPI-NEXT:   }) {"sym_name" = "MPI_Send", "function_type" = !llvm.func<i32 (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// OMPI-NEXT:   "llvm.func"() ({
// OMPI-NEXT:   }) {"sym_name" = "MPI_Recv", "function_type" = !llvm.func<i32 (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// OMPI-NEXT:   "llvm.func"() ({
// OMPI-NEXT:   }) {"sym_name" = "MPI_Finalize", "function_type" = !llvm.func<i32 ()>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// OMPI-NEXT: }


// Lowering to MPICHs integer constants:

// MPICH:      builtin.module {
// MPICH-NEXT:   func.func @mpi_example(%ref : memref<100xf32>, %dest : i32, %tag : i32) {
// MPICH-NEXT:     %0 = "llvm.mlir.null"() : () -> !llvm.ptr
// MPICH-NEXT:     %1 = "llvm.call"(%0, %0) {"callee" = @MPI_Init, "fastmathFlags" = #llvm.fastmath<none>} : (!llvm.ptr, !llvm.ptr) -> i32
// MPICH-NEXT:     %rank = arith.constant 1140850688 : i32
// MPICH-NEXT:     %rank_1 = arith.constant 1 : i64
// MPICH-NEXT:     %rank_2 = "llvm.alloca"(%rank_1) {"alignment" = 32 : i64, "elem_type" = i32} : (i64) -> !llvm.ptr
// MPICH-NEXT:     %rank_3 = "llvm.call"(%rank, %rank_2) {"callee" = @MPI_Comm_rank, "fastmathFlags" = #llvm.fastmath<none>} : (i32, !llvm.ptr) -> i32
// MPICH-NEXT:     %rank_4 = "llvm.load"(%rank_2) : (!llvm.ptr) -> i32
// MPICH-NEXT:     %2 = "memref.extract_aligned_pointer_as_index"(%ref) : (memref<100xf32>) -> index
// MPICH-NEXT:     %3 = "arith.index_cast"(%2) : (index) -> i64
// MPICH-NEXT:     %4 = "llvm.inttoptr"(%3) : (i64) -> !llvm.ptr
// MPICH-NEXT:     %5 = arith.constant 100 : i32
// MPICH-NEXT:     %6 = arith.constant 1275069450 : i32
// MPICH-NEXT:     %7 = arith.constant 1140850688 : i32
// MPICH-NEXT:     %8 = "llvm.call"(%4, %5, %6, %dest, %tag, %7) {"callee" = @MPI_Send, "fastmathFlags" = #llvm.fastmath<none>} : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
// MPICH-NEXT:     %9 = "memref.extract_aligned_pointer_as_index"(%ref) : (memref<100xf32>) -> index
// MPICH-NEXT:     %10 = "arith.index_cast"(%9) : (index) -> i64
// MPICH-NEXT:     %11 = "llvm.inttoptr"(%10) : (i64) -> !llvm.ptr
// MPICH-NEXT:     %12 = arith.constant 100 : i32
// MPICH-NEXT:     %13 = arith.constant 1275069450 : i32
// MPICH-NEXT:     %14 = arith.constant 1140850688 : i32
// MPICH-NEXT:     %15 = arith.constant 1 : i32
// MPICH-NEXT:     %16 = "llvm.call"(%11, %12, %13, %dest, %tag, %14, %15) {"callee" = @MPI_Recv, "fastmathFlags" = #llvm.fastmath<none>} : (!llvm.ptr, i32, i32, i32, i32, i32, i32) -> i32
// MPICH-NEXT:     %17 = "llvm.call"() {"callee" = @MPI_Finalize, "fastmathFlags" = #llvm.fastmath<none>} : () -> i32
// MPICH-NEXT:     func.return
// MPICH-NEXT:   }
// MPICH-NEXT:   "llvm.func"() ({
// MPICH-NEXT:   }) {"sym_name" = "MPI_Init", "function_type" = !llvm.func<i32 (!llvm.ptr, !llvm.ptr)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// MPICH-NEXT:   "llvm.func"() ({
// MPICH-NEXT:   }) {"sym_name" = "MPI_Comm_rank", "function_type" = !llvm.func<i32 (i32, !llvm.ptr)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// MPICH-NEXT:   "llvm.func"() ({
// MPICH-NEXT:   }) {"sym_name" = "MPI_Send", "function_type" = !llvm.func<i32 (!llvm.ptr, i32, i32, i32, i32, i32)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// MPICH-NEXT:   "llvm.func"() ({
// MPICH-NEXT:   }) {"sym_name" = "MPI_Recv", "function_type" = !llvm.func<i32 (!llvm.ptr, i32, i32, i32, i32, i32, i32)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// MPICH-NEXT:   "llvm.func"() ({
// MPICH-NEXT:   }) {"sym_name" = "MPI_Finalize", "function_type" = !llvm.func<i32 ()>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()
// MPICH-NEXT: }

We slightly prefer supporting to multiple implementations through a toggle in the lowering instead of an MLIR runtime but don’t want to rule out anything yet. The ABI standardisation efforts put forth by Hammond et al. hint at a more unified landscape in the future.

MPI Error Codes

Almost all MPI functions return error codes (C int) (which are often ignored). We propose to add an optional result to all operations that can return error codes. This result value will be of type !mpi.retval, that can be queried against various error codes:

%err = mpi.send ...

// Check if returned value is MPI_SUCCESS
%is_success = mpi.retval_check %err = MPI_SUCCESS : !mpi.retval -> i1
%is_err_in_stat = mpi.retval_check %err = MPI_ERR_IN_STATUS : !mpi.retval -> i1

// in order to check gainst other classes of errors, one must first call
// MPI_Error_class
%err_class = mpi.error_class %err : !mpi.retval -> !mpi.retval

// Check against specific error code
%is_err_rank = mpi.retval_check %err_class = MPI_ERR_RANK : !mpi.retval -> i1

Note:

  • We could also model !mpi.retval as i32 if we wanted to. Although all the MPI error classes and codes are library dependent, so modelling it as int may not be that helpful anyways.

Open Questions:

Operation Naming

We make use of a pretty standard translation from MPI names to MLIR operation names and types, where the first _ is replaced by . and everything is lowercased. That way MPI_Comm_rank becomes mpi.comm_rank. We also introduce some operations that are needed due to MLIR abstraction (e.g. mpi.retval_check). We could prefix them similar to how it’s done in the LLVM dialect to become mpi.mlir.retval_check.

Supporting more MPI Datatypes

The current version can support many kinds of memref layouts in arguments by mapping them to MPI strided datatypes.
MPI is able to express even more datatypes like heterogeneous arrays and structs. This is however not explored as part of this
RFC.

Interaction with Sharding Dialect

We believe that this dialect can serve as a target for higher-level dialects like the sharding dialect recently introduced to MLIR. We are interested in seeing these two dialects connected, possibly through a mid-level dialect that models higher-level communication calls.

11 Likes

Thanks for the introduction. Is there anything that is open to play around with now? Btw will there be more ops added, such as scatter and gather approach?

Is there anything that is open to play around with now?

There isn’t much to play around with quite yet.

will there be more ops added, such as scatter and gather approach?

I didn’t add scatter/gather to the RFC as I felt they were very similar to the collectives showcased and could be modeled in the same way. I am sure that they will be added to the MPI dialect relatively quickly as they are important operations to many MPI programs. Perhaps it would make sense to add them more explicitly to the RFC. (I’ll add a quick comment in right now)

Thank you for the RFC and working on this! Couple of questions:

  1. why not model requests as a return of isend, irecv…? I think it’d be better having them as use-def chains.
  2. I’m worried that having lowerings to each vendor might bloat a lowering pass, any reason why this is preferred over a wrapper library?

why not model requests as a return of isend, irecv…? I think it’d be better having them as use-def chains.

We initially modeled them as operation results for exactly the same reason, but quickly found that it became cumbersome when used in combination with MPI_Waitall and others. Modelling them using a memref+index a lot closer to how they are used in many MPI programs.

I’m worried that having lowerings to each vendor might bloat a lowering pass, any reason why this is preferred over a wrapper library?

That is a very valid concern. I guess it comes down to preference of maintaining and compiling a wrapper library for each MPI library vs. a lowering. A slight argument in favor of putting things in a lowering would be speed, but I understand that maintainability is a more important requirement. In the end I am open to suggestions and advice as I have very limited experience in that regard.

Why? Can’t MPI_Waitall and others be modeled with variadic args?

What I’m thinking is that there are no guarantees on ABI stability even for a particular vendor, thus the lowerings would be tied to an specific major version + vendor of MPI. Updating a library is definitely simpler than a lowering pass.

Why? Can’t MPI_Waitall and others be modeled with variadic args?

Yes, that is certainly an option. But since Waitall expects a pointer to a contiguous array of request objects, the operation would need to allocate memory, which we decided against internally.

Furthermore, modelling it as variadic arguments makes it much more difficult to model a waitall over a runtime known length of request objects, or modelling a waitall to requests fired off in a loop.

What I’m thinking is that there are no guarantees on ABI stability even for a particular vendor, thus the lowerings would be tied to an specific major version + vendor of MPI. Updating a library is definitely simpler than a lowering pass.

That is a very valid point. But it actually might not quite as bad as it seems. MPICH has a stable ABI commitment, and OpenMPI is less strictly bound in some sense as it makes use of external global symbols whose values will be determined at link time. Nonetheless, you have a very valid point.

The overhead of doing an alloca for storing the requests should be minimal, OpenMP, CUDA, do this for kernel args.

Fair point, still I think there should be a better solution.

+1 on adding the dialect, some specific implementation details can be discussed in PRs.

1 Like

Hey, thanks for the RFC!

To me, it is not yet quite clear how we would benefit from exposing MPI as a dialect if we essentially make all function calls in MPI an operation in the dialect. I see that the MPI IR is much simpler than the lowered IR, which makes it much easier to target MPI from other dialects, so that’s a plus I guess. However, I am not sure whether this really warrants a full MPI dialect or if we could just solve this through other means of sharing lowering infrastructure.

So essentially, I am wondering if you explored any optimizations on the abstraction level of MPI. Do you think this is even possible, in particular with all the ABI differences that probably make performance very hard to grasp?

1 Like

With a full-fledged MPI dialect the optimiser has more optimizations opportunities.
In theory, you could combine two mpi.put into one larger mpi.put. Through the C library interface it is harder to achieve.

3 Likes

You could also represent halo operations as a set of semantically-rich ops, which can be tiled/fused/moved/simplified/joined with their “parent” ops (stencil, etc). A direct call to MPI library functions is less clear what the data refers to.

What is less clear to me is how we clearly represent the communication domains. Just channels, or could we have another (group? shard?) op that carries that semantics (and are transformed with their “parents”), and then we can infer (optimally) the channels that need to be created only after all optimizations were done?

Fair point, still I think there should be a better solution.

We should definitely explore different ways to model the request objects. Preserving the use-def chain would already be a huge gain in expressiveness for this dialect and enable us to do a lot more optimizations. I’m interested in hearing your thoughts on this.

So essentially, I am wondering if you explored any optimizations on the abstraction level of MPI. Do you think this is even possible, in particular with all the ABI differences that probably make performance very hard to grasp?

The easiest optimiztation to automatically do is probably converting a group of blocking calls to nonblocking calls with a single waitall:

    mpi.send(%ref1, %rank1, %tag) : memref<100xf32>, i32, i32
    mpi.send(%ref2, %rank2, %tag) : memref<100xf32>, i32, i32
    mpi.send(%ref3, %rank3, %tag) : memref<100xf32>, i32, i32
    mpi.recv(%ref4, %rank1, %tag) : memref<100xf32>, i32, i32
    mpi.recv(%ref5, %rank2, %tag) : memref<100xf32>, i32, i32
    mpi.recv(%ref6, %rank3, %tag) : memref<100xf32>, i32, i32
    // ...

Becomes:

    %requests = memref.alloca : memref<6x!mpi.request>
    mpi.isend(%ref1, %rank1, %tag) as %requests[0] : memref<100xf32>, i32, i32, memref<6x!mpi.request>
    mpi.isend(%ref2, %rank2, %tag) as %requests[1] : memref<100xf32>, i32, i32, memref<6x!mpi.request>
    mpi.isend(%ref3, %rank3, %tag) as %requests[2] : memref<100xf32>, i32, i32, memref<6x!mpi.request>
    mpi.irecv(%ref4, %rank1, %tag) as %requests[3] : memref<100xf32>, i32, i32, memref<6x!mpi.request>
    mpi.irecv(%ref5, %rank2, %tag) as %requests[4] : memref<100xf32>, i32, i32, memref<6x!mpi.request>
    mpi.irecv(%ref6, %rank3, %tag) as %requests[5] : memref<100xf32>, i32, i32, memref<6x!mpi.request>
    mpi.waitall %requests : memref<6x!mpi.request>
    // ...

This may not look like very much, but in this case we reduce the number of synchronization barriers from 6 down to 1, which would yield huge speedups.

If we want to do more complex rewrites we would need information on which send call is matched by which receive call, which is sadly impossible to do in general at the MPI level. We are already exploring higher-level message passing dialects that model this relationship though. Having an MPI dialect in MLIR would open up the possibility of targeting it from higher-level message passing abstractions. And I guess as is usual with MLIR, the juiciest optimizations would happen at higher levels than individual MPI calls.


Yes, representing halo operations using richer IR is absolutely something that we are exploring. Although these would probably not live within the MPI dialect, as the dialect is meant to model the MPI library more closely.

2 Likes

I see, thanks for the explanation!

I think the high-level message passing dialect you mentioned will be very interesting as an MLIR dialect, in particular for optimization. And, I think this MPI dialect is an important, incremental step in this direction, so a +1 from me :slight_smile:

1 Like

Another optimization that is thinkable is to overlap communication with computation

  1. replace blocking with non-blocking communication
  2. pull-up non-blocking operation as much as possible
  3. push-down wait call as much as possible

Said that, I suggest including non-blocking Collective Operations.

1 Like

If there is not going to be a wrapper with a stable ABI for different MPI implementations, then we need to have constructs that allow to more quickly bring in new implementations while avoiding code duplication. I imagine there is a lot of boilerplate there.

Absolutely, luckily it seems to be quite easy to parametrize the lowering, so that each ABI target only needs to provide four pieces of information:

  1. Procedure to instantiate an MPI_Datatype variable from an MLIR type like f32 or i64
  2. Procedure to materialize named MPI constants
  3. The exact structure of the MPI_Status struct
  4. A list of globals to add to the module

The example lowering for the initial part of the dialect at the example lowering implemented in xDSL showcases how we handle points 1,2 and 4.

Actually, now that I look at the code again, 1) could be implemented as a subset of 2), which would make the design even easier.

Now, there may be some more cases that we have not yet considered, but our preliminary results show that this could be reasonably straight forward without too much duplication.

I can see the advantage to hide the ABI specificity when lowering from a higher level dialect introduces the mpi dialect. It is less clear to me what would be the advantage of the mpi dialect being introduce in the lowering from a front-end like clang or flang. Do you have any advantage in mind?

Yes, this is a very desirable goal to reach. Although I fear that this may be somewhat difficult to do in an automated fashion at MPI level due to the pointer semantics (and our usage of memrefs to model those). A pass like this may need to introduce an additional allocation for the received data. This is why we are also looking into higher level message passing on tensors, which would allow us to ague about data flow much more easily.

Depending on the amount of high-level information preserved, one could imagine that one would gain the opportunity for more optimizations like the one described by @fschlimb above. Alternatively, the frontend could provide a slightly more well-structured message passing interface (e.g. through pragmas instead) that one could translate into MPI calls and optimize in MLIR.

Yep. In our prototype we have a higher-level construct to update halos which does exactly that.
With sufficient trust in alias analysis and understanding of the semantics of individual non-blocking operations this should be doable even without.

1 Like