Hi all,
I have the following definition of a custom attribute:
def Xml_ShardingAttr : Xp_Attr<Xml_Dialect, "XSharding", [ShardingTrait]> {
let summary = "sharding info for tensor";
let description = [{store sharding info for tensor, param beginX means start index at axis X,
param strideX means stride at axis X.
X should < tensor shape dims.
Optional means no sharding at X axis.}];
let parameters = (ins
OptionalParameter<"::mlir::ArrayAttr">:$begin0,
OptionalParameter<"::mlir::ArrayAttr">:$stride0,
OptionalParameter<"::mlir::ArrayAttr">:$begin1,
OptionalParameter<"::mlir::ArrayAttr">:$stride1,
OptionalParameter<"::mlir::ArrayAttr">:$begin2,
OptionalParameter<"::mlir::ArrayAttr">:$stride2,
OptionalParameter<"::mlir::ArrayAttr">:$begin3,
OptionalParameter<"::mlir::ArrayAttr">:$stride3,
OptionalParameter<"::mlir::ArrayAttr">:$begin4,
OptionalParameter<"::mlir::ArrayAttr">:$stride4
);
let assemblyFormat = "`<` struct(params) `>`";
let builders = [
AttrOrTypeBuilder<(ins), [{
return $_get($_ctxt);
}]>,
AttrOrTypeBuilder<(ins "XShardingAttr":$cloneAttr), [{
XShardingAttr newAttr = $_get($_ctxt);
if (cloneAttr == nullptr) return newAttr;
for (size_t i = 0; i < 5; i++) {
if (cloneAttr.getStride(i) == nullptr) newAttr.setStride(i, nullptr);
if (cloneAttr.getBegin(i) == nullptr) newAttr.setBegin(i, nullptr);
if (cloneAttr.getStride(i) != nullptr) {
newAttr.setStride(i, cloneAttr.getStride(i));
}
if (cloneAttr.getBegin(i) != nullptr) {
newAttr.setBegin(i, cloneAttr.getBegin(i));
}
}
return newAttr;
}]>
];
let extraClassDeclaration = [{
void setBegin(size_t axis, const std::vector<int64_t> &beginIdx);
void setBegin(size_t axis, ::mlir::ArrayAttr beginAttr);
void setStride(size_t axis, const std::vector<int64_t> &stride);
void setStride(size_t axis, ::mlir::ArrayAttr strideAttr);
bool isValid();
}];
let extraClassDefinition = [{
void $cppClass::setBegin(size_t axis, ::mlir::ArrayAttr beginAttr) {
if (axis > 4) {
LOGE("invalid axis %zu\n", axis);
return;
}
if (beginAttr == nullptr) return;
if (axis == 0) getImpl()->begin0 = beginAttr;
else if (axis == 1) getImpl()->begin1 = beginAttr;
else if (axis == 2) getImpl()->begin2 = beginAttr;
else if (axis == 3) getImpl()->begin3 = beginAttr;
else if (axis == 4) getImpl()->begin4 = beginAttr;
}
void $cppClass::setBegin(size_t axis, const std::vector<int64_t> &beginIdx) {
if (beginIdx.empty()) return;
Builder builder(this->getContext());
setBegin(axis, builder.getI64ArrayAttr(beginIdx));
}
void $cppClass::setStride(size_t axis, ::mlir::ArrayAttr strideAttr) {
if (axis > 4) {
LOGE("invalid axis %zu\n", axis);
return;
}
if (strideAttr == nullptr) return;
if (axis == 0) getImpl()->stride0 = strideAttr;
else if (axis == 1) getImpl()->stride1 = strideAttr;
else if (axis == 2) getImpl()->stride2 = strideAttr;
else if (axis == 3) getImpl()->stride3 = strideAttr;
else if (axis == 4) getImpl()->stride4 = strideAttr;
}
void $cppClass::setStride(size_t axis, const std::vector<int64_t> &stride) {
if (stride.empty()) return;
Builder builder(this->getContext());
setStride(axis, builder.getI64ArrayAttr(stride));
}
bool $cppClass::isValid() {
if (getImpl() == nullptr) return true;
bool isValid{true};
for (size_t i = 0; i < 5; i++) {
std::vector<int64_t> begin = getIntArrayValue<int64_t>(this->getBegin(i));
std::for_each(begin.begin(), begin.end(), [&](int64_t b) {
if (b < 0) isValid = false;
});
std::vector<int64_t> stride = getIntArrayValue<int64_t>(this->getStride(i));
std::for_each(stride.begin(), stride.end(), [&](int64_t s) {
if (s < 0) isValid = false;
});
}
return isValid;
}
}];
}
As you can see, I use default generated printing and parsing hooks by table-gen. And in C++, it is used smoothly. But in python bindings, this is how I bind this attribute
static void populateDialectXMLSubmodule(py::module &m) {
mlir_attribute_subclass(m, "XShardingAttr", xmlAttrIsAShardingAttr)
.def_classmethod(
"get1",
[ ](py::object cls, MlirContext context, MlirAttribute begin0, MlirAttribute stride0, MlirAttribute begin1, MlirAttribute stride1,
MlirAttribute begin2, MlirAttribute stride2, MlirAttribute begin3, MlirAttribute stride3, MlirAttribute begin4, MlirAttribute stride4) {
return cls(xShardingAttrGet0(context, begin0, stride0, begin1, stride1, begin2, stride2, begin3, stride3, begin4, stride4));
})
.def_classmethod(
"get",
[ ](py::object cls, MlirContext context) {
return cls(xShardingAttrGet1(context));
})
.def(
"setBegin",
[ ](MlirAttribute self, int64_t axis, MlirAttribute begin) {
xShardingSetBegin(self, axis, begin);
})
.def_property_readonly(
"begin",
[](MlirAttribute self, int64_t axis) {
return xShardingGetBegin(self, axis);
}
)
;
}
I used two kinds of construct function of attribute “XShardingAttr”, which are get1 and get. And here is how I use them in python:
import xxmlir
with Context() as ctx:
ctx.allow_unregistered_dialects = True
xxmlir.register_dialects(ctx)
y = ArrayAttr.get([IntegerAttr.get(IntegerType.get_signed(32), x) for x in [1,2,3,4]] )
print(y)
xshard0 = XShardingAttr.get(ctx)
xshard = XShardingAttr.get1(ctx, y,y,y,y,y,y,y,y,y,y)
print(xshard)
And this is the error:
Can anyone tell me how to tackle this error and let the printing hook work correctly? Thanks in advance!