Add an expanded Load/Store Op in memref dialect

I am not sure how sharing historical context can be an issue? :slight_smile:
As I said just in the next sentence, this is a good time to extend and we’re welcoming ideas along this line of thought, thanks for sharing!

Would be great to unpack this thought a little more: would this partial_linearize be an abstraction needed for (1) added expressiveness, (2) simpler progressive optimization or (3) something else ?

Right, and conversely, it would be great to see examples of what the memref-level abstraction you have spelled out allows as well as the tradeoffs and complexities.
For now I have just inferred based on partial context.

I am particularly interested in what the transformations look like on the particular representation (i.e. taking a structured transformations-first look). Ideally this would be as much in isolation as possible from other concepts but there are tradeoffs everywhere.

For instance, atm I see memory abstractions as somewhat orthogonal. However, to avoid unpacking strides at the std level, taking a memref operand to carry the sizes and strides opaquely may be useful.
If we push along this direction, I can imagine we will want memref.subview, memref.transpose linearize/delinearize and strided layouts to form a class of concepts that are transformed and canonicalized together and vectorize nicely.

It is also possible this is not good enough and we would actually be better off unpacking the strides and do all this completely independently of memref types.

Another idea we have been thinking about is representing more symbolic information in the buffer type (i.e. relationships between sizes and strides). This has another set of tradeoffs and complexities.

Since you seem to already have functional implementations in your repo, could you please point as some examples that show the abstraction you are proposing and the concrete transformations it enables/how they look like in IR?

Thanks much!

I understand that for pattern_2 the multi-dimensional index is unavoidable. What I do not fully understand is why mixing multi-dimensional and linear indices in the same larger computation is a problem. One wants to avoid computing the same index multiple times (multiple delinearize operations), but if the index computation is a high-level operation like the compute_offset (or linearize) one can CSE these away.

I am not able to follow with just the abstract pattern, do you have sample of IR to illustrate? Or even “pseudo IR” that shows why you need both at the same time?

An intermediate state during our “InputInlineFusion” pass:
屏幕快照 2021-05-21 上午9.17.48

Here is an intermediate state during fusion. We are about to fuse %4 into the loop. After %4 is fused it should be something like this:
屏幕快照 2021-05-21 上午9.30.04

If we could potentially have both the linear_index and multidim_index, we can choose which one to use during the consequent actions, without worrying what the leaf nodes are.

In my understanding this is the same reason why xla::llvm_ir::IrArray::Index has both linear_ and dims_ as class members.

I agree this is not the only approach to get the same result. It might be regarded as a kind of syntactic sugar. But I also feel intuitive that the user should have the rights to represent in the IR if he already knows the offset and could guarantee that the offset is correct.

Can you expand on why is it useful to have the two indexing when fusing %4 then? I feel like I’m still missing some pieces here.

if we are going to fuse a hlo.pad in the future, multidim index is more helpful;
if we are going to fuse a hlo.reshape in the future, linear_index is more helpful;
if we are going to fuse an elementwise add in the future, both is acceptable and should be kept.

I understand that for pattern_2 the multi-dimensional index is unavoidable. What I do not fully understand is why mixing multi-dimensional and linear indices in the same larger computation is a problem. One wants to avoid computing the same index multiple times (multiple delinearize operations), but if the index computation is a high-level operation like the compute_offset (or linearize ) one can CSE these away.

CSE is ok enough for static shape scenarios, but for dynamic shape scenarios, it has more higher requirements for the works of previous passes. For example:
memref_1 <?x3x2xf32>
memref_2 <?x3x2xf32>
the two ‘?’ is infact the same, but this can only be known by CSE if the previous passes finished their work excellently. This is a goal, but however this is not always true for now. (the ‘shape constraint work’ on the road, which is off-topic for this thread)

If such problems can be fully solved, I believe compute_offset+offset_load can achieve the same effect as load_lin_idx proposed in this thread.

I’m going by memory here (CC @akuegel): in the past we’ve tried improving instcombine to cancel out delinearize / linearize pairs generated by XLA, and it was quite difficult even with static shapes. With dynamic shapes I think it would be basically impossible.

However, we don’t need to rely on LLVM here. We could introduce linearize and delinearize operations like Stephan mentioned and write a fairly straightforward MLIR pass to cancel these out when possible. This should let us have a simpler representation while getting the linear index based optimizations.

So the loop body for add → mul → reshape → sub (root) could initially be naively lowered to:

%linear_index = ... // from thread & block id

// Lowering for sub
%multidim_index_0 = delinearize(%linear_index, %shape0)

// sub recursively invokes generator for reshape

// reshape is from shape1 to shape0
%lin_index_0 = linearize(%multidim_index_0, %shape0)
%multidim_index_1 = delinearize(%multidim_index_0, %shape1)

// reshape recursively invokes generator for mul

// mul recursively invokes generator for add

// add recursively invokes generator for param

%lin_index_1 = linearize(%multidim_index_1, %shape1)
%value = load_offset ..., %lin_index_1

