The original post of the design doc could be found in byteir repo. And the initial PR is also there.
As many of you are aware, machine learning models continue to grow in size and complexity, the need for efficient and flexible distribution strategies becomes paramount. The device mesh concept, combined with a robust sharding framework, can significantly enhance performance, scalability, and flexibility across various hardware setups.
I’m reaching out to this community for two primary reasons:

Collaboration: If you’re already working on a sharding framework or see potential in this framework, I’d love to collaborate. There’s immense value in pooling our collective expertise to make sharding framework in MLIR robust and widely applicable.

Feedback: Whether you’re an expert in the field or someone with a keen interest, your feedback is invaluable. Constructive criticism, suggestions, or even pointing out potential pitfalls can significantly enhance the quality and applicability of this framework.
Mesh Dialect
The mesh
dialect contains a set of attributes, operations, interfaces and transformations that are useful for representing and optimization the computation on device mesh.
MeshShardingAttr
Attribute that extends tensor type to distributed tensor type.
Syntax:
#mesh.shard<
::llvm::ArrayRef<::mlir::ArrayAttr> # axes
>
The mesh.shard attribute is an array composed of int64_t subarrays. The outer array’s maximum size is the rank of the related tensor plus one. For the ith subarray, if its value is [x, y]:
 When i <
rank
, it indicates that the tensor’s ith dimension is sharded along the x and y axes of the device mesh.  When i ==
rank
, it signifies that the tensor represents a partial sum along the x and y axes. More partial types could be introduced if needed, e.g. partialmax, partialmin.
Example:
// the tensor is sharded on the first dimension along axis 0
tensor<4x8xf32, #mesh.shard<[[0]]>
// the tensor is sharded on the first dimension along axis 0 and it is also
// a partialsum along axis 1.
tensor<4x8xf32, #mesh.shard<[[0], [], [1]]>
Parameters:
Parameter  C++ type  Description 

axes  ::llvm::ArrayRef<::mlir::ArrayAttr> 
ShardingIteratorType Enum
Currently there’re only three sharding iterator types:
 parallel: there should be an allgather along the tensor dimension to get the full tensor.
 reduction_sum: there should be an allreducesum along the tensor dimension to get the full tensor. Other types of reduction could be introduced when needed, even a generic reduction type, where a payload body indicating what exactly the reduction is needes to be included.
 invalid: it means the dimension cannot be sharded
