Using the example of XLA GSPMD, where an convolution op is sharded on the H dimension (Assume its layout is NHWC).
The SPMD IR / Format looks like below
HloModule module, entry_computation_layout={(f32[128,224,224,3]{3,2,1,0}, f32[7,7,3,64]{3,2,1,0})->f32[128,56,112,64]{3,2,1,0}}
ENTRY %entry_spmd (param: f32[128,224,224,3], param.1: f32[7,7,3,64]) -> f32[128,56,112,64] {
%iota = s32[128,117,224,3]{3,2,1,0} iota(), iota_dimension=1
%constant.6 = s32[2]{0} constant({0, 1})
%partition-id = u32[] partition-id()
%dynamic-slice.2 = s32[1]{0} dynamic-slice(s32[2]{0} %constant.6, u32[] %partition-id), dynamic_slice_sizes={1}
%reshape.1 = s32[] reshape(s32[1]{0} %dynamic-slice.2)
%constant.10 = s32[] constant(112)
%multiply = s32[] multiply(s32[] %reshape.1, s32[] %constant.10)
%broadcast = s32[128,117,224,3]{3,2,1,0} broadcast(s32[] %multiply), dimensions={}
%add = s32[128,117,224,3]{3,2,1,0} add(s32[128,117,224,3]{3,2,1,0} %iota, s32[128,117,224,3]{3,2,1,0} %broadcast)
%constant.13 = s32[] constant(3)
%broadcast.1 = s32[128,117,224,3]{3,2,1,0} broadcast(s32[] %constant.13), dimensions={}
%compare = pred[128,117,224,3]{3,2,1,0} compare(s32[128,117,224,3]{3,2,1,0} %add, s32[128,117,224,3]{3,2,1,0} %broadcast.1), direction=GE
%constant.14 = s32[] constant(227)
%broadcast.2 = s32[128,117,224,3]{3,2,1,0} broadcast(s32[] %constant.14), dimensions={}
%compare.1 = pred[128,117,224,3]{3,2,1,0} compare(s32[128,117,224,3]{3,2,1,0} %add, s32[128,117,224,3]{3,2,1,0} %broadcast.2), direction=LT
%and = pred[128,117,224,3]{3,2,1,0} and(pred[128,117,224,3]{3,2,1,0} %compare, pred[128,117,224,3]{3,2,1,0} %compare.1)
%param = f32[128,224,224,3]{3,2,1,0} parameter(0), sharding={replicated}
%constant = s32[] constant(0)
%constant.1 = s32[2]{0} constant({0, 112})
%dynamic-slice = s32[1]{0} dynamic-slice(s32[2]{0} %constant.1, u32[] %partition-id), dynamic_slice_sizes={1}
%reshape = s32[] reshape(s32[1]{0} %dynamic-slice)
%dynamic-slice.1 = f32[128,112,224,3]{3,2,1,0} dynamic-slice(f32[128,224,224,3]{3,2,1,0} %param, s32[] %constant, s32[] %reshape, s32[] %constant, s32[] %constant), dynamic_slice_sizes={128,112,224,3}
%lhs.copy.1 = f32[128,112,224,3]{3,2,1,0} copy(f32[128,112,224,3]{3,2,1,0} %dynamic-slice.1)
%slice = f32[128,3,224,3]{3,2,1,0} slice(f32[128,112,224,3]{3,2,1,0} %lhs.copy.1), slice={[0:128], [109:112], [0:224], [0:3]}
%collective-permute = f32[128,3,224,3]{3,2,1,0} collective-permute(f32[128,3,224,3]{3,2,1,0} %slice), channel_id=1, source_target_pairs={{0,1}}
%slice.1 = f32[128,2,224,3]{3,2,1,0} slice(f32[128,112,224,3]{3,2,1,0} %lhs.copy.1), slice={[0:128], [0:2], [0:224], [0:3]}
%collective-permute.1 = f32[128,2,224,3]{3,2,1,0} collective-permute(f32[128,2,224,3]{3,2,1,0} %slice.1), channel_id=2, source_target_pairs={{1,0}}
%concatenate = f32[128,117,224,3]{3,2,1,0} concatenate(f32[128,3,224,3]{3,2,1,0} %collective-permute, f32[128,112,224,3]{3,2,1,0} %lhs.copy.1, f32[128,2,224,3]{3,2,1,0} %collective-permute.1), dimensions={1}
%constant.4 = f32[] constant(0)
%broadcast.3 = f32[128,117,224,3]{3,2,1,0} broadcast(f32[] %constant.4), dimensions={}
%select = f32[128,117,224,3]{3,2,1,0} select(pred[128,117,224,3]{3,2,1,0} %and, f32[128,117,224,3]{3,2,1,0} %concatenate, f32[128,117,224,3]{3,2,1,0} %broadcast.3)
%param.1 = f32[7,7,3,64]{3,2,1,0} parameter(1), sharding={replicated}
%rhs.copy.1 = f32[7,7,3,64]{3,2,1,0} copy(f32[7,7,3,64]{3,2,1,0} %param.1)
ROOT %convolution = f32[128,56,112,64]{3,2,1,0} convolution(f32[128,117,224,3]{3,2,1,0} %select, f32[7,7,3,64]{3,2,1,0} %rhs.copy.1), window={size=7x7 stride=2x2 pad=0_0x3_3}, dim_labels=b01f_01io->b01f
}
And its corresponding DTensor Annotation format is
HloModule module
ENTRY entry {
%lhs = f32[128,224,224,3] parameter(0)
%lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
sharding={devices=[1,2,1,1]0,1}
%rhs = f32[7,7,3,64] parameter(1)
%rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
sharding={replicated}
ROOT %conv = f32[128,112,112,64] convolution(
f32[128,224,224,3] %lhs.copy,
f32[7,7,3,64] %rhs.copy),
window={size=7x7 stride=2x2 pad=3_3x3_3},
dim_labels=b01f_01io->b01f,
sharding={devices=[1,2,1,1]0,1}
}