Custom Attribute printing hook in Python bindings

Hi all,
I have the following definition of a custom attribute:

def Xml_ShardingAttr : Xp_Attr<Xml_Dialect, "XSharding", [ShardingTrait]> {
    let summary = "sharding info for tensor";
    let description = [{store sharding info for tensor, param beginX means start index at axis X,
      param strideX means stride at axis X.
      X should < tensor shape dims.
      Optional means no sharding at X axis.}];
    let parameters = (ins
      OptionalParameter<"::mlir::ArrayAttr">:$begin0,
      OptionalParameter<"::mlir::ArrayAttr">:$stride0,
      OptionalParameter<"::mlir::ArrayAttr">:$begin1,
      OptionalParameter<"::mlir::ArrayAttr">:$stride1,
      OptionalParameter<"::mlir::ArrayAttr">:$begin2,
      OptionalParameter<"::mlir::ArrayAttr">:$stride2,
      OptionalParameter<"::mlir::ArrayAttr">:$begin3,
      OptionalParameter<"::mlir::ArrayAttr">:$stride3,
      OptionalParameter<"::mlir::ArrayAttr">:$begin4,
      OptionalParameter<"::mlir::ArrayAttr">:$stride4
    );
    let assemblyFormat = "`<` struct(params) `>`";
    let builders = [
      AttrOrTypeBuilder<(ins), [{
        return $_get($_ctxt);
      }]>,
      AttrOrTypeBuilder<(ins "XShardingAttr":$cloneAttr), [{
        XShardingAttr newAttr = $_get($_ctxt);
        if (cloneAttr == nullptr) return newAttr;
        for (size_t i = 0; i < 5; i++) {
          if (cloneAttr.getStride(i) == nullptr) newAttr.setStride(i, nullptr);
          if (cloneAttr.getBegin(i) == nullptr) newAttr.setBegin(i, nullptr);
          if (cloneAttr.getStride(i) != nullptr) {
            newAttr.setStride(i, cloneAttr.getStride(i));
          }
          if (cloneAttr.getBegin(i) != nullptr) {
            newAttr.setBegin(i, cloneAttr.getBegin(i));
          }
        }
        return newAttr;
      }]>
    ];
    let extraClassDeclaration = [{
      void setBegin(size_t axis, const std::vector<int64_t> &beginIdx);
      void setBegin(size_t axis, ::mlir::ArrayAttr beginAttr);
      void setStride(size_t axis, const std::vector<int64_t> &stride);
      void setStride(size_t axis, ::mlir::ArrayAttr strideAttr);
      bool isValid();
    }];
    let extraClassDefinition = [{
      void $cppClass::setBegin(size_t axis, ::mlir::ArrayAttr beginAttr) {
        if (axis > 4) {
          LOGE("invalid axis %zu\n", axis);
          return;
        }
        if (beginAttr == nullptr) return;
        if (axis == 0) getImpl()->begin0 = beginAttr;
        else if (axis == 1) getImpl()->begin1 = beginAttr;
        else if (axis == 2) getImpl()->begin2 = beginAttr;
        else if (axis == 3) getImpl()->begin3 = beginAttr;
        else if (axis == 4) getImpl()->begin4 = beginAttr;
      }

      void $cppClass::setBegin(size_t axis, const std::vector<int64_t> &beginIdx) {
        if (beginIdx.empty()) return;
        Builder builder(this->getContext());
        setBegin(axis, builder.getI64ArrayAttr(beginIdx));
      }

      void $cppClass::setStride(size_t axis, ::mlir::ArrayAttr strideAttr) {
        if (axis > 4) {
          LOGE("invalid axis %zu\n", axis);
          return;
        }
        if (strideAttr == nullptr) return;
        if (axis == 0) getImpl()->stride0 = strideAttr;
        else if (axis == 1) getImpl()->stride1 = strideAttr;
        else if (axis == 2) getImpl()->stride2 = strideAttr;
        else if (axis == 3) getImpl()->stride3 = strideAttr;
        else if (axis == 4) getImpl()->stride4 = strideAttr;
      }

      void $cppClass::setStride(size_t axis, const std::vector<int64_t> &stride) {
        if (stride.empty()) return;
        Builder builder(this->getContext());
        setStride(axis, builder.getI64ArrayAttr(stride));
      }

      bool $cppClass::isValid() {
        if (getImpl() == nullptr) return true;
        bool isValid{true};
        for (size_t i = 0; i < 5; i++) {
          std::vector<int64_t> begin = getIntArrayValue<int64_t>(this->getBegin(i));
          std::for_each(begin.begin(), begin.end(), [&](int64_t b) {
            if (b < 0) isValid = false;
          });
          std::vector<int64_t> stride = getIntArrayValue<int64_t>(this->getStride(i));
          std::for_each(stride.begin(), stride.end(), [&](int64_t s) {
            if (s < 0) isValid = false;
          });
        }
        return isValid;
      }
    }];
}

As you can see, I use default generated printing and parsing hooks by table-gen. And in C++, it is used smoothly. But in python bindings, this is how I bind this attribute

static void populateDialectXMLSubmodule(py::module &m) {
  mlir_attribute_subclass(m, "XShardingAttr", xmlAttrIsAShardingAttr)
  .def_classmethod(
    "get1",
    [  ](py::object cls, MlirContext context, MlirAttribute begin0, MlirAttribute stride0, MlirAttribute begin1, MlirAttribute stride1, 
    MlirAttribute begin2, MlirAttribute stride2, MlirAttribute begin3, MlirAttribute stride3, MlirAttribute begin4, MlirAttribute stride4) {
      return cls(xShardingAttrGet0(context, begin0, stride0, begin1, stride1, begin2, stride2, begin3, stride3, begin4, stride4));
    })
  .def_classmethod(
    "get",
    [  ](py::object cls, MlirContext context) {
      return cls(xShardingAttrGet1(context));
    })
  .def(
    "setBegin",
    [   ](MlirAttribute self, int64_t axis, MlirAttribute begin) {
      xShardingSetBegin(self, axis, begin);
    })
  .def_property_readonly(
    "begin",
    [](MlirAttribute self, int64_t axis) {
      return xShardingGetBegin(self, axis);
    }
  )
  ;
}

I used two kinds of construct function of attribute “XShardingAttr”, which are get1 and get. And here is how I use them in python:

import xxmlir

with Context() as ctx:
  ctx.allow_unregistered_dialects = True
  xxmlir.register_dialects(ctx)

  y = ArrayAttr.get([IntegerAttr.get(IntegerType.get_signed(32), x) for x in [1,2,3,4]] )
  print(y)
  xshard0 = XShardingAttr.get(ctx)
  xshard = XShardingAttr.get1(ctx, y,y,y,y,y,y,y,y,y,y)
  print(xshard)

And this is the error:

Can anyone tell me how to tackle this error and let the printing hook work correctly? Thanks in advance!

You mentioned that it works in C++, do you have an opt tool like mlir-opt that can print the attributes as expected? I’m not sure what is different about Python here, so I’d expect the same issue using C++ based tools. I’m assuming xxmlir.register_dialects(ctx) will register the dialect where this attribute is defined.

I think we’ll need to see more about how your dialect is set up in C++. Take a look at the options here: Defining Dialect Attributes and Types - MLIR.

You mentioned using the default printing and parsing hooks generated by ODS, but it seems from the error that your dialect doesn’t define printAttribute: https://github.com/llvm/llvm-project/blob/76f20099a5ab72a261661ecb545dceed52e5592d/mlir/include/mlir/IR/Dialect.h#L103. I think you’re either need to define printAttribute and call the generated hook, or use useDefaultAttributePrinterParser. But if you’re using useDefaultAttributePrinterParser, you’ll also need to define a mnemonic on all your attributes. I don’t see a mnemonic in the example you shared.

Thanks for answering, Mike.

The first question, yes, I have an opt tool named xx-mlir-opt. And this is how I defined the custom attribute. Following is the dialect definition, where I set up the useDefaultAttributePrinterParser. And it worked nicely in C++.
image
And I think this is the mnemonic?


And here is the ODS generated printer and parser

class XShardingAttr : public ::mlir::Attribute::AttrBase<XShardingAttr, ::mlir::Attribute, detail::XShardingAttrStorage, ::mlir::AttributeTrait::ShardingTrait> {
public:
  using Base::Base;
  void setBegin(size_t axis, const std::vector<int64_t> &beginIdx);
  void setBegin(size_t axis, ::mlir::ArrayAttr beginAttr);
  void setStride(size_t axis, const std::vector<int64_t> &stride);
  void setStride(size_t axis, ::mlir::ArrayAttr strideAttr);
  bool isValid();
public:
  static XShardingAttr get(::mlir::MLIRContext *context, ::mlir::ArrayAttr begin0, ::mlir::ArrayAttr stride0, ::mlir::ArrayAttr begin1, ::mlir::ArrayAttr stride1, ::mlir::ArrayAttr begin2, ::mlir::ArrayAttr stride2, ::mlir::ArrayAttr begin3, ::mlir::ArrayAttr stride3, ::mlir::ArrayAttr begin4, ::mlir::ArrayAttr stride4);
  static XShardingAttr get(::mlir::MLIRContext *context);
  static XShardingAttr get(::mlir::MLIRContext *context, XShardingAttr cloneAttr);
  static constexpr ::llvm::StringLiteral getMnemonic() {
    return {"XSharding"};
  }

  static ::mlir::Attribute parse(::mlir::AsmParser &odsParser, ::mlir::Type odsType);
  void print(::mlir::AsmPrinter &odsPrinter) const;
  ::mlir::ArrayAttr getBegin0() const;
  ::mlir::ArrayAttr getStride0() const;
  ::mlir::ArrayAttr getBegin1() const;
  ::mlir::ArrayAttr getStride1() const;
  ::mlir::ArrayAttr getBegin2() const;
  ::mlir::ArrayAttr getStride2() const;
  ::mlir::ArrayAttr getBegin3() const;
  ::mlir::ArrayAttr getStride3() const;
  ::mlir::ArrayAttr getBegin4() const;
  ::mlir::ArrayAttr getStride4() const;
};

One more thing, I found the default generated printer and parser in my xxxAttr.cpp.inc
/// Parse an attribute registered to this dialect.

::mlir::Attribute XmlDialect::parseAttribute(::mlir::DialectAsmParser &parser,
                                      ::mlir::Type type) const {
  ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
  ::llvm::StringRef attrTag;
  {
    ::mlir::Attribute attr;
    auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
    if (parseResult.has_value())
      return attr;
  }
  
  parser.emitError(typeLoc) << "unknown attribute `"
      << attrTag << "` in dialect `" << getNamespace() << "`";
  return {};
}
/// Print an attribute registered to this dialect.
void XmlDialect::printAttribute(::mlir::Attribute attr,
                         ::mlir::DialectAsmPrinter &printer) const {
  if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
    return;
  
}

Are these generated printer and parser not registered in Python? How can I use them in Python in order to correctly print and parse the custom built Attributes?

Does the xx-mlir-opt tool print/parse your custom attribute correctly?

This is how it will be named in generated C++, not the mnemonic. The mnemonic is a field of the declaration: Defining Dialect Attributes and Types - MLIR

let mnemonic = "xsharding";

Or something like that. I think you need to define this for the generated printer/parser to work.

I defined the mnemonic field in another base td.

class Xp_Attr<Dialect dialect, string name, list<Trait> traits = []>
    : AttrDef<dialect, name, traits> {
    let mnemonic = name;
}

So I think “xsharding” will be automatically added to mnemnoc field.
Yes, it works in xx-mlir-opt, and yes it can be parsed and printed correctly. But in python environment, it issued that error.

How do you build the python bindings? There are known issues when downstream dialect bindings live in a separate shared object from the upstream dialect bindings. Basically, each shared object includes its own copy of MLIRContext / DialectRegistry and sees only a part of dialects. The common solutions to that is to either put everything in one shared object or link both shared objects dynamically to libMLIR.so.

Thanks for sharing. I build the upstream llvm/mlir seperately from my downstream dialects and their python bindings. I knew CIRCT build python bindings with LLVM in a unified way.
But I don’t quite understand the solution you provide clearly

The common solutions to that is to either put everything in one shared object or link both shared objects dynamically to libMLIR.so.

Can you give a simple example? Thanks!!!

The idea is to do exactly the same as CIRCT.

IREE also seems to have a monolithic extension: https://github.com/openxla/iree/blob/main/llvm-external-projects/iree-dialects/python/CMakeLists.txt#L68-L105.

Thanks! I have tried to learn how CIRCT build python bindings, but I failed to understand and port their methods into my own project. Since CIRCT is a big project, I can not extract the very essential building blocks from their CMakeLists.txt.

Would you please kindly elaborate that? Thanks a lot!

I have no idea how CIRCT is organized :wink: The IREE example is probably smaller.

Okay, Thank you so much

Just sending a pointer, this is where CIRCT sets up its monolithic extension, same as the IREE example above: https://github.com/llvm/circt/blob/d03a80ccde146b4186a97232d4ef8b2eb5a32348/lib/Bindings/Python/CMakeLists.txt#L142-L167

Thank you, Mike. I’ll try this out.