@rohany The following will do it:
@@ -37,11 +43,69 @@ gpu::SerializeToBlobPass::SerializeToBlobPass(TypeID passID)
gpu::SerializeToBlobPass::SerializeToBlobPass(const SerializeToBlobPass &other)
: OperationPass<gpu::GPUModuleOp>(other) {}
+/// Link a bitcode file into `llvmModule`.
+// This code has been adapted and reused from XLA:
+// https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc;drc=262777e9f9304c7df6b694934af819c820954ef5;l=334.
+static LogicalResult linkBitcode(StringRef filename, llvm::Module &llvmModule) {
+ llvm::SMDiagnostic diagnosticErr;
+ std::unique_ptr<llvm::Module> bitcodeModule(
+ llvm::parseIRFile(llvm::StringRef(filename.data(), filename.size()),
+ diagnosticErr, llvmModule.getContext()));
+ if (!bitcodeModule) {
+ llvm::errs() << "Error loading IR module: " << filename << '\n';
+ return failure();
+ }
+ if (!bitcodeModule)
+ return failure();
+
+ // Ignore the data layout of the module we're importing. This avoids a
+ // warning from the linker.
+ llvm::Linker linker(llvmModule);
+ bitcodeModule->setDataLayout(llvmModule.getDataLayout());
+ if (linker.linkInModule(
+ std::move(bitcodeModule), llvm::Linker::Flags::LinkOnlyNeeded,
+ [](llvm::Module &m, const llvm::StringSet<> &gvs) {
+ internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) {
+ return !gv.hasName() || (gvs.count(gv.getName()) == 0);
+ });
+ })) {
+ llvm::errs() << "Error linking bitcode module from " << filename << '\n';
+ return failure();
+ }
+
+ return success();
+}
+
llvmModule.setDataLayout(targetMachine.createDataLayout());
+ // Link in CUDA's libdevice bitcode file which has NVVM bitcode for common
+ // math primitives and bit-manipulation functions.
+ // TODO: Replace this hardcoded path with a cmake provided value.
+ // TODO: In the future, this should be removed in favor of any linking support
+ // that may be added to the LLVM NVPTX backend.
+ const std::string libdevicePath =
+ "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc";
+ if (failed(linkBitcode(libdevicePath, llvmModule)))
+ return std::nullopt;
+