KCFI.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. //===-- KCFI.cpp - Generic KCFI operand bundle lowering ---------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass emits generic KCFI indirect call checks for targets that don't
  10. // support lowering KCFI operand bundles in the back-end.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/Transforms/Instrumentation/KCFI.h"
  14. #include "llvm/ADT/Statistic.h"
  15. #include "llvm/IR/Constants.h"
  16. #include "llvm/IR/DiagnosticInfo.h"
  17. #include "llvm/IR/DiagnosticPrinter.h"
  18. #include "llvm/IR/Function.h"
  19. #include "llvm/IR/GlobalObject.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/IR/InstIterator.h"
  22. #include "llvm/IR/Instructions.h"
  23. #include "llvm/IR/Intrinsics.h"
  24. #include "llvm/IR/MDBuilder.h"
  25. #include "llvm/IR/Module.h"
  26. #include "llvm/InitializePasses.h"
  27. #include "llvm/Pass.h"
  28. #include "llvm/Target/TargetMachine.h"
  29. #include "llvm/Transforms/Instrumentation.h"
  30. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  31. using namespace llvm;
  32. #define DEBUG_TYPE "kcfi"
  33. STATISTIC(NumKCFIChecks, "Number of kcfi operands transformed into checks");
  34. namespace {
  35. class DiagnosticInfoKCFI : public DiagnosticInfo {
  36. const Twine &Msg;
  37. public:
  38. DiagnosticInfoKCFI(const Twine &DiagMsg,
  39. DiagnosticSeverity Severity = DS_Error)
  40. : DiagnosticInfo(DK_Linker, Severity), Msg(DiagMsg) {}
  41. void print(DiagnosticPrinter &DP) const override { DP << Msg; }
  42. };
  43. } // namespace
  44. PreservedAnalyses KCFIPass::run(Function &F, FunctionAnalysisManager &AM) {
  45. Module &M = *F.getParent();
  46. if (!M.getModuleFlag("kcfi"))
  47. return PreservedAnalyses::all();
  48. // Find call instructions with KCFI operand bundles.
  49. SmallVector<CallInst *> KCFICalls;
  50. for (Instruction &I : instructions(F)) {
  51. if (auto *CI = dyn_cast<CallInst>(&I))
  52. if (CI->getOperandBundle(LLVMContext::OB_kcfi))
  53. KCFICalls.push_back(CI);
  54. }
  55. if (KCFICalls.empty())
  56. return PreservedAnalyses::all();
  57. LLVMContext &Ctx = M.getContext();
  58. // patchable-function-prefix emits nops between the KCFI type identifier
  59. // and the function start. As we don't know the size of the emitted nops,
  60. // don't allow this attribute with generic lowering.
  61. if (F.hasFnAttribute("patchable-function-prefix"))
  62. Ctx.diagnose(
  63. DiagnosticInfoKCFI("-fpatchable-function-entry=N,M, where M>0 is not "
  64. "compatible with -fsanitize=kcfi on this target"));
  65. IntegerType *Int32Ty = Type::getInt32Ty(Ctx);
  66. MDNode *VeryUnlikelyWeights =
  67. MDBuilder(Ctx).createBranchWeights(1, (1U << 20) - 1);
  68. for (CallInst *CI : KCFICalls) {
  69. // Get the expected hash value.
  70. const uint32_t ExpectedHash =
  71. cast<ConstantInt>(CI->getOperandBundle(LLVMContext::OB_kcfi)->Inputs[0])
  72. ->getZExtValue();
  73. // Drop the KCFI operand bundle.
  74. CallBase *Call =
  75. CallBase::removeOperandBundle(CI, LLVMContext::OB_kcfi, CI);
  76. assert(Call != CI);
  77. Call->copyMetadata(*CI);
  78. CI->replaceAllUsesWith(Call);
  79. CI->eraseFromParent();
  80. if (!Call->isIndirectCall())
  81. continue;
  82. // Emit a check and trap if the target hash doesn't match.
  83. IRBuilder<> Builder(Call);
  84. Value *HashPtr = Builder.CreateConstInBoundsGEP1_32(
  85. Int32Ty, Call->getCalledOperand(), -1);
  86. Value *Test = Builder.CreateICmpNE(Builder.CreateLoad(Int32Ty, HashPtr),
  87. ConstantInt::get(Int32Ty, ExpectedHash));
  88. Instruction *ThenTerm =
  89. SplitBlockAndInsertIfThen(Test, Call, false, VeryUnlikelyWeights);
  90. Builder.SetInsertPoint(ThenTerm);
  91. Builder.CreateCall(Intrinsic::getDeclaration(&M, Intrinsic::trap));
  92. ++NumKCFIChecks;
  93. }
  94. return PreservedAnalyses::none();
  95. }