Random number generation in MLIR?

Hello!

I’m just checking if there is any ongoing effort or ideas to represent the generation of random numbers in MLIR (e.g., initializing a random number generator with a seed, generating a random number, etc.). I couldn’t find anything in the documentation, the MLIR codebase or Discourse.

Generating a random number is an operation with specific side effects that might prevent many optimizations if they are not modeled properly. I’m particularly interested in finding a representation that is suitable for optimizations to reason about these side effects. For example, under some circumstances, we should be able to fuse or vectorize two loops even if one of them contains an operation to generate random numbers. I’m sure some subtleties should be taken into account, though. Anybody with expertise on this front? Any ideas?

Thanks!
Diego

I’m not aware of any ongoing efforts, but my experience is that it is best to split RNG generation operations into two parts: one stateful operation that produces a “key” and updates some internal state, and a pure operation that reads the key and generates an array of random numbers (or does other loopy things). The latter can be vectorized, fused etc. just fine, and the former is a trivial amount of computation so fusing it isn’t super important. If your’e running on a GPU then the former could even be run on the host, avoiding the kernel launch overhead.

Agreed with Sanjoy, the stateful HLO rng is one of the things I was least happy being made to implement. Having the side effect be an explicit state representation is much better (downside is you have to thread it around all uses, but global ordering is needed for fully reproducible runs so it doesn’t change what can be represented, and it is treated as a data dependency so may be more strictly maintained than needed but can’t recall us hitting an.issue with that coming from a higher representation).

I did an exploration a couple months ago on what Random Number generation could look like in MLIR. Overall I came to a few similar conclusions:

If we have an explicit way to capture/store state, then supporting various RNG algorithms is relatively trivial. They could eventually be lowered straight to the linalg dialect, and trivially fuse with other operations. This is particularly useful for operations like dropout as the random values can be generated on write only (no intermediate buffers).

Likely we would want to support generalized lowerings for the most common parallel RNGs, some hardware accelerators support these same RNGs as native operations so supporting a handful of the most common cases would allow consistent execution between accelerator and CPU/GPU.

Storing / retrieving state is the only significant issue I see. There is no equivalent behavior in existing MLIR dialects and adding support for it seems likely to be abused. Is there any plan for stateful variables within MLIR?

If threaded along, it is just a value and a stateful variable is not needed.

Ah okay. I think I understand what you mean. A lot of our work from the IREE side does work with explicit variables so we would need to handle that externally from our variable system. Feels feasible?

A good first step would be looking at how we would expect a random dialect to be laid out and how we would support lowering different random algorithms.

Thank you all for the replies! The suggestions make sense to me. Currently, I’m more interested in the RNG representation at loop (Affine/Scf) or Standard level than at graph level, although I think the considerations would be similar.

Regarding the stateful and pure ops and the threading of those, at loop/std level, are you referring to something like the example below? (I’m using iter_args but I guess the threading could also happen through memory).

func @random(%seed: i32) {
  %init_key = "random.init_rng"(%seed): (i32) -> i32

  %i_key = affine.for %i = 0 to 10 iter_args(%curr_key = %init_key) -> (i32) {
    %rnd, %next_key = "random.get_random_number"(%curr_key): (i32) -> (i32, i32)
    // %rnd use.
    affine.yield %next_key : i32
  }

  affine.for %j = 0 to 50 iter_args(%curr_key = %i_key) -> (i32) {
    %rnd, %next_key = "random.get_random_number"(%curr_key): (i32) -> (i32, i32)
    // %rnd use.
    affine.yield %next_key : i32
  }

  return
}

+1

Indeed, so both value and updated state is produced, and state is carried. That is nice in case you want multiple of these in parallel too (independent seeds) etc. Simple reasoning, clear dependencies, and yes could be done via memory but I’d probably defer that to lowering and then whether in registers, memory or special units can all depend on the lowering.

Agree, I like it! It models the RNG side effects as simple dependences, which is something we know how to reason about.

Agree. Sorry, I was just thinking out loud. Affine analysis doesn’t have support for loop-carried dependences through iter_args so we would have to use memory if we wanted to have this working out of the box with Affine. Indeed, using memory or SSA values is something that can be changed at any stage of the pipeline.

Regarding the vector version of random.get_random_number, I think we should keep the key scalar, right?. The input scalar key would be the one that will be used to generate the random number for the first vector lane. The output scalar key would be the resulting state from generating the random number for the last lane. Does this make sense? Example:

%rnd0, %next_key0 = "random.get_random_number"(%curr_key): (i32) -> (vector<4xi32>, i32) 

%rnd1, %next_key1 = "random.get_random_number"(%next_key0): (i32) -> (vector<4xi32>, i32) 

I think I may not have enough expertise to come up with a proposal that takes into account all the requirements for different abstraction levels, RNG scenarios and algorithms and HW with native support. However, I’m interested in having this working, investigating what is needed to enable fusion and vectorization at Affine level, and hopefully adding support for it. Would it make sense to implement a Random dialect with the basic ideas that we have discussed in this thread so that we can use it to build some expertise? Something simple and constrained to the basic example that I showed and that we can extend in the future, once we have more use cases. WDYT?

Thanks,
Diego