ReduceOperands.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. //===----------------------------------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "ReduceOperands.h"
  9. #include "llvm/IR/Constants.h"
  10. #include "llvm/IR/InstIterator.h"
  11. #include "llvm/IR/InstrTypes.h"
  12. #include "llvm/IR/Operator.h"
  13. #include "llvm/IR/PatternMatch.h"
  14. #include "llvm/IR/Type.h"
  15. using namespace llvm;
  16. using namespace PatternMatch;
  17. static void
  18. extractOperandsFromModule(Oracle &O, ReducerWorkItem &WorkItem,
  19. function_ref<Value *(Use &)> ReduceValue) {
  20. Module &Program = WorkItem.getModule();
  21. for (auto &F : Program.functions()) {
  22. for (auto &I : instructions(&F)) {
  23. if (PHINode *Phi = dyn_cast<PHINode>(&I)) {
  24. for (auto &Op : Phi->incoming_values()) {
  25. if (!O.shouldKeep()) {
  26. if (Value *Reduced = ReduceValue(Op))
  27. Phi->setIncomingValueForBlock(Phi->getIncomingBlock(Op), Reduced);
  28. }
  29. }
  30. continue;
  31. }
  32. for (auto &Op : I.operands()) {
  33. if (Value *Reduced = ReduceValue(Op)) {
  34. if (!O.shouldKeep())
  35. Op.set(Reduced);
  36. }
  37. }
  38. }
  39. }
  40. }
  41. static bool isOne(Use &Op) {
  42. auto *C = dyn_cast<Constant>(Op);
  43. return C && C->isOneValue();
  44. }
  45. static bool isZero(Use &Op) {
  46. auto *C = dyn_cast<Constant>(Op);
  47. return C && C->isNullValue();
  48. }
  49. static bool isZeroOrOneFP(Value *Op) {
  50. const APFloat *C;
  51. return match(Op, m_APFloat(C)) &&
  52. ((C->isZero() && !C->isNegative()) || C->isExactlyValue(1.0));
  53. }
  54. static bool shouldReduceOperand(Use &Op) {
  55. Type *Ty = Op->getType();
  56. if (Ty->isLabelTy() || Ty->isMetadataTy())
  57. return false;
  58. // TODO: be more precise about which GEP operands we can reduce (e.g. array
  59. // indexes)
  60. if (isa<GEPOperator>(Op.getUser()))
  61. return false;
  62. if (auto *CB = dyn_cast<CallBase>(Op.getUser())) {
  63. if (&CB->getCalledOperandUse() == &Op)
  64. return false;
  65. }
  66. return true;
  67. }
  68. static bool switchCaseExists(Use &Op, ConstantInt *CI) {
  69. SwitchInst *SI = dyn_cast<SwitchInst>(Op.getUser());
  70. if (!SI)
  71. return false;
  72. return SI->findCaseValue(CI) != SI->case_default();
  73. }
  74. void llvm::reduceOperandsOneDeltaPass(TestRunner &Test) {
  75. auto ReduceValue = [](Use &Op) -> Value * {
  76. if (!shouldReduceOperand(Op))
  77. return nullptr;
  78. Type *Ty = Op->getType();
  79. if (auto *IntTy = dyn_cast<IntegerType>(Ty)) {
  80. // Don't duplicate an existing switch case.
  81. if (switchCaseExists(Op, ConstantInt::get(IntTy, 1)))
  82. return nullptr;
  83. // Don't replace existing ones and zeroes.
  84. return (isOne(Op) || isZero(Op)) ? nullptr : ConstantInt::get(IntTy, 1);
  85. }
  86. if (Ty->isFloatingPointTy())
  87. return isZeroOrOneFP(Op) ? nullptr : ConstantFP::get(Ty, 1.0);
  88. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  89. if (isOne(Op) || isZero(Op) || isZeroOrOneFP(Op))
  90. return nullptr;
  91. Type *ElementType = VT->getElementType();
  92. Constant *C;
  93. if (ElementType->isFloatingPointTy()) {
  94. C = ConstantFP::get(ElementType, 1.0);
  95. } else if (IntegerType *IntTy = dyn_cast<IntegerType>(ElementType)) {
  96. C = ConstantInt::get(IntTy, 1);
  97. } else {
  98. return nullptr;
  99. }
  100. return ConstantVector::getSplat(VT->getElementCount(), C);
  101. }
  102. return nullptr;
  103. };
  104. runDeltaPass(
  105. Test,
  106. [ReduceValue](Oracle &O, ReducerWorkItem &WorkItem) {
  107. extractOperandsFromModule(O, WorkItem, ReduceValue);
  108. },
  109. "Reducing Operands to one");
  110. }
  111. void llvm::reduceOperandsZeroDeltaPass(TestRunner &Test) {
  112. auto ReduceValue = [](Use &Op) -> Value * {
  113. if (!shouldReduceOperand(Op))
  114. return nullptr;
  115. // Don't duplicate an existing switch case.
  116. if (auto *IntTy = dyn_cast<IntegerType>(Op->getType()))
  117. if (switchCaseExists(Op, ConstantInt::get(IntTy, 0)))
  118. return nullptr;
  119. // Don't replace existing zeroes.
  120. return isZero(Op) ? nullptr : Constant::getNullValue(Op->getType());
  121. };
  122. runDeltaPass(
  123. Test,
  124. [ReduceValue](Oracle &O, ReducerWorkItem &Program) {
  125. extractOperandsFromModule(O, Program, ReduceValue);
  126. },
  127. "Reducing Operands to zero");
  128. }
  129. void llvm::reduceOperandsNaNDeltaPass(TestRunner &Test) {
  130. auto ReduceValue = [](Use &Op) -> Value * {
  131. Type *Ty = Op->getType();
  132. if (!Ty->isFPOrFPVectorTy())
  133. return nullptr;
  134. // Prefer 0.0 or 1.0 over NaN.
  135. //
  136. // TODO: Preferring NaN may make more sense because FP operations are more
  137. // universally foldable.
  138. if (match(Op.get(), m_NaN()) || isZeroOrOneFP(Op.get()))
  139. return nullptr;
  140. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  141. return ConstantVector::getSplat(VT->getElementCount(),
  142. ConstantFP::getQNaN(VT->getElementType()));
  143. }
  144. return ConstantFP::getQNaN(Ty);
  145. };
  146. runDeltaPass(
  147. Test,
  148. [ReduceValue](Oracle &O, ReducerWorkItem &Program) {
  149. extractOperandsFromModule(O, Program, ReduceValue);
  150. },
  151. "Reducing Operands to NaN");
  152. }