[RFC] Remove tosa.fully_connected operator

Posting here for MLIR side visibility: RFC: Remove the FULLY_CONNECTED TOSA operator - TOSA - Discourse

The proposal eliminates tosa.fully_connected and instead proposes the use of tosa.transpose + tosa.matmul + tosa.add in its place. Using TensorFlow Lite as input dialect for example:

Current conversion:

module  {
  func @test_fullyconnected(%arg0: tensor<14x19xf32>, %arg1: tensor<28x19xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
    %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<14x19xf32>, tensor<28x19xf32>, tensor<28xf32>) -> tensor<14x28xf32>
    return %0 : tensor<14x28xf32>
  }
}

Proposed (raw):

module  {
  func @test_fullyconnected(%arg0: tensor<14x19xf32>, %arg1: tensor<28x19xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
    %0 = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
    %1 = "tosa.reshape"(%arg0) {new_shape = [1, 14, 19]} : (tensor<14x19xf32>) -> tensor<1x14x19xf32>
    %2 = "tosa.transpose"(%arg1, %0) : (tensor<28x19xf32>, tensor<2xi32>) -> tensor<19x28xf32>
    %3 = "tosa.reshape"(%2) {new_shape = [1, 19, 28]} : (tensor<19x28xf32>) -> tensor<1x19x28xf32>
    %4 = "tosa.matmul"(%1, %3) : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
    %5 = "tosa.add"(%4, %arg2) : (tensor<1x14x28xf32>, tensor<28xf32>) -> tensor<1x14x28xf32>
    %6 = "tosa.reshape"(%5) {new_shape = [14, 28]} : (tensor<1x14x28xf32>) -> tensor<14x28xf32>
    return %6 : tensor<14x28xf32>
  }
}

With a supporting optimization to enable the dialect form to express an implicit batchsize=1 for tosa.matmul, the synthetic 2D>3D>2D reshapes can also be eliminated:

module  {
  func @test_fullyconnected(%arg0: tensor<14x19xf32>, %arg1: tensor<28x19xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
    %0 = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
    %1 = "tosa.transpose"(%arg1, %0) : (tensor<28x19xf32>, tensor<2xi32>) -> tensor<19x28xf32>
    %2 = "tosa.matmul"(%arg0, %1) : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32>
    %3 = "tosa.add"(%2, %arg2) : (tensor<14x28xf32>, tensor<28xf32>) -> tensor<14x28xf32>
    return %3 : tensor<14x28xf32>
  }
}

This simplifies a few things with existing legalizations, removing what’s essentially a speciallization of the matmul op, and also removing the existing lack of broadcastability on the bias add, which the tosa.add would express now.

Impact:

  • Changes dialect op definition to express 2D or 3D as the permitted input ranks for matmul.
  • Changes to frontend legalizations in TensorFlow repo (Torch->TOSA for this op is currently TBD).

Thanks in advance for your feedback.

1 Like