Why get wrong result when call MLIR function by c++?

I wrote softmax in linalg as follows:

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, 0)>
module {
  func.func @forward(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
    %cst = arith.constant dense<0.000000e+00> : tensor<3x1xf32>
    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x2xf32>) outs(%cst : tensor<3x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %3 = math.exp %in : f32
      %4 = arith.addf %3, %out : f32
      linalg.yield %4 : f32
    } -> tensor<3x1xf32>
    %1 = tensor.empty() : tensor<3x2xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %0 : tensor<3x2xf32>, tensor<3x1xf32>) outs(%1 : tensor<3x2xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %3 = math.exp %in : f32
      %4 = arith.divf %3, %in_0 : f32
      linalg.yield %4 : f32
    } -> tensor<3x2xf32>
    return %2 : tensor<3x2xf32>
  }
}

I use the following pass to convert to LLVM IR:

mlir-opt --llvm-request-c-wrappers \
	--linalg-bufferize --tensor-bufferize --arith-bufferize --arith-expand \
	--convert-linalg-to-affine-loops \
	-func-bufferize -finalizing-bufferize -convert-bufferization-to-memref \
	--convert-to-llvm
mlir-translate --mlir-to-llvmir

Here is my LLVM IR:

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"

@__constant_3x1xf32 = private constant [3 x [1 x float]] zeroinitializer

declare ptr @malloc(i64)

declare void @free(ptr)

define { ptr, ptr, i64, [2 x i64], [2 x i64] } @forward(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6) {
  %8 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %0, 0
  %9 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, ptr %1, 1
  %10 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %9, i64 %2, 2
  %11 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %10, i64 %3, 3, 0
  %12 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, i64 %5, 4, 0
  %13 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, i64 %4, 3, 1
  %14 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %13, i64 %6, 4, 1
  %15 = call ptr @malloc(i64 add (i64 ptrtoint (ptr getelementptr (float, ptr null, i64 3) to i64), i64 64))
  %16 = ptrtoint ptr %15 to i64
  %17 = add i64 %16, 63
  %18 = urem i64 %17, 64
  %19 = sub i64 %17, %18
  %20 = inttoptr i64 %19 to ptr
  %21 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %15, 0
  %22 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %21, ptr %20, 1
  %23 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %22, i64 0, 2
  %24 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %23, i64 3, 3, 0
  %25 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %24, i64 1, 3, 1
  %26 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %25, i64 1, 4, 0
  %27 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %26, i64 1, 4, 1
  %28 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 1
  %29 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 2
  %30 = getelementptr float, ptr %28, i64 %29
  call void @llvm.memcpy.p0.p0.i64(ptr %30, ptr @__constant_3x1xf32, i64 mul (i64 ptrtoint (ptr getelementptr (float, ptr null, i32 1) to i64), i64 3), i1 false)
  br label %31

31:                                               ; preds = %54, %7
  %32 = phi i64 [ %55, %54 ], [ 0, %7 ]
  %33 = icmp slt i64 %32, 3
  br i1 %33, label %34, label %56

34:                                               ; preds = %31
  br label %35

35:                                               ; preds = %38, %34
  %36 = phi i64 [ %53, %38 ], [ 0, %34 ]
  %37 = icmp slt i64 %36, 2
  br i1 %37, label %38, label %54

38:                                               ; preds = %35
  %39 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, 1
  %40 = mul i64 %32, 2
  %41 = add i64 %40, %36
  %42 = getelementptr float, ptr %39, i64 %41
  %43 = load float, ptr %42, align 4
  %44 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 1
  %45 = add i64 %32, 0
  %46 = getelementptr float, ptr %44, i64 %45
  %47 = load float, ptr %46, align 4
  %48 = call float @llvm.exp.f32(float %43)
  %49 = fadd float %48, %47
  %50 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 1
  %51 = add i64 %32, 0
  %52 = getelementptr float, ptr %50, i64 %51
  store float %49, ptr %52, align 4
  %53 = add i64 %36, 1
  br label %35

54:                                               ; preds = %35
  %55 = add i64 %32, 1
  br label %31

56:                                               ; preds = %31
  %57 = call ptr @malloc(i64 add (i64 ptrtoint (ptr getelementptr (float, ptr null, i64 6) to i64), i64 64))
  %58 = ptrtoint ptr %57 to i64
  %59 = add i64 %58, 63
  %60 = urem i64 %59, 64
  %61 = sub i64 %59, %60
  %62 = inttoptr i64 %61 to ptr
  %63 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %57, 0
  %64 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %63, ptr %62, 1
  %65 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %64, i64 0, 2
  %66 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %65, i64 3, 3, 0
  %67 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %66, i64 2, 3, 1
  %68 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %67, i64 2, 4, 0
  %69 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %68, i64 1, 4, 1
  br label %70

70:                                               ; preds = %94, %56
  %71 = phi i64 [ %95, %94 ], [ 0, %56 ]
  %72 = icmp slt i64 %71, 3
  br i1 %72, label %73, label %96

73:                                               ; preds = %70
  br label %74

74:                                               ; preds = %77, %73
  %75 = phi i64 [ %93, %77 ], [ 0, %73 ]
  %76 = icmp slt i64 %75, 2
  br i1 %76, label %77, label %94

77:                                               ; preds = %74
  %78 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, 1
  %79 = mul i64 %71, 2
  %80 = add i64 %79, %75
  %81 = getelementptr float, ptr %78, i64 %80
  %82 = load float, ptr %81, align 4
  %83 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, 1
  %84 = add i64 %71, 0
  %85 = getelementptr float, ptr %83, i64 %84
  %86 = load float, ptr %85, align 4
  %87 = call float @llvm.exp.f32(float %82)
  %88 = fdiv float %87, %86
  %89 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %69, 1
  %90 = mul i64 %71, 2
  %91 = add i64 %90, %75
  %92 = getelementptr float, ptr %89, i64 %91
  store float %88, ptr %92, align 4
  %93 = add i64 %75, 1
  br label %74

94:                                               ; preds = %74
  %95 = add i64 %71, 1
  br label %70

96:                                               ; preds = %70
  ret { ptr, ptr, i64, [2 x i64], [2 x i64] } %69
}

