MPI / SPMD dialect?

Does anyone know of any existing work towards an MPI or any SPMD dialect for MLIR? I’m interested in using MLIR for non-ML use cases in the HPC / data analytics realm.

Functionality that I would be interested in:

  1. Automatic sharding of large memrefs onto a cluster of known rank, and loop splitting
  2. Lowering of existing ops (affine.dma_*,tensor.scatter/gather, etc) into corresponding MPI data transfer calls

I haven’t been able to find anything specifically around SPMD. I’ve seen other questions on this discourse asking for similar things. I’m thinking of going ahead and writing my own, but would prefer to re-use existing projects or collaborate on any in development.

We have a rough version of an MPI dialect with some things implemented internally in our lab, it’s implemented in xDSL though: https://github.com/xdslproject/xdsl/blob/main/xdsl/dialects/mpi.py

We lower it to llvm using some dark magic with constant values and such, let me know if you have any questions!

1 Like

This looks very interesting. Can you explain the relation between XDSL and MLIR?

See the xDSL presentation at EuroLLVM a few months ago: 2023 EuroLLVM - Prototyping MLIR in Python - YouTube

1 Like

I would be very interested in seeing this materialize in MLIR.

@AntonLydike any plans to upstream some of your work?

1 Like

Upstreaming this is currently in the discussion, but not immediately planned. We are working on a prototype to show that this is indeed something useful, which should hopefully be done soon. I’m sure I will have updates by the time the LLVM dev meeting comes around!

@mehdi_amini Maybe we can have an ODM on that in the near future? We have a deadline coming up in two weeks, but after that maybe?

1 Like

Are you looking for something closely adhering to the MPI standard or something more high-level that works on tensors?

Ping me whenever you feel like you’re ready!

1 Like

I’m actually not using MLIR for machine learning so I’d prefer something a bit lower-level.

We are also cooking up a dialect to represent exchanges of rectangular regions of memrefs, let me know if you are interested in that!

Yes, very much!

From my perspective I’d also be interested in seeing the same abstractions allow tensors and DPS if possible: at the ML graph level, tensors have been proven critical for transformations and slicing between host and accelerator.

1 Like

We’ll have the next open meeting (11/2) on this topic with a proposal from @AntonLydike. Book your calendars :slight_smile:

4 Likes

We are also working on a set of dialects which shard and distribute tensors/arrays using MLIR. Our distribution semantics are higher-level than MPI. This makes it possible to potentially plug in any other communication framework.

@tathougies our work is actually focused on HPC, e.g. we have optimizations which transparently achieve good scalability on stencil-like codes (e.g. using shifted views). It also differs from typical tensors in MLIR as we allow in-place updates. We have a high-level dialect intended to eventually cover the Python array-API (Python array API standard — Python array API standard 2022.12 documentation).

We started the process of getting this into open-source to allow getting feedback. If you are interested, I can try to accelerate this process as much as possible.

1 Like

Would it be closer to [RFC] Sharding Framework Design for Device Mesh instead?

1 Like

I took a look at the Python array-API a bit, and I feel there’re some differences.

The mesh sharding framework is basically two-level (Although it is not finalized yet)

  1. the mesh.shard level: it is merely like a DTensor annotation, it should be compatible with the DTensor representation in the deep learning framework (for example the DTensor in PyTorch, which is also annotation).

  2. the SPMD IR level: it represents the real computation and communication, and most sharding / ccl optimization should be implemented here.

While the Python array-API seems to provide the user API to represent accurate slice, which could participate in the real computation and communication. Its functionality is close to the SPMD IR level but I don’t feel like it uses the SPMD format.

Its functionality is close to the SPMD IR level but I don’t feel like it uses the SPMD format.

@yaochengji which “SPMD format” and “SPMD IR” are you referring to?

Using the example of XLA GSPMD, where an convolution op is sharded on the H dimension (Assume its layout is NHWC).

The SPMD IR / Format looks like below

HloModule module, entry_computation_layout={(f32[128,224,224,3]{3,2,1,0}, f32[7,7,3,64]{3,2,1,0})->f32[128,56,112,64]{3,2,1,0}}