mesh.cluster
(mesh::ClusterOp)
Representing a mesh cluster
Syntax:
operation ::= `mesh.cluster` $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)` attrdict
The mesh.cluster operation is a symbol operation that identifies a specific mesh cluster, which can be used for distributed computations across a mesh topology. The operation has three attributes:

sym_name
: This attribute uniquely identifies the name of the mesh cluster. This name serves as a symbolic reference to the cluster throughout the MLIR module, allowing for consistent referencing and easier debugging. 
rank
: This attribute specifies the number of axes of the cluster. The rank indicates the dimensionality of the mesh cluster and can be used to determine the layout and the addressing space of the computation distributed across the mesh. 
dim_sizes
: This attribute represents the device assignment along the axes of the cluster. Each integer in the array corresponds to the number of devices along a specific axis. If an integer value is <= 0, it implies that the number of devices along that axis is unknown. This flexibility allows for dynamic device assignment or configurations where the exact number of ·devices might not be determined during compile time.
Example:
// A device mesh cluster with 3 axes, the totol device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
// A device mesh cluster with 2 axes, the totol device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
// A device mesh cluster with 2 axes, the totol device number is unknown
// The first dimension size is unknown and the second is 4
mesh.cluster @mesh1(rank = 2, dim_sizes = [0, 4])
// a func op running on @mesh0
func.func(%arg0 : tensor<?x?xf32>) > tensor<?x?xf32> attributes
{ mesh_cluster = @mesh0 } {
...
}
Interfaces: Symbol
Attributes:
Attribute  MLIR Type  Description 

sym_name 
::mlir::StringAttr  string attribute 
rank 
::mlir::IntegerAttr  64bit signless integer attribute 
dim_sizes 
::mlir::ArrayAttr  64bit integer array attribute 
mesh.idx
(mesh::IdxOp)
Get the index of current device along specified mesh axis.
Syntax:
operation ::= `mesh.idx` attrdict `:` type($result)
It is used in the SPMD format of IR. Constraints:
 The
axis
mush be nonnegative and less than the total number of mesh axes.  Its parent op must be a FuncOp with
mesh_cluster
attribute
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute  MLIR Type  Description 

axis 
::mlir::IntegerAttr  index attribute 
Results:
Result  Description 

result 
Integerlike type with unknown platformdependent bit width 
mesh.size
(mesh::SizeOp)
Get the device number along specified mesh axis.
Syntax:
operation ::= `mesh.size` attrdict `:` type($result)
It is used in the SPMD format of IR.
Constraints:
 The
axis
mush be nonnegative and less than the total number of mesh axes.  Its parent op must be a FuncOp with
mesh_cluster
attribute
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute  MLIR Type  Description 

axis 
::mlir::IntegerAttr  index attribute 
Results:
Result  Description 

result 
Integerlike type with unknown platformdependent bit width 
mesh.annotate
(mesh::AnnotateOp)
Annotate on how a tensor is sharded across a mesh cluster.
Syntax:
operation ::= `mesh.annotate` $input attrdict `:` type($input) `>` type($output)
The mesh.annotate operation is designed to specify and guide the sharding
behavior of a tensor value across a mesh topology. It offers both strict
requirements and hints for the sharding process, allowing for flexibility
in distributed computations. This operation has one operand and three
attributes:

input
: This operand represents the tensor value that needs to be
annotated for sharding. 
sharding
: An array of int64 arrays with a maximum size equal to the
rank of the input tensor plus one. Each element of the outer array
corresponds to a dimension of the input tensor, except for the last element
which signifies the tensor as a partialsum. Each inner int64 array lists
the axes to shard on. An axis will be sharded along at most one input
dimension. If an axis is not present in any of the inner arrays, it
indicates that the tensor will be replicated along that axis in the mesh. 
required
: A boolean attribute. When set to true, it mandates the
compiler to adhere to the sharding annotation specified. If set to false,
the sharding annotation serves merely as a hint, allowing the compiler
discretion in optimizing the sharding strategy. 
as_result
: A boolean attribute addressing the scenario when a tensor’s
sharding annotation differs based on its context of use (either as a result
or an operand). If true, the annotation applies to the operation that
defines the tensor value. If false, the annotation pertains to specific
users of the tensor value, indicating how it should be considered when used
as an operand in subsequent operations.
Example:
// The first mesh.annotate op applies to op0, the second mesh.annotate op
// applies to op1, the third mesh.annotate op applies to op2
%0 = op0 ...
%1 = mesh.annotate %0 {sharding = [[0], [1]], required = true,
as_result = true} : tensor<2x5xf32> > tensor<2x5xf32>
%2 = mesh.annotate %1 {sharding = [[0]], required = true,
as_result = false} : tensor<2x5xf32> > tensor<2x5xf32>
%3 = op1(%2) : ...
%4 = mesh.annotate %1 {sharding = [[1]], required = true,
as_result = false} : tensor<2x5xf32> > tensor<2x5xf32>
%5 = op2(%4) : ...
// The mesh.annotation op applies to op0, the op1's operand has no
// annotation
%0 = op0 ...
%1 = mesh.annotate %0 {sharding = [[0], [1]], required = true,
as_result = true} : tensor<2x5xf32> > tensor<2x5xf32>
%2 = op1(%1) : ...
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute  MLIR Type  Description 

sharding 
::mlir::ArrayAttr  array attribute 
required 
::mlir::BoolAttr  bool attribute 
as_result 
::mlir::BoolAttr  bool attribute 
Operands:
Operand  Description 

input 
Multidimensional array with a fixed number of dimensions 
Results:
Result  Description 

output 
Multidimensional array with a fixed number of dimensions 
mesh.all_gather
(mesh::AllGatherOp)
Allgather op in device mesh
Syntax:
operation ::= `mesh.all_gather` $src attrdict `:` type($src) `>` type($result)
The operation is designed to facilitate allgather computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr
. It has one
attributes:
mesh_axis
: An array of int64 array, representing the axes of the device
mesh where the allgather operation will be applied.
Example:
%1 = mesh.all_gather %0 {mesh_axis = [[0], [1]]} :
tensor<2x4xf32, #mesh.shard<[[0], [1]]>> > tensor<2x4xf32>>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute  MLIR Type  Description 

mesh_axis 
::mlir::ArrayAttr  array attribute 
Operands:
Operand  Description 

src 
Multidimensional array with a fixed number of dimensions 
Results:
Result  Description 

result 
Multidimensional array with a fixed number of dimensions 
mesh.all_reduce
(mesh::AllReduceOp)
Allreduce op in device mesh
Syntax:
operation ::= `mesh.all_reduce` $src attrdict `:` type($src) `>` type($result)
The operation is designed to facilitate allreduce computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr
. It has two
attributes:

mesh_axis
: An int64 array representing the axes of the device mesh
where the allreduce operation will be applied. 
reduction
: Indicates the reduction method.
Example:
%1 = mesh.all_reduce %0 {reduction = "sum", mesh_axis = [0]} :
tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0, 1]]>> > tensor<2x4x8xf32, #mesh.shard<[[], [], [], [1]]>>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute  MLIR Type  Description 

mesh_axis 
::mlir::ArrayAttr  64bit integer array attribute 
reduction 
::mlir::StringAttr  string attribute 
Operands:
Operand  Description 

src 
Multidimensional array with a fixed number of dimensions 
Results:
Result  Description 

result 
Multidimensional array with a fixed number of dimensions 
mesh.all_to_all
(mesh::AllToAllOp)
TODO
mesh.local_split
(mesh::LocalSplitOp)
Split a ranked tensor locally
Syntax:
operation ::= `mesh.local_split` $src attrdict `:` type($src) `>` type($result)
The operation represents spliting an ranked tensor locally specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr
. It has one
attributes:
sharding
: An array of int64 arrays with a maximum size equal to the
rank of thesrc
tensor. Each element of the outer array corresponds to a
dimension of thesrc
tensor. Each inner int64 array lists the axes to
split on. An axis will be sharded along at most one dimension, and it
should not appears in theMeshShardingAttr
of thesrc
tensor.
Example:
%1 = mesh.local_split %0 {sharding = [[], [], [0]]} : tensor<2x4x8xf32, , #mesh.shard<[[], [1], []]>>
> tensor<2x4x8xf32, #mesh.shard<[[], [1], [0]]>>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, InferDTensorInterface, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute  MLIR Type  Description 

sharding 
::mlir::ArrayAttr  array attribute 
Operands:
Operand  Description 

src 
Multidimensional array with a fixed number of dimensions 
Results:
Result  Description 

result 
Multidimensional array with a fixed number of dimensions 
mesh.reduce_scatter
(mesh::ReduceScatterOp)
Reducescatter op in device mesh
Syntax:
operation ::= `mesh.reduce_scatter` $src attrdict `:` type($src) `>` type($result)
The operation is designed to facilitate reducescatter computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr
. It has one
attributes:

mesh_axis
: An int64 array representing the axes of the device mesh
where the allreduce operation will be applied. 
reduction
: Indicates the reduction method. 
tensor_axis
: Indicates the axis to scatter.
Example:
%1 = mesh.reduce_scatter %0 {mesh_axis = [0], reduction = "sum", tensor_axis = 2 : i64} :
tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>> > tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, InferDTensorInterface, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute  MLIR Type  Description 

mesh_axis 
::mlir::ArrayAttr  64bit integer array attribute 
reduction 
::mlir::StringAttr  string attribute 
tensor_axis 
::mlir::IntegerAttr  64bit signless integer attribute 
Operands:
Operand  Description 

src 
Multidimensional array with a fixed number of dimensions 
Results:
Result  Description 

result 
Multidimensional array with a fixed number of dimensions 
ShardingInterface
The ShardingInterface
is an interface within MLIR that enables operations to
provide necessary information for sharding. This interface is primarily used for
sharding propagation. The interface is composed of four methods, out of which
two must be overridden for each operation type. The remaining two methods
provide default implementations which suffice for a majority of operation types.
Methods
getLoopIteratorTypes()
 Description: This method provides a list of iterator types associated with the
number of loops within the operation.  Return Type: SmallVectormlir::mesh::ShardingIteratorType
 Details: The iterator types can be one of the following:
 parallel: If the loop is sharded based on this iterator type, a subsequent
allgather is required after the sharded operation to produce the complete
tensor.  reduction: When sharded on this loop, a subsequent allreduce operation is
essential post the sharded operation to generate the complete tensor.  invalid: This signifies that the loop cannot undergo sharding.
 parallel: If the loop is sharded based on this iterator type, a subsequent
getIndexingMaps()
 Description: Offers the indexing maps associated with the current operation.
 Return Type: SmallVector
 Details: These are affine maps, translating between loop iterators and tensor
indices. Affine maps are formed from linear combinations and constants. The
indexing maps of the operation results are restricted to projected
permutations.
getShardingOption(OpBuilder &b)
 Description: Given that certain operands or results of the operation may be
annotated, this method leverages this information to deduce how the operation
should be sharded.  Return Type: FailureOr
 Details:
ShardingOption
is represented as an array of int64 arrays. The
subarray at the ith position signifies the mesh axes the ith loop will be
sharded on.  Default implementation logic:
 Check for Existing Attribute: If the operation already possesses a
ShardingOption
attribute, return this attribute immediately.  Initialization: Instantiate an empty `ShardingOption``. This should be an
array containing int64 subarrays, each corresponding to a loop in the
operation.  Results Annotation Handling:
 Iterate over all the results of the operation, If a result has an
