Through all these days, I implemented a DFS method to find all related data. But now I ran into a problem that I cannot correctly recognize related data / functions. What I need to do is to find those functions which produces the input for the current function. In this sample, the 7th, 10th, 14th argument are the input, the 14th argument is also the output of GemmEx.
Sample:
; ModuleID = '/mnt/data/home/mzw/workspace/test_space/llvm_test/gemm_pass2_test/dataflow_test/short_test.hip'
source_filename = "/mnt/data/home/mzw/workspace/test_space/llvm_test/gemm_pass2_test/dataflow_test/short_test.hip"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
%"class.std::ios_base::Init" = type { i8 }
@_ZStL8__ioinit = internal global %"class.std::ios_base::Init" zeroinitializer, align 1
@__dso_handle = external hidden global i8
@llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 65535, void ()* @_GLOBAL__sub_I_short_test.hip, i8* null }]
declare dso_local void @_ZNSt8ios_base4InitC1Ev(%"class.std::ios_base::Init"* nonnull dereferenceable(1)) unnamed_addr #0
; Function Attrs: nounwind
declare dso_local void @_ZNSt8ios_base4InitD1Ev(%"class.std::ios_base::Init"* nonnull dereferenceable(1)) unnamed_addr #1
; Function Attrs: nofree nounwind
declare dso_local i32 @__cxa_atexit(void (i8*)*, i8*, i8*) local_unnamed_addr #2
; Function Attrs: norecurse uwtable mustprogress
define dso_local i32 @main() local_unnamed_addr #3 {
%1 = alloca float*, align 8
%2 = alloca float*, align 8
%3 = alloca float*, align 8
%4 = alloca i8*, align 8
%5 = alloca float, align 4
%6 = alloca float, align 4
%7 = bitcast float** %1 to i8*
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %7) #6
%8 = bitcast float** %2 to i8*
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %8) #6
%9 = bitcast float** %3 to i8*
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %9) #6
%10 = bitcast float** %1 to i8**
%11 = call i32 @hipMalloc(i8** nonnull %10, i64 4)
%12 = bitcast float** %2 to i8**
%13 = call i32 @hipMalloc(i8** nonnull %12, i64 4)
%14 = bitcast float** %3 to i8**
%15 = call i32 @hipMalloc(i8** nonnull %14, i64 4)
%16 = bitcast i8** %4 to i8*
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %16) #6
store i8* null, i8** %4, align 8, !tbaa !2
%17 = call i32 @hipblasCreate(i8** nonnull %4)
%18 = bitcast float* %5 to i8*
call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %18) #6
%19 = bitcast float* %6 to i8*
call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %19) #6
%20 = load i8*, i8** %4, align 8, !tbaa !2
%21 = load i8*, i8** %10, align 8, !tbaa !2
%22 = load i8*, i8** %12, align 8, !tbaa !2
%23 = load i8*, i8** %14, align 8, !tbaa !2
%24 = call i32 @hipblasGemmEx(i8* %20, i32 111, i32 111, i32 1, i32 1, i32 1, i8* nonnull %18, i8* %21, i32 151, i32 1, i8* %22, i32 151, i32 1, i8* nonnull %19, i8* %23, i32 151, i32 1, i32 151, i32 160)
%25 = load i8*, i8** %4, align 8, !tbaa !2
%26 = load i8*, i8** %12, align 8, !tbaa !2
%27 = load i8*, i8** %10, align 8, !tbaa !2
%28 = call i32 @hipblasGemmEx(i8* %25, i32 111, i32 111, i32 1, i32 1, i32 1, i8* nonnull %18, i8* %26, i32 151, i32 1, i8* %26, i32 151, i32 1, i8* nonnull %19, i8* %27, i32 151, i32 1, i32 151, i32 160)
call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %19) #6
call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %18) #6
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %16) #6
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %9) #6
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %8) #6
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %7) #6
ret i32 0
}
; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #4
declare dso_local i32 @hipMalloc(i8**, i64) local_unnamed_addr #0
declare dso_local i32 @hipblasCreate(i8**) local_unnamed_addr #0
declare dso_local i32 @hipblasGemmEx(i8*, i32, i32, i32, i32, i32, i8*, i8*, i32, i32, i8*, i32, i32, i8*, i8*, i32, i32, i32, i32) local_unnamed_addr #0
; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #4
; Function Attrs: uwtable
define internal amdgpu_kernel void @_GLOBAL__sub_I_short_test.hip() #5 section ".text.startup" {
tail call void @_ZNSt8ios_base4InitC1Ev(%"class.std::ios_base::Init"* nonnull dereferenceable(1) @_ZStL8__ioinit)
%1 = tail call i32 @__cxa_atexit(void (i8*)* bitcast (void (%"class.std::ios_base::Init"*)* @_ZNSt8ios_base4InitD1Ev to void (i8*)*), i8* getelementptr inbounds (%"class.std::ios_base::Init", %"class.std::ios_base::Init"* @_ZStL8__ioinit, i64 0, i32 0), i8* nonnull @__dso_handle) #6
ret void
}
attributes #0 = { "frame-pointer"="none" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #1 = { nounwind "frame-pointer"="none" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #2 = { nofree nounwind }
attributes #3 = { norecurse uwtable mustprogress "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #4 = { argmemonly nofree nosync nounwind willreturn }
attributes #5 = { uwtable "device-init" "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #6 = { nounwind }
!llvm.module.flags = !{!0}
!llvm.ident = !{!1}
!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{!"clang version 13.0.0 (https://github.com/RadeonOpenCompute/llvm-project roc-4.3.0 21295 f2943f684437d2c1143a56e418d29fc6b3314072)"}
!2 = !{!3, !3, i64 0}
!3 = !{!"any pointer", !4, i64 0}
!4 = !{!"omnipotent char", !5, i64 0}
!5 = !{!"Simple C++ TBAA"}
My Core DFS code is contained below:
struct dataflow : public ModulePass{
static char ID;
dataflow();
~dataflow(){};
bool runOnModule(Module & M) override;
void dfs(Value *,Function *, int,std::unordered_map<Value*,bool>&);
void jump_out_of_parent(Value *, Function *, int,std::unordered_map<Value*,bool>&);
template<typename T>bool is_in_list(std::vector<T>,T);
int gemmex_id;
std::unordered_map<int,Instruction*> gemm_id_call_inst_map;
std::unordered_map<Instruction*,int> gemm_call_inst_int_map;
std::unordered_map<int,std::vector<int>> related_gemm_id_map;
//below one is for general call inst, above are for gemm
std::vector<CallInst*> call_inst_list;
std::string modify_matrix_mem_func_name;
};
dataflow::dataflow() : ModulePass(ID){
gemmex_id = 0;
modify_matrix_mem_func_name = "dududu";
}
template<typename T>
bool dataflow::is_in_list(std::vector<T> l, T target)
{
for(T e: l)
{
if(target == e) return true;
}
return false;
}
//NOTE: The related_gemm_id_map we get here contains all related gemm without consindering
//1)whether is runned acutally 2) whether it's runned before it(we only consider the before gemm every time we optimize)
void dataflow::dfs(Value * called_arg, Function * caller_func, int target_id, std::unordered_map<Value*,bool> & dfsed_value_map)
{
//TO.DO.: How to avoid repeating the same instruction? //DONE
if(dfsed_value_map.find(called_arg) != dfsed_value_map.end() && dfsed_value_map[called_arg]) return;
else
dfsed_value_map[called_arg] = true;
errs()<<"Current target arg is "<<*called_arg<<"\n";
//"user" means that this argument is used as argument/operand somewhere
for(auto user = called_arg->user_begin(), user_end = called_arg->user_end();
user != user_end; user++) //User means that this arg is used as operand in these instructions
{
//If this arg is just right as the argument of GemmEx
if(Instruction * inst = dyn_cast<Instruction>(*user))
{
errs()<<*inst<<"\n";
//errs()<<"Facing instruction of "<<*inst<<"\n";
if(isa<CallInst>(inst))
{
CallInst * call_inst = dyn_cast<CallInst>(inst);
Function * called_func = call_inst->getCalledFunction();
if(called_func && called_func->getName() == "hipblasGemmEx")
{
int cur_id = gemm_call_inst_int_map[inst];
errs()<<"We found the call_inst of GemmEx: "<<*call_inst<<"\n";
//we only care about those gemmex that accept this arg as output
//NOTE: This makes us wont add gemmex itself into its dependence list
if(called_arg==call_inst->getOperand(14))
{
//TO.DO.: Avoid searching the same gemm //DONE
if(is_in_list<int>(related_gemm_id_map[target_id],cur_id))
{
//do nothing
//errs()<<"We occur the same GemmEx with id "<<cur_id<<"\n";
}
else
{
errs()<<"The "<<target_id<<"th GemmEx depends on "<<cur_id<<"th GemmEx\n";
related_gemm_id_map[target_id].push_back(cur_id);
//dfs(call_inst->getOperand(7),caller_func, cur_id, dfsed_value_map);
//dfs(call_inst->getOperand(10),caller_func, cur_id, dfsed_value_map);
//dfs(call_inst->getOperand(14),caller_func, cur_id, dfsed_value_map);
}
}
else
continue;
}
else if(called_func && called_func->getName() == modify_matrix_mem_func_name)
{
//TO.DO.: when we run into something like ReadFile() that can modify the Matrix to be a new data matrix
//what should we do?
}
else
{
//we dont care about other functions
//10-26:But if we met a function containing the GemmEx, we wont dig in.
//Oppositely, it will start from the contained GemmEx and jump out to find this GemmEx
//QUES.: But in this way, we cannot find the dependency from called func to current GemmEx
//like {testfunc(AAC),GemmEx(ABC)} we cannot know the GemmEx depends on the one in testfunc
//we only can know testfunc depends on GemmEx's C
continue;
}
}
else
{
//we assume we only have load/store in this branch
//errs()<<"Now we have met the load/store inst\n";
Value * ret_v = dynamic_cast<Value*>(inst);
errs()<<*ret_v<<"\n";
dfs(ret_v,caller_func, target_id, dfsed_value_map);
}
}
}
//"use" means that this argument is def/not as an operand somewhere.
for(auto use = called_arg->use_begin(), use_end = called_arg->use_end(); use != use_end; use++)
{
if(Instruction * inst = dyn_cast<Instruction>(*use))
{
for(auto i = 0; i < inst->getNumOperands(); i++)
{
Value * related_v = inst->getOperand(i);
if(related_v == called_arg) continue;
else dfs(related_v,caller_func, target_id, dfsed_value_map);
}
}
}
//For those whose argument is passed through the arguments of parent functions
//In fact, we should check whether it is in parent's arguments whenever we are handling a new Value
//So that we can jump out of parent function, get the all coresponding call_inst of parent function
//and dfs on the coresponding passed arguments of call_inst
//TO.DO.:
jump_out_of_parent(called_arg,caller_func,target_id,dfsed_value_map);
}
void dataflow::jump_out_of_parent(Value * target_arg, Function * parent_func, int target_id, std::unordered_map<Value*,bool>& dfsed_value_map)
{
//errs()<<parent_func->getName()<<"\n";
/*
if(parent_func->getName().str() == "main") return;
if(parent_func->getName().str() == "_Z8testfuncPv18hipblasOperation_tS0_iiiS_S_17hipblasDatatype_tiS_S1_iS_S_S1_iS1_17hipblasGemmAlgo_t")
{
errs()<<"jumping out of testfunc\n";
errs()<<"It has total "<<parent_func->arg_size()<<" arguments\n";
}
*/
//NOTE: We cannot use getNumOperands to get the argument list size of a function def
for(size_t i = 0; i < parent_func->arg_size(); i++)
{
Value * arg = parent_func->getArg(i);
//errs()<<"The function "<<parent_func->getName().str()<<" "<<i<<"th argument is "<<*arg<<"\n";
if(target_arg == arg)
{
//TO.DO.: Loop over call_inst_list and locate all corresponding passed-in argument, dfs on them
for(int j = 0; j < call_inst_list.size(); j++)
{
Function * called_func = call_inst_list[j]->getCalledFunction();
if(called_func && called_func == parent_func)
{
//errs()<<"We now jump out of test func\n";
CallInst * call_inst = call_inst_list[j];
Value * target_passed_arg = call_inst->getArgOperand(i);
Function * new_parent_func = call_inst->getParent()->getParent();
dfs(target_passed_arg,new_parent_func,target_id, dfsed_value_map);
}
}
}
}
}
bool dataflow::runOnModule(Module &M)
{
//TO.DO.: In official version, we should only care about functions outside the tool_library
//NOTE: We only collect all id of call_inst and gemmex here. Because in dfs, no any in-order ensured
for(auto func = M.getFunctionList().begin(), end_func = M.getFunctionList().end();
func != end_func; func++)
{
//errs()<<"Now we are facing declare of function "<<func->getName()<<"\n";
for(auto bb = func->begin(); bb != func->end(); bb++)
{
for(auto inst = bb->begin(); inst != bb->end(); inst++)
{
if(CallInst * call_inst = dyn_cast<CallInst>(inst))
{
call_inst_list.push_back(call_inst);
Function * called_func = call_inst->getCalledFunction();
if(called_func && called_func->getName() == "hipblasGemmEx")
{
errs()<<"We get the "<<++gemmex_id<<"th called GemmEx function in "<<*call_inst<<"\n";
gemm_id_call_inst_map[gemmex_id]=call_inst;
gemm_call_inst_int_map[call_inst] = gemmex_id;
}
}
}
}
}
errs()<<"Now we finish collecting all call_inst(including gemmex)\n";
//Now we are dealing with related-gemm
gemmex_id = 0;
for(auto func = M.getFunctionList().begin(), end_func = M.getFunctionList().end();
func != end_func; func++)
{
//errs()<<"Now we are facing declare of function "<<func->getName()<<"\n";
for(auto bb = func->begin(); bb != func->end(); bb++)
{
for(auto inst = bb->begin(); inst != bb->end(); inst++)
{
//only focus on GemmEx
if(CallInst * call_inst = dyn_cast<CallInst>(inst))
{
Function * called_func = call_inst->getCalledFunction();
if(called_func && called_func->getName() == "hipblasGemmEx")
{
gemmex_id++;
errs()<<"Now we use "<<gemmex_id<<"th GemmEx as target GemmEx\n";
//get MatrixA argument defined before.
Value * Matrix1_argv = call_inst->getArgOperand(7);
//For those who has a load / related operation previous
Function * caller_func = dyn_cast<Function>(func);
std::unordered_map<Value*,bool> dfsed_value_map;
dfs(Matrix1_argv,caller_func,gemmex_id,dfsed_value_map);
Value * Matrix2_argv = call_inst->getArgOperand(10);
dfsed_value_map.clear();
dfs(Matrix2_argv,caller_func,gemmex_id,dfsed_value_map);
Value * Matrix3_argv = call_inst->getArgOperand(14);
dfsed_value_map.clear();
dfs(Matrix3_argv,caller_func,gemmex_id,dfsed_value_map);
}
}
}
}
}
//running above, we are able to get the whole data-related gemm
//filter those gemm are not actually runned or not runned before target-gemm
for(auto it = related_gemm_id_map.begin(); it != related_gemm_id_map.end(); it++)
{
std::cout<<"The related GemmEx id of "<<it->first<<" contains: ";
for(auto id : it->second) std::cout<<id<<" ";
std::cout<<std::endl;
}
return false;
}