Hi, everyone. I’m trying something with the latest version of MLIR. And I found some problems after I tried to rewriting AffineLoad/StoreOp
. Here is what I’ve done: I created a MyType
, which is basically just another name for i64
. So I’m trying to lower this MyType
into i64
after a Pass. I created a typeConverter
to do the unrealizedCast
for both scalar and shaped type and used the FunctionOpInterface
, CallOpTypeConversion
and ReturnOpTypeConversion
. Right now this pass can lowering MyType
into i64
in func.call
and also affine.load/store
. The problem is that with memref<?xMyType>
in affine.load/store
. If I try to read/write to a index like [%c]
, everything is fine. But if to a index like [%c + 1]
, after the lowering this kind of index will become [%c]
, which is definitely wrong. The Rewriting pass I’m using is as following:
struct AffineLoadRewriting final : public OpConversionPattern<AffineLoadOp> {
public:
using OpConversionPattern<AffineLoadOp>::OpConversionPattern;
AffineLoadRewriting(
TypeConverter &typeConverter,
MLIRContext* context,
PatternBenefit benefit)
: OpConversionPattern<AffineLoadOp>(
typeConverter,
context,
benefit){};
LogicalResult matchAndRewrite(
AffineLoadOp op,
AffineLoadOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override
{
auto type = op.getResult().getType();
if (!type.template isa<MyType>()) return failure();
rewriter.replaceOpWithNewOp<AffineLoadOp>(
op,
adaptor.getMemref(),
adaptor.getIndices());
return success();
}
};
struct AffineStoreRewriting final : public OpConversionPattern<AffineStoreOp> {
public:
using OpConversionPattern<AffineStoreOp>::OpConversionPattern;
AffineStoreRewriting(
TypeConverter &typeConverter,
MLIRContext* context,
PatternBenefit benefit)
: OpConversionPattern<AffineStoreOp>(
typeConverter,
context,
benefit){};
LogicalResult matchAndRewrite(
AffineStoreOp op,
AffineStoreOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override
{
auto type = op.getValue().getType();
if (!type.template isa<MyType>()) return failure();
rewriter.replaceOpWithNewOp<AffineStoreOp>(
op,
adaptor.getValue(),
adaptor.getMemref(),
adaptor.getIndices());
return success();
}
};
The testbench I’m using is:
#le_3 = affine_set<(i): (3 - i >= 0)>
#ge_4 = affine_set<(i): (i - 4 >= 0)>
!scalar = !mydialect.mytype
module {
func.func @kernel(%a : memref<8x!scalar>, %b : memref<8x!scalar>) attributes {llvm.emit_c_interface} {
affine.for %c1 = 0 to 8 {
affine.if #le_3(%c1) {
%1 = affine.load %a[%c1 + 4] : memref<8x!scalar>
affine.store %1, %b[%c1] : memref<8x!scalar>
}
affine.if #ge_4(%c1) {
%1 = affine.load %a[%c1 - 4] : memref<8x!scalar>
affine.store %1, %b[%c1] : memref<8x!scalar>
}
}
return
}
}
After applying the lowering pass the result is:
#map = affine_map<(d0) -> (d0)>
#set = affine_set<(d0) : (-d0 + 3 >= 0)>
#set1 = affine_set<(d0) : (d0 - 4 >= 0)>
module {
func.func @kernel(%arg0: memref<8xi64>, %arg1: memref<8xi64>) attributes {llvm.emit_c_interface} {
affine.for %arg2 = 0 to 8 {
affine.if #set(%arg2) {
%0 = affine.load %arg0[%arg2] : memref<8xi64>
affine.store %0, %arg1[%arg2] : memref<8xi64>
}
affine.if #set1(%arg2) {
%0 = affine.load %arg0[%arg2] : memref<8xi64>
affine.store %0, %arg1[%arg2] : memref<8xi64>
}
}
return
}
}
Here is where the problem is during lowering I guess, the affine_map is not correctly lowered.
//===-------------------------------------------===//
Legalizing operation : 'affine.load'(0x56005a39b750) {
%0 = "affine.load"(<<UNKNOWN SSA VALUE>>, %arg2) {map = affine_map<(d0) -> (d0 + 4)>} : (memref<8x!mydialect.mytype>, index) -> !mydialect.mytype
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'affine.load -> ()' {
** Insert : 'affine.load'(0x56005a37b0b0)
** Replace : 'affine.load'(0x56005a39b750)
//===-------------------------------------------===//
Legalizing operation : 'affine.load'(0x56005a37b0b0) {
%0 = "affine.load"(%arg0, %arg2) {map = affine_map<(d0) -> (d0)>} : (memref<8xi64>, index) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//
Could anyone tell me why is this kind of index not properly resolved? And how exactly should I rewrite the AffineLoad/StoreOp
? Thanks you all;)