annotation: Map the tensor dimensions to loop iterators.
 Set the corresponding axes based on the mapped loop iterators.
 In cases where there’s a conflict with previously set axes, it implies
an invalid sharding annotation. In such instances, flag this
inconsistency for subsequent error handling or correction.
 Iterate over all the results of the operation, If a result has an
 Operands Annotation Handling:
 Iterate over all the operands of the operation, using the information
from: Reduction iterator loops and
 Unhandled parallel iterator loops
 Validate the remaining iterator loops. If discrepancies arise during
validation, take appropriate corrective actions or raise errors.
 Iterate over all the operands of the operation, using the information
 Replication of Mesh Axes: Any mesh axes that haven’t been addressed or
mapped during the above steps should be treated as replicated axes.  Return Logic:
 If the constructed or modified ShardingOption is valid, return it.
 If inconsistencies or errors were detected, return a `failure()``.
 Check for Existing Attribute: If the operation already possesses a
setShardingAnnotations(OpBuilder &b, const ShardingOption &option)
 Description: Based on a given ShardingOption, this method annotates those
operands and results which previously lacked sharding annotations.  Return Type: LogicalResult
 Details: The primary role is to propagate sharding annotations throughout
the operation based on the provided sharding options.  Default implementation logic:
 Results Annotation Handling: Given the constraints of the result indexing
maps, which are limited to projected permutations, there can only be a
single DimId across all the result indexing maps.
 For parallel loop iterators: Establish and assign the corresponding axes
based on the mapped loop iterators.  For reduction loops: Append additional axes to the end of the existing
annotations to indicate their association with the reduction loops.
 Operands Annotation Handling: Operand annotations pose a more intricate
challenge compared to results due to the possibility that they might not
strictly adhere to projected permutations. Here, we constrain the results of the operand’s indexing maps to a
representation format: c_i * d_i + c_j * d_j + …, In this
representation: c_i and c_j denote constants. If a constant has a value of one, it may
be excluded from the representation.  d_i and d_j represent the
DimId
.
 c_i and c_j denote constants. If a constant has a value of one, it may
 In situations where the representation contains multiple
DimId
s:
Sharding can only be applied to at most one of them. This constraint
ensures that the operand annotations don’t introduce excessive complexity
and retain predictability in their sharding behavior.
 Here, we constrain the results of the operand’s indexing maps to a
 Results Annotation Handling: Given the constraints of the result indexing
Sharding Propagation Pass
The sharding propagation pass aims to address two primary objectives:
 Sharding Annotation Completion: Computational graphs often have incomplete
sharding annotations. This pass is designed to fill in these gaps.  Distributed Tensor materialization: Once the computational graph is fully
annotated, this pass will convert it into distributed tensors and incorporate
concrete communication operations.
Implementation Logic:
 Backward Sharding Propagation:
 Traverse all operations that implement the `ShardingInterface``, iterating
