How to generate Conv2d in NHWC format

Hi, I’m trying to generate NHWC format conv2d.

The code I used:

net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
net = net.to(memory_format=torch.channels_last)
a = torch.zeros([1, 3, 256, 256])
a = a.to(memory_format=torch.channels_last)
net.train(False)
module = torch_mlir.compile(net, a, output_type="linalg-on-tensors")
with open("res18.mlir", "wb") as f:
    f.write(module.operation.get_asm(large_elements_limit=0).encode())

However, memory_format=torch.channels_last seems not works in torch-mlir. It still generates linalg.conv_2d_nchw_fchw.