[RFC]Add op for semantics of nullptr in memref dialect

Hi everyone, I’m an undergraduate student submitting an RFC to the community for the first time.I’m happy to be here to talk to you all.Feel free to comment below if you think there’s a problem, thanks.

Overview

Previously my research was based on RISC-V AI chips.For a matrix multiplication or convolution operation, bias may be incorporated in the chip.For example, ucb-bar/gemmini: Berkeley’s Spatial Array Generator (github.com).When doing matrix multiplication, you can pass in a nullptr, which means no bias, but when using MLIR for support, you need to pass in a bias even if you don’t have one, because MLIR’s memref dialect doesn’t support nullptr,this causes unnecessary performance loss.The specific reasons are some chips rely on addresses for data transfer.The chip is implemented using the ROCC interface.The instruction set of the chip is an extension of the RISC-V R type.The rs1 and rs2 fields may be an address.

memref.null

Call it that for now.This op is able to get a memref, but the address of this memref is nullptr, and when the address is taken for it, it will get nullptr.Here are the specific tests.

// test.mlir
func.func @main() {
  %1 = memref.null : memref<4x4xf32>
  %2 = memref.extract_aligned_pointer_as_index %1 : memref<4x4xf32> -> index
  vector.print %2 : index
  return
}
.// .mlir-opt test.mlir -convert-func-to-llvm  -finalize-memref-to-llvm -reconcile-unrealized-casts  
module {
  llvm.func @main() {
    %0 = llvm.mlir.constant(4 : index) : i64
    %1 = llvm.mlir.constant(4 : index) : i64
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = llvm.mlir.constant(16 : index) : i64
    %4 = llvm.mlir.zero : !llvm.ptr
    %5 = llvm.getelementptr %4[16] : (!llvm.ptr) -> !llvm.ptr, f32
    %6 = llvm.ptrtoint %5 : !llvm.ptr to i64
    %7 = llvm.mlir.zero : !llvm.ptr
    %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %11 = llvm.mlir.constant(0 : index) : i64
    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %13 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %14 = llvm.insertvalue %1, %13[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %15 = llvm.insertvalue %1, %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %16 = llvm.insertvalue %2, %15[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %17 = llvm.ptrtoint %7 : !llvm.ptr to i64
    llvm.return
  }
}

advantages

  1. Enriches the semantics of memref.memref provides semantics in terms of memory, but it is not complete.I think implementing this in memref is modest, not really radical.
  2. Extended memref.extract_aligned_pointer_as_index functionality.
  3. Many companies are now using MLIR to build their AI compilers, with a wide variety of chip implementations, and MLIR is useful as a compiler infrastructure, I think it’s useful to provide nullptr functionality at the memref level, and Gemmini shouldn’t be a special case.I shouldn’t be the only one to benefit.

It sounds generally reasonable to me to have a way to produce a null memref (it’s already possible by unrealized_conversion_casting an appropriately-typed null pointer if the lowering target is known to be, e.g., LLVM).

It would be nice to take this opportunity to better specify the semantics of null memrefs. I suppose accessing any element in such a memref should be treated as undefined behavior. Do we want the load to result in poison?

Separately, what should the lowering of memref.null be? Memref constructed from a null pointer, but what offset/sizes/strides?

One thing crossed my mind is to use it in the DPS for the shape semantics. Currently, we emit a tensor.empty() which becomes a memref.alloc(), but if we had a memref.null, we could emit that instead for when the DPS operation does not write to it.

If the operation does read/write from/to it, then bufferization would still emit an allocation and converting that allocation to null (via other means) would yield poison for sure.

Null pointers are useful at run time to compare with null (ex. branching), but shapes are lost. If we want to do shape/type checks at runtime too (ex. for RTTI), then we need to store that data somewhere with the null pointer. Though, this would be a dynamic list for unranked memrefs and could potentially be updated at run time for dynamic shapes.

Why “null” and not “undef” or “poison”? There is a slight semantics difference here, and it’s not clear to me when “null” memref is desirable? In particular the example construct a “null” memref with a shape (I understand the assumptions about the runtime representation of memref but still, this gets into semantics as @ftynse mentioned).
The motivation isn’t immediately clear to me from the RFC just now to really navigate these options.

Why wouldn’t load of nullptr (or poison) be just immediate UB?

1 Like

Here are my thoughts,what I can think of isn’t really that much, and may even be a bit superficial.
@mehdi_amini Here’s what I’ve personally encountered.It’s also my motivation.

matmul %1 % 2 % 3 {...} : memref memref memref  // input0 input1 bias

 // lower process
addr0 = memref.extract_aligned_pointer_as_index %1 
addr1 = memref.extract_aligned_pointer_as_index %2 
addr2 = memref.extract_aligned_pointer_as_index %3 

// set addrs to registers.
  1. @ftynse The name memref.null was taken just because I couldn’t think of a better name.I think memref.undef or memref.poison might be more appropriate
  2. Probably don’t need to care about offset/sizes/strides, the values here can be set any way you want depending on the person using it, happy to make suggestions if there are any good ones.
  3. It might be possible to analyze his reads and report errors.Or the program performs just like that, leaving the responsibility to those who use it.Because that would also be consistent with its semantics.
  4. Why add this Op to the memref dialect?
    i. Memref is related to memory,this Op addition is mild in memref dialect because he is directly related to memref dialect and is not very aggressive.
    ii.There are other ways to get null pointers, but isn’t that a bit hacky?
    iii.This was actually mentioned earlier. A lot of AI chips are now using MLIR, and I’m not the only beneficiary of providing such a feature.As @rengolin said, it might be useful elsewhere.Honestly, there were some things he said that I didn’t understand.

Finally, thank you for leaving comments!

At that level of abstraction, I vote for an optional< memref >. It is a memref or not. Nullptr at the level seems weird to me.

Honestly, I thought about doing that. But there would be other issues as well, you might be passing in other SSA values that just aren’t reflected in the example here. Given the definition of Op, I’m starting to envision such functionality existing in memref dialect.

Wouldn’t this imply a ternary state? Seems a bit more than just undef, poison or null.

Happy to clarify, here or on discord. Though, I was more wondering about possibilities than making hard statements. Quite possibly, what I said makes no sense.

1 Like

What does LLVM IR have for pointers? Are there any lessons we can draw from that?

I don’t have a specific reason one way or another, just asking.

We can already construct a null memref, along the lines of: (1) create a null pointer, (2) shove it into the llvm struct compatible with the descriptor; (3) unrealized_conversion_cast to memref. This guarantees a null memref upon lowering. So I’m afraid the argument that it should be a non-nullable reference may have been already lost.

Is the code you intended to paste incomplete? I only see 4 lines of code and 2 comments. So I still don’t see the motivating example?

undef and poison are available in the UB dialect, aren’t these enough here?

I don’t necessarily agree: you don’t get a “null memref” upon lowering, you get a structure with a nullptr. The subtle different is that we’re in a different level of abstraction and a different type system. The fact that a pointer can be null does not define the semantics of what a “null memref” means to me. The use of a unrealized_conversion_cast or other low level primitive to build a “memref” from a “null pointer” is in the same scope to me.

Sorry, I don’t actually know the semantics of poison, I went to check out the information and wrote some examples.But it’s very different from memref.null as defined here.

Let me show the whole lower process put.

matmul %1 % 2 % 3 %4 {...} : memref memref memref  memref

 ==>
// lower matmul precess
addr0 = memref.extract_aligned_pointer_as_index %1 
addr1 = memref.extract_aligned_pointer_as_index %2 
addr2 = memref.extract_aligned_pointer_as_index %3 
addr3 = memref.extract_aligned_pointer_as_index %4
...
hardware.intr.config_addr_ab(%addr0, %addr1)
hardware.intr.config_addr_dc(%addr2, %addr3)
==> 
// lower to LLVM dialect
 %176 = llvm.extractvalue %22[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
 %177 = llvm.ptrtoint %176 : !llvm.ptr to i64
 %178 = llvm.extractvalue %39[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
 %179 = llvm.ptrtoint %178 : !llvm.ptr to i64
 %180 = llvm.extractvalue %56[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
 %181 = llvm.ptrtoint %180 : !llvm.ptr to i64
 %182 = llvm.extractvalue %73[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
// The values of %178 and %179 are actually put into the rs1 and rs2 registers.
// Later it transforms to LLVM IR, in LLVM IR it is intrinsic, and then it 
// transforms to assembly language. I actually think this example is clear here.
 hardware.intr.config_addr_ab(%176, %178)
 hardware.intr.config_addr_dc(...,...)

You can see that null is inserted into the memref description.

 %7 = llvm.mlir.zero : !llvm.ptr
 %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
 %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 

The following program will output 0.It might be worth thinking about what to do when the user is orientated to the elements in this memref.

func.func @main() {
  %1 = memref_exp.null : memref<4x4xf32>
  %2 = memref.extract_aligned_pointer_as_index %1 : memref<4x4xf32> -> index
  // CHECK: 0
  vector.print %2 : index
  return
}

I must be blind, but on the snippet you show above this sentence I don’t see anything actually.
You start with

matmul %1 % 2 % 3 %4 {...} : memref memref memref  memref

OK so we have a Matmul operating on memref, I don’t see where a null pointer would come from here?

Then you lower the Matmul to some HW dialect-specific IR?

// lower matmul precess
addr0 = memref.extract_aligned_pointer_as_index %1 
addr1 = memref.extract_aligned_pointer_as_index %2 
addr2 = memref.extract_aligned_pointer_as_index %3 
addr3 = memref.extract_aligned_pointer_as_index %4
...
hardware.intr.config_addr_ab(%addr0, %addr1)
hardware.intr.config_addr_dc(%addr2, %addr3)

Finally in LLVM IR:

// lower to LLVM dialect
 %176 = llvm.extractvalue %22[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
 %177 = llvm.ptrtoint %176 : !llvm.ptr to i64
 %178 = llvm.extractvalue %39[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
 %179 = llvm.ptrtoint %178 : !llvm.ptr to i64
 %180 = llvm.extractvalue %56[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
 %181 = llvm.ptrtoint %180 : !llvm.ptr to i64
 %182 = llvm.extractvalue %73[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
// The values of %178 and %179 are actually put into the rs1 and rs2 registers.
// Later it transforms to LLVM IR, in LLVM IR it is intrinsic, and then it 
// transforms to assembly language. I actually think this example is clear here.
 hardware.intr.config_addr_ab(%176, %178)
 hardware.intr.config_addr_dc(...,...)

Now we just lowered the previous IR to LLVM descriptor manipulation, at this level anyway there is no memref left.

There is still no “null memref” involved here?

Sorry, I didn’t show that much at first either, and as you say, the key issue is actually that the subtle different is that we’re in a different level of abstraction and a different type system,is such an abstraction needed in memref dialect?

Honestly, for me, it’s needed here, and I think the community might need it too. There must actually be other ways to implement it, but it might be a bit of a hack for the whole project.

Sorry, I can’t make promises.

Sorry, I’m just now seeing your current reply.
It’s true that memref.null isn’t used here, because that’s the way it was done before, when bias was handled by manually allocating a block of space in the pass and filling it with 0 using linalg.fill.