[RFC] Sharding Framework Design for Device Mesh

We could treat MHLO/StableHLO ones as a restricted target/instance. And so yes a reference op here could be good.

In the ShardingPartitionPass I see that the handling of the operations that have a ShardingInterface is to change the result types.

  } else if (llvm::isa<ShardingInterface>(op)) {
    ValueRange results = op->getResults();
    for (Value res : results) {
      auto oldType = res.getType().dyn_cast<RankedTensorType>();
      if (!oldType) {
        op->emitOpError() << "fail to get ranked tensor type";
        return signalPassFailure();
      }
      FailureOr<RankedTensorType> newType =
          getShardedType(oldType, clusterShape);
      if (failed(newType)) {
        op->emitOpError() << "fail to get sharded type";
        return signalPassFailure();
      }
      res.setType(*newType);
    }
  }

Isn’t this too restrictive? How can you be sure that any operation would be partitionable like that? Maybe there could be operations that have to be decomposed into multiple operations when partitioned.
How would a convolution be handled here for example?

Maybe the ShardingInterface could have another method partition, that knows how to handle the given operation. Then there could be a concept hierarchy with this default implementation at the root.

Yeah, you are right, add a partition method to ShardingInterface is a good idea.

The design and code implementation for the sharding partition section is comparatively less mature than the others. I’ll refine it prior to the presentation.

There are also the mesh dialect collectives operations. The partitioning algorithm should be extensible to allow to target collectives from another dialect.
Maybe the 2 concepts deserve 2 separate operation interfaces.

  1. ShardingInterface is concerned with sharding annotations.
  2. PartitioningInterface is concerned with lowering a fully annotated program to collectives.

I have a slightly different perspective. I believe that if an operation needs to be decomposed during sharding, it would be more effective for the decomposition to occur in the materialization pass(I’d like to extract the materialization stage of original propagation pass to a single pass in the future design). This is because the Mesh CCL optimization passes might also have the opportunity to optimize the communication resulting from the decomposition.

Given this approach, the logic of the sharding partition pass could remain consistent with its current state.

1 Like

Makes sense. Then the sharding attribute in tensor values may potentially need to shelter more complexity.
With convolution for example should the tensor sharding encode halo / border overlap between adjacent tensor shards? Or should it be decomposed during materialization into halo exchange + convolution? For more info see GSPMD paper.

Good questions.

With convolution for example should the tensor sharding encode halo / border overlap between adjacent tensor shards?

In the current design, both mesh.annotate and the sharding attribute are not intended to be precise. For instance, both overlapping and non-overlapping configurations could be valid with the same annotation. Their primary use cases cater to frontend users who don’t have much knowledge how an op in implemented and wish to manually specify the sharding. On the other hand, the combination of ShardingOption and materialization is designed to describe the sharding of an operation with precision. The primary audience for this feature is developers working on auto-sharding algorithms.

Or should it be decomposed during materialization into halo exchange + convolution?

I believe it should be decomposed during materialization. Because it will benefit communication analysis and optimization if all communications are expressed explicitly.

@stellaraccident, do you think a working E2E sharding + lowering to SPMD with CCL ops + runtime implementation of collectives with some backend (CPU for example) is a requirement for this to land?
Can we leave the IR at the mesh dialect level and let others lower to their particular CCL runtime?

I know ultimately we would like to have a CCL dialect in (upstream) MLIR, but this PR will get quite big if we include it.

No, not at all. I think that this has been done enough times that we just need to come to consensus that we want to do it and what the scope/design is. Then the work proceeds incrementally.

Are we talking about this at the open design meeting this week?

1 Like

Yes, that is my understanding.

Wouldn’t this conflict with other uses of the econding attribute? This could potentially be behind an interface, allowing for the storage scheme to be unspecified.
I have a more general question. Why doesn’t types allow for arbitrary attributes the same way operations do? Or at least the tensor types.

It would, but also these are probably “stage bound” so you wouldn’t have these throughout. But indeed doing via a type interface is the easiest. Alternatively folks can have their dictionary attribute for encoding and do arbitrary extensions (have query mechanism that iterates dictionary) - this has cost of additional attribute storage per, which interface avoids. (Folks know this is one of my pet peeves with the encoding attribute, creates a bit of a first mover advantage, there have been very little overlap as it is mostly been disjoint but a type interface removes the issue & enables for customized lowering paths without enforcing a less efficient encoding).