define void @_mlir_ciface_forward(ptr %0, ptr %1) {
  %3 = load { ptr, ptr, i64, [2 x i64], [2 x i64] }, ptr %1, align 8
  %4 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 0
  %5 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 1
  %6 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 2
  %7 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 3, 0
  %8 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 3, 1
  %9 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 4, 0
  %10 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 4, 1
  %11 = call { ptr, ptr, i64, [2 x i64], [2 x i64] } @forward(ptr %4, ptr %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10)
  store { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, ptr %0, align 8
  ret void
}

; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: readwrite)
declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg) #0

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare float @llvm.exp.f32(float) #1

attributes #0 = { nocallback nofree nounwind willreturn memory(argmem: readwrite) }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

!llvm.module.flags = !{!0}

!0 = !{i32 2, !"Debug Info Version", i32 3}

But when I call it by c++, I got wrong result. Here is my code:

extern "C" {
void _mlir_ciface_forward(memref::MemRefDescriptor<float, 2> *C,
                     memref::MemRefDescriptor<float, 2> *A);
}

int main() {

  const int32_t M = 3;
  const int32_t N = 2;

  float matA[M][N] = {{1, 1}, {2, 1}, {3, 1}};

  float matC[M][N];

  std::array<int64_t, 2> aDim = {M, N};
  std::array<int64_t, 2> cDim = {M, N};

  memref::MemRef<float, 2> A((float *)matA, aDim);
  memref::MemRef<float, 2> C((float *)matC, cDim);

  _mlir_ciface_forward(&C.memRefDesc, &A.memRefDesc);

  std::cout << "result:\n";
  utility::printMatrix(C);
}

I used memref to encapsulate MemRefDescriptor,C’s result is incorrect. Looks like there is some offset. C should be a 3*2 matrix, when I output some more data, I find that the correct data is behind.Like this:

result:
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.0000 0.0000 |
| 0.5000 0.5000 |
| 0.7311 0.2689 |
| 0.8808 0.1192 |

Can anyone help, thanks!

Can you point to the definitions of the utility functions/classes you used?

It’s just a simple output function.

I wanted to look at the definition of MemRefDescriptor and utility::printMatrix.

template <typename elementType, size_t memrefRank>
struct MemRefDescriptor {
  elementType *allocated;
  elementType *aligned;
  int64_t offset;
  int64_t sizes[memrefRank];
  int64_t strides[memrefRank];
};
template <typename T>
void printMatrix(const memref::MemRefDescriptor<T, 2> &mat) {
  const int numRows = mat.sizes[0];
  const int numCols = mat.sizes[1];

  for (int i = 0; i < numRows; ++i) {
    printf("| ");

    for (int j = 0; j < numCols; ++j) {
      printValue(mat.allocated[i * numCols + j]);
    }
    printf("|\n");
  }
}

Is there something wrong?

The memref descriptor in runtime utils is defined here. I think the second pointer is what one should be using to get data. Also you are not using the offset field in printValue.
Hope this helps.

It really helps, thanks!

1 Like