I have a normal TOSA mlir file with no any element-wise op. It describes a " conv2d + linear + linear "
neural network model :
module attributes {torch.debug_module_name = "simple"} {
func.func @forward(%arg0: tensor<1x3x16x16xf32>) -> tensor<1x8x16x4xf32> {
%0 = "tosa.const"() <{value = dense<"0x2E068BBCED3DD7BCF309EABD77E9EF3CF43ECD3DFEF7C4BCAC13F2BCFCC2973D30CB82BDF97143BE920C293E086004BE5C88E6BD1755CC3B754D193D7FF68C3DE81B3A3C4DBA24BE9BD4B2BDE4C01F3E1EECF03D9867B0BD5EBB2DBE5CB9B5BDFC531E3D85570CBE46E6323ED4B2B9BBD50506BEBA8307BE1DAFC93DB66DAC3DF8A4E339A88EECBC7D06C73CB29AC33D01B615BD083C283EDA5A353ECA5501BD871307BCC9D4A93A78EC9BBD58B2D3BD731935BE27761D3E7A7C2D3E8C24D73DB53229BE054D9ABD8ED6823DE1BED53D926C35BE9728CBBDFC20423E2EB0CE3CE65622BE8D12EE3DAFFF293E64B6C43D62A7A6BD7B0601BE78A108BEA5E1A03DBE1420BE84E917BEC5EDDDBB364037BEB62453BD1C5728BD7D1A77BD54EA233E3A6748BD3AA8BA3DB6CEFF3D57E0BEBD040312BE45200D3EEC4A2FBEC7C382BCD2A9293ED477503D73D008BD34AAF0BB7D6B15BE3B5F233DF2A0093DE1F42DBDD7BED33D9617303EEE8CEBBD8D7F2CBDB29C1FBED143733D244C833D8C0598BDC479F23DB510EABDA1CE9D3D45AE293DA6DC06BECC4A14BE7ED732BE6E57C6BD9B15DC3D6528133E3BEB543D80422DBD9B04383E104D2FBEE0EE07BE2B4D193E3AF2D03D9BEA5B3DDBED71BDF7FCA3BD9EA68CBD65D4133EABB724BE6F2A823D5E282EBEF3A2CC3D153D57BD601F25BD4899DB3D5D8C273D73068ABDA1F633BE019BEDBD87E782BDE4D00CBE548B1D3E4510D43DBFF58BBD1D07233BC61E133D1A0A3D3E2E9DEF3D3A0BA73C7233073E75BFB73DFA0043BC1DD6C33DF26705BEC3A706BECF1A04BE010E1A3D64BE363E301A233E9CDC6A3DB992133B97C70DBE4989BB3D3F1C2BBE68FB993CE034E1BD15ABD23DE2AEB93CC58E2BBD69A8CE3DB72EBF3DB7D7C53DA8C7DABDA0603BBE2ED043BE6370133E8C75B43CB76A17BE1DB47EBC4ED84E3DB68370BDCB8034BE0956253E80B8313DB0FF8C3D83C813BD80B45B3C7C7ED2BD8EBD83BD5EB0B6BD0FA283BC8F7E1F3ED400CF3DCBE87ABD3ED0243EBF49273E2F41DCBBB8898B3C914DC43DD3F713BE1B69093E33E42BBEC0A442BE8ABFAC3DD8F137BBC8EDE63B16700EBEE01CDD3D7A24A73DAE02CF3D4275D2BD7EAE173E61AD26BE46CCF3BD0490013E9B688BBDA9393DBD33504B3DDA2735BB8B27423E391B79BD81A305BE13A9873D2DC4993C689534BE52BA703D"> : tensor<8x3x3x3xf32>}> : () -> tensor<8x3x3x3xf32>
%1 = "tosa.const"() <{value = dense<"0xD4DA02BE3EAC34BEB22279BEA8966DBEBECF0F3E72417D3E40FF21BD787D4C3E9CFD59BEE47D67BE50640D3E1052413DC21506BE084418BEAE684CBE5C5D37BE6050E1BDEE7D493EC090403D20EDEC3DD0A83BBE10BB643E14322E3E2053023EA260473E1013453DC6F33EBE96303B3E546BCC3D6837123DE65B1F3EE892933DC08246BDC853AB3DA46A5CBE804CAFBD8AE619BE688A4ABD188B8A3D585B283DA0E52DBE12A4063EC06E5DBEE01954BD502E5F3D70E8713EE83D58BEF492E83DE8A9583D28AA253D946596BD2090B3BDECB2513E00A0D0BA086145BEDA2F7F3E6C66593EEA452ABEE8E038BEF09C94BCB0A1C53C9038573D6C40A73DB054A53C0808E3BD80FE62BB94BD16BE12B378BE746BEEBD0C9AC6BD8CCF853DC081983D18B16D3E3E11443E6075A3BD7A715BBEE682453EF4155F3E940A443EFA5D493E968053BEC0B9C93BCC97763E90862DBEF09677BE64F3C3BD50EE5FBE5020E7BC8E01593EEC2387BD4C763F3E2030773D98F9363E30D01F3EF0D6853D60A8903C6A121E3EC084323CC0FC693DDECB143E7C72143E940D2E3EC08AD03DBCD9153E402946BD94D92FBE7CD2063E4E03633E2C3B9BBD1A55273EA0B0693E3891B93D22457F3EA413873DC42F743EDEBD00BEC44AD5BD54724B3E66453F3E90CCBDBC88E26FBD6EDD1CBE72F601BE948483BD505F7EBD980F9A3D98F7E9BDA0CB26BC"> : tensor<1x16x8xf32>}> : () -> tensor<1x16x8xf32>
%2 = "tosa.const"() <{value = dense<[[[0.0987183302, -0.137849271, 0.203331098, 0.105150819], [0.29641974, 1.364540e-01, -0.339602977, -0.199884921], [-0.341286719, -0.282808691, -0.131707594, 0.26574105], [0.168666959, -0.25722453, -0.195234984, -0.22690931], [0.030892713, 0.152626127, 0.180179372, -0.198123857], [0.211126029, 0.061977107, -0.154441521, -0.0343345925], [-0.212056011, 0.139302328, 0.295544952, -0.0595564879], [-0.0775878354, -0.25885573, 0.219060048, -0.174453765]]]> : tensor<1x8x4xf32>}> : () -> tensor<1x8x4xf32>
%3 = "tosa.const"() <{value = dense<1.000010e+00> : tensor<8x1x1xf32>}> : () -> tensor<8x1x1xf32>
%4 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%5 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
%6 = "tosa.const"() <{value = dense<[-0.0717707872, -0.0423891321, 0.00969634484, -0.0271288622, -0.130245402, -0.150103226, -0.0827921405, -0.0530781448]> : tensor<8xf32>}> : () -> tensor<8xf32>
%7 = "tosa.const"() <{value = dense<[-0.132966399, -0.10632199, -6.313750e-02, 0.111462414, -0.0230877101, -2.71707773E-4, -0.102463722, -0.227972597]> : tensor<8xf32>}> : () -> tensor<8xf32>
%8 = "tosa.const"() <{value = dense<[-0.121429406, 0.0909240693, 0.0867559984, 0.169947132]> : tensor<4xf32>}> : () -> tensor<4xf32>
%9 = "tosa.transpose"(%arg0, %5) : (tensor<1x3x16x16xf32>, tensor<4xi32>) -> tensor<1x16x16x3xf32>
%10 = "tosa.conv2d"(%9, %0, %6) <{dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>}> : (tensor<1x16x16x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<1x16x16x8xf32>
%11 = "tosa.transpose"(%10, %4) : (tensor<1x16x16x8xf32>, tensor<4xi32>) -> tensor<1x8x16x16xf32>
%12 = "tosa.rsqrt"(%3) : (tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%13 = "tosa.mul"(%11, %12) <{shift = 0 : i32}> : (tensor<1x8x16x16xf32>, tensor<8x1x1xf32>) -> tensor<1x8x16x16xf32>
%14 = "tosa.clamp"(%13) <{max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<1x8x16x16xf32>) -> tensor<1x8x16x16xf32>
%15 = "tosa.reshape"(%14) <{new_shape = array<i64: 1, 128, 16>}> : (tensor<1x8x16x16xf32>) -> tensor<1x128x16xf32>
%16 = "tosa.matmul"(%15, %1) : (tensor<1x128x16xf32>, tensor<1x16x8xf32>) -> tensor<1x128x8xf32>
%17 = "tosa.reshape"(%16) <{new_shape = array<i64: 1, 8, 16, 8>}> : (tensor<1x128x8xf32>) -> tensor<1x8x16x8xf32>
%18 = "tosa.add"(%17, %7) : (tensor<1x8x16x8xf32>, tensor<8xf32>) -> tensor<1x8x16x8xf32>
%19 = "tosa.clamp"(%18) <{max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<1x8x16x8xf32>) -> tensor<1x8x16x8xf32>
%20 = "tosa.reshape"(%19) <{new_shape = array<i64: 1, 128, 8>}> : (tensor<1x8x16x8xf32>) -> tensor<1x128x8xf32>
%21 = "tosa.matmul"(%20, %2) : (tensor<1x128x8xf32>, tensor<1x8x4xf32>) -> tensor<1x128x4xf32>
%22 = "tosa.reshape"(%21) <{new_shape = array<i64: 1, 8, 16, 4>}> : (tensor<1x128x4xf32>) -> tensor<1x8x16x4xf32>
%23 = "tosa.add"(%22, %8) : (tensor<1x8x16x4xf32>, tensor<4xf32>) -> tensor<1x8x16x4xf32>
return %23 : tensor<1x8x16x4xf32>
}
}
When I lower it to linalg with:
mlir-opt \
--pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-tensor, tosa-to-arith, tosa-to-linalg))" \
tosa_opt.mlir -o linalg_tensor.mlir \
--mlir-print-ir-after-all 2>&1 | cat > before_linalg_intermediate.mlir
It failed:
tosa_opt.mlir:16:11: error: 'tosa.reshape' op Cannot reshape 8 elements into 1
%13 = "tosa.mul"(%11, %12) <{shift = 0 : i32}> : (tensor<1x8x16x16xf32>, tensor<8x1x1xf32>) -> tensor<1x8x16x16xf32>
^
tosa_opt.mlir:16:11: note: see current operation: %27 = "tosa.reshape"(%25) <{new_shape = array<i64>}> : (tensor<8x1x1xf32>) -> tensor<f32>
Then I exchange the order of lowering pass:
mlir-opt \
--pass-pipeline="builtin.module(func.func(tosa-to-tensor, tosa-to-arith, tosa-to-linalg,tosa-to-linalg-named))" \
tosa_opt.mlir -o linalg_tensor.mlir \
--mlir-print-ir-after-all 2>&1 | cat > before_linalg_intermediate.mlir
The error changes:
tosa_opt.mlir:13:10: error: 'tosa.conv2d' op attribute 'pad' failed to satisfy constraint: i64 dense array attribute with exactly 4 elements
%1 = "tosa.conv2d"(%0, %cst, %cst_5) {dilation = [1, 1], pad = [ 1, 1, 1, 1], stride =[ 1, 1]} : (tensor<1x16x16x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<1x16x16x8xf32>
^
tosa_opt.mlir:13:10: note: see current operation: %10 = "tosa.conv2d"(%9, %0, %6) {dilation = [1, 1], pad = [1, 1, 1, 1], stride = [1, 1]} : (tensor<1x16x16x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<1x16x16x8xf32>
I’m bothered in this TOSA->Linalg lowering process. How can I lower this TOSA mlir? Thanks!