Thanks, Mehdi for starting this conversation. On our project, we’ve struggled with the lack of high level primitives in this area. Given that there wasn’t a conversation going on about it, it causes common infra/ops to get developed outside of MLIR with no place to anchor those things that make sense to go core-ward. Also, having lived through a number of rounds of debates on opinionated op-sets, I appreciate the tone you have set with respect to creating a space to work this out. While my most recent background is with XLA/HLO, I’ve worked with a number of others at different layers and would like the chance to distill out some common ideas and infra. There are a lot of priors on this, but in my experience, we have to re-learn the best way to express them in MLIR, which gives us a lot more flexibility than older representations (with all of the positives and negatives that come from flexibility).
Jumping a bit into the technical discussion, I think that one of the issues with HLO specifically is that it combines opinions on a few different axes that don’t necessarily need to be combined in new work.
- Static shapes (versus a full story for dynamic)
- Implicit vs explicit broadcast semantics
- Functional control flow (vs CFG)
- Preference for “primitive” vs high-level math ops
- Preference for explicit reductions vs high-level aggregate ops
When bundled into a single-dialect and codegen path, all of these opinions get taken at once, and there are reasonable arguments for alternatives to each (and others). Based on the discussions on the mailing list and our experience of late, it is actually #1 (and by extension #2, since you have to handle that) which benefit the most right now from some common infra, and I’m not sure that needs to be tied to op story (and can support different sets of high level ops).
In our work for #3-5, we end up taking different opinions about these at different parts of the pipeline, and I think that MLIR lets us do this: we don’t necessarily need to converge on one “best” set of “dnn math ops”. Since there are a small set of canonical ways to express them, we might opt to define them at multiple levels (i.e. have high level “nn” ops like softmax and relu and what they are implemented in terms of). My opinion for this working group would be to get the lowest common level specified in MLIR itself, possibly leaving the higher levels to frontend specific framework. In practice, this has worked reasonably well for TensorFlow/XLA.
For dynamic shapes specifically, there is a lot bound up in that with respect to a “source-level” representation. Discourse isn’t letting me post links to github, where we are working on a sample shape
dialect, but here is the general direction of types/ops we think would be helpful. I suspect that we need to define some new shape-related types, corresponding ops and then that gets into broadcasting quickly, which perhaps should be considered at the same level:
%shp0 = shape.get_ranked_shape %arg0 : tensor<?x5xf32> -> !shape.ranked_shape<?x5xindex>
%shp1 = shape.get_ranked_shape %arg1 : tensor<5xf32> -> !shape.ranked_shape<5xindex>
%dim0 = "compute_broadcasted_shape"(%shp0, %shp1) : (!shape.ranked_shape<?x5xindex>, !shape.ranked_shape<5xindex>) -> (index) // Can be codegened directly
%1 = shape.ranked_broadcast_in_dim %arg1, %dim0 { broadcast_dimensions = dense<1> : tensor<1xi64> } : tensor<5xf32> -> tensor<?x5xf32>
To wrap up, I think it would be great to have the work group break the problem up a bit and focus semi-independently on:
- Dynamic shape related infra (probably with sub-points for high-level representations and things like shape inference)
- Structural primitives (reductions, control flow, etc)
- Small set of primitive math ops
It doesn’t take very complicated examples to meaningfully need resolution on each of these.
Examples:
Examples:
Add op broadcasting types
// op-carried
%24 = "xla_hlo.add"(%23, %4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<10xf32>) -> tensor<?x10xf32>
// explicit
%8 = "xla_hlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<16xf32>) -> tensor<?x16xf32>
%9 = xla_hlo.add %7, %8 : tensor<?x16xf32>
tanh vs sigmoid
// TANH
%10 = "xla_hlo.tanh"(%9) : (tensor<?x16xf32>) -> tensor<?x16xf32>
// SIGMOID
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
%13 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[-1, 16]> : tensor<2xi64>} : (tensor<f32>) -> tensor<?x16xf32>
%14 = xla_hlo.mul %12, %13 : tensor<?x16xf32>
%15 = "xla_hlo.tanh"(%14) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%16 = xla_hlo.mul %15, %13 : tensor<?x16xf32>
%17 = xla_hlo.add %16, %13 : tensor<?x16xf32>
XLA softmax
%36 = "xla_hlo.reduce"(%35, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%45 = xla_hlo.max %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%45) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%37 = "xla_hlo.broadcast_in_dim"(%35) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>) -> tensor<?x10xf32>
%38 = "xla_hlo.broadcast_in_dim"(%36) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>) -> tensor<?x10xf32>
%39 = xla_hlo.sub %37, %38 : tensor<?x10xf32>
%40 = "xla_hlo.exp"(%39) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%41 = "xla_hlo.reduce"(%40, %2) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%45 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%45) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
“simplified” mlp (tanh, no softmax)
func @predict_tanh_no_softmax(%arg0: tensor<?x16xf32>) -> tensor<?x10xf32> attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 16 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.signature.is_stateful} {
%0 = flow.variable.load @h2_bias : tensor<16xf32>
%1 = flow.variable.load @out_bias : tensor<10xf32>
%2 = flow.variable.load @h1_bias : tensor<16xf32>
%3 = flow.variable.load @h2_weights : tensor<16x16xf32>
%4 = flow.variable.load @out_weights : tensor<16x10xf32>
%5 = flow.variable.load @h1_weights : tensor<16x16xf32>
%6 = "xla_hlo.dot"(%arg0, %5) : (tensor<?x16xf32>, tensor<16x16xf32>) -> tensor<?x16xf32>
%7 = "xla_hlo.add"(%6, %2) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x16xf32>, tensor<16xf32>) -> tensor<?x16xf32>
%8 = "xla_hlo.tanh"(%7) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%9 = "xla_hlo.dot"(%8, %3) : (tensor<?x16xf32>, tensor<16x16xf32>) -> tensor<?x16xf32>
%10 = "xla_hlo.add"(%9, %0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x16xf32>, tensor<16xf32>) -> tensor<?x16xf32>
%11 = "xla_hlo.tanh"(%10) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%12 = "xla_hlo.dot"(%11, %4) : (tensor<?x16xf32>, tensor<16x10xf32>) -> tensor<?x10xf32>
%13 = "xla_hlo.add"(%12, %1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<10xf32>) -> tensor<?x10xf32>
%14 = "xla_hlo.tanh"(%13) : (tensor<?x10xf32>) -> tensor<?x10xf32>
return %14 : tensor<?x10xf32>
}
full mlp
func @predict(%arg0: tensor<?x16xf32>) -> tensor<?x10xf32> attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 16 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.signature.is_stateful} {
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
%1 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = flow.variable.load @h2_bias : tensor<16xf32>
%4 = flow.variable.load @out_bias : tensor<10xf32>
%5 = flow.variable.load @h1_bias : tensor<16xf32>
%6 = flow.variable.load @h2_weights : tensor<16x16xf32>
%7 = flow.variable.load @out_weights : tensor<16x10xf32>
%8 = flow.variable.load @h1_weights : tensor<16x16xf32>
%9 = "xla_hlo.dot"(%arg0, %8) : (tensor<?x16xf32>, tensor<16x16xf32>) -> tensor<?x16xf32>
%10 = "xla_hlo.add"(%9, %5) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x16xf32>, tensor<16xf32>) -> tensor<?x16xf32>
%11 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[-1, 16]> : tensor<2xi64>} : (tensor<f32>) -> tensor<?x16xf32>
%12 = xla_hlo.mul %10, %11 : tensor<?x16xf32>
%13 = "xla_hlo.tanh"(%12) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%14 = xla_hlo.mul %13, %11 : tensor<?x16xf32>
%15 = xla_hlo.add %14, %11 : tensor<?x16xf32>
%16 = "xla_hlo.dot"(%15, %6) : (tensor<?x16xf32>, tensor<16x16xf32>) -> tensor<?x16xf32>
%17 = "xla_hlo.add"(%16, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x16xf32>, tensor<16xf32>) -> tensor<?x16xf32>
%18 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[-1, 16]> : tensor<2xi64>} : (tensor<f32>) -> tensor<?x16xf32>
%19 = xla_hlo.mul %17, %18 : tensor<?x16xf32>
%20 = "xla_hlo.tanh"(%19) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%21 = xla_hlo.mul %20, %18 : tensor<?x16xf32>
%22 = xla_hlo.add %21, %18 : tensor<?x16xf32>
%23 = "xla_hlo.dot"(%22, %7) : (tensor<?x16xf32>, tensor<16x10xf32>) -> tensor<?x10xf32>
%24 = "xla_hlo.add"(%23, %4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<10xf32>) -> tensor<?x10xf32>
%25 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[-1, 10]> : tensor<2xi64>} : (tensor<f32>) -> tensor<?x10xf32>
%26 = xla_hlo.mul %24, %25 : tensor<?x10xf32>
%27 = "xla_hlo.tanh"(%26) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%28 = xla_hlo.mul %27, %25 : tensor<?x10xf32>
%29 = xla_hlo.add %28, %25 : tensor<?x10xf32>
%30 = "xla_hlo.reduce"(%29, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%35 = xla_hlo.max %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%35) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%31 = "xla_hlo.sub"(%29, %30) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<?xf32>) -> tensor<?x10xf32>
%32 = "xla_hlo.exp"(%31) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%33 = "xla_hlo.reduce"(%32, %2) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%35 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%35) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%34 = "xla_hlo.div"(%32, %33) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<?xf32>) -> tensor<?x10xf32>
return %34 : tensor<?x10xf32>
}