[RFC] Ignore tail call flag when comparing instructions for sinking in `SimplifyCFGPass`

Hello LLVM developers,

I think I found a small improvement for the SimplifyCFGPass.
Please comment if it’s worth a patch.

Consider the following example:

extern void escape(void*);
extern int magic(void);
extern void tail(void);

void foofn() {
  int x;
  if (magic()) {
    tail();
  } else {
    escape(&x);
    tail();
  }
}

Note that it is safe to rewrite this function by sinking the call to tail as follows:

void foofn() {
  int x;
  if (!magic())
    escape(&x);
  tail();
}

However, currently this does not happen. The TailCallElimPass does not mark the second call to tail as a tail call because of the call to escape leaking local variable:

*** IR Dump After TailCallElimPass on foofn ***
; Function Attrs: nounwind uwtable
define dso_local void @foofn() local_unnamed_addr #0 {
entry:
  %x = alloca i32, align 4
  call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %x) #3
  %call = tail call i32 @magic() #3
  %tobool.not = icmp eq i32 %call, 0
  br i1 %tobool.not, label %if.else, label %if.then

if.then:                                          ; preds = %entry
  tail call void @tail() #3
  br label %if.end

if.else:                                          ; preds = %entry
  call void @escape(ptr noundef nonnull %x) #3
  call void @tail() #3
  br label %if.end

if.end:                                           ; preds = %if.else, %if.then
  call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %x) #3
  ret void
}

When SimplifyCFGPass considers what instructions to sink to a common successor it uses the function Instruction::isSameOperationAs to identify sink candidates. However this function returns false if call instructions have different values of tail call kind flag.

It is possible to extend Instruction::isSameOperationAs to optionally ignore tail call kind flag as below:

diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index adcfee5db03a..8d336d95d2a5 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -715,7 +715,8 @@ public:
     CompareIgnoringAlignment = 1<<0,
     /// Check for equivalence treating a type and a vector of that type
     /// as equivalent.
-    CompareUsingScalarTypes = 1<<1
+    CompareUsingScalarTypes = 1<<1,
+    CompareIgnoringTailCall = 1<<2
   };
 
   /// This function determines if the specified instruction executes the same
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 007e518a1a81..0a6f9ff1e8a2 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -459,7 +459,8 @@ const char *Instruction::getOpcodeName(unsigned OpCode) {
 /// kept in sync with FunctionComparator::cmpOperations in
 /// lib/Transforms/IPO/MergeFunctions.cpp.
 static bool haveSameSpecialState(const Instruction *I1, const Instruction *I2,
-                                 bool IgnoreAlignment = false) {
+                                 bool IgnoreAlignment = false,
+                                 bool IgnoreTailCall = false) {
   assert(I1->getOpcode() == I2->getOpcode() &&
          "Can not compare special state of different instructions");
 
@@ -482,7 +483,8 @@ static bool haveSameSpecialState(const Instruction *I1, const Instruction *I2,
   if (const CmpInst *CI = dyn_cast<CmpInst>(I1))
     return CI->getPredicate() == cast<CmpInst>(I2)->getPredicate();
   if (const CallInst *CI = dyn_cast<CallInst>(I1))
-    return CI->isTailCall() == cast<CallInst>(I2)->isTailCall() &&
+    return (CI->isTailCall() == cast<CallInst>(I2)->isTailCall()
+            || IgnoreTailCall) &&
            CI->getCallingConv() == cast<CallInst>(I2)->getCallingConv() &&
            CI->getAttributes() == cast<CallInst>(I2)->getAttributes() &&
            CI->hasIdenticalOperandBundleSchema(*cast<CallInst>(I2));
@@ -561,6 +563,7 @@ bool Instruction::isSameOperationAs(const Instruction *I,
                                     unsigned flags) const {
   bool IgnoreAlignment = flags & CompareIgnoringAlignment;
   bool UseScalarTypes  = flags & CompareUsingScalarTypes;
+  bool IgnoreTailCall  = flags & CompareIgnoringTailCall;
 
   if (getOpcode() != I->getOpcode() ||
       getNumOperands() != I->getNumOperands() ||
@@ -578,7 +581,7 @@ bool Instruction::isSameOperationAs(const Instruction *I,
         getOperand(i)->getType() != I->getOperand(i)->getType())
       return false;
 
-  return haveSameSpecialState(this, I, IgnoreAlignment);
+  return haveSameSpecialState(this, I, IgnoreAlignment, IgnoreTailCall);
 }
 
 bool Instruction::isUsedOutsideOfBlock(const BasicBlock *BB) const {
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 49ecd988dba7..971948d9b72a 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1816,7 +1816,7 @@ static bool canSinkInstructions(
 
   const Instruction *I0 = Insts.front();
   for (auto *I : Insts)
-    if (!I->isSameOperationAs(I0))
+    if (!I->isSameOperationAs(I0, Instruction::CompareIgnoringTailCall))
       return false;
 
   // All instructions in Insts are known to be the same opcode. If they have a

With such modification the final IR for the example above looks as follows (clang -S -O2 -emit-llvm -o t.ll t.c):

define dso_local void @foofn() local_unnamed_addr #0 {
entry:
  %x = alloca i32, align 4
  call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %x) #3
  %call = tail call i32 @magic() #3
  %tobool.not = icmp eq i32 %call, 0
  br i1 %tobool.not, label %if.else, label %if.end

if.else:                                          ; preds = %entry
  call void @escape(ptr noundef nonnull %x) #3
  br label %if.end

if.end:                                           ; preds = %entry, %if.else
  call void @tail() #3
  call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %x) #3
  ret void
}

What do you think? Did I miss something? Should I proceed to Phabricator with it?

Thanks,
Eduard

This generally sounds reasonable to me, with two caveats:

  1. You need to explicitly drop the tail marker on the sunk call if there is a mismatch. You probably just got lucky that the call without the tail marker got sunk.
  2. This is not legal for musttail calls. This looks like a pre-existing bug in haveSameSpecialState(), which would consider a tail call and a musttail call as being the same.

Noted, thank you. Will proceed tomorrow.