in reverse order.  For each operation, invoke the
getShardingOption
and
setShardingAnnotation
methods.
 Forward Sharding Propagation:
 Traverse all operations that implement the `ShardingInterface``, but this
time in a nonreversed (forward) order.  Similarly, for each operation, call the
getShardingOption
and
setShardingAnnotation
methods.
 Annotation Operations Handling: Process all annotation operations in reverse
order Result Annotations (as_result = true): Extend the type of the annotated
value by incorporating a MeshShardingAttr. This attribute is derived from
the annotation operation itself.  Operand Annotations (as_result = false): Introduce additional communication
operations. The final produced value will replace the result of the
original annotation operation. Note: At this stage, the logic for
communication creation can be kept straightforward. Further
canonicalization and optimization of these communications can be executed
later. The process can be categorized into three stages: AllReduce: If any reduction sharding axes are absent in the
current annotation operation relative to its operand’s defining operation
(which should also be an annotation operation withas_result
= true),
an allreduce operation should be initialized.  AllGather: Create an allgather operation to reconstruct the
complete tensor.  LocalSplit: Launch a localsplit operation to derive the final
sharded tensor.
 AllReduce: If any reduction sharding axes are absent in the
 Result Annotations (as_result = true): Extend the type of the annotated
Collective Communication Optimization Passes
After the sharding propagation pass, collective communication optimization aims
to further streamline and optimize communication operations. Some passes are
list below as examples:
AllReduce Folder Pass

Purpose: To consolidate successive allreduce operations for efficiency.

Description: This pass identifies scenarios where one allreduce operation feeds
directly into another. When detected, the toreduce mesh axes are expanded,
leading to a folded representation and reduced redundancy.
AllReduce Reassociate Pass

Purpose: To streamline multiple allreduce operations acting on elementwise
operations. 
Description: This pass identifies patterns where multiple allreduce
operations are applied to the results of elementwise operations. Upon detection,
the pass reassociates these operations to reduce the number of collective
communications. For instance, the sequenceadd(allreduce(x), allreduce(y))
would be transformed intoallreduce(add(x,y))
.
ReduceScatter Reassociate Pass

Purpose: To optimize multiple reducescatter operations that act on
elementwise operations. 
Description: This pass detects patterns where multiple reducescatter
operations are applied to the results of elementwise operations. When such
patterns are identified, the pass reassociates these operations to consolidate
the collective communications. As an example, a sequence like
add(reducescatter(x), reducescatter(y))
would be reshaped into
reducescatter(add(x,y))
.
AllGather Move Down Pass

Purpose: To reposition allgather operations for improved efficiency in the
computational flow. 
Description: This pass targets scenarios where an allgather operation
precedes operations that have a parallel loop type for gathering. In such
situations, the allgather operation is shifted downwards in the sequence
Sharding Mutations
While the logic of the sharding propagation pass is designed for simplicity, it doesn’t always yield the most optimal outcome. An optional sharding mutation can be introduced to modify the sharding result of the IR. The sharding mutation is usually determined by the analysis of communication and computation based on current sharded IR.
Sharding Partition
This pass transforms distributed tensors into specific tensors for each device. Additionally, it converts mesh CCL operations into more defined CCL ops with device IDs. Based on various scenarios, we can opt for one of two types of result IR:
 The physical weights/arguments are already partitioned across different devices, eliminating the need to retain this information in the IR. In this case,
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])
func.func(%arg0: tensor<8xf32, #mesh.shard<[[0]]>>) > () attributes { mesh_cluster = @mesh0 } {
"use"(%arg0) ...
...
}
// will be converted to
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])
func.func(%arg0: tensor<4xf32>) > () attributes { mesh_cluster = @mesh0 } {
"use"(%arg0) ...
...
}
 The physical weights/arguments haven’t been partitioned for individual devices, necessitating knowledge of the actual slice information. In this case,
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])
func.func(%arg0: tensor<8xf32, #mesh.shard<[[0]]>>) > () attributes { mesh_cluster = @mesh0 } {
"use"(%arg0) ...
...
}
// will be converted to
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])
func.func(%arg0: tensor<8xf32>) > () attributes { mesh_cluster = @mesh0 } {
%idx = mesh.idx(0)
%c4 = arith.constant 4 : i64
%start = arith.muli %idx, %c4 : i64
%arg0_slice = "mhlo.dynamic_slice"(%arg0, %start) {
slice_sizes = dense<[4]> : tensor<1xi64>
} : (tensor<8xf32>, i64) > tensor<4xf32>
"use"(%arg0_slice) ...
...
}
End2End Walkthrough Example
MLP 1D weight stationary with all sharding annotation on tensors
[2211.05102] Efficiently Scaling Transformer Inference, figure 2(a)
 Original IR
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])
func.func @mlp(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) > tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
%0 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
%1 = "mhlo.dot_general"(%0, %arg1) {
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [0]>,
precision_config =
[#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
} : (tensor<2x4x8xf32>, tensor<8x32xf32>) > tensor<2x4x32xf32>
%2 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
%3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
%4 = "mhlo.dot_general"(%3, %arg2) {
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [0]>,
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
} : (tensor<2x4x32xf32>, tensor<32x8xf32>) > tensor<2x4x8xf32>
%5 = mesh.annotate %4 {required = true, sharding = [[], [], [], [0]]} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
%6 = mesh.annotate %5 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
return %6 : tensor<2x4x8xf32>
}
 Loop types and indexing maps
%3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
loop types: [parallel parallel parallel ]
indexing maps:
(d0, d1, d2) > (d0, d1, d2)
(d0, d1, d2) > (d0, d1, d2)
(d0, d1, d2) > (d0, d1, d2)
%1 = "mhlo.dot_general"(%0, %arg1) {
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [0]>,
precision_config =
[#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
} : (tensor<2x4x8xf32>, tensor<8x32xf32>) > tensor<2x4x32xf32>
loop types: [parallel parallel parallel reduction_sum ]
indexing maps:
(d0, d1, d2, d3) > (d0, d1, d3)
(d0, d1, d2, d3) > (d3, d2)
(d0, d1, d2, d3) > (d0, d1, d2)
%4 = "mhlo.dot_general"(%3, %arg2) {
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [0]>,
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
} : (tensor<2x4x32xf32>, tensor<32x8xf32>) > tensor<2x4x8xf32>
loop types: [parallel parallel parallel reduction_sum ]
indexing maps:
(d0, d1, d2, d3) > (d0, d1, d3)
(d0, d1, d2, d3) > (d3, d2)
(d0, d1, d2, d3) > (d0, d1, d2)
 After annotation completion
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])
func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) > tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
%0 = mesh.annotate %arg1 {as_result = false, required = false, sharding = [[], [0]]} : tensor<8x32xf32> > tensor<8x32xf32>
%1 = mesh.annotate %arg2 {as_result = false, required = false, sharding = [[0]]} : tensor<32x8xf32> > tensor<32x8xf32>
%2 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
%3 = mesh.annotate %2 {as_result = false, required = false, sharding = []} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
%4 = "mhlo.dot_general"(%3, %0) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x32xf32>) > tensor<2x4x32xf32>
%5 = mesh.annotate %4 {required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> > tensor<2x4x32xf32>
%6 = mesh.annotate %5 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> > tensor<2x4x32xf32>
%7 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
%8 = mesh.annotate %7 {required = false, sharding = []} : tensor<2x4x32xf32> > tensor<2x4x32xf32>
%9 = mesh.annotate %8 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> > tensor<2x4x32xf32>
%10 = mhlo.maximum %6, %9 {sharding = [[], [], [0]]} : tensor<2x4x32xf32>
%11 = mesh.annotate %10 {required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> > tensor<2x4x32xf32>
%12 = mesh.annotate %11 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> > tensor<2x4x32xf32>
%13 = "mhlo.dot_general"(%12, %1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x32xf32>, tensor<32x8xf32>) > tensor<2x4x8xf32>
%14 = mesh.annotate %13 {required = true, sharding = [[], [], [], [0]]} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
%15 = mesh.annotate %14 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
return %15 : tensor<2x4x8xf32>
}
 After sharding materialization
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])
func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>, %arg1: tensor<8x32xf32, #mesh.shard<[[], [0]]>>, %arg2: tensor<32x8xf32, #mesh.shard<[[0]]>>) > tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>> attributes { mesh_cluster = @mesh0 } {
%0 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
%1 = mesh.all_gather %arg0 {mesh_axis = [[], [], [0]], tensor_axis = [2]} : tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>> > tensor<2x4x8xf32>
%2 = "mhlo.dot_general"(%1, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x32xf32, #mesh.shard<[[], [0]]>>) > tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
%3 = mesh.local_split %0 {sharding = [[], [], [0]]} : tensor<2x4x32xf32> > tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
%4 = mhlo.maximum %2, %3 {sharding = [[], [], [0]]} : tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
%5 = "mhlo.dot_general"(%4, %arg2) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>, tensor<32x8xf32, #mesh.shard<[[0]]>>) > tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>>
%6 = mesh.reduce_scatter %5 {mesh_axis = [0], reduction = "sum", tensor_axis = 2 : i64} : tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>> > tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>
return %6 : tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>
}
 After sharding partition
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])
func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x4xf32>, %arg1: tensor<8x16xf32>, %arg2: tensor<16x8xf32>) > tensor<2x4x4xf32> attributes {mesh_cluster = @mesh0} {
%0 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
%1 = "mhlo.all_gather"(%arg0) {all_gather_dim = 2 : i64, channel_handle = #mhlo.channel_handle<handle = 0, type = 0>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<2x4x4xf32>) > tensor<2x4x8xf32>
%2 = "mhlo.dot_general"(%1, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x16xf32>) > tensor<2x4x16xf32>
%3 = mhlo.constant dense<0.000000e+00> : tensor<2x4x16xf32>
%4 = mhlo.maximum %2, %3 {sharding = [[], [], [0]]} : tensor<2x4x16xf32>
%5 = "mhlo.dot_general"(%4, %arg2) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x16xf32>, tensor<16x8xf32>) > tensor<2x4x8xf32>
%6 = "mhlo.reduce_scatter"(%5) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%7 = mhlo.add %arg3, %arg4 : tensor<f32>
mhlo.return %7 : tensor<f32>
}) {channel_handle = #mhlo.channel_handle<handle = 0, type = 0>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, scatter_dimension = 2 : i64} : (tensor<2x4x8xf32>) > tensor<2x4x4xf32>
return %6 : tensor<2x4x4xf32>
}
MLP 1D weight stationary with sharding option on operations
 Original IR
mesh.cluster @mesh0(rank = 1, dim_sizes = [8])
func.func @mlp_1d_weight_stationary_with_sharding_on_operation(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) > tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
%0 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
%1 = "mhlo.dot_general"(%0, %arg1) {
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [0]>,
precision_config =
[#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
} : (tensor<2x4x8xf32>, tensor<8x32xf32>) > tensor<2x4x32xf32>
%2 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
%3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
%4 = "mhlo.dot_general"(%3, %arg2) {
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [0]>,
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
sharding = [[], [], [], [0]]
} : (tensor<2x4x32xf32>, tensor<32x8xf32>) > tensor<2x4x8xf32>
%6 = mesh.annotate %4 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> > tensor<2x4x8xf32>
return %6 : tensor<2x4x8xf32>
}
 Result of the pass is same as the first example
Q & A
What are the differences and connections between MeshShardingAttr, mesh.annotate, and operation ShardingOption?
 MeshShardingAttr and mesh.annotate:
 Purpose: Both aim to represent distributed tensors.
 Differences: MeshShardingAttr serves as an optional encoding for
RankedTensorType, offering a more concise expression. In contrast,
mesh.annotate introduces an additional operation, ensuring that information
isn’t lost after executing a pass.
 operation ShardingOption:
 This pertains to the sharding annotations of an operation. It more precisely
depicts how an operation is sharded. Due to its need for a deeper
understanding of operation computations, it isn’t typically exposed to
endusers.
Why isn’t the framework built based on the upstream TilingInterface system?
While utilizing the tiling interface might seem like a logical choice, it
actually introduces certain intricacies. This interface would encode specific
slices directly into the IR. However, during the sharding optimization phase,
this specific slice information isn’t necessary. Including it could
inadvertently complicate the implementation.
For instance, sharding naturally denotes that each device has an equal logical
division. If we then dive deeper into lowerlevel operations like
tensor.extract_slice/insert_slice, this even distribution information could
potentially be lost, which is counterintuitive for sharding.
How do we represent MOE (expert parallel)?
To represent MOE (expert parallel), a tensor would introduce an additional
expert dimension. Alongside, there should be a concept of a “virtual axis”.
This comes into play especially when the number of experts is less than the
number of devices on the physical axis. The approach is somewhat akin to how
it’s done with tensor parallel.
How does the sharding framework compare to XLA’s GSPMD?
 Unified Interface: For the majority of operations, only
getIndexingMaps
and
getLoopIteratorTypes
need to be implemented.  Strategy Independence in Propagation: The propagation phase doesn’t employ
sharding optimization strategies.  Sharding Option at Operation Level: This provides an option for sharding at
the operation level, making it more convenient for automatic parallel algorithms
to set sharding strategies with precision.  Explicit Communication Post Propagation: The results after propagation
explicitly depict communication, facilitating efficient analysis and
optimization.