As I find myself hacking out a relu recognizer in the sparsifier to get some PyTorch models working for sparse tensor inputs, I started to wonder why we don’t have a proper min/max/abs and thus relu recognition as an MLIR pass? (or at least I could not find one; we do have the opposite arith-expand
that maps these higher level ops back to lower level ops to enable lowering to llvm).
Pretty much every compiler I ever worked on benefited greatly from a unified analyzer that extracts min/max/abs (and others) from conditional constructs. The rewriting is pretty straightforward, with the caveat that for floating-point, care has to be taken to preserve the semantics around special cases, such as +0, -0, NaN, Inf etc.
I don’t have the bandwidth currently to start this side-project on doing this in a general useful way (hence the quick ReLu(x) solution to find Max(x,0) alluded to above), but this seems like a really fun starter project for somebody that wants to contribute to MLIR.
Thoughts?