Calling generic function accepting any tensor

I am writing a jit-compiler using MLIR. In there I need to call (back) into a library. I would like some of these functions/callbacks to accept any dense, unstrided tensor (e.g. any shape, any element-type). At the same time, my jit-compiled function will return tensors which are possibly strided. I solved how cope with the latter, but I stuck with passing arbitrary tensors to an external function from within MLIR. Even though I can encode the element-type into the function signature, I have no good intuition of how to make this work in a generic way (without providing functions for any possible combination of (reasonable) values for ‘rank’ and ‘type’; this would give me at least 100 variants).

What’s the best way to get this done?

We have here a small runtime library that calls into a vendor-optimized one (tpp-sandbox/standalone-rt at main · plaidml/tpp-sandbox · GitHub). Note that tensors are an high-level concepts and you don’t have any memory associated with them, so you should pass memref. Also have a look on how Linalg lowers to function call (llvm-project/LinalgToStandard.cpp at 00874c48ea4d291908517afaab50d1dcbfb016c3 · llvm/llvm-project · GitHub)

Thanks @chelini for your quick response. From what I see the ttp-sandbox uses different functions per element-type. If possible, I would like to avoid that.

In any case, I am not sure I follow your suggestion. Probably I am missing something fundamental and/or basic.

I generate MLIR on the fly and then compile&run. The codes includes calls to a c-function so I generate something like

