Add an expanded Load/Store Op in memref dialect

Currently, the semantics of Load/Store Op in std dialect is only able to accept multidim indices, and it will calculate the linear indices during std->llvm. The calculation of unnecessary linear_index → multidim_index and vice versa may be quite expensive in a lot of scenarios and causes severe performance issue.

For example, a fusion kernel in this pattern on GPU device:
mhlo.Add → mhlo.Reshape → mhlo.Add
multidim_index is not necessarily needed during the whole period.
But we would have to do a series of transformation like linear_index → multidim_index → linear_index due to the restriction of Load/Store semantics in std dialect.

My current proposal is to expand the definition of the Load/Store op: the linear_index/multidim_index are both optional but at least one should be provided ( except for scalar buffer operation).
The user can choose to use linear_index or multidim_index during a lowering or optimization pass.

let description = [{
An variant of the “load” op with optional linear index or multidim
index. Either linear or multidim index can be empty but at least one
of them must be provided, unless for load from scalar tensor.
%3 = load_linidx %0[][%1, %1] : memref<4x4xi32>
%3 = load_linidx %0[%2][%1, %1] : memref<4x4xi32>
%3 = load_linidx %0[%2][] : memref<4x4xi32>

let description = [{
An variant of the “store” op with optional linear index or multidim
index. Either linear or multidim index can be empty but at least one
of them must be provided, unless for store to scalar tensor.
store_linidx %v, %A[][%i, %j] : memref<4x128xf32>
store_linidx %v, %A[%m][%i, %j] : memref<4x128xf32>
store_linidx %v, %A[%m][] : memref<4x128xf32>

These can get the similar codegen result as the trick xla::IrArray::Index does.
And will reduce the amount of the indices calculations.
In our observations, this helps a lot in the performance of DISC.

Something unclear to me here is how do you define this linear index? How much this assumes that the memref are contiguous? Seems like there are assumptions made here which don’t generalize for memref.
I may also not follow what you’re trying to solve here, but the solution looks like more of a workaround right now than really designing the proper layers here.

Finally, if what you want is just pure 1D linear offset, can’t you just transform the memrefs by reshaping them into a 1D memref?

For the question "How much this assumes that the memref are contiguous? "
The user who generate the linear index should take care if the memref are contiguous.

The substantial problem is that, for the computation of multidim_index → linear_index, the user has no other choice but to handle it to std->llvm pass.

Reshape the memref into 1D is a considerable alternative solution, but not enough for me. During the lowering, I need a semantic that can represent “I’ve got linear_indnex, and i also got multidim_index. which one to use is up to the consequent lowering actions.” The trick helps in saving redundant
mul+add/div+rem computations, which can not be easily handled by lower level algebraic simplifier passes.

This is a trick already used in XLA. But for the MLIR, it seems to me the index calculation can not be fully controlled from hlo → std. That is the problem i’m looking for a better solution. The memref.load sementic should support if the user already has the linear_index, or both.

I have thought about optimizations along these lines, as well, and I thought about lowering load and store into two operations instead of having an additional linear index argument.

  1. compute_offset which takes an n-dimensional index and a (potentially static) shape and computes the offset into the linear backingstore of the memref.
  2. offset_load that performs the load.

Then you can replace the compute_offset with a linearized offset that you already have at hands, like the thread/blockid in the GPU case. This is along the lines of what is described in [1] and [2].

But as @mehdi_amini said, this gets more complicated with non-standard layouts. So maybe the compute_offset also needs a map and some modelling of strides. That would mean that it only becomes meaningful at the memref level (on tensors, we have no layout). So one can only do this split once the program is on memrefs, at which point we might already have lost the information that some index corresponds to a linearized form.

So if you need it on tensors, my suggestion would be to keep it in a specialized dialect. Then you have tight control over the lowering to memref and can ensure that the layout is valid. When it gets lowered to memrefs, you can split it into the two operations as described above.
Having a memref.offset_load in the memref dialect to give direct access to the backing store of a memref seems a reasonable abstraction to me.


Thanks for the attached materials.

I don’t need it on tensors, only for memref.
But my question is, can we have a op definition which can potentially have both multidim_index & linear_index(offset)? Finally we would only need one of them, but such a representation helps during the process of optimizations & simplification by explicitly having both in the IR.

If you only need it on memrefs, then this should live in the memref dialect, which makes the question of layouts simpler.

I’d still prefer to explicitly split index computations off using a separate op, but mainly because I would want that representation for other optimizations anyway. That op could take both indices, though. Or, if that is too special a case, the index computation op could live in a different dialect.

Do you have an example for why you need to keep around both, linear and n-dimensional index? Mostly out of curiosity.

An example:

add → mul → reshape → sub (root)

add → pad → mul → reshape → concatenate (root)

For fusion codegen, we generally lower the nodes from the root to the leaf nodes. (substantially same as XLA, refer to the attached figure).
For pattern_1, after the ‘reshape’ is lowered, we could have both linear_idx and multidim_index. But only linear_index will be used since no other leaf nodes need multidim_index hereafter. In this case if we keep linear_index for for ‘reshape’ would benifit in performance.
While for pattern_2, after the ‘reshape’ is lowered, we could have both linear_idx and multidim_index. But multidim_index will be needed later due to the characteristic of PadOp. In this case if we keep multidim_index for ‘reshape’ would benifit in performance.
So, at the time when ‘reshape’ is lowered, it helps if both of the index representation are kept.

Hi @linearhit

Thanks for posting about this topic, I’d love to see more details/analyses about current behavior and limitations.

Taking your examples:

%3 = load_linidx %0[%2][%1, %1] : memref<4x4xi32>

IIUC your proposal seems to want to introduce a double representation for indexing with the benefit of carrying the information that: "linear index %2" is statically known to be the same as "multi-dim index [%1, %1] applied to memref<4x4xi32>". In principle this is interesting but it also raises many many questions that I don’t want to unpack on you too early.

This is by design, to allow separation of concerns, in a first approximation.
The idea was essentially to avoid unpacking the extra complexity in the std dialect and above.
The extra complexity here would come from handling strides in the general non-contiguous case.
One of the goals was to avoid adding a memref.stride operation prematurely and instead use semantically charged ops that compose.(load/store indexing + memref type + subview).
This is a simplicity vs flexibility tradeoff: in the early days the focus was on getting anything working by restricting to a well-defined subclass of problems.

In principle, this sounds like a good place in space-time to extend the support, esp. if you bring concrete data to back this up.

Today the linear <-> multidim is captured by load/store indexing + memref type + subview operations. It is definitely possible to unpack this more at levels above the mlir.llvm dialect but it is a complexity tradeoff.

One expectation was that these type of dce / cse / strength reduction / PRE would be done by LLVM.
I would be interested in seeing a more detailed study with examples that illustrates the problem in isolation form the rest (for example here is an example of the level of detail I’d welcome to better understand the problem).

I am interested in seeing the fundamental information that your proposed abstraction allows reasoning on that is lost/hard to recover from std ops. For instance: is this a case of missing CSE / SR / PRE applied at the MLIR LLVM dialect or is there something more profound that you are capturing ?

Thanks for your example, I think I’d like to dig a little deeper here. Based on my prior experience with similar problems I have the impression that what you really want is a higher order representation of linearization / delinearization in some basis ?

%0:5 = arith.delinearize %0 basis(%a, %b, %c, %d, %e) : index ...
%1 = arith.linearize %(0#0, %0#1, %0#2, %0#3, %0#4) basis(%a, %b, %c, %d, %e) : index ...

The invariants would be that %0, %0#0 … %0#4 are within the expected bounds, otherwise that would be undefined behavior.

This type of abstraction is the basis of “unranked” codegen and simple library implementations of simple ops that run through reshapes. One essentially performs:

while(linear_index < lim) {
  idx1 = delinearize_index(linear_index, basis1)
  idx2 = delinearize_index(linear_index, basis2)
  idx3 = delinearize_index(linear_index, basis3)
  out[idx3] = f(in1[idx1], in2[idx2])

For instance, Torch7 (and even before) use that (pardon the CPP pasting…).

I can see how this type of abstraction would make it easier to optimize subcomputes across multiple lienarizations/delinearizations, especially if you perform some unrolling and some dimensions very only a little. This could even be easily vectorized and turn into gather/scatter where appropriate.

Is this close enough to the underlying thinking driving this ?

I can understand the underlying consideration, but, i still feel this is some kind of ‘issue’ for me. I may feel that it’s much easier to control the index calculation lowering in XLA, but much more complex in MLIR. MLIR should provide an expression flexible enough, but it should also provide the simplicity if the user doesn’t need to be too flexible. The very expressive MemRef model in my understanding is an example here.

Yes, an explicit linearize/delinearize abstraction is indeed a considerable solution. We may need some more op definitions such as partial_linearize but that’s definitely doable. I may need some more time to consider about the additional complexity and the benefit it brings. But thanks for the suggestion.

I am not sure how sharing historical context can be an issue? :slight_smile:
As I said just in the next sentence, this is a good time to extend and we’re welcoming ideas along this line of thought, thanks for sharing!

Would be great to unpack this thought a little more: would this partial_linearize be an abstraction needed for (1) added expressiveness, (2) simpler progressive optimization or (3) something else ?

Right, and conversely, it would be great to see examples of what the memref-level abstraction you have spelled out allows as well as the tradeoffs and complexities.
For now I have just inferred based on partial context.

I am particularly interested in what the transformations look like on the particular representation (i.e. taking a structured transformations-first look). Ideally this would be as much in isolation as possible from other concepts but there are tradeoffs everywhere.

For instance, atm I see memory abstractions as somewhat orthogonal. However, to avoid unpacking strides at the std level, taking a memref operand to carry the sizes and strides opaquely may be useful.
If we push along this direction, I can imagine we will want memref.subview, memref.transpose linearize/delinearize and strided layouts to form a class of concepts that are transformed and canonicalized together and vectorize nicely.

It is also possible this is not good enough and we would actually be better off unpacking the strides and do all this completely independently of memref types.

Another idea we have been thinking about is representing more symbolic information in the buffer type (i.e. relationships between sizes and strides). This has another set of tradeoffs and complexities.

Since you seem to already have functional implementations in your repo, could you please point as some examples that show the abstraction you are proposing and the concrete transformations it enables/how they look like in IR?

Thanks much!

I understand that for pattern_2 the multi-dimensional index is unavoidable. What I do not fully understand is why mixing multi-dimensional and linear indices in the same larger computation is a problem. One wants to avoid computing the same index multiple times (multiple delinearize operations), but if the index computation is a high-level operation like the compute_offset (or linearize) one can CSE these away.

I am not able to follow with just the abstract pattern, do you have sample of IR to illustrate? Or even “pseudo IR” that shows why you need both at the same time?

An intermediate state during our “InputInlineFusion” pass:
屏幕快照 2021-05-21 上午9.17.48

Here is an intermediate state during fusion. We are about to fuse %4 into the loop. After %4 is fused it should be something like this:
屏幕快照 2021-05-21 上午9.30.04

If we could potentially have both the linear_index and multidim_index, we can choose which one to use during the consequent actions, without worrying what the leaf nodes are.

In my understanding this is the same reason why xla::llvm_ir::IrArray::Index has both linear_ and dims_ as class members.

I agree this is not the only approach to get the same result. It might be regarded as a kind of syntactic sugar. But I also feel intuitive that the user should have the rights to represent in the IR if he already knows the offset and could guarantee that the offset is correct.

Can you expand on why is it useful to have the two indexing when fusing %4 then? I feel like I’m still missing some pieces here.

if we are going to fuse a hlo.pad in the future, multidim index is more helpful;
if we are going to fuse a hlo.reshape in the future, linear_index is more helpful;
if we are going to fuse an elementwise add in the future, both is acceptable and should be kept.

I understand that for pattern_2 the multi-dimensional index is unavoidable. What I do not fully understand is why mixing multi-dimensional and linear indices in the same larger computation is a problem. One wants to avoid computing the same index multiple times (multiple delinearize operations), but if the index computation is a high-level operation like the compute_offset (or linearize ) one can CSE these away.

CSE is ok enough for static shape scenarios, but for dynamic shape scenarios, it has more higher requirements for the works of previous passes. For example:
memref_1 <?x3x2xf32>
memref_2 <?x3x2xf32>
the two ‘?’ is infact the same, but this can only be known by CSE if the previous passes finished their work excellently. This is a goal, but however this is not always true for now. (the ‘shape constraint work’ on the road, which is off-topic for this thread)

If such problems can be fully solved, I believe compute_offset+offset_load can achieve the same effect as load_lin_idx proposed in this thread.

I’m going by memory here (CC @akuegel): in the past we’ve tried improving instcombine to cancel out delinearize / linearize pairs generated by XLA, and it was quite difficult even with static shapes. With dynamic shapes I think it would be basically impossible.

However, we don’t need to rely on LLVM here. We could introduce linearize and delinearize operations like Stephan mentioned and write a fairly straightforward MLIR pass to cancel these out when possible. This should let us have a simpler representation while getting the linear index based optimizations.

So the loop body for add → mul → reshape → sub (root) could initially be naively lowered to:

%linear_index = ... // from thread & block id

// Lowering for sub
%multidim_index_0 = delinearize(%linear_index, %shape0)

// sub recursively invokes generator for reshape

// reshape is from shape1 to shape0
%lin_index_0 = linearize(%multidim_index_0, %shape0)
%multidim_index_1 = delinearize(%multidim_index_0, %shape1)

// reshape recursively invokes generator for mul

// mul recursively invokes generator for add

// add recursively invokes generator for param

%lin_index_1 = linearize(%multidim_index_1, %shape1)
%value = load_offset ..., %lin_index_1

// actually do the scalar compute for add, mul and sub

Now %lin_index_0 can be replaced with %linearize_index easily (linearize(delinearize(val, shape), shape) => val). Similarly, %lin_index_1 can be replaced with %lin_index_0.

Thanks Sanjoy for the example. That’s pretty clear enough for the question.
I agree that by introducing linearize/delinearize is a feasible solution for the issue.
(BTW, Is this work already ongoing somewhere?)

But, we still need a memref.offset_load/store, are we aligned on this?

BTW2, as you said:

in the past we’ve tried improving instcombine to cancel out delinearize / linearize pairs generated by XLA, and it was quite difficult even with static shapes.

Can you explain a little bit more about the difficulty?

Not that I know of, I proposed it because it seemed like a natural solution to the problem and I’ve used this successfully in the past (outside of compilers).

If you send a patch, I’ll be happy to review!

linearize/delinearize and memref.offset_load/store seem much more manageable to me than the variant of the store that carry optionally both. I’m supportive of adding these!

1 Like