The python utility that translates functions into MLIR code does not seem to produce the correct code. Consider the following script, which builds a model and then converts it to MLIR:
# Shapes, types, and sizes
data_shape = [3,4]
batch_size = 1
data_type = tf.dtypes.int32
io_shape = data_shape.copy().insert(0,batch_size)
# Model
x_in = tf.keras.Input(shape=data_shape,batch_size=batch_size,dtype=data_type)
y_in = tf.keras.Input(shape=data_shape,batch_size=batch_size,dtype=data_type)
z_out = tf.keras.layers.Add()([x_in,y_in])
keras_model = tf.keras.Model(inputs=[x_in,y_in], outputs=[z_out])
# Convert to MLIR
@tf.function
def myfun(x_in,y_in):
return keras_model([x_in,y_in])
cf = myfun.get_concrete_function(tf.TensorSpec(io_shape, data_type),tf.TensorSpec(io_shape, data_type))
mlir_code = tf.mlir.experimental.convert_function(cf)
print(mlir_code)
In its output (provided below without its non-essential attributes), the tensor shape seems to have been lost, and I don’t understand why:
module attributes {
func @__inference_myfun_23(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
}
In a few cases, the shape is correctly syntehsized, but when it happens TF also gives me a warning (so the output is correct only if my input is non-sensical…). To be clear, I was expecting that in all tensors the shape is not <*xi32>, but <1x3x4xi32>.