ReduceOperandsToArgs.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. //===----------------------------------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "ReduceOperandsToArgs.h"
  9. #include "Delta.h"
  10. #include "Utils.h"
  11. #include "llvm/ADT/Sequence.h"
  12. #include "llvm/IR/Constants.h"
  13. #include "llvm/IR/InstIterator.h"
  14. #include "llvm/IR/InstrTypes.h"
  15. #include "llvm/IR/Instructions.h"
  16. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  17. #include "llvm/Transforms/Utils/Cloning.h"
  18. using namespace llvm;
  19. static bool canReplaceFunction(Function *F) {
  20. return all_of(F->uses(), [](Use &Op) {
  21. if (auto *CI = dyn_cast<CallBase>(Op.getUser()))
  22. return &CI->getCalledOperandUse() == &Op;
  23. return false;
  24. });
  25. }
  26. static bool canReduceUse(Use &Op) {
  27. Value *Val = Op.get();
  28. Type *Ty = Val->getType();
  29. // Only replace operands that can be passed-by-value.
  30. if (!Ty->isFirstClassType())
  31. return false;
  32. // Don't pass labels/metadata as arguments.
  33. if (Ty->isLabelTy() || Ty->isMetadataTy())
  34. return false;
  35. // No need to replace values that are already arguments.
  36. if (isa<Argument>(Val))
  37. return false;
  38. // Do not replace literals.
  39. if (isa<ConstantData>(Val))
  40. return false;
  41. // Do not convert direct function calls to indirect calls.
  42. if (auto *CI = dyn_cast<CallBase>(Op.getUser()))
  43. if (&CI->getCalledOperandUse() == &Op)
  44. return false;
  45. return true;
  46. }
  47. /// Goes over OldF calls and replaces them with a call to NewF.
  48. static void replaceFunctionCalls(Function *OldF, Function *NewF) {
  49. SmallVector<CallBase *> Callers;
  50. for (Use &U : OldF->uses()) {
  51. auto *CI = cast<CallBase>(U.getUser());
  52. assert(&U == &CI->getCalledOperandUse());
  53. assert(CI->getCalledFunction() == OldF);
  54. Callers.push_back(CI);
  55. }
  56. // Call arguments for NewF.
  57. SmallVector<Value *> Args(NewF->arg_size(), nullptr);
  58. // Fill up the additional parameters with default values.
  59. for (auto ArgIdx : llvm::seq<size_t>(OldF->arg_size(), NewF->arg_size())) {
  60. Type *NewArgTy = NewF->getArg(ArgIdx)->getType();
  61. Args[ArgIdx] = getDefaultValue(NewArgTy);
  62. }
  63. for (CallBase *CI : Callers) {
  64. // Preserve the original function arguments.
  65. for (auto Z : zip_first(CI->args(), Args))
  66. std::get<1>(Z) = std::get<0>(Z);
  67. // Also preserve operand bundles.
  68. SmallVector<OperandBundleDef> OperandBundles;
  69. CI->getOperandBundlesAsDefs(OperandBundles);
  70. // Create the new function call.
  71. CallBase *NewCI;
  72. if (auto *II = dyn_cast<InvokeInst>(CI)) {
  73. NewCI = InvokeInst::Create(NewF, cast<InvokeInst>(II)->getNormalDest(),
  74. cast<InvokeInst>(II)->getUnwindDest(), Args,
  75. OperandBundles, CI->getName());
  76. } else {
  77. assert(isa<CallInst>(CI));
  78. NewCI = CallInst::Create(NewF, Args, OperandBundles, CI->getName());
  79. }
  80. NewCI->setCallingConv(NewF->getCallingConv());
  81. // Do the replacement for this use.
  82. if (!CI->use_empty())
  83. CI->replaceAllUsesWith(NewCI);
  84. ReplaceInstWithInst(CI, NewCI);
  85. }
  86. }
  87. /// Add a new function argument to @p F for each use in @OpsToReplace, and
  88. /// replace those operand values with the new function argument.
  89. static void substituteOperandWithArgument(Function *OldF,
  90. ArrayRef<Use *> OpsToReplace) {
  91. if (OpsToReplace.empty())
  92. return;
  93. SetVector<Value *> UniqueValues;
  94. for (Use *Op : OpsToReplace)
  95. UniqueValues.insert(Op->get());
  96. // Determine the new function's signature.
  97. SmallVector<Type *> NewArgTypes;
  98. llvm::append_range(NewArgTypes, OldF->getFunctionType()->params());
  99. size_t ArgOffset = NewArgTypes.size();
  100. for (Value *V : UniqueValues)
  101. NewArgTypes.push_back(V->getType());
  102. FunctionType *FTy =
  103. FunctionType::get(OldF->getFunctionType()->getReturnType(), NewArgTypes,
  104. OldF->getFunctionType()->isVarArg());
  105. // Create the new function...
  106. Function *NewF =
  107. Function::Create(FTy, OldF->getLinkage(), OldF->getAddressSpace(),
  108. OldF->getName(), OldF->getParent());
  109. // In order to preserve function order, we move NewF behind OldF
  110. NewF->removeFromParent();
  111. OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF);
  112. // Preserve the parameters of OldF.
  113. ValueToValueMapTy VMap;
  114. for (auto Z : zip_first(OldF->args(), NewF->args())) {
  115. Argument &OldArg = std::get<0>(Z);
  116. Argument &NewArg = std::get<1>(Z);
  117. NewArg.setName(OldArg.getName()); // Copy the name over...
  118. VMap[&OldArg] = &NewArg; // Add mapping to VMap
  119. }
  120. // Adjust the new parameters.
  121. ValueToValueMapTy OldValMap;
  122. for (auto Z : zip_first(UniqueValues, drop_begin(NewF->args(), ArgOffset))) {
  123. Value *OldVal = std::get<0>(Z);
  124. Argument &NewArg = std::get<1>(Z);
  125. NewArg.setName(OldVal->getName());
  126. OldValMap[OldVal] = &NewArg;
  127. }
  128. SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
  129. CloneFunctionInto(NewF, OldF, VMap, CloneFunctionChangeType::LocalChangesOnly,
  130. Returns, "", /*CodeInfo=*/nullptr);
  131. // Replace the actual operands.
  132. for (Use *Op : OpsToReplace) {
  133. Value *NewArg = OldValMap.lookup(Op->get());
  134. auto *NewUser = cast<Instruction>(VMap.lookup(Op->getUser()));
  135. NewUser->setOperand(Op->getOperandNo(), NewArg);
  136. }
  137. // Replace all OldF uses with NewF.
  138. replaceFunctionCalls(OldF, NewF);
  139. // Rename NewF to OldF's name.
  140. std::string FName = OldF->getName().str();
  141. OldF->replaceAllUsesWith(ConstantExpr::getBitCast(NewF, OldF->getType()));
  142. OldF->eraseFromParent();
  143. NewF->setName(FName);
  144. }
  145. static void reduceOperandsToArgs(Oracle &O, ReducerWorkItem &WorkItem) {
  146. Module &Program = WorkItem.getModule();
  147. SmallVector<Use *> OperandsToReduce;
  148. for (Function &F : make_early_inc_range(Program.functions())) {
  149. if (!canReplaceFunction(&F))
  150. continue;
  151. OperandsToReduce.clear();
  152. for (Instruction &I : instructions(&F)) {
  153. for (Use &Op : I.operands()) {
  154. if (!canReduceUse(Op))
  155. continue;
  156. if (O.shouldKeep())
  157. continue;
  158. OperandsToReduce.push_back(&Op);
  159. }
  160. }
  161. substituteOperandWithArgument(&F, OperandsToReduce);
  162. }
  163. }
  164. void llvm::reduceOperandsToArgsDeltaPass(TestRunner &Test) {
  165. runDeltaPass(Test, reduceOperandsToArgs,
  166. "Converting operands to function arguments");
  167. }