I would just have my own hardware-specific dialect and convert the linalg.*conv*
→ linalg.generic
To be honest, our first preference would be to explore and reuse as much from what is present and offered by mlir and probably which is proven/tested too, If nothing works out then we might go as suggested.
“I still believe that this pass is not what you are looking for.”
Exactly we don’t want conv getting lowered to mulf/addf .
I think I should present it with a correct representative example.
So our seed targeted pattern is
#map0 = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
func @main() -> tensor<1x8x8x1xf32> {
%cst = constant 0x7FC00000 : f32
%cst_0 = constant dense<[[[8.405740e-01], [0.107482761], [0.885160744]], [[0.879221558], [0.272046864], [0.219075441]], [[0.853731691], [0.786423742], [0.132054776]]]> : tensor<3x3x1xf32>
%cst_1 = constant dense<[[[[0.61719793], [0.793435096], [0.121170968], [0.573920846], [0.445487499], [0.226183072], [0.973448157], [0.851443469], [0.565707207], [5.227910e-01]], [[0.00731075834], [0.557578623], [0.707973182], [0.959501981], [0.358185142], [0.699126303], [0.466926485], [0.413297445], [0.0725673139], [0.921178698]], [[0.30189532], [0.731407762], [0.370444775], [0.848782122], [0.871007978], [0.748299241], [0.102783702], [0.551843107], [0.816904246], [0.0332700573]], [[0.343397468], [0.982405781], [0.0607045554], [0.541470528], [0.726823389], [0.808600127], [0.256532729], [0.252898484], [0.110048898], [0.478431761]], [[0.979851245], [0.111159958], [0.724086046], [0.982171118], [0.211429045], [0.678284585], [0.0563224852], [0.837513744], [0.657312452], [0.536515653]], [[0.505726278], [0.648696065], [0.999981284], [0.00183360698], [0.745425224], [0.40943867], [0.333478332], [0.644900322], [0.19442305], [0.178031594]], [[0.681895732], [0.0306128804], [0.390298098], [0.893880724], [0.859292924], [0.445982367], [0.303117335], [0.769601822], [0.510563135], [0.387194842]], [[0.426417202], [0.49324289], [0.107445173], [0.790808141], [0.891636371], [0.0934373661], [0.42853874], [0.835353791], [0.0698968098], [0.611316978]], [[0.999428808], [0.938870192], [0.778351902], [0.973116576], [0.0702748299], [0.479627848], [0.455716878], [0.779465734], [0.614541173], [0.207848176]], [[0.847904145], [0.130607277], [0.678844749], [0.03101651], [0.964369654], [0.600809455], [0.0234705787], [0.356526107], [0.636354804], [0.281212419]]]]> : tensor<1x10x10x1xf32>
%cst_2 = constant 0.000000e+00 : f32
%0 = linalg.init_tensor [1, 8, 8, 1] : tensor<1x8x8x1xf32>
%1 = linalg.fill(%cst_2, %0) : f32, tensor<1x8x8x1xf32> -> tensor<1x8x8x1xf32>
%2 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%cst_1, %cst_0 : tensor<1x10x10x1xf32>, tensor<3x3x1xf32>) outs(%1 : tensor<1x8x8x1xf32>) -> tensor<1x8x8x1xf32>
%3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<1x8x8x1xf32>) outs(%0 : tensor<1x8x8x1xf32>) {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%4 = cmpf ogt, %cst_2, %arg0 : f32
%5 = select %4, %cst_2, %arg0 : f32
%6 = cmpf uno, %cst_2, %arg0 : f32
%7 = select %6, %cst, %5 : f32
linalg.yield %7 : f32
} -> tensor<1x8x8x1xf32>
return %3 : tensor<1x8x8x1xf32>
}
}
where in on applying the LinalgFusionOfTensorOp Pass we expect
%2 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
and
%3 = linalg.generic { %4 = cmpf ogt, %cst_2, … }
both of them to get inside a single linalg region , or we can say %3 which houses cmp/select in addition houses linalg.depthwise_conv_2d_input_nhwc_filter_hwc also .
our hypothetical output would be some thing like this
#map0 = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map2 = affine_map<(d0, d1, d2 ) -> (d0, d1, d2 )>
module {
func @main() -> tensor<1x8x8x1xf32> {
%cst = constant 0x7FC00000 : f32
%cst_0 = constant dense<[[[8.405740e-01], [0.107482761], [0.885160744]], [[0.879221558], [0.272046864], [0.219075441]], [[0.853731691], [0.786423742], [0.132054776]]]> : tensor<3x3x1xf32>
%cst_1 = constant dense<[[[[0.61719793], [0.793435096], [0.121170968], [0.573920846], [0.445487499], [0.226183072], [0.973448157], [0.851443469], [0.565707207], [5.227910e-01]], [[0.00731075834], [0.557578623], [0.707973182], [0.959501981], [0.358185142], [0.699126303], [0.466926485], [0.413297445], [0.0725673139], [0.921178698]], [[0.30189532], [0.731407762], [0.370444775], [0.848782122], [0.871007978], [0.748299241], [0.102783702], [0.551843107], [0.816904246], [0.0332700573]], [[0.343397468], [0.982405781], [0.0607045554], [0.541470528], [0.726823389], [0.808600127], [0.256532729], [0.252898484], [0.110048898], [0.478431761]], [[0.979851245], [0.111159958], [0.724086046], [0.982171118], [0.211429045], [0.678284585], [0.0563224852], [0.837513744], [0.657312452], [0.536515653]], [[0.505726278], [0.648696065], [0.999981284], [0.00183360698], [0.745425224], [0.40943867], [0.333478332], [0.644900322], [0.19442305], [0.178031594]], [[0.681895732], [0.0306128804], [0.390298098], [0.893880724], [0.859292924], [0.445982367], [0.303117335], [0.769601822], [0.510563135], [0.387194842]], [[0.426417202], [0.49324289], [0.107445173], [0.790808141], [0.891636371], [0.0934373661], [0.42853874], [0.835353791], [0.0698968098], [0.611316978]], [[0.999428808], [0.938870192], [0.778351902], [0.973116576], [0.0702748299], [0.479627848], [0.455716878], [0.779465734], [0.614541173], [0.207848176]], [[0.847904145], [0.130607277], [0.678844749], [0.03101651], [0.964369654], [0.600809455], [0.0234705787], [0.356526107], [0.636354804], [0.281212419]]]]> : tensor<1x10x10x1xf32>
%cst_2 = constant 0.000000e+00 : f32
%0 = linalg.init_tensor [1, 8, 8, 1] : tensor<1x8x8x1xf32>
%1 = linalg.fill(%cst_2, %0) : f32, tensor<1x8x8x1xf32> -> tensor<1x8x8x1xf32>
%2 = linalg.complex {indexing_maps = [#map0, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_1, %cst_0 : tensor<1x10x10x1xf32>, tensor<3x3x1xf32> ) outs(%1 : tensor<1x8x8x1xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
%3 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%cst_1, %cst_0 : tensor<1x10x10x1xf32>, tensor<3x3x1xf32>) outs(%1 : tensor<1x8x8x1xf32>) -> tensor<1x8x8x1xf32>
%4 = linalg.relu %3
linalg.yield %4 : f32
} -> tensor<1x8x8x1xf32>
return %2 : tensor<1x8x8x1xf32>
}
}
%2 = linalg.complex {
%3 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
%4 = linalg.relu %3
}
so our so called linalg.complex is a Op similar to linalg.generic but it neither has indexing_maps
nor the iterator_types ( right now it has the indexing_maps = [#map0, #map2, #map1], iterator_types = [“parallel”, “parallel”, “parallel”, “parallel”]} but if needed we would prefer to strip it)
And it is a region-based which now houses linalg.depthwise_conv_2d_input_nhwc_filter_hwc whose output is being consumed by relu Op (and btw conv output is not a scalar, essentailly nothing is scalar here ).
Further linalg.relu is equivalent to below elementwise ops compare & select or mhlo.maximum.
%4 = cmpf ogt, %cst_2, %arg0 : f32
%5 = select %4, %cst_2, %arg0 : f32
%6 = cmpf uno, %cst_2, %arg0 : f32
%7 = select %6, %cst, %5 : f32
The motive is to get linalg.depthwise_conv_2d_input_nhwc_filter_hwc and relu inside the region .