ENTRY %entry_spmd (param: f32[128,224,224,3], param.1: f32[7,7,3,64]) -> f32[128,56,112,64] {
  %iota = s32[128,117,224,3]{3,2,1,0} iota(), iota_dimension=1
  %constant.6 = s32[2]{0} constant({0, 1})
  %partition-id = u32[] partition-id()
  %dynamic-slice.2 = s32[1]{0} dynamic-slice(s32[2]{0} %constant.6, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape.1 = s32[] reshape(s32[1]{0} %dynamic-slice.2)
  %constant.10 = s32[] constant(112)
  %multiply = s32[] multiply(s32[] %reshape.1, s32[] %constant.10)
  %broadcast = s32[128,117,224,3]{3,2,1,0} broadcast(s32[] %multiply), dimensions={}
  %add = s32[128,117,224,3]{3,2,1,0} add(s32[128,117,224,3]{3,2,1,0} %iota, s32[128,117,224,3]{3,2,1,0} %broadcast)
  %constant.13 = s32[] constant(3)
  %broadcast.1 = s32[128,117,224,3]{3,2,1,0} broadcast(s32[] %constant.13), dimensions={}
  %compare = pred[128,117,224,3]{3,2,1,0} compare(s32[128,117,224,3]{3,2,1,0} %add, s32[128,117,224,3]{3,2,1,0} %broadcast.1), direction=GE
  %constant.14 = s32[] constant(227)
  %broadcast.2 = s32[128,117,224,3]{3,2,1,0} broadcast(s32[] %constant.14), dimensions={}
  %compare.1 = pred[128,117,224,3]{3,2,1,0} compare(s32[128,117,224,3]{3,2,1,0} %add, s32[128,117,224,3]{3,2,1,0} %broadcast.2), direction=LT
  %and = pred[128,117,224,3]{3,2,1,0} and(pred[128,117,224,3]{3,2,1,0} %compare, pred[128,117,224,3]{3,2,1,0} %compare.1)
  %param = f32[128,224,224,3]{3,2,1,0} parameter(0), sharding={replicated}
  %constant = s32[] constant(0)
  %constant.1 = s32[2]{0} constant({0, 112})
  %dynamic-slice = s32[1]{0} dynamic-slice(s32[2]{0} %constant.1, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape = s32[] reshape(s32[1]{0} %dynamic-slice)
  %dynamic-slice.1 = f32[128,112,224,3]{3,2,1,0} dynamic-slice(f32[128,224,224,3]{3,2,1,0} %param, s32[] %constant, s32[] %reshape, s32[] %constant, s32[] %constant), dynamic_slice_sizes={128,112,224,3}
  %lhs.copy.1 = f32[128,112,224,3]{3,2,1,0} copy(f32[128,112,224,3]{3,2,1,0} %dynamic-slice.1)
  %slice = f32[128,3,224,3]{3,2,1,0} slice(f32[128,112,224,3]{3,2,1,0} %lhs.copy.1), slice={[0:128], [109:112], [0:224], [0:3]}
  %collective-permute = f32[128,3,224,3]{3,2,1,0} collective-permute(f32[128,3,224,3]{3,2,1,0} %slice), channel_id=1, source_target_pairs={{0,1}}
  %slice.1 = f32[128,2,224,3]{3,2,1,0} slice(f32[128,112,224,3]{3,2,1,0} %lhs.copy.1), slice={[0:128], [0:2], [0:224], [0:3]}
  %collective-permute.1 = f32[128,2,224,3]{3,2,1,0} collective-permute(f32[128,2,224,3]{3,2,1,0} %slice.1), channel_id=2, source_target_pairs={{1,0}}
  %concatenate = f32[128,117,224,3]{3,2,1,0} concatenate(f32[128,3,224,3]{3,2,1,0} %collective-permute, f32[128,112,224,3]{3,2,1,0} %lhs.copy.1, f32[128,2,224,3]{3,2,1,0} %collective-permute.1), dimensions={1}
  %constant.4 = f32[] constant(0)
  %broadcast.3 = f32[128,117,224,3]{3,2,1,0} broadcast(f32[] %constant.4), dimensions={}
  %select = f32[128,117,224,3]{3,2,1,0} select(pred[128,117,224,3]{3,2,1,0} %and, f32[128,117,224,3]{3,2,1,0} %concatenate, f32[128,117,224,3]{3,2,1,0} %broadcast.3)
  %param.1 = f32[7,7,3,64]{3,2,1,0} parameter(1), sharding={replicated}
  %rhs.copy.1 = f32[7,7,3,64]{3,2,1,0} copy(f32[7,7,3,64]{3,2,1,0} %param.1)
  ROOT %convolution = f32[128,56,112,64]{3,2,1,0} convolution(f32[128,117,224,3]{3,2,1,0} %select, f32[7,7,3,64]{3,2,1,0} %rhs.copy.1), window={size=7x7 stride=2x2 pad=0_0x3_3}, dim_labels=b01f_01io->b01f
}

And its corresponding DTensor Annotation format is

HloModule module

ENTRY entry {
  %lhs = f32[128,224,224,3] parameter(0)
  %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
    sharding={devices=[1,2,1,1]0,1}
  %rhs = f32[7,7,3,64] parameter(1)
  %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
    sharding={replicated}
  ROOT %conv = f32[128,112,112,64] convolution(
    f32[128,224,224,3] %lhs.copy,
    f32[7,7,3,64] %rhs.copy),
    window={size=7x7 stride=2x2 pad=3_3x3_3},
    dim_labels=b01f_01io->b01f,
    sharding={devices=[1,2,1,1]0,1}
}

The RFC for the MPI dialect is now up: [RFC] `MPI` Dialect