How to implement memref writing in stride?

Hi,

I want to write a memref in stride with ‘memref’ and ‘vector’ dialect as follows. But it seems vector.store op doesn’t support writing in stride, and I don’t find a similar example of vector.transfer_write op in file test-transfer-write.mlir. I wonder if someone can give me some advice.

func.func private @printMemrefF32(memref<*xf32>)

func.func @main() {
    %c0 = arith.constant 0 : index
    %f0 = arith.constant 0.0 : f32
    %f1 = arith.constant 1.0 : f32

    %A = memref.alloc() : memref<4x1xf32>
    %C = memref.alloc() : memref<4x4xf32>

    linalg.fill ins(%f1 : f32) outs(%A : memref<4x1xf32>)
    linalg.fill ins(%f0 : f32) outs(%C : memref<4x4xf32>)

    %a_view = memref.subview %A[0, 0][4, 1][1, 1] : memref<4x1xf32> to memref<4x1xf32, strided<[1, 1]>>
    %c_view = memref.subview %C[0, 0][4, 1][1, 1] : memref<4x4xf32> to memref<4x1xf32, strided<[1, 1]>>
    %a = vector.load %a_view[%c0, %c0] : memref<4x1xf32, strided<[1, 1]>>, vector<4xf32>
    %b = vector.splat %f1 : vector<4xf32>
    %c = vector.load %c_view[%c0, %c0] : memref<4x1xf32, strided<[1, 1]>>, vector<4xf32>

    %result = vector.fma %a, %b, %c : vector<4xf32>

    vector.store %result, %C[%c0, %c0] : memref<4x4xf32>, vector<4xf32>

    %C_ = memref.cast %C : memref<4x4xf32> to memref<*xf32>
    call @printMemrefF32(%C_) : (memref<*xf32>) -> ()

    memref.dealloc %A : memref<4x1xf32>
    memref.dealloc %C : memref<4x4xf32>
    return 
}

The output of above program is

[[2,   2,   2,   2], 
 [1,   1,   1,   1], 
 [1,   1,   1,   1], 
 [1,   1,   1,   1]]

But what I want is

[[2,   1,   1,   1], 
 [2,   1,   1,   1], 
 [2,   1,   1,   1], 
 [2,   1,   1,   1]]

There are several issues:

Your stride specification in the memref type looks incorrect. Strides in the type are specified in terms of the number of elements in linear storage that one has to step over in order to get to the next element along a dimension. So memref<4x1xf32, strided<[1, 1]>> means the first index of the view is indexing over consecutive storage elements, that is, over the first row. You’d need something like strided<[4, 1]>.

The store goes directly to %C, not into a view, so it will use the default row-major layout and write into the row.

Thanks for your reply and correction. I remove the useless views to make my example simpler.

func.func private @printMemrefF32(memref<*xf32>)

func.func @main() {
    %c0 = arith.constant 0 : index
    %f1 = arith.constant 1.0 : f32

    %A = memref.alloc() : memref<4x1xf32>
    %C = memref.alloc() : memref<4x4xf32>

    linalg.fill ins(%f1 : f32) outs(%A : memref<4x1xf32>)
    linalg.fill ins(%f1 : f32) outs(%C : memref<4x4xf32>)

    %a = vector.load %A[%c0, %c0] : memref<4x1xf32>, vector<4xf32>
    %b = vector.splat %f1 : vector<4xf32>
    %c = vector.load %C[%c0, %c0] : memref<4x4xf32>, vector<4xf32>

    %result = vector.fma %a, %b, %c : vector<4xf32>

    vector.store %result, %C[%c0, %c0] : memref<4x4xf32>, vector<4xf32>

    %C_ = memref.cast %C : memref<4x4xf32> to memref<*xf32>
    call @printMemrefF32(%C_) : (memref<*xf32>) -> ()

    memref.dealloc %A : memref<4x1xf32>
    memref.dealloc %C : memref<4x4xf32>
    return 
}

This still writes into %C, the memref with default row-major layout. So naturally it writes the first row. You need to either create a proper view or use memref.transpose for that.