How to optimize linalg_fill + linalg_generic

Today, I have rewrite new conversion for mhlo.pad op : mhlo.pad → linalg.fill + lianlg.generic. as follows:

module {
func @main(%arg0: tensor<1x64x112x112xf32>) → tensor<1x64x114x114xf32> {
%0 = mhlo.constant dense<0xFF800000> : tensor
%1 = “mhlo.pad”(%arg0, %0) {edge_padding_high = dense<[0, 0, 1, 1]> : tensor<4xi64>, edge_padding_low = dense<[0, 0, 1, 1]> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x64x112x112xf32>, tensor) → tensor<1x64x114x114xf32>
return %1 : tensor<1x64x114x114xf32>
}
}

linalg.fill + linalg.generic :

#map0 = affine_map<(d0, d1, d2, d3) → (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) → (d0, d1, d2 + 1, d3 + 1)>
module {
func @main(%arg0: tensor<1x64x112x112xf32>) → tensor<1x64x114x114xf32> {
%cst = constant 0xFF800000 : f32
%0 = linalg.init_tensor [1, 64, 114, 114] : tensor<1x64x114x114xf32>
%1 = linalg.fill(%0, %cst) : tensor<1x64x114x114xf32>, f32 → tensor<1x64x114x114xf32>
%2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = [“parallel”, “parallel”, “parallel”, “parallel”]} ins(%arg0 : tensor<1x64x112x112xf32>) outs(%1 : tensor<1x64x114x114xf32>) {
^bb0(%arg1: f32, %arg2: f32): // no predecessors
linalg.yield %arg1 : f32
} → tensor<1x64x114x114xf32>
return %2 : tensor<1x64x114x114xf32>
}
}

The final affine as follows :

#map = affine_map<(d0) → (d0 + 1)>
module {
func @main(%arg0: memref<1x64x112x112xf32>) → memref<1x64x114x114xf32> {
%cst = constant 0xFF800000 : f32
%0 = memref.alloc() : memref<1x64x114x114xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 64 {
affine.for %arg3 = 0 to 114 {
affine.for %arg4 = 0 to 114 {
affine.store %cst, %0[%arg1, %arg2, %arg3, %arg4] : memref<1x64x114x114xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 64 {
affine.for %arg3 = 0 to 112 {
affine.for %arg4 = 0 to 112 {
%1 = affine.load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<1x64x112x112xf32>
%2 = affine.apply #map(%arg3)
%3 = affine.apply #map(%arg4)
affine.store %1, %0[%arg1, %arg2, %2, %3] : memref<1x64x114x114xf32>
}
}
}
}
return %0 : memref<1x64x114x114xf32>
}
}

I think the final affine IR is not good enough, because we set most eara of the final result twice. How can I optimize these. I think I should use one linalg.generic instead of linalg.fill and linalg.generic. But I don’t know how to construct the indexing_maps if I only use one linalg.generic. please help me, thanks

who can help me :sob: