Memref.tensor_store vs bufferization.to_memref + memref.copy

Hi Everyone, my goal is to use the execution engine to JIT MLIR kernels generated from a DSL from C++.
To do so, I plan to have a wrapper that takes in the raw pointers and convert them to tensors and finally call a kernel. For instance

func @wrap_8(%a: f32, %x: memref<8xf32>, %y: memref<8xf32>, %r: memref<8xf32>) {
  %tx = bufferization.to_tensor %x : memref<8xf32>
  %ty = bufferization.to_tensor %y : memref<8xf32>
  %cast_tx = tensor.cast %tx : tensor<8xf32> to tensor<?xf32>
  %cast_ty = tensor.cast %ty : tensor<8xf32> to tensor<?xf32>
  %ta = tensor.from_elements %a : tensor<f32>
  %ret = call @daxpy(%ta, %cast_tx, %cast_ty): (tensor<f32>, tensor<?xf32>, tensor<?xf32> )->tensor<?xf32>
  %cast_ret = tensor.cast %ret : tensor<?xf32> to tensor<8xf32>
  // this works
  %newref = bufferization.to_memref %cast_ret : memref<8xf32>
  memref.copy %newref, %r : memref<8xf32> to memref<8xf32>
  // I cannot lower this
  // memref.tensor_store %cast_ret, %r : memref<8xf32>

The above works, but I was expecting to be able to use memref.tensor_store instead of bufferization.to_memref followed by memref.copy to store the result generated by the kernel.

Using memref.tensor_store produces a valid MLIR, but I have not been able to then lower it to the LLVM IR dialect. Any pointer on how to do that would be very appreciated!

I manually write a pass to remove it…
I also wonder about a better solution

1 Like