ReduceOperandsToArgs.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. //===----------------------------------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "ReduceOperandsToArgs.h"
  9. #include "Delta.h"
  10. #include "llvm/ADT/Sequence.h"
  11. #include "llvm/IR/InstIterator.h"
  12. #include "llvm/IR/InstrTypes.h"
  13. #include "llvm/IR/Instructions.h"
  14. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  15. #include "llvm/Transforms/Utils/Cloning.h"
  16. using namespace llvm;
  17. static bool canReplaceFunction(Function *F) {
  18. return all_of(F->uses(), [](Use &Op) {
  19. if (auto *CI = dyn_cast<CallBase>(Op.getUser()))
  20. return &CI->getCalledOperandUse() == &Op;
  21. return false;
  22. });
  23. }
  24. static bool canReduceUse(Use &Op) {
  25. Value *Val = Op.get();
  26. Type *Ty = Val->getType();
  27. // Only replace operands that can be passed-by-value.
  28. if (!Ty->isFirstClassType())
  29. return false;
  30. // Don't pass labels/metadata as arguments.
  31. if (Ty->isLabelTy() || Ty->isMetadataTy())
  32. return false;
  33. // No need to replace values that are already arguments.
  34. if (isa<Argument>(Val))
  35. return false;
  36. // Do not replace literals.
  37. if (isa<ConstantData>(Val))
  38. return false;
  39. // Do not convert direct function calls to indirect calls.
  40. if (auto *CI = dyn_cast<CallBase>(Op.getUser()))
  41. if (&CI->getCalledOperandUse() == &Op)
  42. return false;
  43. return true;
  44. }
  45. /// Goes over OldF calls and replaces them with a call to NewF.
  46. static void replaceFunctionCalls(Function *OldF, Function *NewF) {
  47. SmallVector<CallBase *> Callers;
  48. for (Use &U : OldF->uses()) {
  49. auto *CI = cast<CallBase>(U.getUser());
  50. assert(&U == &CI->getCalledOperandUse());
  51. assert(CI->getCalledFunction() == OldF);
  52. Callers.push_back(CI);
  53. }
  54. // Call arguments for NewF.
  55. SmallVector<Value *> Args(NewF->arg_size(), nullptr);
  56. // Fill up the additional parameters with undef values.
  57. for (auto ArgIdx : llvm::seq<size_t>(OldF->arg_size(), NewF->arg_size())) {
  58. Type *NewArgTy = NewF->getArg(ArgIdx)->getType();
  59. Args[ArgIdx] = UndefValue::get(NewArgTy);
  60. }
  61. for (CallBase *CI : Callers) {
  62. // Preserve the original function arguments.
  63. for (auto Z : zip_first(CI->args(), Args))
  64. std::get<1>(Z) = std::get<0>(Z);
  65. // Also preserve operand bundles.
  66. SmallVector<OperandBundleDef> OperandBundles;
  67. CI->getOperandBundlesAsDefs(OperandBundles);
  68. // Create the new function call.
  69. CallBase *NewCI;
  70. if (auto *II = dyn_cast<InvokeInst>(CI)) {
  71. NewCI = InvokeInst::Create(NewF, cast<InvokeInst>(II)->getNormalDest(),
  72. cast<InvokeInst>(II)->getUnwindDest(), Args,
  73. OperandBundles, CI->getName());
  74. } else {
  75. assert(isa<CallInst>(CI));
  76. NewCI = CallInst::Create(NewF, Args, OperandBundles, CI->getName());
  77. }
  78. NewCI->setCallingConv(NewF->getCallingConv());
  79. // Do the replacement for this use.
  80. if (!CI->use_empty())
  81. CI->replaceAllUsesWith(NewCI);
  82. ReplaceInstWithInst(CI, NewCI);
  83. }
  84. }
  85. /// Add a new function argument to @p F for each use in @OpsToReplace, and
  86. /// replace those operand values with the new function argument.
  87. static void substituteOperandWithArgument(Function *OldF,
  88. ArrayRef<Use *> OpsToReplace) {
  89. if (OpsToReplace.empty())
  90. return;
  91. SetVector<Value *> UniqueValues;
  92. for (Use *Op : OpsToReplace)
  93. UniqueValues.insert(Op->get());
  94. // Determine the new function's signature.
  95. SmallVector<Type *> NewArgTypes;
  96. llvm::append_range(NewArgTypes, OldF->getFunctionType()->params());
  97. size_t ArgOffset = NewArgTypes.size();
  98. for (Value *V : UniqueValues)
  99. NewArgTypes.push_back(V->getType());
  100. FunctionType *FTy =
  101. FunctionType::get(OldF->getFunctionType()->getReturnType(), NewArgTypes,
  102. OldF->getFunctionType()->isVarArg());
  103. // Create the new function...
  104. Function *NewF =
  105. Function::Create(FTy, OldF->getLinkage(), OldF->getAddressSpace(),
  106. OldF->getName(), OldF->getParent());
  107. // In order to preserve function order, we move NewF behind OldF
  108. NewF->removeFromParent();
  109. OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF);
  110. // Preserve the parameters of OldF.
  111. ValueToValueMapTy VMap;
  112. for (auto Z : zip_first(OldF->args(), NewF->args())) {
  113. Argument &OldArg = std::get<0>(Z);
  114. Argument &NewArg = std::get<1>(Z);
  115. NewArg.setName(OldArg.getName()); // Copy the name over...
  116. VMap[&OldArg] = &NewArg; // Add mapping to VMap
  117. }
  118. // Adjust the new parameters.
  119. ValueToValueMapTy OldValMap;
  120. for (auto Z : zip_first(UniqueValues, drop_begin(NewF->args(), ArgOffset))) {
  121. Value *OldVal = std::get<0>(Z);
  122. Argument &NewArg = std::get<1>(Z);
  123. NewArg.setName(OldVal->getName());
  124. OldValMap[OldVal] = &NewArg;
  125. }
  126. SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
  127. CloneFunctionInto(NewF, OldF, VMap, CloneFunctionChangeType::LocalChangesOnly,
  128. Returns, "", /*CodeInfo=*/nullptr);
  129. // Replace the actual operands.
  130. for (Use *Op : OpsToReplace) {
  131. Value *NewArg = OldValMap.lookup(Op->get());
  132. auto *NewUser = cast<Instruction>(VMap.lookup(Op->getUser()));
  133. NewUser->setOperand(Op->getOperandNo(), NewArg);
  134. }
  135. // Replace all OldF uses with NewF.
  136. replaceFunctionCalls(OldF, NewF);
  137. // Rename NewF to OldF's name.
  138. std::string FName = OldF->getName().str();
  139. OldF->replaceAllUsesWith(ConstantExpr::getBitCast(NewF, OldF->getType()));
  140. OldF->eraseFromParent();
  141. NewF->setName(FName);
  142. }
  143. static void reduceOperandsToArgs(Oracle &O, Module &Program) {
  144. SmallVector<Use *> OperandsToReduce;
  145. for (Function &F : make_early_inc_range(Program.functions())) {
  146. if (!canReplaceFunction(&F))
  147. continue;
  148. OperandsToReduce.clear();
  149. for (Instruction &I : instructions(&F)) {
  150. for (Use &Op : I.operands()) {
  151. if (!canReduceUse(Op))
  152. continue;
  153. if (O.shouldKeep())
  154. continue;
  155. OperandsToReduce.push_back(&Op);
  156. }
  157. }
  158. substituteOperandWithArgument(&F, OperandsToReduce);
  159. }
  160. }
  161. void llvm::reduceOperandsToArgsDeltaPass(TestRunner &Test) {
  162. outs() << "*** Converting operands to function arguments ...\n";
  163. return runDeltaPass(Test, reduceOperandsToArgs);
  164. }