Consider this MLIR code with an indirect call to a device function:
gpu.module @kernel {
func.func @simple1(%r_size : index,
%in_a : memref<?xi64>,
%in_b : memref<?xi64>,
%in_c : memref<?xi64>) {
%ci0 = arith.constant 0 : index
%ci12 = arith.constant 1 : index
scf.for %idx0 = %ci0 to %r_size step %ci12 {
%idx0_i64 = arith.index_cast %idx0 : index to i64
memref.store %idx0_i64, %in_a[%idx0] : memref<?xi64>
}
return
}
gpu.func @simple(
%r_size : index,
%in_a : memref<?xi64>,
%in_b : memref<?xi64>,
%in_c : memref<?xi64>
) kernel {
// func.call @simple1(%r_size, %in_a, %in_b, %in_c) : (index, memref<?xi64>, memref<?xi64>, memref<?xi64>) -> ()
%15 = func.constant @simple1 : (index, memref<?xi64>, memref<?xi64>, memref<?xi64>) -> ()
func.call_indirect %15(%r_size, %in_a, %in_b, %in_c) : (index, memref<?xi64>, memref<?xi64>, memref<?xi64>) -> ()
gpu.return
}
}
func.func @main() -> i64 {
%ci1 = arith.constant 1 : index
%c0 = arith.constant 0 : i64
%size = arith.constant 100 : index
%a = memref.alloc(%size) : memref<?xi64>
%b = memref.alloc(%size) : memref<?xi64>
%c = memref.alloc(%size) : memref<?xi64>
%a_unranked = memref.cast %a : memref<?xi64> to memref<*xi64>
%b_unranked = memref.cast %b : memref<?xi64> to memref<*xi64>
%c_unranked = memref.cast %c : memref<?xi64> to memref<*xi64>
gpu.host_register %a_unranked : memref<*xi64>
gpu.host_register %b_unranked : memref<*xi64>
gpu.host_register %c_unranked : memref<*xi64>
%tmp_a = gpu.alloc(%size) : memref<?xi64>
%tmp_b = gpu.alloc(%size) : memref<?xi64>
%tmp_c = gpu.alloc(%size) : memref<?xi64>
%token_a = gpu.memcpy async %tmp_a, %a : memref<?xi64>, memref<?xi64>
%token_b = gpu.memcpy async [%token_a] %tmp_b, %b : memref<?xi64>, memref<?xi64>
gpu.memcpy [%token_b] %tmp_c, %c : memref<?xi64>, memref<?xi64>
gpu.launch_func @kernel::@simple blocks in (%ci1, %ci1, %ci1) threads in (%ci1, %ci1, %ci1)
args(%size : index, %tmp_a : memref<?xi64>,%tmp_b : memref<?xi64>, %tmp_c : memref<?xi64>)
gpu.memcpy %a, %tmp_a : memref<?xi64>, memref<?xi64>
call @printMemrefI64(%a_unranked) : (memref<*xi64>) -> ()
return %c0 : i64
}
func.func private @printI64(i64)
func.func private @printMemrefI64(memref<*xi64>)
} // END gpu.container_module
It gives an error:
error: 'func.constant' op reference to undefined function 'simple1'
However, when I use a direct call instead (commented out), it works as expected.
How shall function references be created for device functions?
The intended use case is to pass a function reference as a callback to a precompiled device function from a linked bitcode library.