BoundsChecking.cpp 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. //===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "llvm/Transforms/Instrumentation/BoundsChecking.h"
  9. #include "llvm/ADT/Statistic.h"
  10. #include "llvm/ADT/Twine.h"
  11. #include "llvm/Analysis/MemoryBuiltins.h"
  12. #include "llvm/Analysis/ScalarEvolution.h"
  13. #include "llvm/Analysis/TargetFolder.h"
  14. #include "llvm/Analysis/TargetLibraryInfo.h"
  15. #include "llvm/IR/BasicBlock.h"
  16. #include "llvm/IR/Constants.h"
  17. #include "llvm/IR/DataLayout.h"
  18. #include "llvm/IR/Function.h"
  19. #include "llvm/IR/IRBuilder.h"
  20. #include "llvm/IR/InstIterator.h"
  21. #include "llvm/IR/Instruction.h"
  22. #include "llvm/IR/Instructions.h"
  23. #include "llvm/IR/Intrinsics.h"
  24. #include "llvm/IR/Value.h"
  25. #include "llvm/InitializePasses.h"
  26. #include "llvm/Pass.h"
  27. #include "llvm/Support/Casting.h"
  28. #include "llvm/Support/CommandLine.h"
  29. #include "llvm/Support/Debug.h"
  30. #include "llvm/Support/raw_ostream.h"
  31. #include <cstdint>
  32. #include <utility>
  33. using namespace llvm;
  34. #define DEBUG_TYPE "bounds-checking"
  35. static cl::opt<bool> SingleTrapBB("bounds-checking-single-trap",
  36. cl::desc("Use one trap block per function"));
  37. STATISTIC(ChecksAdded, "Bounds checks added");
  38. STATISTIC(ChecksSkipped, "Bounds checks skipped");
  39. STATISTIC(ChecksUnable, "Bounds checks unable to add");
  40. using BuilderTy = IRBuilder<TargetFolder>;
  41. /// Gets the conditions under which memory accessing instructions will overflow.
  42. ///
  43. /// \p Ptr is the pointer that will be read/written, and \p InstVal is either
  44. /// the result from the load or the value being stored. It is used to determine
  45. /// the size of memory block that is touched.
  46. ///
  47. /// Returns the condition under which the access will overflow.
  48. static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal,
  49. const DataLayout &DL, TargetLibraryInfo &TLI,
  50. ObjectSizeOffsetEvaluator &ObjSizeEval,
  51. BuilderTy &IRB, ScalarEvolution &SE) {
  52. uint64_t NeededSize = DL.getTypeStoreSize(InstVal->getType());
  53. LLVM_DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize)
  54. << " bytes\n");
  55. SizeOffsetEvalType SizeOffset = ObjSizeEval.compute(Ptr);
  56. if (!ObjSizeEval.bothKnown(SizeOffset)) {
  57. ++ChecksUnable;
  58. return nullptr;
  59. }
  60. Value *Size = SizeOffset.first;
  61. Value *Offset = SizeOffset.second;
  62. ConstantInt *SizeCI = dyn_cast<ConstantInt>(Size);
  63. Type *IntTy = DL.getIntPtrType(Ptr->getType());
  64. Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize);
  65. auto SizeRange = SE.getUnsignedRange(SE.getSCEV(Size));
  66. auto OffsetRange = SE.getUnsignedRange(SE.getSCEV(Offset));
  67. auto NeededSizeRange = SE.getUnsignedRange(SE.getSCEV(NeededSizeVal));
  68. // three checks are required to ensure safety:
  69. // . Offset >= 0 (since the offset is given from the base ptr)
  70. // . Size >= Offset (unsigned)
  71. // . Size - Offset >= NeededSize (unsigned)
  72. //
  73. // optimization: if Size >= 0 (signed), skip 1st check
  74. // FIXME: add NSW/NUW here? -- we dont care if the subtraction overflows
  75. Value *ObjSize = IRB.CreateSub(Size, Offset);
  76. Value *Cmp2 = SizeRange.getUnsignedMin().uge(OffsetRange.getUnsignedMax())
  77. ? ConstantInt::getFalse(Ptr->getContext())
  78. : IRB.CreateICmpULT(Size, Offset);
  79. Value *Cmp3 = SizeRange.sub(OffsetRange)
  80. .getUnsignedMin()
  81. .uge(NeededSizeRange.getUnsignedMax())
  82. ? ConstantInt::getFalse(Ptr->getContext())
  83. : IRB.CreateICmpULT(ObjSize, NeededSizeVal);
  84. Value *Or = IRB.CreateOr(Cmp2, Cmp3);
  85. if ((!SizeCI || SizeCI->getValue().slt(0)) &&
  86. !SizeRange.getSignedMin().isNonNegative()) {
  87. Value *Cmp1 = IRB.CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0));
  88. Or = IRB.CreateOr(Cmp1, Or);
  89. }
  90. return Or;
  91. }
  92. /// Adds run-time bounds checks to memory accessing instructions.
  93. ///
  94. /// \p Or is the condition that should guard the trap.
  95. ///
  96. /// \p GetTrapBB is a callable that returns the trap BB to use on failure.
  97. template <typename GetTrapBBT>
  98. static void insertBoundsCheck(Value *Or, BuilderTy &IRB, GetTrapBBT GetTrapBB) {
  99. // check if the comparison is always false
  100. ConstantInt *C = dyn_cast_or_null<ConstantInt>(Or);
  101. if (C) {
  102. ++ChecksSkipped;
  103. // If non-zero, nothing to do.
  104. if (!C->getZExtValue())
  105. return;
  106. }
  107. ++ChecksAdded;
  108. BasicBlock::iterator SplitI = IRB.GetInsertPoint();
  109. BasicBlock *OldBB = SplitI->getParent();
  110. BasicBlock *Cont = OldBB->splitBasicBlock(SplitI);
  111. OldBB->getTerminator()->eraseFromParent();
  112. if (C) {
  113. // If we have a constant zero, unconditionally branch.
  114. // FIXME: We should really handle this differently to bypass the splitting
  115. // the block.
  116. BranchInst::Create(GetTrapBB(IRB), OldBB);
  117. return;
  118. }
  119. // Create the conditional branch.
  120. BranchInst::Create(GetTrapBB(IRB), Cont, Or, OldBB);
  121. }
  122. static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI,
  123. ScalarEvolution &SE) {
  124. if (F.hasFnAttribute(Attribute::NoSanitizeBounds))
  125. return false;
  126. const DataLayout &DL = F.getParent()->getDataLayout();
  127. ObjectSizeOpts EvalOpts;
  128. EvalOpts.RoundToAlign = true;
  129. EvalOpts.EvalMode = ObjectSizeOpts::Mode::ExactUnderlyingSizeAndOffset;
  130. ObjectSizeOffsetEvaluator ObjSizeEval(DL, &TLI, F.getContext(), EvalOpts);
  131. // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory
  132. // touching instructions
  133. SmallVector<std::pair<Instruction *, Value *>, 4> TrapInfo;
  134. for (Instruction &I : instructions(F)) {
  135. Value *Or = nullptr;
  136. BuilderTy IRB(I.getParent(), BasicBlock::iterator(&I), TargetFolder(DL));
  137. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
  138. if (!LI->isVolatile())
  139. Or = getBoundsCheckCond(LI->getPointerOperand(), LI, DL, TLI,
  140. ObjSizeEval, IRB, SE);
  141. } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
  142. if (!SI->isVolatile())
  143. Or = getBoundsCheckCond(SI->getPointerOperand(), SI->getValueOperand(),
  144. DL, TLI, ObjSizeEval, IRB, SE);
  145. } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(&I)) {
  146. if (!AI->isVolatile())
  147. Or =
  148. getBoundsCheckCond(AI->getPointerOperand(), AI->getCompareOperand(),
  149. DL, TLI, ObjSizeEval, IRB, SE);
  150. } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(&I)) {
  151. if (!AI->isVolatile())
  152. Or = getBoundsCheckCond(AI->getPointerOperand(), AI->getValOperand(),
  153. DL, TLI, ObjSizeEval, IRB, SE);
  154. }
  155. if (Or)
  156. TrapInfo.push_back(std::make_pair(&I, Or));
  157. }
  158. // Create a trapping basic block on demand using a callback. Depending on
  159. // flags, this will either create a single block for the entire function or
  160. // will create a fresh block every time it is called.
  161. BasicBlock *TrapBB = nullptr;
  162. auto GetTrapBB = [&TrapBB](BuilderTy &IRB) {
  163. if (TrapBB && SingleTrapBB)
  164. return TrapBB;
  165. Function *Fn = IRB.GetInsertBlock()->getParent();
  166. // FIXME: This debug location doesn't make a lot of sense in the
  167. // `SingleTrapBB` case.
  168. auto DebugLoc = IRB.getCurrentDebugLocation();
  169. IRBuilder<>::InsertPointGuard Guard(IRB);
  170. TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn);
  171. IRB.SetInsertPoint(TrapBB);
  172. auto *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap);
  173. CallInst *TrapCall = IRB.CreateCall(F, {});
  174. TrapCall->setDoesNotReturn();
  175. TrapCall->setDoesNotThrow();
  176. TrapCall->setDebugLoc(DebugLoc);
  177. IRB.CreateUnreachable();
  178. return TrapBB;
  179. };
  180. // Add the checks.
  181. for (const auto &Entry : TrapInfo) {
  182. Instruction *Inst = Entry.first;
  183. BuilderTy IRB(Inst->getParent(), BasicBlock::iterator(Inst), TargetFolder(DL));
  184. insertBoundsCheck(Entry.second, IRB, GetTrapBB);
  185. }
  186. return !TrapInfo.empty();
  187. }
  188. PreservedAnalyses BoundsCheckingPass::run(Function &F, FunctionAnalysisManager &AM) {
  189. auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
  190. auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
  191. if (!addBoundsChecking(F, TLI, SE))
  192. return PreservedAnalyses::all();
  193. return PreservedAnalyses::none();
  194. }