Hi, I’m trying to lower the tensor to llvm, and I find some arguments seem redundant and not used.
The tensor<2x3xf32>
is being lowered to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
, and I’d like to know the role of each element.
I think that %arg1
is a 1D array for storing real data, and %arg3
and %arg4
represent the shape of the tensor. But I have no idea why %arg0
, %arg2
and %arg5,6
exist.
Also, I think %7 = llvm.getelementptr %1[6] : (!llvm.ptr) -> !llvm.ptr, f32
is a bit weird since %1 is zero, I wonder if accessing zero register via GEP operation is okay.
Any help or tips would be appreciated.
input .mlir
module {
func.func @forward(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> {
%ret = tensor.empty() : tensor<2x3xf32>
%idx0 = arith.constant 0 : index
%idx1 = arith.constant 0 : index
%const = tensor.extract %arg0 [%idx0, %idx1] : tensor<2x2xf32>
%ret2 = linalg.fill ins(%const : f32) outs(%ret : tensor<2x3xf32>) -> tensor<2x3xf32>
return %ret2 : tensor<2x3xf32>
}
}
my lowering script
../../build/bin/mlir-opt $1 \
--one-shot-bufferize=bufferize-function-boundaries \
--expand-strided-metadata \
--convert-linalg-to-parallel-loops \
--convert-scf-to-openmp \
--lower-affine \
--convert-math-to-llvm \
--finalize-memref-to-llvm \
--convert-scf-to-cf \
--convert-func-to-llvm \
--convert-openmp-to-llvm \
--convert-index-to-llvm \
--canonicalize \
--cse \
--reconcile-unrealized-casts
output .mlir
module {
llvm.func @malloc(i64) -> !llvm.ptr
llvm.func @forward(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
%0 = llvm.mlir.constant(64 : index) : i64
%1 = llvm.mlir.zero : !llvm.ptr
%2 = llvm.mlir.constant(0 : index) : i64
%3 = llvm.mlir.constant(2 : index) : i64
%4 = llvm.mlir.constant(1 : index) : i64
%5 = llvm.mlir.constant(3 : index) : i64
%6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%7 = llvm.getelementptr %1[6] : (!llvm.ptr) -> !llvm.ptr, f32
%8 = llvm.ptrtoint %7 : !llvm.ptr to i64
%9 = llvm.add %8, %0 : i64
%10 = llvm.call @malloc(%9) : (i64) -> !llvm.ptr
%11 = llvm.ptrtoint %10 : !llvm.ptr to i64
%12 = llvm.sub %0, %4 : i64
%13 = llvm.add %11, %12 : i64
%14 = llvm.urem %13, %0 : i64
%15 = llvm.sub %13, %14 : i64
%16 = llvm.inttoptr %15 : i64 to !llvm.ptr
%17 = llvm.insertvalue %10, %6[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%18 = llvm.insertvalue %16, %17[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%19 = llvm.insertvalue %2, %18[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%20 = llvm.insertvalue %3, %19[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%21 = llvm.insertvalue %5, %20[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%22 = llvm.insertvalue %5, %21[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%23 = llvm.insertvalue %4, %22[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%24 = llvm.getelementptr %arg1[%arg2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%25 = llvm.mul %arg5, %2 : i64
%26 = llvm.mul %arg6, %2 : i64
%27 = llvm.add %25, %26 : i64
%28 = llvm.getelementptr %24[%27] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%29 = llvm.load %28 : !llvm.ptr -> f32
omp.parallel {
omp.wsloop {
omp.loop_nest (%arg7, %arg8) : i64 = (%2, %2) to (%3, %5) step (%4, %4) {
%30 = llvm.intr.stacksave : !llvm.ptr
llvm.br ^bb1
^bb1: // pred: ^bb0
%31 = llvm.mul %arg7, %5 : i64
%32 = llvm.add %31, %arg8 : i64
%33 = llvm.getelementptr %16[%32] : (!llvm.ptr, i64) -> !llvm.ptr, f32
llvm.store %29, %33 : f32, !llvm.ptr
llvm.intr.stackrestore %30 : !llvm.ptr
llvm.br ^bb2
^bb2: // pred: ^bb1
omp.yield
}
omp.terminator
}
omp.terminator
}
llvm.return %23 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
}
}