With the new polynomial dialect getting settled in, I wanted to start work on a pass that replaces transcendental math ops (like sin or relu) with polynomial approximations.
I’m looking for early feedback on a design sketch. I plan to start working on a prototype in the HEIR project (GitHub - google/heir: A compiler for homomorphic encryption) until the design kinks are worked out and it’s ready to upstream.
In its simplest form, the pass would scan for individual ops with an ApproximableByPolynomials trait, and run a solver to replace them by polynomial evaluations of their input. E.g.,
A more complex invocation might replace combinations of multiple ops into a single polynomial evaluation, but I haven’t quite figured out how the pass would determine what ops to combine.
The main considerations are:
What is going to compute the polynomial approximation? I plan to add a sophisticated (dependency-heavy) solver out of tree, and a simple, dependency-free solver in tree, and I need a mechanism for the out-of-tree solver to replace the default solver when configuring the pass.
Why do the approximations need to be computed in the compiler? In our out-of-tree project, which involves a limited computational model and lots of cryptography, there is a trade-off to be made with respect to the choice of cryptographic parameters (larger parameters slow the program down but add more security), the degree of polynomial approximations that can be evaluated (you need larger parameters to evaluate large-degree polynomials), and the accuracy of the result (need large degree to get high accuracy). We expect to hide much of this from the user by having the user specify only security requirements and accuracy tolerance, implying the compiler needs to pick the best polynomial approximation dynamically.
How to choose a domain? I don’t have a good answer for this. Polynomial approximations only make sense on a limited domain. In our out-of-tree project we’re hoping to use use an analysis pass to determine bounds on the input ranges of ops that are approximated. The example above has no indication on the sin op that its domain is [-pi/2, pi/2], and I don’t know if there is an analogous analysis to IntegerRangeAnalysis for fixed/floating point ops.
What basis to use? The above example shows the polynomial approximation in the monomial basis, but it’s generally more numerically stable and efficient to do polynomial approximations in the Chebyshev basis. It may be worthwhile to store both the Chebyshev coefficients as well as the monomial basis form on the evaluate op, so that lowerings have access to that information, though it can be recomputed on the fly relatively easily.
How would this op lower?evaluate would lower to a loop to evaluate the polynomial. This can be done with the clenshaw method, which would support both monomial and Chebyshev bases.
Some questions I have for the community:
What is the right way to override an in-tree solver with an out-of-tree solver? I thought I would just make an abstract base class for the solver API with pure virtual methods, implement them out of tree, and have a pass configuration option to set the solver instance. I’m not sure how one would attach this to an op via existing interface mechanisms, and because I hope to apply this to combinations of ops, it seems like a bad design choice to limit it to a single op by tying it to the op trait/interface mechanism.
Is there appetite for a sophisticated in-tree solver? The classic algorithm is the barycentric Remez exchange, and I’d be willing to implement it in-tree if folks think it would be useful. If not, what sort of in-tree solver would be worthwhile to implement? A degree-1 polynomial approximation solver would be quite simple, but probably useless.
How should I handle the problem of identifying a suitable domain of approximation?
I should add, polynomial.evaluate could be useful outside the context of polynomial approximations as well, in which case the polynomial could also be an SSA value instead of an attribute. I was considering having two different ops, polynomial.evaluate for SSA-value inputs and polynomial.approximate_eval for the statically-known polynomial and SSA-valued argument. Or I could put these together into one op that uses the custom parser window dressing to hide that they are stored in separate fields, like the indexing [] syntax on memref/affine ops.
In general an easy solution for this is to workaround the problem by not making the pass configurable. That is you can implement your entire transformation as a “utility” that is configured by injecting the solver in some ways.
The “pass” is just a thin wrapper that setup the solver and invoke the standalone utility function (the pass can be nothing more than a few lines of code).
That way the upstream pass is -polynomial-approx-with-naive-solver and downstream you just have to implement another pass -polynomial-approx-with-complex-solver (I’ll let you bike-shed the name :))
Another way (more involved) of solving the pluggablity aspect can be to rely on an attribute interface: the pass could look up for the enclosing module and check if it has polynomial.approx-solver attribute which implement the PolynomialSolverAttrInterface and if so use it to get the solver implementation and otherwise fallback to the naive one.
That way out-of-tree user can inject any solver by attaching their own attribute to the module.
Is such a solver useful for anything else? (we have the presburger library for example that was added because it is useful for quite a few application domain).
Otherwise, is a “sophisticated” solver a lot of code? Would it stay contained within the polynomial dialect and not affect anyone else?
It would seem you currently let the user specify it in the attribute, which seems reasonable. Perhaps I’m missing it, but is this asking for more general case?
I think you’ve answered your own question about (2) and Mehdi’s suggestion on (1) is good.
Can’t every sin/cos polynomial approximation be “moved” to the domain [-pi/2, pi/2]? Perhaps you can encode this to just give the quadrant (modulo(2pi)). Other functions will probably have other types of symmetry that you can exploit.
Would it make sense for this to be a new type? The recomputation becomes part of the lowering into native types.
Initially I’d say this would hinder progress more than help. You may want to dissociate the development of such a solver (complex beasts in nature) from the rest of MLIR.
Also, having a way for other efforts to add their own solvers would be a good framework to have upstream, so this can help in more ways than just one.
I think it depends on the transforms that are done to the polynomials. Initial values to be provided by the user are probably fine, but as the solver works its way, it may (and I’m guessing here) start to change the domains when it converts to other functions (ex. derivation/integration, change of variables, simplification).
This may be as trivial as to have a tuple of values (for N-dim ranges), but it could be harder…
I will float the following idea for extensibility. We have delayed registration for interfaces. We can have ops define a promise of implementing a PolynomialApproximationOpInteface that is fulfilled by the simple solver implementation upstream (registered by mlir-opt) and by the complex solver implementation downstream (different implementation registered by downstream-opt or another relevant registrar). This will technically work, but I’m not sure if this is the direction we want to take with interfaces.
I think this is a great idea. For ML workloads, often times the hand-written .ll implementing the same (e.g. polynomial approximation of sing/exp) is patched in etc. Having a way to do it at MLIR level dialect will be a cleaner option perhaps.