Yes, this is what I was referring to at the very top of the thread. Jacques, it might just be a slow Monday morning, but I’m having trouble seeing how a type interface would help in this case? Perhaps you could elaborate?

Actually good point, I was thinking of the analysis part, for the propagation part it won’t help as it needs to create and to create it would need to know what to create/what to propagate — now if one had something like the ShapedType’s cloneWith, then it would work and that could be an attribute interface again. With an easy “fix” in that you have an encoding that has the sharding and then capture the rest as another attribute, in a nested doll fashion :slight_smile: Today, these aren’t really being mixed. And getting all frameworks to agree on what they want in tensor type or creating the a sum of all types seems unappealing. Just slapping in a dictionary for encoding and allowing arbitrary along a flow - which is rather easy to do and then you have helper function that walks the dictionary & uses interfaces etc to return the correct sharding or sets it. Now the same thing as with dictionary could be done with interfaces, which I think will be cheaper (but haven’t measured).

Are we talking about this at the open design meeting this week?

The presentation will cover the Mesh dialect, E2E sharding and lowering to SPMD, but ccl runtime implementation will not be included.

And I quite agree that runtime implementation should be included in the future to complete the landing.

I have been looking at a convolution example and it is unclear to me what should the operations mean after the materialization pass.

Here is example IR and its transformations.

Example input:

func.func public @main(
  %arg0: tensor<8x480x640x3xf32>,
  %arg1: tensor<4x5x3x16xf32>
) -> (tensor<8x477x636x16xf32>) {
  %0 = stablehlo.convolution(%arg0, %arg1)
    dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
    window = {
      stride = [1, 1],
      pad = [[0, 0], [0, 0]],
      lhs_dilate = [1, 1],
      rhs_dilate = [1, 1],
      reverse = [0, 0]
    } {
    batch_group_count = 1 : i64,
    feature_group_count = 1 : i64,
    precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
  } : (tensor<8x480x640x3xf32>, tensor<4x5x3x16xf32>) -> tensor<8x477x636x16xf32>
  return %0 : tensor<8x477x636x16xf32>
}

Annotated:

func.func public @main(
  %arg0: tensor<8x480x640x3xf32>,
  %arg1: tensor<4x5x3x16xf32>
) -> (tensor<8x477x636x16xf32>) {
  %0 = mesh.annotate %arg0 {
    sharding = [[], [], [0], []],
    required = true,
    as_result = true
  } : tensor<8x480x640x3xf32> -> tensor<8x480x640x3xf32>
  %1 = stablehlo.convolution(%0, %arg1)
    dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
    window = {
      stride = [1, 1],
      pad = [[0, 0], [0, 0]],
      lhs_dilate = [1, 1],
      rhs_dilate = [1, 1],
      reverse = [0, 0]
    } {
    batch_group_count = 1 : i64,
    feature_group_count = 1 : i64,
    precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
  } : (tensor<8x480x640x3xf32>, tensor<4x5x3x16xf32>) -> tensor<8x477x636x16xf32>
  return %1 : tensor<8x477x636x16xf32>
}

Indexing maps:

loop names: [    batch  image_h  image_w image_channels      kernel_h      kernel_w out_channels ]
loop types: [ parallel parallel parallel  reduction_sum reduction_sum reduction_sum     parallel ]
indexing maps:
(b, ih, iw, ic, kh, kw, oc) -> (b, ih, iw, ic)
(b, ih, iw, ic, kh, kw, oc) -> (kh, kw, ic, oc)
(b, ih, iw, ic, kh, kw, oc) -> (b, ih, iw, oc)

After sharding propagation:

