I wrote softmax in linalg as follows:
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, 0)>
module {
func.func @forward(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
%cst = arith.constant dense<0.000000e+00> : tensor<3x1xf32>
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x2xf32>) outs(%cst : tensor<3x1xf32>) {
^bb0(%in: f32, %out: f32):
%3 = math.exp %in : f32
%4 = arith.addf %3, %out : f32
linalg.yield %4 : f32
} -> tensor<3x1xf32>
%1 = tensor.empty() : tensor<3x2xf32>
%2 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %0 : tensor<3x2xf32>, tensor<3x1xf32>) outs(%1 : tensor<3x2xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%3 = math.exp %in : f32
%4 = arith.divf %3, %in_0 : f32
linalg.yield %4 : f32
} -> tensor<3x2xf32>
return %2 : tensor<3x2xf32>
}
}
I use the following pass to convert to LLVM IR:
mlir-opt --llvm-request-c-wrappers \
--linalg-bufferize --tensor-bufferize --arith-bufferize --arith-expand \
--convert-linalg-to-affine-loops \
-func-bufferize -finalizing-bufferize -convert-bufferization-to-memref \
--convert-to-llvm
mlir-translate --mlir-to-llvmir
Here is my LLVM IR:
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
@__constant_3x1xf32 = private constant [3 x [1 x float]] zeroinitializer
declare ptr @malloc(i64)
declare void @free(ptr)
define { ptr, ptr, i64, [2 x i64], [2 x i64] } @forward(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6) {
%8 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %0, 0
%9 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, ptr %1, 1
%10 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %9, i64 %2, 2
%11 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %10, i64 %3, 3, 0
%12 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, i64 %5, 4, 0
%13 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, i64 %4, 3, 1
%14 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %13, i64 %6, 4, 1
%15 = call ptr @malloc(i64 add (i64 ptrtoint (ptr getelementptr (float, ptr null, i64 3) to i64), i64 64))
%16 = ptrtoint ptr %15 to i64
%17 = add i64 %16, 63
%18 = urem i64 %17, 64
%19 = sub i64 %17, %18
%20 = inttoptr i64 %19 to ptr
%21 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %15, 0
%22 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %21, ptr %20, 1
%23 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %22, i64 0, 2
%24 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %23, i64 3, 3, 0
%25 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %24, i64 1, 3, 1
%26 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %25, i64 1, 4, 0
%27 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %26, i64 1, 4, 1
%28 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 1
%29 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 2
%30 = getelementptr float, ptr %28, i64 %29
call void @llvm.memcpy.p0.p0.i64(ptr %30, ptr @__constant_3x1xf32, i64 mul (i64 ptrtoint (ptr getelementptr (float, ptr null, i32 1) to i64), i64 3), i1 false)
br label %31
31: ; preds = %54, %7
%32 = phi i64 [ %55, %54 ], [ 0, %7 ]
%33 = icmp slt i64 %32, 3
br i1 %33, label %34, label %56
34: ; preds = %31
br label %35
35: ; preds = %38, %34
%36 = phi i64 [ %53, %38 ], [ 0, %34 ]
%37 = icmp slt i64 %36, 2
br i1 %37, label %38, label %54
38: ; preds = %35
%39 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, 1
%40 = mul i64 %32, 2
%41 = add i64 %40, %36
%42 = getelementptr float, ptr %39, i64 %41
%43 = load float, ptr %42, align 4
%44 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 1
%45 = add i64 %32, 0
%46 = getelementptr float, ptr %44, i64 %45
%47 = load float, ptr %46, align 4
%48 = call float @llvm.exp.f32(float %43)
%49 = fadd float %48, %47
%50 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 1
%51 = add i64 %32, 0
%52 = getelementptr float, ptr %50, i64 %51
store float %49, ptr %52, align 4
%53 = add i64 %36, 1
br label %35
54: ; preds = %35
%55 = add i64 %32, 1
br label %31
56: ; preds = %31
%57 = call ptr @malloc(i64 add (i64 ptrtoint (ptr getelementptr (float, ptr null, i64 6) to i64), i64 64))
%58 = ptrtoint ptr %57 to i64
%59 = add i64 %58, 63
%60 = urem i64 %59, 64
%61 = sub i64 %59, %60
%62 = inttoptr i64 %61 to ptr
%63 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %57, 0
%64 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %63, ptr %62, 1
%65 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %64, i64 0, 2
%66 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %65, i64 3, 3, 0
%67 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %66, i64 2, 3, 1
%68 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %67, i64 2, 4, 0
%69 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %68, i64 1, 4, 1
br label %70
70: ; preds = %94, %56
%71 = phi i64 [ %95, %94 ], [ 0, %56 ]
%72 = icmp slt i64 %71, 3
br i1 %72, label %73, label %96
73: ; preds = %70
br label %74
74: ; preds = %77, %73
%75 = phi i64 [ %93, %77 ], [ 0, %73 ]
%76 = icmp slt i64 %75, 2
br i1 %76, label %77, label %94
77: ; preds = %74
%78 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, 1
%79 = mul i64 %71, 2
%80 = add i64 %79, %75
%81 = getelementptr float, ptr %78, i64 %80
%82 = load float, ptr %81, align 4
%83 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 1
%84 = add i64 %71, 0
%85 = getelementptr float, ptr %83, i64 %84
%86 = load float, ptr %85, align 4
%87 = call float @llvm.exp.f32(float %82)
%88 = fdiv float %87, %86
%89 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %69, 1
%90 = mul i64 %71, 2
%91 = add i64 %90, %75
%92 = getelementptr float, ptr %89, i64 %91
store float %88, ptr %92, align 4
%93 = add i64 %75, 1
br label %74
94: ; preds = %74
%95 = add i64 %71, 1
br label %70
96: ; preds = %70
ret { ptr, ptr, i64, [2 x i64], [2 x i64] } %69
}
define void @_mlir_ciface_forward(ptr %0, ptr %1) {
%3 = load { ptr, ptr, i64, [2 x i64], [2 x i64] }, ptr %1, align 8
%4 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 0
%5 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 1
%6 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 2
%7 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 3, 0
%8 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 3, 1
%9 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 4, 0
%10 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 4, 1
%11 = call { ptr, ptr, i64, [2 x i64], [2 x i64] } @forward(ptr %4, ptr %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10)
store { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, ptr %0, align 8
ret void
}
; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: readwrite)
declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg) #0
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare float @llvm.exp.f32(float) #1
attributes #0 = { nocallback nofree nounwind willreturn memory(argmem: readwrite) }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
But when I call it by c++, I got wrong result. Here is my code:
extern "C" {
void _mlir_ciface_forward(memref::MemRefDescriptor<float, 2> *C,
memref::MemRefDescriptor<float, 2> *A);
}
int main() {
const int32_t M = 3;
const int32_t N = 2;
float matA[M][N] = {{1, 1}, {2, 1}, {3, 1}};
float matC[M][N];
std::array<int64_t, 2> aDim = {M, N};
std::array<int64_t, 2> cDim = {M, N};
memref::MemRef<float, 2> A((float *)matA, aDim);
memref::MemRef<float, 2> C((float *)matC, cDim);
_mlir_ciface_forward(&C.memRefDesc, &A.memRefDesc);
std::cout << "result:\n";
utility::printMatrix(C);
}
I used memref to encapsulate MemRefDescriptor,C’s result is incorrect. Looks like there is some offset. C should be a 3*2 matrix, when I output some more data, I find that the correct data is behind.Like this:
result:
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.5000 0.5000 |
| 0.7311 0.2689 |
| 0.8808 0.1192 |
Can anyone help, thanks!