Float2Int.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. //===- Float2Int.cpp - Demote floating point ops to work on integers ------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements the Float2Int pass, which aims to demote floating
  10. // point operations to work on integers, where that is losslessly possible.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/Transforms/Scalar/Float2Int.h"
  14. #include "llvm/ADT/APInt.h"
  15. #include "llvm/ADT/APSInt.h"
  16. #include "llvm/ADT/SmallVector.h"
  17. #include "llvm/Analysis/GlobalsModRef.h"
  18. #include "llvm/IR/Constants.h"
  19. #include "llvm/IR/Dominators.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/IR/Module.h"
  22. #include "llvm/InitializePasses.h"
  23. #include "llvm/Pass.h"
  24. #include "llvm/Support/CommandLine.h"
  25. #include "llvm/Support/Debug.h"
  26. #include "llvm/Support/raw_ostream.h"
  27. #include "llvm/Transforms/Scalar.h"
  28. #include <deque>
  29. #define DEBUG_TYPE "float2int"
  30. using namespace llvm;
  31. // The algorithm is simple. Start at instructions that convert from the
  32. // float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use
  33. // graph, using an equivalence datastructure to unify graphs that interfere.
  34. //
  35. // Mappable instructions are those with an integer corrollary that, given
  36. // integer domain inputs, produce an integer output; fadd, for example.
  37. //
  38. // If a non-mappable instruction is seen, this entire def-use graph is marked
  39. // as non-transformable. If we see an instruction that converts from the
  40. // integer domain to FP domain (uitofp,sitofp), we terminate our walk.
  41. /// The largest integer type worth dealing with.
  42. static cl::opt<unsigned>
  43. MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
  44. cl::desc("Max integer bitwidth to consider in float2int"
  45. "(default=64)"));
  46. namespace {
  47. struct Float2IntLegacyPass : public FunctionPass {
  48. static char ID; // Pass identification, replacement for typeid
  49. Float2IntLegacyPass() : FunctionPass(ID) {
  50. initializeFloat2IntLegacyPassPass(*PassRegistry::getPassRegistry());
  51. }
  52. bool runOnFunction(Function &F) override {
  53. if (skipFunction(F))
  54. return false;
  55. const DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  56. return Impl.runImpl(F, DT);
  57. }
  58. void getAnalysisUsage(AnalysisUsage &AU) const override {
  59. AU.setPreservesCFG();
  60. AU.addRequired<DominatorTreeWrapperPass>();
  61. AU.addPreserved<GlobalsAAWrapperPass>();
  62. }
  63. private:
  64. Float2IntPass Impl;
  65. };
  66. }
  67. char Float2IntLegacyPass::ID = 0;
  68. INITIALIZE_PASS(Float2IntLegacyPass, "float2int", "Float to int", false, false)
  69. // Given a FCmp predicate, return a matching ICmp predicate if one
  70. // exists, otherwise return BAD_ICMP_PREDICATE.
  71. static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
  72. switch (P) {
  73. case CmpInst::FCMP_OEQ:
  74. case CmpInst::FCMP_UEQ:
  75. return CmpInst::ICMP_EQ;
  76. case CmpInst::FCMP_OGT:
  77. case CmpInst::FCMP_UGT:
  78. return CmpInst::ICMP_SGT;
  79. case CmpInst::FCMP_OGE:
  80. case CmpInst::FCMP_UGE:
  81. return CmpInst::ICMP_SGE;
  82. case CmpInst::FCMP_OLT:
  83. case CmpInst::FCMP_ULT:
  84. return CmpInst::ICMP_SLT;
  85. case CmpInst::FCMP_OLE:
  86. case CmpInst::FCMP_ULE:
  87. return CmpInst::ICMP_SLE;
  88. case CmpInst::FCMP_ONE:
  89. case CmpInst::FCMP_UNE:
  90. return CmpInst::ICMP_NE;
  91. default:
  92. return CmpInst::BAD_ICMP_PREDICATE;
  93. }
  94. }
  95. // Given a floating point binary operator, return the matching
  96. // integer version.
  97. static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
  98. switch (Opcode) {
  99. default: llvm_unreachable("Unhandled opcode!");
  100. case Instruction::FAdd: return Instruction::Add;
  101. case Instruction::FSub: return Instruction::Sub;
  102. case Instruction::FMul: return Instruction::Mul;
  103. }
  104. }
  105. // Find the roots - instructions that convert from the FP domain to
  106. // integer domain.
  107. void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {
  108. for (BasicBlock &BB : F) {
  109. // Unreachable code can take on strange forms that we are not prepared to
  110. // handle. For example, an instruction may have itself as an operand.
  111. if (!DT.isReachableFromEntry(&BB))
  112. continue;
  113. for (Instruction &I : BB) {
  114. if (isa<VectorType>(I.getType()))
  115. continue;
  116. switch (I.getOpcode()) {
  117. default: break;
  118. case Instruction::FPToUI:
  119. case Instruction::FPToSI:
  120. Roots.insert(&I);
  121. break;
  122. case Instruction::FCmp:
  123. if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
  124. CmpInst::BAD_ICMP_PREDICATE)
  125. Roots.insert(&I);
  126. break;
  127. }
  128. }
  129. }
  130. }
  131. // Helper - mark I as having been traversed, having range R.
  132. void Float2IntPass::seen(Instruction *I, ConstantRange R) {
  133. LLVM_DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");
  134. auto IT = SeenInsts.find(I);
  135. if (IT != SeenInsts.end())
  136. IT->second = std::move(R);
  137. else
  138. SeenInsts.insert(std::make_pair(I, std::move(R)));
  139. }
  140. // Helper - get a range representing a poison value.
  141. ConstantRange Float2IntPass::badRange() {
  142. return ConstantRange::getFull(MaxIntegerBW + 1);
  143. }
  144. ConstantRange Float2IntPass::unknownRange() {
  145. return ConstantRange::getEmpty(MaxIntegerBW + 1);
  146. }
  147. ConstantRange Float2IntPass::validateRange(ConstantRange R) {
  148. if (R.getBitWidth() > MaxIntegerBW + 1)
  149. return badRange();
  150. return R;
  151. }
  152. // The most obvious way to structure the search is a depth-first, eager
  153. // search from each root. However, that require direct recursion and so
  154. // can only handle small instruction sequences. Instead, we split the search
  155. // up into two phases:
  156. // - walkBackwards: A breadth-first walk of the use-def graph starting from
  157. // the roots. Populate "SeenInsts" with interesting
  158. // instructions and poison values if they're obvious and
  159. // cheap to compute. Calculate the equivalance set structure
  160. // while we're here too.
  161. // - walkForwards: Iterate over SeenInsts in reverse order, so we visit
  162. // defs before their uses. Calculate the real range info.
  163. // Breadth-first walk of the use-def graph; determine the set of nodes
  164. // we care about and eagerly determine if some of them are poisonous.
  165. void Float2IntPass::walkBackwards() {
  166. std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
  167. while (!Worklist.empty()) {
  168. Instruction *I = Worklist.back();
  169. Worklist.pop_back();
  170. if (SeenInsts.find(I) != SeenInsts.end())
  171. // Seen already.
  172. continue;
  173. switch (I->getOpcode()) {
  174. // FIXME: Handle select and phi nodes.
  175. default:
  176. // Path terminated uncleanly.
  177. seen(I, badRange());
  178. break;
  179. case Instruction::UIToFP:
  180. case Instruction::SIToFP: {
  181. // Path terminated cleanly - use the type of the integer input to seed
  182. // the analysis.
  183. unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
  184. auto Input = ConstantRange::getFull(BW);
  185. auto CastOp = (Instruction::CastOps)I->getOpcode();
  186. seen(I, validateRange(Input.castOp(CastOp, MaxIntegerBW+1)));
  187. continue;
  188. }
  189. case Instruction::FNeg:
  190. case Instruction::FAdd:
  191. case Instruction::FSub:
  192. case Instruction::FMul:
  193. case Instruction::FPToUI:
  194. case Instruction::FPToSI:
  195. case Instruction::FCmp:
  196. seen(I, unknownRange());
  197. break;
  198. }
  199. for (Value *O : I->operands()) {
  200. if (Instruction *OI = dyn_cast<Instruction>(O)) {
  201. // Unify def-use chains if they interfere.
  202. ECs.unionSets(I, OI);
  203. if (SeenInsts.find(I)->second != badRange())
  204. Worklist.push_back(OI);
  205. } else if (!isa<ConstantFP>(O)) {
  206. // Not an instruction or ConstantFP? we can't do anything.
  207. seen(I, badRange());
  208. }
  209. }
  210. }
  211. }
  212. // Calculate result range from operand ranges.
  213. // Return std::nullopt if the range cannot be calculated yet.
  214. std::optional<ConstantRange> Float2IntPass::calcRange(Instruction *I) {
  215. SmallVector<ConstantRange, 4> OpRanges;
  216. for (Value *O : I->operands()) {
  217. if (Instruction *OI = dyn_cast<Instruction>(O)) {
  218. auto OpIt = SeenInsts.find(OI);
  219. assert(OpIt != SeenInsts.end() && "def not seen before use!");
  220. if (OpIt->second == unknownRange())
  221. return std::nullopt; // Wait until operand range has been calculated.
  222. OpRanges.push_back(OpIt->second);
  223. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {
  224. // Work out if the floating point number can be losslessly represented
  225. // as an integer.
  226. // APFloat::convertToInteger(&Exact) purports to do what we want, but
  227. // the exactness can be too precise. For example, negative zero can
  228. // never be exactly converted to an integer.
  229. //
  230. // Instead, we ask APFloat to round itself to an integral value - this
  231. // preserves sign-of-zero - then compare the result with the original.
  232. //
  233. const APFloat &F = CF->getValueAPF();
  234. // First, weed out obviously incorrect values. Non-finite numbers
  235. // can't be represented and neither can negative zero, unless
  236. // we're in fast math mode.
  237. if (!F.isFinite() ||
  238. (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&
  239. !I->hasNoSignedZeros()))
  240. return badRange();
  241. APFloat NewF = F;
  242. auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
  243. if (Res != APFloat::opOK || NewF != F)
  244. return badRange();
  245. // OK, it's representable. Now get it.
  246. APSInt Int(MaxIntegerBW+1, false);
  247. bool Exact;
  248. CF->getValueAPF().convertToInteger(Int,
  249. APFloat::rmNearestTiesToEven,
  250. &Exact);
  251. OpRanges.push_back(ConstantRange(Int));
  252. } else {
  253. llvm_unreachable("Should have already marked this as badRange!");
  254. }
  255. }
  256. switch (I->getOpcode()) {
  257. // FIXME: Handle select and phi nodes.
  258. default:
  259. case Instruction::UIToFP:
  260. case Instruction::SIToFP:
  261. llvm_unreachable("Should have been handled in walkForwards!");
  262. case Instruction::FNeg: {
  263. assert(OpRanges.size() == 1 && "FNeg is a unary operator!");
  264. unsigned Size = OpRanges[0].getBitWidth();
  265. auto Zero = ConstantRange(APInt::getZero(Size));
  266. return Zero.sub(OpRanges[0]);
  267. }
  268. case Instruction::FAdd:
  269. case Instruction::FSub:
  270. case Instruction::FMul: {
  271. assert(OpRanges.size() == 2 && "its a binary operator!");
  272. auto BinOp = (Instruction::BinaryOps) I->getOpcode();
  273. return OpRanges[0].binaryOp(BinOp, OpRanges[1]);
  274. }
  275. //
  276. // Root-only instructions - we'll only see these if they're the
  277. // first node in a walk.
  278. //
  279. case Instruction::FPToUI:
  280. case Instruction::FPToSI: {
  281. assert(OpRanges.size() == 1 && "FPTo[US]I is a unary operator!");
  282. // Note: We're ignoring the casts output size here as that's what the
  283. // caller expects.
  284. auto CastOp = (Instruction::CastOps)I->getOpcode();
  285. return OpRanges[0].castOp(CastOp, MaxIntegerBW+1);
  286. }
  287. case Instruction::FCmp:
  288. assert(OpRanges.size() == 2 && "FCmp is a binary operator!");
  289. return OpRanges[0].unionWith(OpRanges[1]);
  290. }
  291. }
  292. // Walk forwards down the list of seen instructions, so we visit defs before
  293. // uses.
  294. void Float2IntPass::walkForwards() {
  295. std::deque<Instruction *> Worklist;
  296. for (const auto &Pair : SeenInsts)
  297. if (Pair.second == unknownRange())
  298. Worklist.push_back(Pair.first);
  299. while (!Worklist.empty()) {
  300. Instruction *I = Worklist.back();
  301. Worklist.pop_back();
  302. if (std::optional<ConstantRange> Range = calcRange(I))
  303. seen(I, *Range);
  304. else
  305. Worklist.push_front(I); // Reprocess later.
  306. }
  307. }
  308. // If there is a valid transform to be done, do it.
  309. bool Float2IntPass::validateAndTransform() {
  310. bool MadeChange = false;
  311. // Iterate over every disjoint partition of the def-use graph.
  312. for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {
  313. ConstantRange R(MaxIntegerBW + 1, false);
  314. bool Fail = false;
  315. Type *ConvertedToTy = nullptr;
  316. // For every member of the partition, union all the ranges together.
  317. for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
  318. MI != ME; ++MI) {
  319. Instruction *I = *MI;
  320. auto SeenI = SeenInsts.find(I);
  321. if (SeenI == SeenInsts.end())
  322. continue;
  323. R = R.unionWith(SeenI->second);
  324. // We need to ensure I has no users that have not been seen.
  325. // If it does, transformation would be illegal.
  326. //
  327. // Don't count the roots, as they terminate the graphs.
  328. if (!Roots.contains(I)) {
  329. // Set the type of the conversion while we're here.
  330. if (!ConvertedToTy)
  331. ConvertedToTy = I->getType();
  332. for (User *U : I->users()) {
  333. Instruction *UI = dyn_cast<Instruction>(U);
  334. if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
  335. LLVM_DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
  336. Fail = true;
  337. break;
  338. }
  339. }
  340. }
  341. if (Fail)
  342. break;
  343. }
  344. // If the set was empty, or we failed, or the range is poisonous,
  345. // bail out.
  346. if (ECs.member_begin(It) == ECs.member_end() || Fail ||
  347. R.isFullSet() || R.isSignWrappedSet())
  348. continue;
  349. assert(ConvertedToTy && "Must have set the convertedtoty by this point!");
  350. // The number of bits required is the maximum of the upper and
  351. // lower limits, plus one so it can be signed.
  352. unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
  353. R.getUpper().getMinSignedBits()) + 1;
  354. LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
  355. // If we've run off the realms of the exactly representable integers,
  356. // the floating point result will differ from an integer approximation.
  357. // Do we need more bits than are in the mantissa of the type we converted
  358. // to? semanticsPrecision returns the number of mantissa bits plus one
  359. // for the sign bit.
  360. unsigned MaxRepresentableBits
  361. = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;
  362. if (MinBW > MaxRepresentableBits) {
  363. LLVM_DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");
  364. continue;
  365. }
  366. if (MinBW > 64) {
  367. LLVM_DEBUG(
  368. dbgs() << "F2I: Value requires more than 64 bits to represent!\n");
  369. continue;
  370. }
  371. // OK, R is known to be representable. Now pick a type for it.
  372. // FIXME: Pick the smallest legal type that will fit.
  373. Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx);
  374. for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
  375. MI != ME; ++MI)
  376. convert(*MI, Ty);
  377. MadeChange = true;
  378. }
  379. return MadeChange;
  380. }
  381. Value *Float2IntPass::convert(Instruction *I, Type *ToTy) {
  382. if (ConvertedInsts.find(I) != ConvertedInsts.end())
  383. // Already converted this instruction.
  384. return ConvertedInsts[I];
  385. SmallVector<Value*,4> NewOperands;
  386. for (Value *V : I->operands()) {
  387. // Don't recurse if we're an instruction that terminates the path.
  388. if (I->getOpcode() == Instruction::UIToFP ||
  389. I->getOpcode() == Instruction::SIToFP) {
  390. NewOperands.push_back(V);
  391. } else if (Instruction *VI = dyn_cast<Instruction>(V)) {
  392. NewOperands.push_back(convert(VI, ToTy));
  393. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {
  394. APSInt Val(ToTy->getPrimitiveSizeInBits(), /*isUnsigned=*/false);
  395. bool Exact;
  396. CF->getValueAPF().convertToInteger(Val,
  397. APFloat::rmNearestTiesToEven,
  398. &Exact);
  399. NewOperands.push_back(ConstantInt::get(ToTy, Val));
  400. } else {
  401. llvm_unreachable("Unhandled operand type?");
  402. }
  403. }
  404. // Now create a new instruction.
  405. IRBuilder<> IRB(I);
  406. Value *NewV = nullptr;
  407. switch (I->getOpcode()) {
  408. default: llvm_unreachable("Unhandled instruction!");
  409. case Instruction::FPToUI:
  410. NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());
  411. break;
  412. case Instruction::FPToSI:
  413. NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());
  414. break;
  415. case Instruction::FCmp: {
  416. CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());
  417. assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");
  418. NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());
  419. break;
  420. }
  421. case Instruction::UIToFP:
  422. NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);
  423. break;
  424. case Instruction::SIToFP:
  425. NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);
  426. break;
  427. case Instruction::FNeg:
  428. NewV = IRB.CreateNeg(NewOperands[0], I->getName());
  429. break;
  430. case Instruction::FAdd:
  431. case Instruction::FSub:
  432. case Instruction::FMul:
  433. NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),
  434. NewOperands[0], NewOperands[1],
  435. I->getName());
  436. break;
  437. }
  438. // If we're a root instruction, RAUW.
  439. if (Roots.count(I))
  440. I->replaceAllUsesWith(NewV);
  441. ConvertedInsts[I] = NewV;
  442. return NewV;
  443. }
  444. // Perform dead code elimination on the instructions we just modified.
  445. void Float2IntPass::cleanup() {
  446. for (auto &I : reverse(ConvertedInsts))
  447. I.first->eraseFromParent();
  448. }
  449. bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
  450. LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
  451. // Clear out all state.
  452. ECs = EquivalenceClasses<Instruction*>();
  453. SeenInsts.clear();
  454. ConvertedInsts.clear();
  455. Roots.clear();
  456. Ctx = &F.getParent()->getContext();
  457. findRoots(F, DT);
  458. walkBackwards();
  459. walkForwards();
  460. bool Modified = validateAndTransform();
  461. if (Modified)
  462. cleanup();
  463. return Modified;
  464. }
  465. namespace llvm {
  466. FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); }
  467. PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) {
  468. const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
  469. if (!runImpl(F, DT))
  470. return PreservedAnalyses::all();
  471. PreservedAnalyses PA;
  472. PA.preserveSet<CFGAnalyses>();
  473. return PA;
  474. }
  475. } // End namespace llvm