While manipulating memrefs, I have discovered two limitations that seem quite severe. The first one concerns memref.cast
, which does allow “safe” casting between fixed shapes of equal “volume”. For instance, the following is rejected:
%14 = memref.cast %0 : memref<1x1x512x2048xf32> to memref<512x2048xf32>
Is there a reason for this?
The only way I found to program it was to pass by unranked (and unsafe) memrefs:
%14 = memref.cast %0 : memref<1x1x512x2048xf32> to memref<*xf32>
%15 = memref.cast %14 : memref<*xf32> to memref<512x2048xf32>
The second problem is that even this code is rejected during canonicalization. For instance, the following code is rejected:
#map5 = affine_map<(d0) -> (d0)>
#map6 = affine_map<(d0) -> (d0 floordiv 64)>
#map7 = affine_map<(d0) -> (d0 floordiv 8)>
#map8 = affine_map<(d0) -> (d0 floordiv 7)>
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 561 : i32}} {
memref.global "private" constant @__constant_3x3x512x512xf32 : memref<3x3x512x512xf32> = dense<1.000000e+00>
memref.global "private" constant @__constant_1x1x512x2048xf32 : memref<1x1x512x2048xf32> = dense<1.000000e+00>
func @resnet(%arg0: memref<1x224x224x3xf32>, %arg1: memref<1x7x7x512xf32>, %arg2: memref<1x7x7x2048xf32>, %arg3: memref<1x7x7x512xf32>) {
%0 = memref.get_global @__constant_1x1x512x2048xf32 : memref<1x1x512x2048xf32>
%2 = memref.get_global @__constant_3x3x512x512xf32 : memref<3x3x512x512xf32>
%14 = memref.cast %0 : memref<1x1x512x2048xf32> to memref<*xf32>
%15 = memref.cast %14 : memref<*xf32> to memref<512x2048xf32>
// %15 = memref.alloc(): memref<512x2048xf32>
%c0 = constant 0 : index
%c2048 = constant 2048 : index
%c512 = constant 512 : index
%18 = memref.alloca() : memref<64x8xf32>
affine.for %arg4 = #map5(%c0) to #map6(%c2048) {
affine.for %arg5 = #map5(%c0) to #map7(%c512) {
affine.for %arg6 = 0 to 64 {
affine.for %arg7 = 0 to 8 {
%19 = affine.load %15[%arg6 + %arg4 * 64, %arg7 + %arg5 * 8] : memref<512x2048xf32>
affine.store %19, %18[%arg6, %arg7] : memref<64x8xf32>
}
}
}
}
return
}
}
I don’t understand the exact reason for this limitations.