// actually do the scalar compute for add, mul and sub

Now %lin_index_0 can be replaced with %linearize_index easily (linearize(delinearize(val, shape), shape) => val). Similarly, %lin_index_1 can be replaced with %lin_index_0.

Thanks Sanjoy for the example. That’s pretty clear enough for the question.
I agree that by introducing linearize/delinearize is a feasible solution for the issue.
(BTW, Is this work already ongoing somewhere?)

But, we still need a memref.offset_load/store, are we aligned on this?

BTW2, as you said:

in the past we’ve tried improving instcombine to cancel out delinearize / linearize pairs generated by XLA, and it was quite difficult even with static shapes.

Can you explain a little bit more about the difficulty?

Not that I know of, I proposed it because it seemed like a natural solution to the problem and I’ve used this successfully in the past (outside of compilers).

If you send a patch, I’ll be happy to review!

linearize/delinearize and memref.offset_load/store seem much more manageable to me than the variant of the store that carry optionally both. I’m supportive of adding these!

1 Like

I have yet to see however how memref.offset_load/store is an improvement over reshape + load; otherwise it is just another abstraction to do the same thing ?

Is the reshape always “free” (as in will just generate a cast)? Or is it a special case of “reshape to 1D”? I can’t say I’m sure I understand how to reshape memref<?x?xf32> to memref<?xf32> with our current memref.reshape?
Unless you’re thinking of using memref.reinterpret_cast somehow?

That would be linalg.reshape which is being refactored and split into memref.expand_reshape and memref.collapse_reshape as per this RFC.

Depending on where allocation occurs, there is also the possibility to
alloc + memref.view + subview 2-D + subview 1-D. This may need to evolve depending on what alias analysis looks like for OP’s transformations.

1 Like

@linearhit another thing I am wondering is: since you operate on buffers, what does your analysis look like to allow the fusions to occur (e.g. what about control-flow and alias analysis to avoid operations that operate on the same buffer bypass each other) ?

Have you thought of (or tried) applying similar transformations in the tensor domain where SSA use-def chains give you many nice guarantees ?

An added bonus I would see operating in the tensor domain is you could have more control over the buffering scheme. The layout in memory could even potentially be choosen such that many of the linearization/delinearization are statically known to reduce to pointer increments.

For your concern on buffer related issues: lhlo_fusion happens when mhlo is just transformed to lmhlo, by this time no buffer optimizations or control flow lowering has ever been done. So we can guarantee that each buffer has only one writer in the control flow region, which is substantially similar to SSA. So currently it’s OK for now.
We will re-visit fusion in future in order to support ‘shape constraint’ and other features of shape dialect. I think it’s highly possible that fusion will be moved back to hlo by then. There were some historical reasons for a fusion pass on LHLO, sooner or later we’ll involve shape dialect and we’ll reconsider it by then.

Do you have any suggestions on which dialect to put linearize/delinearize into?
Is memref::IndexLinearizeOp & memref::IndexDelinearizeOp a good proposal?

memref::LinearizeIndexOp and memref::DelinearizeIndexOp sounds good to me. But did you conclude on the offset_load/store ops? Like @nicolasvasilache mentions, why do we need these when you can do reshape to 1-d + load. This looks natural and fits into the logical reshape abstractions that already exist. Also, when lowered, wouldn’t you get identical IR? (ptr + offset ultimately)

This reshape as I understood it is just a logical reshape of the shape. So it won’t by itself lead to memory traffic but only gets pulled into the access subscripts.

I’d expect to see a tensor equivalent at some point but memref::LinearizeIndexOp and memref::DelinearizeIndexOp SGTM too.

In the contiguous case (i.e. canonical strides), this is true and seems relatively easy.

Still, I’d expect the abstraction to work with strided memref with dynamic strides (whether we want to unpack the values at LLVM or before still TBD).

This is where it gets trickier: memref.expand_reshape and memref.collapse reshape can only manipulate contiguous dimensions (i.e. there is no representation for representing say a 4-D non-contiguous subarray with a 1-D strided memref).

For such cases, I expect memref.alloc + memref.view + memref.subview n-D + memref.subview 1-D + memref::LinearizeIndexOp + memref::DelinearizeIndexOp to do the job.

Still, I don’t expect @linearhit to have such cases yet, given the IR I have seen so far.

An alternative is to unpack strides as SSA values and manipulate them more directly with e.g. memref.stride and affine maps such as

affine_map<(i)[M, N, K] -> (i / N * K, i mod (N * K) / K, i mod K)>

but I am not confident unpacking this complexity will be easy to recover. It seems to me that representing this complexity in a more controllable structured form is the whole point of the exercise, so this seems it would defeat the purpose.

1 Like

“reshape to 1D and then load/store” should be OK for my case. Will it be better to have an explicit semantic? Not sure what do others think but i feel ‘reshape’ is somewhat counter-intuitive. It’s definitely acceptable anyway.