func.func public @main(
  %arg0: tensor<8x480x640x3xf32>,
  %arg1: tensor<4x5x3x16xf32>
) -> (tensor<8x477x636x16xf32>) {
  %0 = mesh.annotate %arg0 {
    sharding = [[], [], [0], []],
    required = true,
    as_result = true
  } : tensor<8x480x640x3xf32> -> tensor<8x480x640x3xf32>
  %1 = mesh.annotate %0 {
    sharding = [[], [], [0], []],
  } : tensor<8x480x640x3xf32> -> tensor<8x480x640x3xf32>

  %2 = mesh.annotate %arg1 {
    sharding = [[], [], [], []],
    required = true,
    as_result = true
  } : tensor<4x5x3x16xf32> -> tensor<4x5x3x16xf32>
  %3 = mesh.annotate %2 {
    sharding = [[], [], [], []],
  } : tensor<4x5x3x16xf32> -> tensor<4x5x3x16xf32>

  %4 = stablehlo.convolution(%1, %3)
    dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
    window = {
      stride = [1, 1],
      pad = [[0, 0], [0, 0]],
      lhs_dilate = [1, 1],
      rhs_dilate = [1, 1],
      reverse = [0, 0]
    } {
    batch_group_count = 1 : i64,
    feature_group_count = 1 : i64,
    precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
    sharding = [[], [], [], [0], [], [], []]
  } : (tensor<8x480x640x3xf32>, tensor<4x5x3x16xf32>) -> tensor<8x477x636x16xf32>

  %5 = mesh.annotate %4 {
    sharding = [[], [], [0], []],
    required = true,
    as_result = true
  } : tensor<8x477x636x16xf32> -> tensor<8x477x636x16xf32>
  %6 = mesh.annotate %5 {
    sharding = [[], [], [0], []],
  } : tensor<8x477x636x16xf32> -> tensor<8x477x636x16xf32>

  return %6 : tensor<8x477x636x16xf32>
}

After materialization:

func.func public @main(
  %arg0: tensor<8x480x640x3xf32, #mesh.shard<[[], [], [0], []]>>,
  %arg1: tensor<4x5x3x16xf32, #mesh.shard<[[], [], [], []]>>
) -> (tensor<8x477x636x16xf32, #mesh.shard<[[], [], [0], []]>>) {
  // Halo exchange.
  %0 = ...
  %1 = ...

  %1 = stablehlo.convolution(%0, %1)
    dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
    window = {
      stride = [1, 1],
      pad = [[0, 0], [0, 0]],
      lhs_dilate = [1, 1],
      rhs_dilate = [1, 1],
      reverse = [0, 0]
    } {
    batch_group_count = 1 : i64,
    feature_group_count = 1 : i64,
    precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
    sharding = [[], [], [], [0], [], [], []]
  } : (
      tensor<8x480x640x3xf32, #mesh.shard<[[], [], [0], []]>>,
      tensor<4x5x3x16xf32, #mesh.shard<[[], [], [], []]>>
  ) -> tensor<8x477x636x16xf32, , #mesh.shard<[[], [], [0], []]>>

  return %1 : tensor<8x477x636x16xf32, #mesh.shard<[[], [], [0], []]>>
}

As you can see at the end the operation should be decomposed into halo exchange + convolution.
But this convolution does not really mean a sharded convolution. It is actually a collection convolutions on multiple tensor shards. It almost seems there should be a new operation there to represent this.
Should the sharding operator attribute change the meaning of the operation?

In my understanding, not only convolution, any sharded op is actually a collection of ops on multiple tensor shards. This is because a distributed tensor refers a collection of sharded tensors. Could you clarify more why the sharded conv (halo exchange not included) is different from other ops?

Should the sharding operator attribute change the meaning of the operation?

Sharding operation attribute won’t include the halo exchange information. This information is within the materialize method of conv op.

1 Like

Regarding the encoding tensor filed. In the call today someone expressed reservations about introducing an attribute dictionary in the tensor type. Is the main argument against this memory footprint and lookup speed? It will be convenient to be able to store arbitrary attributes without changing the tensor type source code.
We can leave the encoding member and potentially deprecate it to avoid breaking downstream code.

The alternative is to introduce a new sharding member.

I mentioned that you can already put a DictionnaryAttr here, my concern is that regardless what you put there, there is a problem of “contract”: encoding is designed as an opaque field, and different client adding different information without knowing about each other seems… unsafe?
It’s not clear to me how to keep a consistent encoding that way, how do you propagate the various informations from the dictionary through transformations?

1 Like

Operations can have arbitrary attributes and they also suffer from the problem how to handle them during transformations. Ultimately a transformation has to handle these unknown attributes. They can be either copied/propagated or dropped since you can’t handle anything else about their unknown semantics.