Hi there, MLIR community. I was struggling with ops and patterns with optional operands and thought to share it with you. I wrote an op with optional operand using tablegen that looks like this:
def Conv2Op : Op<FrontendDialect, "conv2d", []> {
let summary = "2D Convolution operation";
let description = [{}];
let arguments = (ins AnyTensor:$input, AnyTensor:$filter, Optional<AnyTensor>:$bias,
I32ArrayAttr:$kernel, I32ArrayAttr:$stride,
I32ArrayAttr:$padding, I32ArrayAttr:$dilation,
AnyI32Attr:$num_groups
);
let results = (outs AnyTensor:$result);
let assemblyFormat = "$input `,` $filter `,` $bias attr-dict `:` type($input) `,` type($filter) `,` type($bias) `,` type($result)";
let hasCanonicalizer = 1;
}
As you can see “bias” is an optional tensor. After compilation the generated builder looks like this:
void Conv2Op::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value input, ::mlir::Value filter, /*optional*/::mlir::Value bias, ::mlir::ArrayAttr kernel, ::mlir::ArrayAttr stride, ::mlir::ArrayAttr padding, ::mlir::ArrayAttr dilation, ::mlir::IntegerAttr num_groups) {
odsState.addOperands(input);
odsState.addOperands(filter);
if (bias)
odsState.addOperands(bias);
odsState.addAttribute(kernelAttrName(odsState.name), kernel);
odsState.addAttribute(strideAttrName(odsState.name), stride);
odsState.addAttribute(paddingAttrName(odsState.name), padding);
odsState.addAttribute(dilationAttrName(odsState.name), dilation);
odsState.addAttribute(num_groupsAttrName(odsState.name), num_groups);
odsState.addTypes(result);
}
Which implies that If I don’t want to have the bias tensor, I will pass nullptr instead.
I wrote a pattern using tablegen that takes Conv2Op with bias tensor and defuse it into Conv2Op+AddOp with nullptr as bias in the result pattern. It looks like this:
def nullArg : NativeCodeCall<"nullptr">;
def ConvWithBias : Pat<
(Conv2Op:$res1 $input, $filter, $bias, $kernel, $stride, $padding, $dilation, $num_groups),
(AddOp:$res3 ($bias), (Conv2Op:$res2 $input, $filter, (nullArg), $kernel, $stride, $padding, $dilation, $num_groups)),
[(NotNullPred $bias)]>;
There are a couple of problems with the generated code for this pattern.
- The generated pattern is not using Conv2Op::build explicitly. The generated builder is a generic builder for “Op” and then it attempts to dynamically cast it to Conv2Op. This builder dereferences all operands and therefore causing segmentation fault when I pass nullptr. It looks like this:
res2 = rewriter.create<::mlir::fe::Conv2Op>(odsLoc, tblgen_values, tblgen_attrs);
which calls this code in include/mlir/IR/Builders.h
/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy create(Location location, Args &&...args) {
OperationState state(location,
getCheckRegisteredInfo<OpTy>(location.getContext()));
OpTy::build(*this, state, std::forward<Args>(args)...);
auto *op = createOperation(state);
auto result = dyn_cast<OpTy>(op);
assert(result && "builder didn't return the right type");
return result;
}
A way to work around this is to use NativeCodeCall and call the op’s builder, although this method has its own problems. I would be happy if the generated code for Pat<> could call the correct builder seamlessly, or even just accept null operands.
- When passing “$bias” tensor to the result pattern in:
AddOp:$res3 ($bias)
As I analyze the generated code, this operand should be referred as:
auto nativeVar_0 = *bias.begin();
But for some reason, because this is an optional operand, it gerenates this instead:
auto nativeVar_0 = bias;
which creates compilation error when trying to push this value to the tblgen_values list:
tblgen_values.push_back(nativeVar_0);
So to work around this I had to manually inject this code to the pattern like this:
def biasDeref : NativeCodeCall<"*bias.begin()">;
def ConvWithBias : Pat<
… (AddOp:$res3 (biasDeref), …>;
For an infrastructure that aims to simplify pattern matching, I wish would be easier to work with optional operands.
Right now as I see it the best option is to write the pattern in CPP instead of tablegen.