BoundsChecking.cpp 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. //===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "llvm/Transforms/Instrumentation/BoundsChecking.h"
  9. #include "llvm/ADT/Statistic.h"
  10. #include "llvm/ADT/Twine.h"
  11. #include "llvm/Analysis/MemoryBuiltins.h"
  12. #include "llvm/Analysis/ScalarEvolution.h"
  13. #include "llvm/Analysis/TargetFolder.h"
  14. #include "llvm/Analysis/TargetLibraryInfo.h"
  15. #include "llvm/IR/BasicBlock.h"
  16. #include "llvm/IR/Constants.h"
  17. #include "llvm/IR/DataLayout.h"
  18. #include "llvm/IR/Function.h"
  19. #include "llvm/IR/IRBuilder.h"
  20. #include "llvm/IR/InstIterator.h"
  21. #include "llvm/IR/InstrTypes.h"
  22. #include "llvm/IR/Instruction.h"
  23. #include "llvm/IR/Instructions.h"
  24. #include "llvm/IR/Intrinsics.h"
  25. #include "llvm/IR/Value.h"
  26. #include "llvm/InitializePasses.h"
  27. #include "llvm/Pass.h"
  28. #include "llvm/Support/Casting.h"
  29. #include "llvm/Support/CommandLine.h"
  30. #include "llvm/Support/Debug.h"
  31. #include "llvm/Support/ErrorHandling.h"
  32. #include "llvm/Support/raw_ostream.h"
  33. #include <cstdint>
  34. #include <utility>
  35. using namespace llvm;
  36. #define DEBUG_TYPE "bounds-checking"
  37. static cl::opt<bool> SingleTrapBB("bounds-checking-single-trap",
  38. cl::desc("Use one trap block per function"));
  39. STATISTIC(ChecksAdded, "Bounds checks added");
  40. STATISTIC(ChecksSkipped, "Bounds checks skipped");
  41. STATISTIC(ChecksUnable, "Bounds checks unable to add");
  42. using BuilderTy = IRBuilder<TargetFolder>;
  43. /// Gets the conditions under which memory accessing instructions will overflow.
  44. ///
  45. /// \p Ptr is the pointer that will be read/written, and \p InstVal is either
  46. /// the result from the load or the value being stored. It is used to determine
  47. /// the size of memory block that is touched.
  48. ///
  49. /// Returns the condition under which the access will overflow.
  50. static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal,
  51. const DataLayout &DL, TargetLibraryInfo &TLI,
  52. ObjectSizeOffsetEvaluator &ObjSizeEval,
  53. BuilderTy &IRB, ScalarEvolution &SE) {
  54. uint64_t NeededSize = DL.getTypeStoreSize(InstVal->getType());
  55. LLVM_DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize)
  56. << " bytes\n");
  57. SizeOffsetEvalType SizeOffset = ObjSizeEval.compute(Ptr);
  58. if (!ObjSizeEval.bothKnown(SizeOffset)) {
  59. ++ChecksUnable;
  60. return nullptr;
  61. }
  62. Value *Size = SizeOffset.first;
  63. Value *Offset = SizeOffset.second;
  64. ConstantInt *SizeCI = dyn_cast<ConstantInt>(Size);
  65. Type *IntTy = DL.getIntPtrType(Ptr->getType());
  66. Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize);
  67. auto SizeRange = SE.getUnsignedRange(SE.getSCEV(Size));
  68. auto OffsetRange = SE.getUnsignedRange(SE.getSCEV(Offset));
  69. auto NeededSizeRange = SE.getUnsignedRange(SE.getSCEV(NeededSizeVal));
  70. // three checks are required to ensure safety:
  71. // . Offset >= 0 (since the offset is given from the base ptr)
  72. // . Size >= Offset (unsigned)
  73. // . Size - Offset >= NeededSize (unsigned)
  74. //
  75. // optimization: if Size >= 0 (signed), skip 1st check
  76. // FIXME: add NSW/NUW here? -- we dont care if the subtraction overflows
  77. Value *ObjSize = IRB.CreateSub(Size, Offset);
  78. Value *Cmp2 = SizeRange.getUnsignedMin().uge(OffsetRange.getUnsignedMax())
  79. ? ConstantInt::getFalse(Ptr->getContext())
  80. : IRB.CreateICmpULT(Size, Offset);
  81. Value *Cmp3 = SizeRange.sub(OffsetRange)
  82. .getUnsignedMin()
  83. .uge(NeededSizeRange.getUnsignedMax())
  84. ? ConstantInt::getFalse(Ptr->getContext())
  85. : IRB.CreateICmpULT(ObjSize, NeededSizeVal);
  86. Value *Or = IRB.CreateOr(Cmp2, Cmp3);
  87. if ((!SizeCI || SizeCI->getValue().slt(0)) &&
  88. !SizeRange.getSignedMin().isNonNegative()) {
  89. Value *Cmp1 = IRB.CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0));
  90. Or = IRB.CreateOr(Cmp1, Or);
  91. }
  92. return Or;
  93. }
  94. /// Adds run-time bounds checks to memory accessing instructions.
  95. ///
  96. /// \p Or is the condition that should guard the trap.
  97. ///
  98. /// \p GetTrapBB is a callable that returns the trap BB to use on failure.
  99. template <typename GetTrapBBT>
  100. static void insertBoundsCheck(Value *Or, BuilderTy &IRB, GetTrapBBT GetTrapBB) {
  101. // check if the comparison is always false
  102. ConstantInt *C = dyn_cast_or_null<ConstantInt>(Or);
  103. if (C) {
  104. ++ChecksSkipped;
  105. // If non-zero, nothing to do.
  106. if (!C->getZExtValue())
  107. return;
  108. }
  109. ++ChecksAdded;
  110. BasicBlock::iterator SplitI = IRB.GetInsertPoint();
  111. BasicBlock *OldBB = SplitI->getParent();
  112. BasicBlock *Cont = OldBB->splitBasicBlock(SplitI);
  113. OldBB->getTerminator()->eraseFromParent();
  114. if (C) {
  115. // If we have a constant zero, unconditionally branch.
  116. // FIXME: We should really handle this differently to bypass the splitting
  117. // the block.
  118. BranchInst::Create(GetTrapBB(IRB), OldBB);
  119. return;
  120. }
  121. // Create the conditional branch.
  122. BranchInst::Create(GetTrapBB(IRB), Cont, Or, OldBB);
  123. }
  124. static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI,
  125. ScalarEvolution &SE) {
  126. const DataLayout &DL = F.getParent()->getDataLayout();
  127. ObjectSizeOpts EvalOpts;
  128. EvalOpts.RoundToAlign = true;
  129. ObjectSizeOffsetEvaluator ObjSizeEval(DL, &TLI, F.getContext(), EvalOpts);
  130. // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory
  131. // touching instructions
  132. SmallVector<std::pair<Instruction *, Value *>, 4> TrapInfo;
  133. for (Instruction &I : instructions(F)) {
  134. Value *Or = nullptr;
  135. BuilderTy IRB(I.getParent(), BasicBlock::iterator(&I), TargetFolder(DL));
  136. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
  137. if (!LI->isVolatile())
  138. Or = getBoundsCheckCond(LI->getPointerOperand(), LI, DL, TLI,
  139. ObjSizeEval, IRB, SE);
  140. } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
  141. if (!SI->isVolatile())
  142. Or = getBoundsCheckCond(SI->getPointerOperand(), SI->getValueOperand(),
  143. DL, TLI, ObjSizeEval, IRB, SE);
  144. } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(&I)) {
  145. if (!AI->isVolatile())
  146. Or =
  147. getBoundsCheckCond(AI->getPointerOperand(), AI->getCompareOperand(),
  148. DL, TLI, ObjSizeEval, IRB, SE);
  149. } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(&I)) {
  150. if (!AI->isVolatile())
  151. Or = getBoundsCheckCond(AI->getPointerOperand(), AI->getValOperand(),
  152. DL, TLI, ObjSizeEval, IRB, SE);
  153. }
  154. if (Or)
  155. TrapInfo.push_back(std::make_pair(&I, Or));
  156. }
  157. // Create a trapping basic block on demand using a callback. Depending on
  158. // flags, this will either create a single block for the entire function or
  159. // will create a fresh block every time it is called.
  160. BasicBlock *TrapBB = nullptr;
  161. auto GetTrapBB = [&TrapBB](BuilderTy &IRB) {
  162. if (TrapBB && SingleTrapBB)
  163. return TrapBB;
  164. Function *Fn = IRB.GetInsertBlock()->getParent();
  165. // FIXME: This debug location doesn't make a lot of sense in the
  166. // `SingleTrapBB` case.
  167. auto DebugLoc = IRB.getCurrentDebugLocation();
  168. IRBuilder<>::InsertPointGuard Guard(IRB);
  169. TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn);
  170. IRB.SetInsertPoint(TrapBB);
  171. auto *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap);
  172. CallInst *TrapCall = IRB.CreateCall(F, {});
  173. TrapCall->setDoesNotReturn();
  174. TrapCall->setDoesNotThrow();
  175. TrapCall->setDebugLoc(DebugLoc);
  176. IRB.CreateUnreachable();
  177. return TrapBB;
  178. };
  179. // Add the checks.
  180. for (const auto &Entry : TrapInfo) {
  181. Instruction *Inst = Entry.first;
  182. BuilderTy IRB(Inst->getParent(), BasicBlock::iterator(Inst), TargetFolder(DL));
  183. insertBoundsCheck(Entry.second, IRB, GetTrapBB);
  184. }
  185. return !TrapInfo.empty();
  186. }
  187. PreservedAnalyses BoundsCheckingPass::run(Function &F, FunctionAnalysisManager &AM) {
  188. auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
  189. auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
  190. if (!addBoundsChecking(F, TLI, SE))
  191. return PreservedAnalyses::all();
  192. return PreservedAnalyses::none();
  193. }
  194. namespace {
  195. struct BoundsCheckingLegacyPass : public FunctionPass {
  196. static char ID;
  197. BoundsCheckingLegacyPass() : FunctionPass(ID) {
  198. initializeBoundsCheckingLegacyPassPass(*PassRegistry::getPassRegistry());
  199. }
  200. bool runOnFunction(Function &F) override {
  201. auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  202. auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  203. return addBoundsChecking(F, TLI, SE);
  204. }
  205. void getAnalysisUsage(AnalysisUsage &AU) const override {
  206. AU.addRequired<TargetLibraryInfoWrapperPass>();
  207. AU.addRequired<ScalarEvolutionWrapperPass>();
  208. }
  209. };
  210. } // namespace
  211. char BoundsCheckingLegacyPass::ID = 0;
  212. INITIALIZE_PASS_BEGIN(BoundsCheckingLegacyPass, "bounds-checking",
  213. "Run-time bounds checking", false, false)
  214. INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
  215. INITIALIZE_PASS_END(BoundsCheckingLegacyPass, "bounds-checking",
  216. "Run-time bounds checking", false, false)
  217. FunctionPass *llvm::createBoundsCheckingLegacyPass() {
  218. return new BoundsCheckingLegacyPass();
  219. }