Write a pass to change data layout from nchw to nhwc

Hi,
I want to write a pass which translate all my tensor type from nchw to nhwc, I want to use typeconvecrter to complete this and the code is like below, but it didnt work. Does anyone have any suggestion? Thanks alot.

template <>
LogicalResult ConvertOnnxOp<ONNXConvOp>::matchAndRewrite(ONNXConvOp op,
    OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
  mlir::Value X = adaptor.X();
  mlir::Value W = adaptor.W();
  mlir::Value B = adaptor.B();
  mlir::ArrayAttr padsAttr = adaptor.padsAttr();
  mlir::ArrayAttr stridesAttr = adaptor.stridesAttr();
  mlir::ArrayAttr dilationsAttr = adaptor.dilationsAttr();

  auto XTy = X.getType().dyn_cast<TensorType>();
  auto WTy = W.getType().dyn_cast<TensorType>();

  if (!XTy || !WTy)
    return op.emitError("Only Tensor types supported in TOSA");
  
  if(!padsAttr)
    padsAttr = rewriter.getI64ArrayAttr({0, 0, 0, 0});

  if(!stridesAttr)
    stridesAttr = rewriter.getI64ArrayAttr({1, 1});

  if(!dilationsAttr)
    dilationsAttr = rewriter.getI64ArrayAttr({1, 1});

  if(B.getType() == NoneType::get(rewriter.getContext())) {
    int64_t BArray[WTy.getShape()[0]] = {0};
    ArrayRef<int64_t> BArrayRef(&BArray[0], WTy.getShape()[0]);
    auto BArrayAttr = rewriter.getI64ArrayAttr(BArrayRef);
    int64_t temp =  WTy.getShape()[0];
    llvm::ArrayRef<int64_t> B_shape({&temp,1});
    RankedTensorType B_type = 
      RankedTensorType::get({B_shape}, rewriter.getI64Type());
    auto BAttr = DenseElementsAttr::get(
      B_type,
      BArrayAttr.getValue());
    B = rewriter.create<tosa::ConstOp>(op->getLoc(), B_type,
                                              BAttr);
  }

  rewriter.replaceOpWithNewOp<tosa::Conv2DOp>(op, op.getType(), X, W, B, padsAttr, stridesAttr, dilationsAttr);
  return success();
}


int32_t NCHW2NHWC[4] = { 0u, 2u, 3u, 1u };
void FormatLoweringPass::runOnOperation() {
  MLIRContext *context = &getContext();
  RewritePatternSet patterns(context);
  ConversionTarget target(*context);

  TypeConverter typeConverter;

  typeConverter.addConversion(
        [](RankedTensorType type)-> Type {
          if (type.getRank() != 4) return type;
          auto shape = type.getShape();

          SmallVector<int64_t> new_shape;

          for (int i = 0; i < shape.size(); ++i) {
            new_shape.push_back(shape[NCHW2NHWC[i]]);
          }

          return  RankedTensorType::get(new_shape, type.getElementType());
        });
 
    target.addLegalDialect<tosa::TosaDialect, func::FuncDialect>();

                                        \
	target.addIllegalOp<OnnxOp>();                                               \
    patterns.add<ConvertOnnxOp<OnnxOp>>(typeConverter, context);
	INSERT_ONNXOP_PATTERN(ONNXReluOp);
    INSERT_ONNXOP_PATTERN(ONNXAddOp);
	INSERT_ONNXOP_PATTERN(ONNXAbsOp);
	INSERT_ONNXOP_PATTERN(ONNXAveragePoolOp);
	INSERT_ONNXOP_PATTERN(ONNXConvOp);

  if (failed(
          applyPartialConversion(getOperation(), target, std::move(patterns))))
    signalPassFailure();
}

You’ll want to provide some input/output or fully executable code to get better help as your current description is vague.

For example, how are you using the type converter right now? Do you have any patterns? Do you also want the signature of the entry point (like the func.func signature) to change?

I would wildly guess that you aren’t changing the signature of the entry point, and then nothing is changing, or you are getting a type mismatch later with some op inside the function.

It looks like you’re trying to do a) ONNX → TOSA conversion and b) layout conversion at the same time. Is this because TOSA only supports NHWC? My uniformed opinion is that you’ll get better progress not mixing those two things. The layout rewrite doesn’t need to be done with a type converter. In fact, that’s really not what the type converter is intended for (but someone please correct me if I’m wrong here). You’re trying to use the type converter to say one type of format is illegal and another is legal, but the Tensor type doesn’t have a named format semantic. Instead, you should be able to do everything inside your “Conv2D” converter. If the type of the ONNX conv op is NCHW, then inside matchAndRewrite you should emit Transpose operations on the input, weights, and outputs. That should give a legal conversion. Then run your pass, canonicalize, and inspect the result. If there are inefficiencies (e.g. due to the transposes you inserted), then you can write more canonicalizations or a new pass to correct those.

Thanks for your suggestion, Yes TOSA only supports NHWC. Though inserting transpose will succeed , it costs some hardware executing time, so I want to just

  1. reshape all my Tensors shape, like translate NCHW to NHWC direclty
  2. change model input/output shape from NCHW to NHWC.
  3. reorder all the weights from NCHW to NHWC.

All 3 steps above will save transpose time right? So any suggestion to acheive this by not inserting transpose op? Thanks alot.

The TypeConverter does not automatically change types as far as I know. I am 80% sure that you need to actually be using the TypeConverter in your patterns. Searching for TypeConverter in conversions to the LLVM Dialect will provide you with many examples that might give a better sense of what you need to do.

One thing to note is that tensors do not have a data layout at this time, they are SSA values without a concrete materialization in memory. The operations that take the tensors as operands interpret their content and dimensions as the op sepcification says.

After bufferization, different layouts may be materialized in memory.

The rule of thumb is that the op semantics define the order of iterations (i.e. control-flow loops) and the data type defines the layout in memory (i.e. which bytes are close to one another in linear memory).

TOSA on tensors seems like too high-level a place to properly treat considerations.

This may change in the future if/when tensors gain a layout attribute; there will be tradeoffs related to bufferization, what can be represented inplace vs require a copy etc.