[RFC] Sharding Framework Design for Device Mesh

The original post of the design doc could be found in byteir repo. And the initial PR is also there.

As many of you are aware, machine learning models continue to grow in size and complexity, the need for efficient and flexible distribution strategies becomes paramount. The device mesh concept, combined with a robust sharding framework, can significantly enhance performance, scalability, and flexibility across various hardware setups.

I’m reaching out to this community for two primary reasons:

  1. Collaboration: If you’re already working on a sharding framework or see potential in this framework, I’d love to collaborate. There’s immense value in pooling our collective expertise to make sharding framework in MLIR robust and widely applicable.

  2. Feedback: Whether you’re an expert in the field or someone with a keen interest, your feedback is invaluable. Constructive criticism, suggestions, or even pointing out potential pitfalls can significantly enhance the quality and applicability of this framework.

Mesh Dialect

The mesh dialect contains a set of attributes, operations, interfaces and transformations that are useful for representing and optimization the computation on device mesh.

MeshShardingAttr

Attribute that extends tensor type to distributed tensor type.

Syntax:

#mesh.shard<
  ::llvm::ArrayRef<::mlir::ArrayAttr>   # axes
>

The mesh.shard attribute is an array composed of int64_t sub-arrays. The outer array’s maximum size is the rank of the related tensor plus one. For the i-th sub-array, if its value is [x, y]:

  • When i < rank, it indicates that the tensor’s i-th dimension is sharded along the x and y axes of the device mesh.
  • When i == rank, it signifies that the tensor represents a partial sum along the x and y axes. More partial types could be introduced if needed, e.g. partial-max, partial-min.

Example:

// the tensor is sharded on the first dimension along axis 0
tensor<4x8xf32, #mesh.shard<[[0]]>

// the tensor is sharded on the first dimension along axis 0 and it is also
// a partial-sum along axis 1.
tensor<4x8xf32, #mesh.shard<[[0], [], [1]]>

Parameters:

Parameter C++ type Description
axes ::llvm::ArrayRef<::mlir::ArrayAttr>

ShardingIteratorType Enum

Currently there’re only three sharding iterator types:

  • parallel: there should be an all-gather along the tensor dimension to get the full tensor.
  • reduction_sum: there should be an all-reduce-sum along the tensor dimension to get the full tensor. Other types of reduction could be introduced when needed, even a generic reduction type, where a payload body indicating what exactly the reduction is needes to be included.
  • invalid: it means the dimension cannot be sharded

mesh.cluster (mesh::ClusterOp)

Representing a mesh cluster

Syntax:

operation ::= `mesh.cluster` $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)` attr-dict

The mesh.cluster operation is a symbol operation that identifies a specific mesh cluster, which can be used for distributed computations across a mesh topology. The operation has three attributes:

  1. sym_name: This attribute uniquely identifies the name of the mesh cluster. This name serves as a symbolic reference to the cluster throughout the MLIR module, allowing for consistent referencing and easier debugging.

  2. rank: This attribute specifies the number of axes of the cluster. The rank indicates the dimensionality of the mesh cluster and can be used to determine the layout and the addressing space of the computation distributed across the mesh.

  3. dim_sizes: This attribute represents the device assignment along the axes of the cluster. Each integer in the array corresponds to the number of devices along a specific axis. If an integer value is <= 0, it implies that the number of devices along that axis is unknown. This flexibility allows for dynamic device assignment or configurations where the exact number of ·devices might not be determined during compile time.

Example:

// A device mesh cluster with 3 axes, the totol device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])

// A device mesh cluster with 2 axes, the totol device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.cluster @mesh1(rank = 2, dim_sizes = [4])

// A device mesh cluster with 2 axes, the totol device number is unknown
// The first dimension size is unknown and the second is 4
mesh.cluster @mesh1(rank = 2, dim_sizes = [0, 4])

// a func op running on @mesh0
func.func(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> attributes
                                                { mesh_cluster = @mesh0 } {
  ...
}

Interfaces: Symbol

Attributes:

Attribute MLIR Type Description
sym_name ::mlir::StringAttr string attribute
rank ::mlir::IntegerAttr 64-bit signless integer attribute
dim_sizes ::mlir::ArrayAttr 64-bit integer array attribute

mesh.idx (mesh::IdxOp)

Get the index of current device along specified mesh axis.

Syntax:

operation ::= `mesh.idx` attr-dict `:` type($result)

It is used in the SPMD format of IR. Constraints:

  1. The axis mush be non-negative and less than the total number of mesh axes.
  2. Its parent op must be a FuncOp with mesh_cluster attribute

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
axis ::mlir::IntegerAttr index attribute

Results:

Result Description
result Integer-like type with unknown platform-dependent bit width

mesh.size (mesh::SizeOp)

Get the device number along specified mesh axis.

Syntax:

operation ::= `mesh.size` attr-dict `:` type($result)

It is used in the SPMD format of IR.
Constraints:

  1. The axis mush be non-negative and less than the total number of mesh axes.
  2. Its parent op must be a FuncOp with mesh_cluster attribute

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
axis ::mlir::IntegerAttr index attribute

Results:

Result Description
result Integer-like type with unknown platform-dependent bit width

mesh.annotate (mesh::AnnotateOp)

Annotate on how a tensor is sharded across a mesh cluster.

Syntax:

operation ::= `mesh.annotate` $input attr-dict `:` type($input) `->` type($output)

The mesh.annotate operation is designed to specify and guide the sharding
behavior of a tensor value across a mesh topology. It offers both strict
requirements and hints for the sharding process, allowing for flexibility
in distributed computations. This operation has one operand and three
attributes:

  1. input: This operand represents the tensor value that needs to be
    annotated for sharding.

  2. sharding: An array of int64 arrays with a maximum size equal to the
    rank of the input tensor plus one. Each element of the outer array
    corresponds to a dimension of the input tensor, except for the last element
    which signifies the tensor as a partial-sum. Each inner int64 array lists
    the axes to shard on. An axis will be sharded along at most one input
    dimension. If an axis is not present in any of the inner arrays, it
    indicates that the tensor will be replicated along that axis in the mesh.

  3. required: A boolean attribute. When set to true, it mandates the
    compiler to adhere to the sharding annotation specified. If set to false,
    the sharding annotation serves merely as a hint, allowing the compiler
    discretion in optimizing the sharding strategy.

  4. as_result: A boolean attribute addressing the scenario when a tensor’s
    sharding annotation differs based on its context of use (either as a result
    or an operand). If true, the annotation applies to the operation that
    defines the tensor value. If false, the annotation pertains to specific
    users of the tensor value, indicating how it should be considered when used
    as an operand in subsequent operations.

Example:

// The first mesh.annotate op applies to op0, the second mesh.annotate op
// applies to op1, the third mesh.annotate op applies to op2
%0 = op0 ...
%1 = mesh.annotate %0 {sharding = [[0], [1]], required = true,
        as_result = true} : tensor<2x5xf32> -> tensor<2x5xf32>
%2 = mesh.annotate %1 {sharding = [[0]], required = true,
        as_result = false} : tensor<2x5xf32> -> tensor<2x5xf32>
%3 = op1(%2) : ...
%4 = mesh.annotate %1 {sharding = [[1]], required = true,
        as_result = false} : tensor<2x5xf32> -> tensor<2x5xf32>
%5 = op2(%4) : ...

// The mesh.annotation op applies to op0, the op1's operand has no
// annotation
%0 = op0 ...
%1 = mesh.annotate %0 {sharding = [[0], [1]], required = true,
        as_result = true} : tensor<2x5xf32> -> tensor<2x5xf32>
%2 = op1(%1) : ...

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
sharding ::mlir::ArrayAttr array attribute
required ::mlir::BoolAttr bool attribute
as_result ::mlir::BoolAttr bool attribute

Operands:

Operand Description
input Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
output Multi-dimensional array with a fixed number of dimensions

mesh.all_gather (mesh::AllGatherOp)

All-gather op in device mesh

Syntax:

operation ::= `mesh.all_gather` $src attr-dict `:` type($src) `->` type($result)

The operation is designed to facilitate all-gather computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr. It has one
attributes:

  1. mesh_axis: An array of int64 array, representing the axes of the device
    mesh where the all-gather operation will be applied.

Example:

%1 = mesh.all_gather %0 {mesh_axis = [[0], [1]]} :
  tensor<2x4xf32, #mesh.shard<[[0], [1]]>> -> tensor<2x4xf32>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
mesh_axis ::mlir::ArrayAttr array attribute

Operands:

Operand Description
src Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
result Multi-dimensional array with a fixed number of dimensions

mesh.all_reduce (mesh::AllReduceOp)

All-reduce op in device mesh

Syntax:

operation ::= `mesh.all_reduce` $src attr-dict `:` type($src) `->` type($result)

The operation is designed to facilitate all-reduce computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr. It has two
attributes:

  1. mesh_axis: An int64 array representing the axes of the device mesh
    where the all-reduce operation will be applied.

  2. reduction: Indicates the reduction method.

Example:

%1 = mesh.all_reduce %0 {reduction = "sum", mesh_axis = [0]} :
  tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0, 1]]>> -> tensor<2x4x8xf32, #mesh.shard<[[], [], [], [1]]>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
mesh_axis ::mlir::ArrayAttr 64-bit integer array attribute
reduction ::mlir::StringAttr string attribute

Operands:

Operand Description
src Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
result Multi-dimensional array with a fixed number of dimensions

mesh.all_to_all (mesh::AllToAllOp)

TODO

mesh.local_split (mesh::LocalSplitOp)

Split a ranked tensor locally

Syntax:

operation ::= `mesh.local_split` $src attr-dict `:` type($src) `->` type($result)

The operation represents spliting an ranked tensor locally specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr. It has one
attributes:

  1. sharding: An array of int64 arrays with a maximum size equal to the
    rank of the src tensor. Each element of the outer array corresponds to a
    dimension of the src tensor. Each inner int64 array lists the axes to
    split on. An axis will be sharded along at most one dimension, and it
    should not appears in the MeshShardingAttr of the src tensor.

Example:

%1 = mesh.local_split %0 {sharding = [[], [], [0]]} : tensor<2x4x8xf32, , #mesh.shard<[[], [1], []]>>
       -> tensor<2x4x8xf32, #mesh.shard<[[], [1], [0]]>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferDTensorInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
sharding ::mlir::ArrayAttr array attribute

Operands:

Operand Description
src Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
result Multi-dimensional array with a fixed number of dimensions

mesh.reduce_scatter (mesh::ReduceScatterOp)

Reduce-scatter op in device mesh

Syntax:

operation ::= `mesh.reduce_scatter` $src attr-dict `:` type($src) `->` type($result)

The operation is designed to facilitate reduce-scatter computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr. It has one
attributes:

  1. mesh_axis: An int64 array representing the axes of the device mesh
    where the all-reduce operation will be applied.

  2. reduction: Indicates the reduction method.

  3. tensor_axis: Indicates the axis to scatter.

Example:

%1 = mesh.reduce_scatter %0 {mesh_axis = [0], reduction = "sum", tensor_axis = 2 : i64} :
   tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>> -> tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferDTensorInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
mesh_axis ::mlir::ArrayAttr 64-bit integer array attribute
reduction ::mlir::StringAttr string attribute
tensor_axis ::mlir::IntegerAttr 64-bit signless integer attribute

Operands:

Operand Description
src Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
result Multi-dimensional array with a fixed number of dimensions

ShardingInterface

The ShardingInterface is an interface within MLIR that enables operations to
provide necessary information for sharding. This interface is primarily used for
sharding propagation. The interface is composed of four methods, out of which
two must be overridden for each operation type. The remaining two methods
provide default implementations which suffice for a majority of operation types.

Methods

getLoopIteratorTypes()

  • Description: This method provides a list of iterator types associated with the
    number of loops within the operation.
  • Return Type: SmallVectormlir::mesh::ShardingIteratorType
  • Details: The iterator types can be one of the following:
    • parallel: If the loop is sharded based on this iterator type, a subsequent
      all-gather is required after the sharded operation to produce the complete
      tensor.
    • reduction: When sharded on this loop, a subsequent all-reduce operation is
      essential post the sharded operation to generate the complete tensor.
    • invalid: This signifies that the loop cannot undergo sharding.

getIndexingMaps()

  • Description: Offers the indexing maps associated with the current operation.
  • Return Type: SmallVector
  • Details: These are affine maps, translating between loop iterators and tensor
    indices. Affine maps are formed from linear combinations and constants. The
    indexing maps of the operation results are restricted to projected
    permutations.

getShardingOption(OpBuilder &b)

  • Description: Given that certain operands or results of the operation may be
    annotated, this method leverages this information to deduce how the operation
    should be sharded.
  • Return Type: FailureOr
  • Details: ShardingOption is represented as an array of int64 arrays. The
    sub-array at the i-th position signifies the mesh axes the i-th loop will be
    sharded on.
  • Default implementation logic:
    1. Check for Existing Attribute: If the operation already possesses a
      ShardingOption attribute, return this attribute immediately.
    2. Initialization: Instantiate an empty `ShardingOption``. This should be an
      array containing int64 sub-arrays, each corresponding to a loop in the
      operation.
    3. Results Annotation Handling:
      • Iterate over all the results of the operation, If a result has an
        annotation:
        • Map the tensor dimensions to loop iterators.
        • Set the corresponding axes based on the mapped loop iterators.
        • In cases where there’s a conflict with previously set axes, it implies
          an invalid sharding annotation. In such instances, flag this
          inconsistency for subsequent error handling or correction.
    4. Operands Annotation Handling:
      • Iterate over all the operands of the operation, using the information
        from:
        • Reduction iterator loops and
        • Unhandled parallel iterator loops
      • Validate the remaining iterator loops. If discrepancies arise during
        validation, take appropriate corrective actions or raise errors.
    5. Replication of Mesh Axes: Any mesh axes that haven’t been addressed or
      mapped during the above steps should be treated as replicated axes.
    6. Return Logic:
      • If the constructed or modified ShardingOption is valid, return it.
      • If inconsistencies or errors were detected, return a `failure()``.

setShardingAnnotations(OpBuilder &b, const ShardingOption &option)

  • Description: Based on a given ShardingOption, this method annotates those
    operands and results which previously lacked sharding annotations.
  • Return Type: LogicalResult
  • Details: The primary role is to propagate sharding annotations throughout
    the operation based on the provided sharding options.
  • Default implementation logic:
    1. Results Annotation Handling: Given the constraints of the result indexing
      maps, which are limited to projected permutations, there can only be a
      single DimId across all the result indexing maps.
    • For parallel loop iterators: Establish and assign the corresponding axes
      based on the mapped loop iterators.
    • For reduction loops: Append additional axes to the end of the existing
      annotations to indicate their association with the reduction loops.
    1. Operands Annotation Handling: Operand annotations pose a more intricate
      challenge compared to results due to the possibility that they might not
      strictly adhere to projected permutations.
      • Here, we constrain the results of the operand’s indexing maps to a
        representation format: c_i * d_i + c_j * d_j + …, In this
        representation:
        • c_i and c_j denote constants. If a constant has a value of one, it may
          be excluded from the representation.
        • ​d_i and d_j represent the DimId.
      • In situations where the representation contains multiple DimIds:
        Sharding can only be applied to at most one of them. This constraint
        ensures that the operand annotations don’t introduce excessive complexity
        and retain predictability in their sharding behavior.

Sharding Propagation Pass

The sharding propagation pass aims to address two primary objectives:

  1. Sharding Annotation Completion: Computational graphs often have incomplete
    sharding annotations. This pass is designed to fill in these gaps.
  2. Distributed Tensor materialization: Once the computational graph is fully
    annotated, this pass will convert it into distributed tensors and incorporate
    concrete communication operations.

Implementation Logic:

  1. Backward Sharding Propagation:
  • Traverse all operations that implement the `ShardingInterface``, iterating
    in reverse order.
  • For each operation, invoke the getShardingOption and
    setShardingAnnotation methods.
  1. Forward Sharding Propagation:
  • Traverse all operations that implement the `ShardingInterface``, but this
    time in a non-reversed (forward) order.
  • Similarly, for each operation, call the getShardingOption and
    setShardingAnnotation methods.
  1. Annotation Operations Handling: Process all annotation operations in reverse
    order
    • Result Annotations (as_result = true): Extend the type of the annotated
      value by incorporating a MeshShardingAttr. This attribute is derived from
      the annotation operation itself.
    • Operand Annotations (as_result = false): Introduce additional communication
      operations. The final produced value will replace the result of the
      original annotation operation. Note: At this stage, the logic for
      communication creation can be kept straightforward. Further
      canonicalization and optimization of these communications can be executed
      later. The process can be categorized into three stages:
      • All-Reduce: If any reduction sharding axes are absent in the
        current annotation operation relative to its operand’s defining operation
        (which should also be an annotation operation with as_result = true),
        an all-reduce operation should be initialized.
      • All-Gather: Create an all-gather operation to reconstruct the
        complete tensor.
      • Local-Split: Launch a local-split operation to derive the final
        sharded tensor.

Collective Communication Optimization Passes

After the sharding propagation pass, collective communication optimization aims
to further streamline and optimize communication operations. Some passes are
list below as examples:

All-Reduce Folder Pass

  • Purpose: To consolidate successive all-reduce operations for efficiency.

  • Description: This pass identifies scenarios where one all-reduce operation feeds
    directly into another. When detected, the to-reduce mesh axes are expanded,
    leading to a folded representation and reduced redundancy.

All-Reduce Reassociate Pass

  • Purpose: To streamline multiple all-reduce operations acting on elementwise
    operations.

  • Description: This pass identifies patterns where multiple all-reduce
    operations are applied to the results of elementwise operations. Upon detection,
    the pass reassociates these operations to reduce the number of collective
    communications. For instance, the sequence add(all-reduce(x), all-reduce(y))
    would be transformed into all-reduce(add(x,y)).

Reduce-Scatter Reassociate Pass

  • Purpose: To optimize multiple reduce-scatter operations that act on
    elementwise operations.

  • Description: This pass detects patterns where multiple reduce-scatter
    operations are applied to the results of elementwise operations. When such
    patterns are identified, the pass reassociates these operations to consolidate
    the collective communications. As an example, a sequence like
    add(reduce-scatter(x), reduce-scatter(y)) would be reshaped into
    reduce-scatter(add(x,y)).

All-Gather Move Down Pass

  • Purpose: To reposition all-gather operations for improved efficiency in the
    computational flow.

  • Description: This pass targets scenarios where an all-gather operation
    precedes operations that have a parallel loop type for gathering. In such
    situations, the all-gather operation is shifted downwards in the sequence

Sharding Mutations

While the logic of the sharding propagation pass is designed for simplicity, it doesn’t always yield the most optimal outcome. An optional sharding mutation can be introduced to modify the sharding result of the IR. The sharding mutation is usually determined by the analysis of communication and computation based on current sharded IR.

Sharding Partition

This pass transforms distributed tensors into specific tensors for each device. Additionally, it converts mesh CCL operations into more defined CCL ops with device IDs. Based on various scenarios, we can opt for one of two types of result IR:

  1. The physical weights/arguments are already partitioned across different devices, eliminating the need to retain this information in the IR. In this case,
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func(%arg0: tensor<8xf32, #mesh.shard<[[0]]>>) -> () attributes { mesh_cluster = @mesh0 } {
  "use"(%arg0) ...
  ...
}

// will be converted to 

mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func(%arg0: tensor<4xf32>) -> () attributes { mesh_cluster = @mesh0 } {
  "use"(%arg0) ...
  ...
}

  1. The physical weights/arguments haven’t been partitioned for individual devices, necessitating knowledge of the actual slice information. In this case,
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func(%arg0: tensor<8xf32, #mesh.shard<[[0]]>>) -> () attributes { mesh_cluster = @mesh0 } {
  "use"(%arg0) ...
  ...
}

// will be converted to 

mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func(%arg0: tensor<8xf32>) -> () attributes { mesh_cluster = @mesh0 } {
  %idx = mesh.idx(0)
  %c4 = arith.constant 4 : i64
  %start = arith.muli %idx, %c4 : i64
  %arg0_slice = "mhlo.dynamic_slice"(%arg0, %start) {
    slice_sizes = dense<[4]> : tensor<1xi64>
  } : (tensor<8xf32>, i64) -> tensor<4xf32>
  "use"(%arg0_slice) ...
  ...
}

End2End Walkthrough Example

MLP 1D weight stationary with all sharding annotation on tensors

[2211.05102] Efficiently Scaling Transformer Inference, figure 2(a)

  1. Original IR
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func @mlp(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) -> tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
  %0 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %1 = "mhlo.dot_general"(%0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>,
                                      precision_config =
                                      [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x8xf32>, tensor<8x32xf32>) -> tensor<2x4x32xf32>
  %2 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
  %3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
  %4 = "mhlo.dot_general"(%3, %arg2) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>,
                                      precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x32xf32>, tensor<32x8xf32>) -> tensor<2x4x8xf32>
  %5 = mesh.annotate %4 {required = true, sharding = [[], [], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %6 = mesh.annotate %5 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  return %6 : tensor<2x4x8xf32>
}
  1. Loop types and indexing maps
%3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
loop types: [parallel parallel parallel ]
indexing maps:
(d0, d1, d2) -> (d0, d1, d2)
(d0, d1, d2) -> (d0, d1, d2)
(d0, d1, d2) -> (d0, d1, d2)
%1 = "mhlo.dot_general"(%0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>,
                                      precision_config =
                                      [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x8xf32>, tensor<8x32xf32>) -> tensor<2x4x32xf32>
loop types: [parallel parallel parallel reduction_sum ]
indexing maps:
(d0, d1, d2, d3) -> (d0, d1, d3)
(d0, d1, d2, d3) -> (d3, d2)
(d0, d1, d2, d3) -> (d0, d1, d2)
%4 = "mhlo.dot_general"(%3, %arg2) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>,
                                      precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x32xf32>, tensor<32x8xf32>) -> tensor<2x4x8xf32>
loop types: [parallel parallel parallel reduction_sum ]
indexing maps:
(d0, d1, d2, d3) -> (d0, d1, d3)
(d0, d1, d2, d3) -> (d3, d2)
(d0, d1, d2, d3) -> (d0, d1, d2)
  1. After annotation completion
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) -> tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
  %0 = mesh.annotate %arg1 {as_result = false, required = false, sharding = [[], [0]]} : tensor<8x32xf32> -> tensor<8x32xf32>
  %1 = mesh.annotate %arg2 {as_result = false, required = false, sharding = [[0]]} : tensor<32x8xf32> -> tensor<32x8xf32>
  %2 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %3 = mesh.annotate %2 {as_result = false, required = false, sharding = []} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %4 = "mhlo.dot_general"(%3, %0) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x32xf32>) -> tensor<2x4x32xf32>
  %5 = mesh.annotate %4 {required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %6 = mesh.annotate %5 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %7 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
  %8 = mesh.annotate %7 {required = false, sharding = []} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %9 = mesh.annotate %8 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %10 = mhlo.maximum %6, %9 {sharding = [[], [], [0]]} : tensor<2x4x32xf32>
  %11 = mesh.annotate %10 {required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %12 = mesh.annotate %11 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %13 = "mhlo.dot_general"(%12, %1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x32xf32>, tensor<32x8xf32>) -> tensor<2x4x8xf32>
  %14 = mesh.annotate %13 {required = true, sharding = [[], [], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %15 = mesh.annotate %14 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  return %15 : tensor<2x4x8xf32>
}
  1. After sharding materialization
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

  func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>, %arg1: tensor<8x32xf32, #mesh.shard<[[], [0]]>>, %arg2: tensor<32x8xf32, #mesh.shard<[[0]]>>) -> tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>> attributes { mesh_cluster = @mesh0 } {
    %0 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
    %1 = mesh.all_gather %arg0 {mesh_axis = [[], [], [0]], tensor_axis = [2]} : tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>> -> tensor<2x4x8xf32>
    %2 = "mhlo.dot_general"(%1, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x32xf32, #mesh.shard<[[], [0]]>>) -> tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
    %3 = mesh.local_split %0 {sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
    %4 = mhlo.maximum %2, %3 {sharding = [[], [], [0]]} : tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
    %5 = "mhlo.dot_general"(%4, %arg2) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>, tensor<32x8xf32, #mesh.shard<[[0]]>>) -> tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>>
    %6 = mesh.reduce_scatter %5 {mesh_axis = [0], reduction = "sum", tensor_axis = 2 : i64} : tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>> -> tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>
    return %6 : tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>
  }
  1. After sharding partition
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x4xf32>, %arg1: tensor<8x16xf32>, %arg2: tensor<16x8xf32>) -> tensor<2x4x4xf32> attributes {mesh_cluster = @mesh0} {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
  %1 = "mhlo.all_gather"(%arg0) {all_gather_dim = 2 : i64, channel_handle = #mhlo.channel_handle<handle = 0, type = 0>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<2x4x4xf32>) -> tensor<2x4x8xf32>
  %2 = "mhlo.dot_general"(%1, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x16xf32>) -> tensor<2x4x16xf32>
  %3 = mhlo.constant dense<0.000000e+00> : tensor<2x4x16xf32>
  %4 = mhlo.maximum %2, %3 {sharding = [[], [], [0]]} : tensor<2x4x16xf32>
  %5 = "mhlo.dot_general"(%4, %arg2) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x16xf32>, tensor<16x8xf32>) -> tensor<2x4x8xf32>
  %6 = "mhlo.reduce_scatter"(%5) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %7 = mhlo.add %arg3, %arg4 : tensor<f32>
    mhlo.return %7 : tensor<f32>
  }) {channel_handle = #mhlo.channel_handle<handle = 0, type = 0>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, scatter_dimension = 2 : i64} : (tensor<2x4x8xf32>) -> tensor<2x4x4xf32>
  return %6 : tensor<2x4x4xf32>
}

MLP 1D weight stationary with sharding option on operations

  1. Original IR
mesh.cluster @mesh0(rank = 1, dim_sizes = [8])

func.func @mlp_1d_weight_stationary_with_sharding_on_operation(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) -> tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
  %0 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %1 = "mhlo.dot_general"(%0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], 
                                      rhs_contracting_dimensions = [0]>, 
                                      precision_config = 
                                      [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x8xf32>, tensor<8x32xf32>) -> tensor<2x4x32xf32>
  %2 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
  %3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
  %4 = "mhlo.dot_general"(%3, %arg2) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>, 
                                      precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    sharding = [[], [], [], [0]]
  } : (tensor<2x4x32xf32>, tensor<32x8xf32>) -> tensor<2x4x8xf32>
  %6 = mesh.annotate %4 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  return %6 : tensor<2x4x8xf32>
}
  1. Result of the pass is same as the first example

Q & A

What are the differences and connections between MeshShardingAttr, mesh.annotate, and operation ShardingOption?

  1. MeshShardingAttr and mesh.annotate:
  • Purpose: Both aim to represent distributed tensors.
  • Differences: MeshShardingAttr serves as an optional encoding for
    RankedTensorType, offering a more concise expression. In contrast,
    mesh.annotate introduces an additional operation, ensuring that information
    isn’t lost after executing a pass.
  1. operation ShardingOption:
  • This pertains to the sharding annotations of an operation. It more precisely
    depicts how an operation is sharded. Due to its need for a deeper
    understanding of operation computations, it isn’t typically exposed to
    end-users.

Why isn’t the framework built based on the upstream TilingInterface system?

While utilizing the tiling interface might seem like a logical choice, it
actually introduces certain intricacies. This interface would encode specific
slices directly into the IR. However, during the sharding optimization phase,
this specific slice information isn’t necessary. Including it could
inadvertently complicate the implementation.

For instance, sharding naturally denotes that each device has an equal logical
division. If we then dive deeper into lower-level operations like
tensor.extract_slice/insert_slice, this even distribution information could
potentially be lost, which is counterintuitive for sharding.

How do we represent MOE (expert parallel)?

To represent MOE (expert parallel), a tensor would introduce an additional
expert dimension. Alongside, there should be a concept of a “virtual axis”.
This comes into play especially when the number of experts is less than the
number of devices on the physical axis. The approach is somewhat akin to how
it’s done with tensor parallel.

How does the sharding framework compare to XLA’s GSPMD?

  1. Unified Interface: For the majority of operations, only getIndexingMaps and
    getLoopIteratorTypes need to be implemented.
  2. Strategy Independence in Propagation: The propagation phase doesn’t employ
    sharding optimization strategies.
  3. Sharding Option at Operation Level: This provides an option for sharding at
    the operation level, making it more convenient for automatic parallel algorithms
    to set sharding strategies with precision.
  4. Explicit Communication Post Propagation: The results after propagation
    explicitly depict communication, facilitating efficient analysis and
    optimization.
4 Likes

Thank you for the detailed proposal. I’m still reviewing in detail but strongly +1 on the need and coming together to get this built out. We had discussed sponsoring some work in this area but hadn’t gotten past internal discussions.

@harsh-nod @sogartar

For a feature of this magnitude, I would probably go ahead and add a dedicated sharding attribute to TensorType (or extend it with an attribute dict).

1 Like

For a feature of this magnitude, I would probably go ahead and add a dedicated sharding attribute to TensorType (or extend it with an attribute dict).

Yeah, currently it is a dedicated sharding attribute named MeshShardingAttr.
Although it has only an ArrayAttr for now, it could be enhanced if more reduction types or features need to be supported.

Did you extend TensorType itself, or are you just using the existing encoding attribute? I’m referring to better factoring the TensorType itself, but this is a minor point to your overall proposal (it is just the thing I noticed straight off).

I haven’t reviewed your proposal, but I’d echo @stellaraccident on the need for something like this and some community collaborating on this here! It’s long overdue.

What about scheduling a presentation at an open-meeting on the design?

I am using the encoding of RankedTensorType.

I’ve opted not to extend TensorType because it imposes a significant burden on other dialects, such as mhlo, to support this new type. And I will retain the encoding attribute within the mesh transformations.

I’d be more than happy to schedule a presentation at an open meeting to discuss the design in detail. Please let me know a suitable time, and I’ll make the necessary arrangements.

1 Like

Thursday next week is open (or the week after) if you’re up to!

Thursday next week works for me, thanks

2 Likes

This looks good. I am surprised the PR is only 3.6K lines of code compared to what is in XLA. @yaochengji how tested would you say your current implementation is?
I like that you are modeling the device mesh abstraction explicitly with assignment of mesh to tensor dimensions. I wonder if we would ever need to have a finer-grained assignment as in XLA, where each tensor tile can be mapped to device(s) explicitly.
It is an interesting choice to bake the reduction into the tensor value type. What would you say is the advantage of this approach?
Have you taught how would this play with other auto sharding algorithms in the place of sharding propagation? The ShardingInterface seems to be designed in mind to operate in the context of sharding propagation. For example something like Alpa.

1 Like

I am surprised the PR is only 3.6K lines of code compared to what is in XLA. @yaochengji how tested would you say your current implementation is?

This PR only contains draft code and is intended for demonstration purposes. As such, it hasn’t been extensively tested. I’ve only tested the sharding propagation pass using figures 2.a and 2.b from the paper on efficiently scaling transformer inference and the sharding partition pass on figure 2.a.

I wonder if we would ever need to have a finer-grained assignment as in XLA, where each tensor tile can be mapped to device(s) explicitly.

Opting to specify the mesh axis only on the mesh’s CCL operations, rather than the device IDs, simplifies the implementation of passes at the mesh level. Furthermore, this approach can support a dynamic number of device IDs, facilitating symbolic analysis based on the number of devices. After the sharding partition, the mesh’s CCL operations will be converted to standard CCL operations with designated device IDs.

Have you taught how would this play with other auto sharding algorithms in the place of sharding propagation?

For auto-sharding algorithms, such as Alpa, they conduct sharding analysis using their proprietary data structures and generate optimized sharding annotations for tensors. Subsequently, they depend on XLA for remaining tasks, including sharding propagation. My design can accommodate this approach, and it also offers an alternative method for auto-sharding where the data structure can be reused for analysis.
A typical workflow might look like this:

  1. An initial sharding strategy is set.
  2. Sharding propagation and general optimization are applied.
  3. Metrics for guiding the next sharding mutation are computed. These might include:
    a. Calculating communication volume from the explicitly represented mesh’s CCL operations. This can be symbolic, viewed from the perspective of shapes or the number of devices.
    b. FLOPs of redundant computation.
    c. Memory required for a single device’s computation.
  4. Sharding mutation is applied, followed by general optimization. The sharding mutation can take the form of a transform operation in MLIR: transform.mesh.mutate %op, [new_op_sharding_option]. Users could mutate an operation based on either the sharding annotations on tensors or the sharding options on operations. The latter is more recommended, as detailed in the Q&A section.
  5. Repeat step 3 until an optimal sharding result is achieved.
  6. Execute sharding partition so that each device receives a sharded computation graph.

I see that from the Mesh dialect you lower to MHLO. Isn’t it better to target StableHLO GitHub - openxla/stablehlo: Backward compatible ML compute opset inspired by HLO/MHLO instead?

If this gets upstreamed to the MLIR repo this part will have to live somewhere else. As there should be no dependency on StableHLO. Maybe these should be rewrite patterns in the StableHLO repo.

In

The indexing maps of the operation results are restricted to projected permutations.

what do you mean by projected permutations?

I see that from the Mesh dialect you lower to MHLO. Isn’t it better to target StableHLO GitHub - openxla/stablehlo: Backward compatible ML compute opset inspired by HLO/MHLO instead?

Because the pass pipeline would be

general_graph_passes -> mesh_passes -> general_graph_passes -> ...

StableHLO is primarily used as an input IR, while the majority of the passes are implemented in MHLO. If I opt for StableHLO as the dialect responsible for implementing the ShardingInterface and for lowering the target from mesh CCL operations, this would necessitate additional conversions both before and after the mesh passes in the aforementioned pipeline.

general_graph_passes_on_mhlo -> convert_to_stablehlo -> mesh_passes -> 
convert_to_mhlo -> general_graph_passes_on_mhlo -> ...

If this gets upstreamed to the MLIR repo this part will have to live somewhere else. As there should be no dependency on StableHLO. Maybe these should be rewrite patterns in the StableHLO repo.

We don’t have to choose MHLO/StableHLO if we want to gets it upstreamed to MLIR repo. But I don’t think there existing Dialect(s) that could used as the payload ops during mesh transformations as convenient as MHLO.

  1. Tosa: doesn’t have optimizations as many as MHLO
  2. Linalg/Tensor, linalg.generic inherits DestinationStyleOpInterface, which complicates handling of the init.

Meanwhile, MLIR repo lacks CCL operations.

I think the dialect is indeed (or should) be orthogonal to this. So using MHLO or StableHLO upstream seems fine. I think here it should be about the core representations and passes that are independent of dialect with which it is used.

In tree, it may be good to have an end-to-end with TOSA. It would at least enable testing many of the same things. And avoid only detecting issues very late.

I’m looking forward to the presentation!

what do you mean by projected permutations?

If you look at the code, you will soon know that.

I could also provide some examples here

(d0, d1, d2) -> (d1, d0) // is projected permutation
(d0, d1, d2) -> (d2, d0, d1) // is projected permutation
(do, d1, d2) -> (d0 + d1) // not projected permutation

TOSA does not have collective operations.

Yes, the Mesh dialect is orthogonal to the payload dialect, such as MHLO.

Additionally, the CCL operations after lowering from Mesh don’t have to reside in the same dialect as the payload dialect.

I think it’s just time to have reference CCL ops built out as part of this contribution. I’d even be fine if they were part of the mesh dialect to start if we don’t want to commit to a proper namespace for them.

1 Like

I see 2 drawbacks of StableHLO/MHLO collectives:

  1. No support for dynamic number of devices. The set of devices is baked into an operation attribute.
  2. The different modes of operation cross-replica, cross-partition, cross-replica-and-partition, use_global_device_ids are just unnecessary complications. Probably a legacy from XLA.