Float2Int.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. //===- Float2Int.cpp - Demote floating point ops to work on integers ------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements the Float2Int pass, which aims to demote floating
  10. // point operations to work on integers, where that is losslessly possible.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/InitializePasses.h"
  14. #include "llvm/Support/CommandLine.h"
  15. #include "llvm/Transforms/Scalar/Float2Int.h"
  16. #include "llvm/ADT/APInt.h"
  17. #include "llvm/ADT/APSInt.h"
  18. #include "llvm/ADT/SmallVector.h"
  19. #include "llvm/Analysis/GlobalsModRef.h"
  20. #include "llvm/IR/Constants.h"
  21. #include "llvm/IR/IRBuilder.h"
  22. #include "llvm/IR/InstIterator.h"
  23. #include "llvm/IR/Instructions.h"
  24. #include "llvm/IR/Module.h"
  25. #include "llvm/Pass.h"
  26. #include "llvm/Support/Debug.h"
  27. #include "llvm/Support/raw_ostream.h"
  28. #include "llvm/Transforms/Scalar.h"
  29. #include <deque>
  30. #include <functional> // For std::function
  31. #define DEBUG_TYPE "float2int"
  32. using namespace llvm;
  33. // The algorithm is simple. Start at instructions that convert from the
  34. // float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use
  35. // graph, using an equivalence datastructure to unify graphs that interfere.
  36. //
  37. // Mappable instructions are those with an integer corrollary that, given
  38. // integer domain inputs, produce an integer output; fadd, for example.
  39. //
  40. // If a non-mappable instruction is seen, this entire def-use graph is marked
  41. // as non-transformable. If we see an instruction that converts from the
  42. // integer domain to FP domain (uitofp,sitofp), we terminate our walk.
  43. /// The largest integer type worth dealing with.
  44. static cl::opt<unsigned>
  45. MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
  46. cl::desc("Max integer bitwidth to consider in float2int"
  47. "(default=64)"));
  48. namespace {
  49. struct Float2IntLegacyPass : public FunctionPass {
  50. static char ID; // Pass identification, replacement for typeid
  51. Float2IntLegacyPass() : FunctionPass(ID) {
  52. initializeFloat2IntLegacyPassPass(*PassRegistry::getPassRegistry());
  53. }
  54. bool runOnFunction(Function &F) override {
  55. if (skipFunction(F))
  56. return false;
  57. const DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  58. return Impl.runImpl(F, DT);
  59. }
  60. void getAnalysisUsage(AnalysisUsage &AU) const override {
  61. AU.setPreservesCFG();
  62. AU.addRequired<DominatorTreeWrapperPass>();
  63. AU.addPreserved<GlobalsAAWrapperPass>();
  64. }
  65. private:
  66. Float2IntPass Impl;
  67. };
  68. }
  69. char Float2IntLegacyPass::ID = 0;
  70. INITIALIZE_PASS(Float2IntLegacyPass, "float2int", "Float to int", false, false)
  71. // Given a FCmp predicate, return a matching ICmp predicate if one
  72. // exists, otherwise return BAD_ICMP_PREDICATE.
  73. static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
  74. switch (P) {
  75. case CmpInst::FCMP_OEQ:
  76. case CmpInst::FCMP_UEQ:
  77. return CmpInst::ICMP_EQ;
  78. case CmpInst::FCMP_OGT:
  79. case CmpInst::FCMP_UGT:
  80. return CmpInst::ICMP_SGT;
  81. case CmpInst::FCMP_OGE:
  82. case CmpInst::FCMP_UGE:
  83. return CmpInst::ICMP_SGE;
  84. case CmpInst::FCMP_OLT:
  85. case CmpInst::FCMP_ULT:
  86. return CmpInst::ICMP_SLT;
  87. case CmpInst::FCMP_OLE:
  88. case CmpInst::FCMP_ULE:
  89. return CmpInst::ICMP_SLE;
  90. case CmpInst::FCMP_ONE:
  91. case CmpInst::FCMP_UNE:
  92. return CmpInst::ICMP_NE;
  93. default:
  94. return CmpInst::BAD_ICMP_PREDICATE;
  95. }
  96. }
  97. // Given a floating point binary operator, return the matching
  98. // integer version.
  99. static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
  100. switch (Opcode) {
  101. default: llvm_unreachable("Unhandled opcode!");
  102. case Instruction::FAdd: return Instruction::Add;
  103. case Instruction::FSub: return Instruction::Sub;
  104. case Instruction::FMul: return Instruction::Mul;
  105. }
  106. }
  107. // Find the roots - instructions that convert from the FP domain to
  108. // integer domain.
  109. void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {
  110. for (BasicBlock &BB : F) {
  111. // Unreachable code can take on strange forms that we are not prepared to
  112. // handle. For example, an instruction may have itself as an operand.
  113. if (!DT.isReachableFromEntry(&BB))
  114. continue;
  115. for (Instruction &I : BB) {
  116. if (isa<VectorType>(I.getType()))
  117. continue;
  118. switch (I.getOpcode()) {
  119. default: break;
  120. case Instruction::FPToUI:
  121. case Instruction::FPToSI:
  122. Roots.insert(&I);
  123. break;
  124. case Instruction::FCmp:
  125. if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
  126. CmpInst::BAD_ICMP_PREDICATE)
  127. Roots.insert(&I);
  128. break;
  129. }
  130. }
  131. }
  132. }
  133. // Helper - mark I as having been traversed, having range R.
  134. void Float2IntPass::seen(Instruction *I, ConstantRange R) {
  135. LLVM_DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");
  136. auto IT = SeenInsts.find(I);
  137. if (IT != SeenInsts.end())
  138. IT->second = std::move(R);
  139. else
  140. SeenInsts.insert(std::make_pair(I, std::move(R)));
  141. }
  142. // Helper - get a range representing a poison value.
  143. ConstantRange Float2IntPass::badRange() {
  144. return ConstantRange::getFull(MaxIntegerBW + 1);
  145. }
  146. ConstantRange Float2IntPass::unknownRange() {
  147. return ConstantRange::getEmpty(MaxIntegerBW + 1);
  148. }
  149. ConstantRange Float2IntPass::validateRange(ConstantRange R) {
  150. if (R.getBitWidth() > MaxIntegerBW + 1)
  151. return badRange();
  152. return R;
  153. }
  154. // The most obvious way to structure the search is a depth-first, eager
  155. // search from each root. However, that require direct recursion and so
  156. // can only handle small instruction sequences. Instead, we split the search
  157. // up into two phases:
  158. // - walkBackwards: A breadth-first walk of the use-def graph starting from
  159. // the roots. Populate "SeenInsts" with interesting
  160. // instructions and poison values if they're obvious and
  161. // cheap to compute. Calculate the equivalance set structure
  162. // while we're here too.
  163. // - walkForwards: Iterate over SeenInsts in reverse order, so we visit
  164. // defs before their uses. Calculate the real range info.
  165. // Breadth-first walk of the use-def graph; determine the set of nodes
  166. // we care about and eagerly determine if some of them are poisonous.
  167. void Float2IntPass::walkBackwards() {
  168. std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
  169. while (!Worklist.empty()) {
  170. Instruction *I = Worklist.back();
  171. Worklist.pop_back();
  172. if (SeenInsts.find(I) != SeenInsts.end())
  173. // Seen already.
  174. continue;
  175. switch (I->getOpcode()) {
  176. // FIXME: Handle select and phi nodes.
  177. default:
  178. // Path terminated uncleanly.
  179. seen(I, badRange());
  180. break;
  181. case Instruction::UIToFP:
  182. case Instruction::SIToFP: {
  183. // Path terminated cleanly - use the type of the integer input to seed
  184. // the analysis.
  185. unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
  186. auto Input = ConstantRange::getFull(BW);
  187. auto CastOp = (Instruction::CastOps)I->getOpcode();
  188. seen(I, validateRange(Input.castOp(CastOp, MaxIntegerBW+1)));
  189. continue;
  190. }
  191. case Instruction::FNeg:
  192. case Instruction::FAdd:
  193. case Instruction::FSub:
  194. case Instruction::FMul:
  195. case Instruction::FPToUI:
  196. case Instruction::FPToSI:
  197. case Instruction::FCmp:
  198. seen(I, unknownRange());
  199. break;
  200. }
  201. for (Value *O : I->operands()) {
  202. if (Instruction *OI = dyn_cast<Instruction>(O)) {
  203. // Unify def-use chains if they interfere.
  204. ECs.unionSets(I, OI);
  205. if (SeenInsts.find(I)->second != badRange())
  206. Worklist.push_back(OI);
  207. } else if (!isa<ConstantFP>(O)) {
  208. // Not an instruction or ConstantFP? we can't do anything.
  209. seen(I, badRange());
  210. }
  211. }
  212. }
  213. }
  214. // Walk forwards down the list of seen instructions, so we visit defs before
  215. // uses.
  216. void Float2IntPass::walkForwards() {
  217. for (auto &It : reverse(SeenInsts)) {
  218. if (It.second != unknownRange())
  219. continue;
  220. Instruction *I = It.first;
  221. std::function<ConstantRange(ArrayRef<ConstantRange>)> Op;
  222. switch (I->getOpcode()) {
  223. // FIXME: Handle select and phi nodes.
  224. default:
  225. case Instruction::UIToFP:
  226. case Instruction::SIToFP:
  227. llvm_unreachable("Should have been handled in walkForwards!");
  228. case Instruction::FNeg:
  229. Op = [](ArrayRef<ConstantRange> Ops) {
  230. assert(Ops.size() == 1 && "FNeg is a unary operator!");
  231. unsigned Size = Ops[0].getBitWidth();
  232. auto Zero = ConstantRange(APInt::getZero(Size));
  233. return Zero.sub(Ops[0]);
  234. };
  235. break;
  236. case Instruction::FAdd:
  237. case Instruction::FSub:
  238. case Instruction::FMul:
  239. Op = [I](ArrayRef<ConstantRange> Ops) {
  240. assert(Ops.size() == 2 && "its a binary operator!");
  241. auto BinOp = (Instruction::BinaryOps) I->getOpcode();
  242. return Ops[0].binaryOp(BinOp, Ops[1]);
  243. };
  244. break;
  245. //
  246. // Root-only instructions - we'll only see these if they're the
  247. // first node in a walk.
  248. //
  249. case Instruction::FPToUI:
  250. case Instruction::FPToSI:
  251. Op = [I](ArrayRef<ConstantRange> Ops) {
  252. assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!");
  253. // Note: We're ignoring the casts output size here as that's what the
  254. // caller expects.
  255. auto CastOp = (Instruction::CastOps)I->getOpcode();
  256. return Ops[0].castOp(CastOp, MaxIntegerBW+1);
  257. };
  258. break;
  259. case Instruction::FCmp:
  260. Op = [](ArrayRef<ConstantRange> Ops) {
  261. assert(Ops.size() == 2 && "FCmp is a binary operator!");
  262. return Ops[0].unionWith(Ops[1]);
  263. };
  264. break;
  265. }
  266. bool Abort = false;
  267. SmallVector<ConstantRange,4> OpRanges;
  268. for (Value *O : I->operands()) {
  269. if (Instruction *OI = dyn_cast<Instruction>(O)) {
  270. assert(SeenInsts.find(OI) != SeenInsts.end() &&
  271. "def not seen before use!");
  272. OpRanges.push_back(SeenInsts.find(OI)->second);
  273. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {
  274. // Work out if the floating point number can be losslessly represented
  275. // as an integer.
  276. // APFloat::convertToInteger(&Exact) purports to do what we want, but
  277. // the exactness can be too precise. For example, negative zero can
  278. // never be exactly converted to an integer.
  279. //
  280. // Instead, we ask APFloat to round itself to an integral value - this
  281. // preserves sign-of-zero - then compare the result with the original.
  282. //
  283. const APFloat &F = CF->getValueAPF();
  284. // First, weed out obviously incorrect values. Non-finite numbers
  285. // can't be represented and neither can negative zero, unless
  286. // we're in fast math mode.
  287. if (!F.isFinite() ||
  288. (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&
  289. !I->hasNoSignedZeros())) {
  290. seen(I, badRange());
  291. Abort = true;
  292. break;
  293. }
  294. APFloat NewF = F;
  295. auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
  296. if (Res != APFloat::opOK || NewF != F) {
  297. seen(I, badRange());
  298. Abort = true;
  299. break;
  300. }
  301. // OK, it's representable. Now get it.
  302. APSInt Int(MaxIntegerBW+1, false);
  303. bool Exact;
  304. CF->getValueAPF().convertToInteger(Int,
  305. APFloat::rmNearestTiesToEven,
  306. &Exact);
  307. OpRanges.push_back(ConstantRange(Int));
  308. } else {
  309. llvm_unreachable("Should have already marked this as badRange!");
  310. }
  311. }
  312. // Reduce the operands' ranges to a single range and return.
  313. if (!Abort)
  314. seen(I, Op(OpRanges));
  315. }
  316. }
  317. // If there is a valid transform to be done, do it.
  318. bool Float2IntPass::validateAndTransform() {
  319. bool MadeChange = false;
  320. // Iterate over every disjoint partition of the def-use graph.
  321. for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {
  322. ConstantRange R(MaxIntegerBW + 1, false);
  323. bool Fail = false;
  324. Type *ConvertedToTy = nullptr;
  325. // For every member of the partition, union all the ranges together.
  326. for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
  327. MI != ME; ++MI) {
  328. Instruction *I = *MI;
  329. auto SeenI = SeenInsts.find(I);
  330. if (SeenI == SeenInsts.end())
  331. continue;
  332. R = R.unionWith(SeenI->second);
  333. // We need to ensure I has no users that have not been seen.
  334. // If it does, transformation would be illegal.
  335. //
  336. // Don't count the roots, as they terminate the graphs.
  337. if (!Roots.contains(I)) {
  338. // Set the type of the conversion while we're here.
  339. if (!ConvertedToTy)
  340. ConvertedToTy = I->getType();
  341. for (User *U : I->users()) {
  342. Instruction *UI = dyn_cast<Instruction>(U);
  343. if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
  344. LLVM_DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
  345. Fail = true;
  346. break;
  347. }
  348. }
  349. }
  350. if (Fail)
  351. break;
  352. }
  353. // If the set was empty, or we failed, or the range is poisonous,
  354. // bail out.
  355. if (ECs.member_begin(It) == ECs.member_end() || Fail ||
  356. R.isFullSet() || R.isSignWrappedSet())
  357. continue;
  358. assert(ConvertedToTy && "Must have set the convertedtoty by this point!");
  359. // The number of bits required is the maximum of the upper and
  360. // lower limits, plus one so it can be signed.
  361. unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
  362. R.getUpper().getMinSignedBits()) + 1;
  363. LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
  364. // If we've run off the realms of the exactly representable integers,
  365. // the floating point result will differ from an integer approximation.
  366. // Do we need more bits than are in the mantissa of the type we converted
  367. // to? semanticsPrecision returns the number of mantissa bits plus one
  368. // for the sign bit.
  369. unsigned MaxRepresentableBits
  370. = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;
  371. if (MinBW > MaxRepresentableBits) {
  372. LLVM_DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");
  373. continue;
  374. }
  375. if (MinBW > 64) {
  376. LLVM_DEBUG(
  377. dbgs() << "F2I: Value requires more than 64 bits to represent!\n");
  378. continue;
  379. }
  380. // OK, R is known to be representable. Now pick a type for it.
  381. // FIXME: Pick the smallest legal type that will fit.
  382. Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx);
  383. for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
  384. MI != ME; ++MI)
  385. convert(*MI, Ty);
  386. MadeChange = true;
  387. }
  388. return MadeChange;
  389. }
  390. Value *Float2IntPass::convert(Instruction *I, Type *ToTy) {
  391. if (ConvertedInsts.find(I) != ConvertedInsts.end())
  392. // Already converted this instruction.
  393. return ConvertedInsts[I];
  394. SmallVector<Value*,4> NewOperands;
  395. for (Value *V : I->operands()) {
  396. // Don't recurse if we're an instruction that terminates the path.
  397. if (I->getOpcode() == Instruction::UIToFP ||
  398. I->getOpcode() == Instruction::SIToFP) {
  399. NewOperands.push_back(V);
  400. } else if (Instruction *VI = dyn_cast<Instruction>(V)) {
  401. NewOperands.push_back(convert(VI, ToTy));
  402. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {
  403. APSInt Val(ToTy->getPrimitiveSizeInBits(), /*isUnsigned=*/false);
  404. bool Exact;
  405. CF->getValueAPF().convertToInteger(Val,
  406. APFloat::rmNearestTiesToEven,
  407. &Exact);
  408. NewOperands.push_back(ConstantInt::get(ToTy, Val));
  409. } else {
  410. llvm_unreachable("Unhandled operand type?");
  411. }
  412. }
  413. // Now create a new instruction.
  414. IRBuilder<> IRB(I);
  415. Value *NewV = nullptr;
  416. switch (I->getOpcode()) {
  417. default: llvm_unreachable("Unhandled instruction!");
  418. case Instruction::FPToUI:
  419. NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());
  420. break;
  421. case Instruction::FPToSI:
  422. NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());
  423. break;
  424. case Instruction::FCmp: {
  425. CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());
  426. assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");
  427. NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());
  428. break;
  429. }
  430. case Instruction::UIToFP:
  431. NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);
  432. break;
  433. case Instruction::SIToFP:
  434. NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);
  435. break;
  436. case Instruction::FNeg:
  437. NewV = IRB.CreateNeg(NewOperands[0], I->getName());
  438. break;
  439. case Instruction::FAdd:
  440. case Instruction::FSub:
  441. case Instruction::FMul:
  442. NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),
  443. NewOperands[0], NewOperands[1],
  444. I->getName());
  445. break;
  446. }
  447. // If we're a root instruction, RAUW.
  448. if (Roots.count(I))
  449. I->replaceAllUsesWith(NewV);
  450. ConvertedInsts[I] = NewV;
  451. return NewV;
  452. }
  453. // Perform dead code elimination on the instructions we just modified.
  454. void Float2IntPass::cleanup() {
  455. for (auto &I : reverse(ConvertedInsts))
  456. I.first->eraseFromParent();
  457. }
  458. bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
  459. LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
  460. // Clear out all state.
  461. ECs = EquivalenceClasses<Instruction*>();
  462. SeenInsts.clear();
  463. ConvertedInsts.clear();
  464. Roots.clear();
  465. Ctx = &F.getParent()->getContext();
  466. findRoots(F, DT);
  467. walkBackwards();
  468. walkForwards();
  469. bool Modified = validateAndTransform();
  470. if (Modified)
  471. cleanup();
  472. return Modified;
  473. }
  474. namespace llvm {
  475. FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); }
  476. PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) {
  477. const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
  478. if (!runImpl(F, DT))
  479. return PreservedAnalyses::all();
  480. PreservedAnalyses PA;
  481. PA.preserveSet<CFGAnalyses>();
  482. return PA;
  483. }
  484. } // End namespace llvm