Proper way to use memeref.transpose

I have a custom lowering of a custom transpose op, which is special in the sense that the last dimension is kept as is.

mlir::Value createMemRefTransposeOp(mlir::PatternRewriter &rewriter,
                                    mlir::Location loc, mlir::Value tensor,
                                    mlir::OpResult result) {
  std::vector<unsigned int> perms = {};
  auto n_dim = tensor.getType().cast<MemRefType>().getShape().size();
  // invert dimensions except last one
  for (int i = n_dim - 2; i >= 0; i--)
    perms.push_back(i);
  perms.push_back(n_dim - 1);
  AffineMapAttr perm = AffineMapAttr::get(
      AffineMap::getPermutationMap(perms, rewriter.getContext()));
  mlir::memref::TransposeOp transposeOp =
      rewriter.create<mlir::memref::TransposeOp>(loc, tensor, perm);
  // we allocate a new buffer and copy the strided memref into it to produce
  // a totally new tranposed memref
  mlir::memref::AllocOp allocOp = rewriter.create<mlir::memref::AllocOp>(
      loc, result.getType().cast<MemRefType>());
  rewriter.create<mlir::memref::CopyOp>(loc, transposeOp.getResult(),
                                        allocOp.getResult());
  return allocOp.getODSResults(0).front();
}

Let’s take an example IR as produced after the pass

#map = affine_map<(d0, d1, d2) -> (d1 * 8194 + d0 * 4097 + d2)>
module  {
  func @main(%arg0: memref<3x2x4097xi64>) -> memref<2x3x4097xi64> {
    %0 = memref.transpose %arg0 (d0, d1, d2) -> (d1, d0, d2) : memref<3x2x4097xi64> to memref<2x3x4097xi64, #map>
    %1 = memref.alloc() : memref<2x3x4097xi64>
    memref.copy %0, %1 : memref<2x3x4097xi64, #map> to memref<2x3x4097xi64>
    return %1 : memref<2x3x4097xi64>
  }
}

This last IR when lowered and executed would result in a segfault. After investigating, I found that the main reason was that the strides of the source tensor (%0) in the copy were set [4097, 8194, 8194] (which in my understanding should be set to [4097, 8194, 1]), which will definitely lead to computing an index out of the memory region of the source memref.

I wanted to make sure that the way I’m using memref.transpose is actually the expected one, so that I start considering the strides being wrongly set as a bug.