Index problem after rewriting AffineLoad/StoreOp

Hi, everyone. I’m trying something with the latest version of MLIR. And I found some problems after I tried to rewriting AffineLoad/StoreOp. Here is what I’ve done: I created a MyType, which is basically just another name for i64. So I’m trying to lower this MyType into i64 after a Pass. I created a typeConverter to do the unrealizedCast for both scalar and shaped type and used the FunctionOpInterface, CallOpTypeConversion and ReturnOpTypeConversion. Right now this pass can lowering MyType into i64 in func.call and also affine.load/store. The problem is that with memref<?xMyType> in affine.load/store. If I try to read/write to a index like [%c], everything is fine. But if to a index like [%c + 1], after the lowering this kind of index will become [%c], which is definitely wrong. The Rewriting pass I’m using is as following:

struct AffineLoadRewriting final : public OpConversionPattern<AffineLoadOp> {
public:
    using OpConversionPattern<AffineLoadOp>::OpConversionPattern;

    AffineLoadRewriting(
        TypeConverter &typeConverter,
        MLIRContext* context,
        PatternBenefit benefit)
            : OpConversionPattern<AffineLoadOp>(
                typeConverter,
                context,
                benefit){};

    LogicalResult matchAndRewrite(
        AffineLoadOp op,
        AffineLoadOpAdaptor adaptor,
        ConversionPatternRewriter &rewriter) const override
    {
        auto type = op.getResult().getType();
        if (!type.template isa<MyType>()) return failure();

        rewriter.replaceOpWithNewOp<AffineLoadOp>(
            op,
            adaptor.getMemref(),
            adaptor.getIndices());

        return success();
    }
};

struct AffineStoreRewriting final : public OpConversionPattern<AffineStoreOp> {
public:
    using OpConversionPattern<AffineStoreOp>::OpConversionPattern;

    AffineStoreRewriting(
        TypeConverter &typeConverter,
        MLIRContext* context,
        PatternBenefit benefit)
            : OpConversionPattern<AffineStoreOp>(
                typeConverter,
                context,
                benefit){};

    LogicalResult matchAndRewrite(
        AffineStoreOp op,
        AffineStoreOpAdaptor adaptor,
        ConversionPatternRewriter &rewriter) const override
    {
        auto type = op.getValue().getType();
        if (!type.template isa<MyType>()) return failure();

        rewriter.replaceOpWithNewOp<AffineStoreOp>(
            op,
            adaptor.getValue(),
            adaptor.getMemref(),
            adaptor.getIndices());

        return success();
    }
};

The testbench I’m using is:

#le_3 = affine_set<(i): (3 - i >= 0)>
#ge_4 = affine_set<(i): (i - 4 >= 0)>
!scalar = !mydialect.mytype

module {
    func.func @kernel(%a : memref<8x!scalar>, %b : memref<8x!scalar>) attributes {llvm.emit_c_interface} {
        affine.for %c1 = 0 to 8 {
            affine.if #le_3(%c1) {
                %1 = affine.load %a[%c1 + 4] : memref<8x!scalar>
                affine.store %1, %b[%c1] : memref<8x!scalar>
            }
            affine.if #ge_4(%c1) {
                %1 = affine.load %a[%c1 - 4] : memref<8x!scalar>
                affine.store %1, %b[%c1] : memref<8x!scalar>
            }
        }

        return
    }
}

After applying the lowering pass the result is:

#map = affine_map<(d0) -> (d0)>
#set = affine_set<(d0) : (-d0 + 3 >= 0)>
#set1 = affine_set<(d0) : (d0 - 4 >= 0)>
module {
  func.func @kernel(%arg0: memref<8xi64>, %arg1: memref<8xi64>) attributes {llvm.emit_c_interface} {
    affine.for %arg2 = 0 to 8 {
      affine.if #set(%arg2) {
        %0 = affine.load %arg0[%arg2] : memref<8xi64>
        affine.store %0, %arg1[%arg2] : memref<8xi64>
      }
      affine.if #set1(%arg2) {
        %0 = affine.load %arg0[%arg2] : memref<8xi64>
        affine.store %0, %arg1[%arg2] : memref<8xi64>
      }
    }
    return
  }
}

Here is where the problem is during lowering I guess, the affine_map is not correctly lowered.

//===-------------------------------------------===//
Legalizing operation : 'affine.load'(0x56005a39b750) {
  %0 = "affine.load"(<<UNKNOWN SSA VALUE>>, %arg2) {map = affine_map<(d0) -> (d0 + 4)>} : (memref<8x!mydialect.mytype>, index) -> !mydialect.mytype

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'affine.load -> ()' {
    ** Insert  : 'affine.load'(0x56005a37b0b0)
    ** Replace : 'affine.load'(0x56005a39b750)

    //===-------------------------------------------===//
    Legalizing operation : 'affine.load'(0x56005a37b0b0) {
      %0 = "affine.load"(%arg0, %arg2) {map = affine_map<(d0) -> (d0)>} : (memref<8xi64>, index) -> i64

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//

Could anyone tell me why is this kind of index not properly resolved? And how exactly should I rewrite the AffineLoad/StoreOp? Thanks you all;)

Seems you’re dropping the map attribute, you should use the builder which takes an AffineMap, e.g. llvm-project/AffineOps.td at 81f1f6db40abc2cc7f964bf450a2e9f78f14a8a8 · llvm/llvm-project · GitHub

Ohh! Yes! That’s exactly what I should do! Thanks so much, I’ve been stucking here for a long time.

BTW if one doesn’t want affine.load/store anymore after lowering, can just use the same method within the pass -lower-affine to lower these ops into memref.load/store:wink: