Revisiting ownership and lifetime in the python bindings

Context:

Per discussion on the above PR and other observations, I think we need to revise our approach to ownership and validity in the python bindings. The pattern we have been following has been to use the py::keep_alive facility provided by pybind11 and avoid keeping explicit references back to the python-land hierarchy. This has resulted in a couple of problems:

  • It really only models direct parent-child relationships where the parent must always be an argument of the function (incl. this/self). This makes the API awkward and divergent from the C++ API and ownership model where, typically, objects are dependent on the MlirContext regardless of how they are created (i.e. if the creation method takes an implict MlirContext or can derive it through another means). To get around this, we introduce py::keep_alive on intermediate objects, which has the desired effect but is wasteful and imprecise, extending the lifetime of various temporaries and introducing more accounting lists and activities behind the scenes.
  • Without being able to track back from an MlirContext (which can always be derived) to a PyContext (the python-land wrapper), we are missing a place to perform other accounting, interning and the like. This makes it impossible to provide robust lifetime management for mutable objects, since there is no python-API side management point where we can add logic to, say, invalidate all mutable python wrappers after a bulk-mutation, or intern (Operation|Value)->python wrappers in a way that lets us invalidate wrappers in the presence of python-initiated erase calls.

I’m proposing that we make all objects in the Python API be able to directly retain a reference to the owning PyContext (which wraps an MlirContext), eschewing the py::keep_alive mechanism. I’m further proposing that we introduce new base classes for immutable/uniqued context-owned objects and mutable client-unowned objects. For the former, they have no restrictions on being accessed for their lifetime. For the latter, we would enforce a “generation check” at the PyContext level that will let us perform bulk invalidation when crossing boundaries where we can no longer be certain that pointers will not be left dangling (i.e. a good example would be when invoking the PassManager).

This also gives us a place to add more interning/accounting as needed with the goal of supporting arbitrary python mutations in a way that is safe/cannot produce dangling pointers, although elaborating on that is left to followups that introduce such mutations.

Here is some rough sample code of the class hierarchy:

// Holds a C++ PyContext and associated py::object, making
// it convenient to have an auto-releasing C++-side reference
// to a PyContext.
class PyContextRef {
public:
  PyContextRef(PyContext &referrent, py::object object)
    : referrent(referrent), object(std::move(object)) {
  }
  PyContext &operator->() { return referrent; }
  PyContext &referrent;
private:
  py::object object;
};

// Bound context, the root of the ownership hierarchy.
// MlirContexts are interned so that there can only be one
// PyContext for an MlirContext, and it is always possible to
// track back from an arbitrary MlirContext to its associated
// PyContext.
class PyContext {
public:
  // Returns a PyContext for an associated MlirContext, creating if
  // necessary.
  static PyContextRef forContext(MlirContext context) {
    // ... details to avoid spurious copies/conversions ...
    py::handle existing = ... lookup ...;
    if (!existing) {
      // create
    }
    return PyContextRef(nativeRef, pyRef);
  } 
  ~PyContext() {
    liveContexts.erase(context.ptr);
    mlirContextDestroy(context);
  }
  
  // The current "mutation generation". Mutable objects managed by the
  // context can only be validly accessed when the generation has not
  // changed.
  int getScopedGeneration() { return scopedGeneration; }
  
  // Invalidates all scoped objects. Should be called for any bulk
  // mutation that leaves the IR in an unknown state where pointers
  // may be left dangling.
  void invalidate() {
    scopedGeneration += 1;
  }
  
  MlirContext context;
private:
  PyContext(MlirContext context, py::handle handle)
    : context(context) {
    liveContexts[context.ptr] = handle;
  }  
  int scopedGeneration = 0;
  
  // Interns the mapping of live MlirContext::ptr to PyContext instances,
  // preserving the relationship that an MlirContext maps to a single
  // PyContext wrapper. This could be replaced in the future with an
  // extension mechanism on the MlirContext for stashing user pointers.
  // Note that this holds a handle, which does not imply ownership.
  // Mappings will be removed when the context is destructed.
  static llvm::DenseMap<void *, py::handle> liveContexts;
};

// Base class for all objects that directly or indirectly depend on an
// MlirContext. The lifetime of the context will extend at least to the
// lifetime of these instances.
// Immutable objects that depend on a context extend this directly.
class BaseContextObject {
public:
  BaseContextObject(MlirContext context) : context(PyContext::forContext(context)) {}
  
  PyContext &getContext() { return context.referrent; }
  
private:
  PyContextRef context;
};

class PyAttribute : BaseContextObject { ... };

// Base class for any "scoped" objects owned by some top level object
// associated with a context. This includes all mutable objects in the
// operation hierarchy whose state is indeterminate when a bulk IR mutation
// takes place. Such objects must guard all accesses with `checkValid()`
// which will ensure that the context has not had its managed, mutable
// object invalidated.
class ScopedContextObject : public BaseContextObject {
public:
  ScopedContextObject(MlirContext context) 
    : BaseContextObject(context), 
      currentGeneration(getContext()->getScopedGeneration()) {}
  
  // Checks that the object is still valid to access.
  void checkValid() {
    if (currentGeneration != getContext()->getScopedGeneration())
      throw py::raiseException(...);
  }
  
private:
  int currentGeneration;
};

class PyOperation : public ScopedContextObject { ... };