func.func private @_myfunc(tensor<##x##>) -> i64
%5 = call @_myfunc(%2) : (tensor<?xindex>) -> i64
%6 = call @_myfunc(%3) : (tensor<?x?xi32>) -> i64

I would not know what function-type to use for the prototype _myfunc (##x## is just a placeholder for illustration) which implies that I do not know how to define the actual _myfunc.

I believe you need to define your “Tensor Type” (think a C struct defining the ABI for your “Tensor”) and then have your JIT compiler emit calls to you library respecting this ABI contract.

1 Like

I see, thanks @mehdi_amini

My hope was that I do not need to basically rewrite the memref calling convention logic and instead somehow reuse what’s already there. Since I could pass rank and type separately it would be just fine to use the llmv-struct convention like the auto-generated c-wrappers do. Is there a way to tell existing passes to use this c-wrapper callling convention for (certain all even all) function calls instead of passing individual args (ptr, ptr, intptr, int[], int[])? Right now the c-wrappers for my jit-function as well as for the prototypes get generated nicely, but the callOps stick to the standard calling convention.

You can probably reuse the memref convention for the buffer part of your tensor, but you also mentioned you want to pass the element type. You could do it by lowering your Tensor to two pieces: a custom type representing the element type and an unranked memref.

For the calling convention (if I understood correctly your question) it is controlled by an attribute on the function itself to emit the C calling convention.

Wrappers are only what their names indicate – wrappers added post-hoc to simplify interfacing with C. They were never intended as a proper calling convention so internally MLIR will not be using them. One specific reason is the aliasing information that can only be added to an argument as a whole.

Is there a strong reason why using individual args is problematic for your use case? They are just the unpacked content of the struct that the C wrappers have.

Indeed, the separation of arguments is not the issue. The elemen-type is the issue. I need to provide a function prototype (funcOp), the callOps are then verified against the funcOp. I wouldn’t know which elementtype to use for the prototype since the various passes would fail if the element-type does not match. Is it somehow possible to have multiple funcOp with different signatures? Or is there an AnyType? The actual library-function could then accept void* instead of a specific type. I could handle it from there.

You can use !llvm.ptr (or !llvm.ptr<i8> if your version hasn’t switched to opaque pointers yet), this will be equivalent to void* after lowerings on both ends. This may look a bit frankensteinish, but this was already exercised by the sparse compiler tests.

1 Like

I might be missing something; I don’t think I follow.
The types in the function prototype must match the types if the callOp. The callop will pass a memref/tensor, so I must provide/generate a funcOp which also accepts memref/tensor. The ptr in the final calling convention is not controlled by any of my passes, I use the standard passes.

In XLA runtime we pass arbitrary memrefs to C++ handlers by:

  1. Encoding them as a LLVM struct with a memory layout that we control: tensorflow/ at master · tensorflow/tensorflow · GitHub

  2. On the C++ run-time side we just reinterpret cast a pointer and convert it to a user-friendly C++ type: tensorflow/custom_call.h at master · tensorflow/tensorflow · GitHub

Well, you can’t get non-standard convention from a standard lowering pass. If you need more information than the existing convention provides, you will have to implement that yourself. I think you only need to reimplement convert-func-to-llvm so that it passes more information across the function boundary, and then reconstructs the regular memref descriptor that other operations expect. @ezhulenev’s example does something similar but probably needs to be wrapped into RewritePatterns or some other IR rewriting mechanism.

Thanks @ezhulenev, from a quick look at the code I get the impression that the MemRefDescriptor is a custom thing which relies on manual/explicit memory allocation. Is this a valid understanding?

:slight_smile: Sure, I am aware of that.
I am looking for recommendations for getting this done quickly without reinventing the wheel. I just assumed something simple-looking like this would be available of-the-shelf. Basically all I need is to extract the raw pointer from the memref and cast it to some generic pointer (like void* or i8*). Requiring an extra pass feels like an overkill.

Memref descriptor is defined by MLIR, here llvm-project/MemRefBuilder.h at main · llvm/llvm-project · GitHub. It represents a (strided) memref in terms of LLVM dialect types. Its exact form is described here LLVM IR Target - MLIR. The allocation and management of the data pointed to by the memref is an orthogonal concern, it can be manually allocated/deallocated on heap, it can be allocated on function’s stack or at some finer-level granularity through ops with an automatic allocation scope trait. The descriptor itself almost always resides in virtual registers. The only exception is returning an unranked descriptor, in which case it must go through manually-deallocatable memory.

This is not as simple as it might look if you need to design it for the full generality of MLIR. For example, the MLIR type system is open, including what can be used as memref or tensor element type, so we would need to design some runtime representation of an arbitrary type that can be sent to a library, accompanied by some glue code specific to each target library.

It is simpler in restricted use cases, e.g., when the library accepts only a fixed predefined list of element types. The MLIResque way of connecting to such a library via LLVM would be to (1) define a !mydialect.typed_tensor type that conceptually corresponds to the collection of the data at the LLVM level that you would like to have, (2) define a mydialect.memref_to_typed_tensor operation and its lowering to the LLVM dialect that populates whatever LLVM-level data structure the library expects given its operands and attributes/types, and (3) define a mydialect.librarycall operation accepting !mydialect.typed_tensor along with regular types that lowers to a function call with the right signature. Operations (2) and (3) can be combined into a single one, but that would likely be too complex. This will require a custom dialect and a custom partial lowering pass, both are cheap in MLIR. We lean towards having more and more lightweight passes for improved modularity.

I have to post this at least once per month: memref is not (just) a pointer, it is dangerous to treat it as if it were, this will backfire. With that disclaimer in mind, there is the recently added memref.extract_strided_metadata operation that extracts base in the form of 0D memref. Adding a new downstream operation that converts that to a pointer as well as the corresponding lowering is likely the cheapest quick-and-dirty option. Note that it still involves a custom dialect and a partial lowering pass.

@ftynse Thanks so much for the detailed explanation. Even though I am definitely not capture all the detailed implications of a fully generic solution, I am aware of the fact that a memref is not just a pointer. Still, it seems useful to have a safe and clean way to convert a ‘compatible’ memref into something like (void*, #bytes). There are many libraries out there which accept such generic, untyped buffers. It’d be nice to be able to call them without going through too much hassle.

I also talked to @chelini. I will try the memref.extract_strided_metadata business. I’ll keep you posted.