Alternatives considered:

  • Using a C+±side shared_ptr/shared_from_this to wrap the MLIRContext. I took this approach with the IREE and npcomp iteration of these bindings, and it turns into a real mess, making some of the lifetime issues slightly better but still leaving corner cases. Also shared_ptr is… not great… and then you end up with double reference counting. Also, the magic that makes it work only applies if directly wrapping the C++ API, not the C API.
  • Using a context manager and thread local tracking of activations, restricting visibility of mutable IR objects to withing the with context: block. From an API standpoint, I would rather avoid this if possible since it makes the API hard to use. On the implementation side, it also tends to require even heavier accounting, creates implicit bindings to the current thread, etc. I was able to convince myself that the above approach gives us the tools we need and we don’t need to use the bigger hammer of a with context manager to scope IR access. (note: we may still want a context manager for higher level APIs like EDSCs, etc, but it gets more natural at that level).

Any ideas/preferences? @ftynse @mehdi_amini

SGTM in general, it also sounds simpler as a model than what we currently have with keep_alive.

My main concern is avoiding wrong assumptions about validity. That is, we should either take extreme care and guarantee that when a Python is valid when it says so or have a big warning everywhere that invalidation only happens for “big” transformations. As a specific example, consider an op with a region containing other ops. We clear the region by calling the Python equivalent of region.getBlocks().clear(), but we may still have references to ops from that region and values defined by those ops. Do we invalidate everything on every object erasure (can be a lot of generations)?

On a technical side, mlir::Region and mlir::Block do not have a pointer to the context themselves. Rather, they get the context from the parent operation if any. Detached regions and blocks don’t have access to the context at all; arguably they cannot be invalidated by passes, but they can be invalidated by Python-triggered mutations. Region and block creation and their relation to the context is something that actually bothers me in the C++ API, but I don’t have a concrete proposal to change that atm.

Regarding context managers, I kind of like the idea of using them to construct nested regions if we want an equivalent of OpBuilder or EDSC (that is being increasingly based on OpBuilder) one day.

Thanks. I think the ownership change is a strict improvement and I’ll go forward with it.

It would be nice to not have a completely disjoint ownership model on the python side for these just because the c++ side was trying to save a pointer (which is not relevant here). Rooting these on a python context helps and leaves options open for the future, imo.

Agreed - that is a nice way to expose that kind of higher level API. I’d just like to avoid the baggage for low level access to the IR if possible.

I’ve been thinking the same way and oscillating on this point. I still don’t think it is possible to have a full featured, usable binding of the IR that is completely safe from the python side – at least without some fairly onerous accounting that the C/C++ side has no concept of.

As far as mutations are concerned, it is fairly easy to be safe with respect to reading+adding IR objects. The problems start to arise with bulk transformations and fine gained mutation/erasing. I would like a mechanism to make the former safe because it is so common (ie. Accumulate some IR, run a pass, do something else), and the mechanism here can work for that.

For the latter, I have been leaning towards requiring some further action before allowing unsafe, in place modification of the IR (ie. Raise an error if the action has not been taken). We could have such an “allow unsafe” mechanism in a couple of ways:

  • ctx.allow_unsafe_mutations()
  • with ctx.allow_unsafe_mutations():

In other words, it is going to be extremely hard to get all of these lifetime issues right for advanced IR manipulation, and I’d rather focus on the things that can be made nice/safe and have such a big opt-in to enable the rest once we start adding them.

Sure, I was more warning that the bindings will have to keep track of context explicitly without trying to get it from MlirRegion / MlirBlock.

I have not yet convinced myself either way. Trying to think about the issue, there does not seem to be that many “core” IR manipulations that can invalidate things. First, context-owned objects such as attributes and types are essentially immortal if they hold a back reference to the context. The objects that have their backing pointers invalidated are ops, regions, blocks and values. And there does not seem to be that many mutations that change the pointers. Erasing objects is a clear invalidation of the object and everything that it contains (nested regions etc. and defined values). Moving operations/blocks/regions around does not invalidate. One cannot add new results of an operation without recreating the operation, so it’s hard to invalidate a value. (It’s possible to erase block arguments, which are the other type of values, but we can track that). With regions and blocks its sounds like the only invalidating action is erasing them. It will get more problematic if we expose higher-level APIs on top of core IR, for example pattern rewriters. Then we won’t be able to catch all the invalidating mutation.

Regarding big vs small mutations, how does one define “big”? Erasing a module does not sound as big as running a pass pipeline, but may ultimately invalidate more objects.

“Big” is probably the wrong word. “Opaque mutation”? We can’t make any guarantees about liveness, even if we had pointers back from every c++ object to every python object.

You may be right, and if so, the required mechanism for fine grained invalidation has to be python level tracking back of native pointer to python instance for mutable objects. This could be implemented as a map on the PyContext associating native pointer to live python object. We might be able to get away with just doing this for operations and values at the context level and maintaining local maps on the operation for each. We could add such tracking pretty incrementally given the small number of types involved, and the point that we have to make the decision is when adding one of the problematic mutations.

I still haven’t thought through each of those mutations fully to understand just how gross the accounting updates will be (ie. There would seem to be some reachability barriers that would simplify invalidation walks but I’d need to spend more time considering).

FWIW, I agree with everything here @stellaraccident ; I’m also particularly interested by the with ctx.allow_unsafe_mutations(): “escape hatch” which could provide a convenient way to skip the generation invalidation when “you know what you’re doing”.

First pass here: https://reviews.llvm.org/D87886

It was subtler than my first attempt with respect to interoping with the pybind ownership tracking, but I poked at it in some detail and believe this is correct.