ARMSLSHardening.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. //===- ARMSLSHardening.cpp - Harden Straight Line Missspeculation ---------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file contains a pass to insert code to mitigate against side channel
  10. // vulnerabilities that may happen under straight line miss-speculation.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "ARM.h"
  14. #include "ARMInstrInfo.h"
  15. #include "ARMSubtarget.h"
  16. #include "llvm/CodeGen/IndirectThunks.h"
  17. #include "llvm/CodeGen/MachineBasicBlock.h"
  18. #include "llvm/CodeGen/MachineFunction.h"
  19. #include "llvm/CodeGen/MachineFunctionPass.h"
  20. #include "llvm/CodeGen/MachineInstr.h"
  21. #include "llvm/CodeGen/MachineInstrBuilder.h"
  22. #include "llvm/CodeGen/MachineOperand.h"
  23. #include "llvm/IR/DebugLoc.h"
  24. #include <cassert>
  25. using namespace llvm;
  26. #define DEBUG_TYPE "arm-sls-hardening"
  27. #define ARM_SLS_HARDENING_NAME "ARM sls hardening pass"
  28. namespace {
  29. class ARMSLSHardening : public MachineFunctionPass {
  30. public:
  31. const TargetInstrInfo *TII;
  32. const ARMSubtarget *ST;
  33. static char ID;
  34. ARMSLSHardening() : MachineFunctionPass(ID) {
  35. initializeARMSLSHardeningPass(*PassRegistry::getPassRegistry());
  36. }
  37. bool runOnMachineFunction(MachineFunction &Fn) override;
  38. StringRef getPassName() const override { return ARM_SLS_HARDENING_NAME; }
  39. void getAnalysisUsage(AnalysisUsage &AU) const override {
  40. AU.setPreservesCFG();
  41. MachineFunctionPass::getAnalysisUsage(AU);
  42. }
  43. private:
  44. bool hardenReturnsAndBRs(MachineBasicBlock &MBB) const;
  45. bool hardenIndirectCalls(MachineBasicBlock &MBB) const;
  46. MachineBasicBlock &
  47. ConvertIndirectCallToIndirectJump(MachineBasicBlock &MBB,
  48. MachineBasicBlock::iterator) const;
  49. };
  50. } // end anonymous namespace
  51. char ARMSLSHardening::ID = 0;
  52. INITIALIZE_PASS(ARMSLSHardening, "arm-sls-hardening",
  53. ARM_SLS_HARDENING_NAME, false, false)
  54. static void insertSpeculationBarrier(const ARMSubtarget *ST,
  55. MachineBasicBlock &MBB,
  56. MachineBasicBlock::iterator MBBI,
  57. DebugLoc DL,
  58. bool AlwaysUseISBDSB = false) {
  59. assert(MBBI != MBB.begin() &&
  60. "Must not insert SpeculationBarrierEndBB as only instruction in MBB.");
  61. assert(std::prev(MBBI)->isBarrier() &&
  62. "SpeculationBarrierEndBB must only follow unconditional control flow "
  63. "instructions.");
  64. assert(std::prev(MBBI)->isTerminator() &&
  65. "SpeculationBarrierEndBB must only follow terminators.");
  66. const TargetInstrInfo *TII = ST->getInstrInfo();
  67. assert(ST->hasDataBarrier() || ST->hasSB());
  68. bool ProduceSB = ST->hasSB() && !AlwaysUseISBDSB;
  69. unsigned BarrierOpc =
  70. ProduceSB ? (ST->isThumb() ? ARM::t2SpeculationBarrierSBEndBB
  71. : ARM::SpeculationBarrierSBEndBB)
  72. : (ST->isThumb() ? ARM::t2SpeculationBarrierISBDSBEndBB
  73. : ARM::SpeculationBarrierISBDSBEndBB);
  74. if (MBBI == MBB.end() || !isSpeculationBarrierEndBBOpcode(MBBI->getOpcode()))
  75. BuildMI(MBB, MBBI, DL, TII->get(BarrierOpc));
  76. }
  77. bool ARMSLSHardening::runOnMachineFunction(MachineFunction &MF) {
  78. ST = &MF.getSubtarget<ARMSubtarget>();
  79. TII = MF.getSubtarget().getInstrInfo();
  80. bool Modified = false;
  81. for (auto &MBB : MF) {
  82. Modified |= hardenReturnsAndBRs(MBB);
  83. Modified |= hardenIndirectCalls(MBB);
  84. }
  85. return Modified;
  86. }
  87. bool ARMSLSHardening::hardenReturnsAndBRs(MachineBasicBlock &MBB) const {
  88. if (!ST->hardenSlsRetBr())
  89. return false;
  90. assert(!ST->isThumb1Only());
  91. bool Modified = false;
  92. MachineBasicBlock::iterator MBBI = MBB.getFirstTerminator(), E = MBB.end();
  93. MachineBasicBlock::iterator NextMBBI;
  94. for (; MBBI != E; MBBI = NextMBBI) {
  95. MachineInstr &MI = *MBBI;
  96. NextMBBI = std::next(MBBI);
  97. if (isIndirectControlFlowNotComingBack(MI)) {
  98. assert(MI.isTerminator());
  99. assert(!TII->isPredicated(MI));
  100. insertSpeculationBarrier(ST, MBB, std::next(MBBI), MI.getDebugLoc());
  101. Modified = true;
  102. }
  103. }
  104. return Modified;
  105. }
  106. static const char SLSBLRNamePrefix[] = "__llvm_slsblr_thunk_";
  107. static const struct ThunkNameRegMode {
  108. const char* Name;
  109. Register Reg;
  110. bool isThumb;
  111. } SLSBLRThunks[] = {
  112. {"__llvm_slsblr_thunk_arm_r0", ARM::R0, false},
  113. {"__llvm_slsblr_thunk_arm_r1", ARM::R1, false},
  114. {"__llvm_slsblr_thunk_arm_r2", ARM::R2, false},
  115. {"__llvm_slsblr_thunk_arm_r3", ARM::R3, false},
  116. {"__llvm_slsblr_thunk_arm_r4", ARM::R4, false},
  117. {"__llvm_slsblr_thunk_arm_r5", ARM::R5, false},
  118. {"__llvm_slsblr_thunk_arm_r6", ARM::R6, false},
  119. {"__llvm_slsblr_thunk_arm_r7", ARM::R7, false},
  120. {"__llvm_slsblr_thunk_arm_r8", ARM::R8, false},
  121. {"__llvm_slsblr_thunk_arm_r9", ARM::R9, false},
  122. {"__llvm_slsblr_thunk_arm_r10", ARM::R10, false},
  123. {"__llvm_slsblr_thunk_arm_r11", ARM::R11, false},
  124. {"__llvm_slsblr_thunk_arm_sp", ARM::SP, false},
  125. {"__llvm_slsblr_thunk_arm_pc", ARM::PC, false},
  126. {"__llvm_slsblr_thunk_thumb_r0", ARM::R0, true},
  127. {"__llvm_slsblr_thunk_thumb_r1", ARM::R1, true},
  128. {"__llvm_slsblr_thunk_thumb_r2", ARM::R2, true},
  129. {"__llvm_slsblr_thunk_thumb_r3", ARM::R3, true},
  130. {"__llvm_slsblr_thunk_thumb_r4", ARM::R4, true},
  131. {"__llvm_slsblr_thunk_thumb_r5", ARM::R5, true},
  132. {"__llvm_slsblr_thunk_thumb_r6", ARM::R6, true},
  133. {"__llvm_slsblr_thunk_thumb_r7", ARM::R7, true},
  134. {"__llvm_slsblr_thunk_thumb_r8", ARM::R8, true},
  135. {"__llvm_slsblr_thunk_thumb_r9", ARM::R9, true},
  136. {"__llvm_slsblr_thunk_thumb_r10", ARM::R10, true},
  137. {"__llvm_slsblr_thunk_thumb_r11", ARM::R11, true},
  138. {"__llvm_slsblr_thunk_thumb_sp", ARM::SP, true},
  139. {"__llvm_slsblr_thunk_thumb_pc", ARM::PC, true},
  140. };
  141. // An enum for tracking whether Arm and Thumb thunks have been inserted into the
  142. // current module so far.
  143. enum ArmInsertedThunks { ArmThunk = 1, ThumbThunk = 2 };
  144. inline ArmInsertedThunks &operator|=(ArmInsertedThunks &X,
  145. ArmInsertedThunks Y) {
  146. return X = static_cast<ArmInsertedThunks>(X | Y);
  147. }
  148. namespace {
  149. struct SLSBLRThunkInserter
  150. : ThunkInserter<SLSBLRThunkInserter, ArmInsertedThunks> {
  151. const char *getThunkPrefix() { return SLSBLRNamePrefix; }
  152. bool mayUseThunk(const MachineFunction &MF,
  153. ArmInsertedThunks InsertedThunks) {
  154. if ((InsertedThunks & ArmThunk &&
  155. !MF.getSubtarget<ARMSubtarget>().isThumb()) ||
  156. (InsertedThunks & ThumbThunk &&
  157. MF.getSubtarget<ARMSubtarget>().isThumb()))
  158. return false;
  159. ComdatThunks &= !MF.getSubtarget<ARMSubtarget>().hardenSlsNoComdat();
  160. // FIXME: This could also check if there are any indirect calls in the
  161. // function to more accurately reflect if a thunk will be needed.
  162. return MF.getSubtarget<ARMSubtarget>().hardenSlsBlr();
  163. }
  164. ArmInsertedThunks insertThunks(MachineModuleInfo &MMI, MachineFunction &MF);
  165. void populateThunk(MachineFunction &MF);
  166. private:
  167. bool ComdatThunks = true;
  168. };
  169. } // namespace
  170. ArmInsertedThunks SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
  171. MachineFunction &MF) {
  172. // FIXME: It probably would be possible to filter which thunks to produce
  173. // based on which registers are actually used in indirect calls in this
  174. // function. But would that be a worthwhile optimization?
  175. const ARMSubtarget *ST = &MF.getSubtarget<ARMSubtarget>();
  176. for (auto T : SLSBLRThunks)
  177. if (ST->isThumb() == T.isThumb)
  178. createThunkFunction(MMI, T.Name, ComdatThunks);
  179. return ST->isThumb() ? ThumbThunk : ArmThunk;
  180. }
  181. void SLSBLRThunkInserter::populateThunk(MachineFunction &MF) {
  182. // FIXME: How to better communicate Register number, rather than through
  183. // name and lookup table?
  184. assert(MF.getName().startswith(getThunkPrefix()));
  185. auto ThunkIt = llvm::find_if(
  186. SLSBLRThunks, [&MF](auto T) { return T.Name == MF.getName(); });
  187. assert(ThunkIt != std::end(SLSBLRThunks));
  188. Register ThunkReg = ThunkIt->Reg;
  189. bool isThumb = ThunkIt->isThumb;
  190. const TargetInstrInfo *TII = MF.getSubtarget<ARMSubtarget>().getInstrInfo();
  191. MachineBasicBlock *Entry = &MF.front();
  192. Entry->clear();
  193. // These thunks need to consist of the following instructions:
  194. // __llvm_slsblr_thunk_(arm/thumb)_rN:
  195. // bx rN
  196. // barrierInsts
  197. Entry->addLiveIn(ThunkReg);
  198. if (isThumb)
  199. BuildMI(Entry, DebugLoc(), TII->get(ARM::tBX))
  200. .addReg(ThunkReg)
  201. .add(predOps(ARMCC::AL));
  202. else
  203. BuildMI(Entry, DebugLoc(), TII->get(ARM::BX))
  204. .addReg(ThunkReg);
  205. // Make sure the thunks do not make use of the SB extension in case there is
  206. // a function somewhere that will call to it that for some reason disabled
  207. // the SB extension locally on that function, even though it's enabled for
  208. // the module otherwise. Therefore set AlwaysUseISBSDB to true.
  209. insertSpeculationBarrier(&MF.getSubtarget<ARMSubtarget>(), *Entry,
  210. Entry->end(), DebugLoc(), true /*AlwaysUseISBDSB*/);
  211. }
  212. MachineBasicBlock &ARMSLSHardening::ConvertIndirectCallToIndirectJump(
  213. MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const {
  214. // Transform an indirect call to an indirect jump as follows:
  215. // Before:
  216. // |-----------------------------|
  217. // | ... |
  218. // | instI |
  219. // | BLX rN |
  220. // | instJ |
  221. // | ... |
  222. // |-----------------------------|
  223. //
  224. // After:
  225. // |---------- -------------------------|
  226. // | ... |
  227. // | instI |
  228. // | *call* __llvm_slsblr_thunk_mode_xN |
  229. // | instJ |
  230. // | ... |
  231. // |--------------------------------------|
  232. //
  233. // __llvm_slsblr_thunk_mode_xN:
  234. // |-----------------------------|
  235. // | BX rN |
  236. // | barrierInsts |
  237. // |-----------------------------|
  238. //
  239. // The __llvm_slsblr_thunk_mode_xN thunks are created by the
  240. // SLSBLRThunkInserter.
  241. // This function merely needs to transform an indirect call to a direct call
  242. // to __llvm_slsblr_thunk_xN.
  243. MachineInstr &IndirectCall = *MBBI;
  244. assert(isIndirectCall(IndirectCall) && !IndirectCall.isReturn());
  245. int RegOpIdxOnIndirectCall = -1;
  246. bool isThumb;
  247. switch (IndirectCall.getOpcode()) {
  248. case ARM::BLX: // !isThumb2
  249. case ARM::BLX_noip: // !isThumb2
  250. isThumb = false;
  251. RegOpIdxOnIndirectCall = 0;
  252. break;
  253. case ARM::tBLXr: // isThumb2
  254. case ARM::tBLXr_noip: // isThumb2
  255. isThumb = true;
  256. RegOpIdxOnIndirectCall = 2;
  257. break;
  258. default:
  259. llvm_unreachable("unhandled Indirect Call");
  260. }
  261. Register Reg = IndirectCall.getOperand(RegOpIdxOnIndirectCall).getReg();
  262. // Since linkers are allowed to clobber R12 on function calls, the above
  263. // mitigation only works if the original indirect call instruction was not
  264. // using R12. Code generation before must make sure that no indirect call
  265. // using R12 was produced if the mitigation is enabled.
  266. // Also, the transformation is incorrect if the indirect call uses LR, so
  267. // also have to avoid that.
  268. assert(Reg != ARM::R12 && Reg != ARM::LR);
  269. bool RegIsKilled = IndirectCall.getOperand(RegOpIdxOnIndirectCall).isKill();
  270. DebugLoc DL = IndirectCall.getDebugLoc();
  271. MachineFunction &MF = *MBBI->getMF();
  272. auto ThunkIt = llvm::find_if(SLSBLRThunks, [Reg, isThumb](auto T) {
  273. return T.Reg == Reg && T.isThumb == isThumb;
  274. });
  275. assert(ThunkIt != std::end(SLSBLRThunks));
  276. Module *M = MF.getFunction().getParent();
  277. const GlobalValue *GV = cast<GlobalValue>(M->getNamedValue(ThunkIt->Name));
  278. MachineInstr *BL =
  279. isThumb ? BuildMI(MBB, MBBI, DL, TII->get(ARM::tBL))
  280. .addImm(IndirectCall.getOperand(0).getImm())
  281. .addReg(IndirectCall.getOperand(1).getReg())
  282. .addGlobalAddress(GV)
  283. : BuildMI(MBB, MBBI, DL, TII->get(ARM::BL)).addGlobalAddress(GV);
  284. // Now copy the implicit operands from IndirectCall to BL and copy other
  285. // necessary info.
  286. // However, both IndirectCall and BL instructions implictly use SP and
  287. // implicitly define LR. Blindly copying implicit operands would result in SP
  288. // and LR operands to be present multiple times. While this may not be too
  289. // much of an issue, let's avoid that for cleanliness, by removing those
  290. // implicit operands from the BL created above before we copy over all
  291. // implicit operands from the IndirectCall.
  292. int ImpLROpIdx = -1;
  293. int ImpSPOpIdx = -1;
  294. for (unsigned OpIdx = BL->getNumExplicitOperands();
  295. OpIdx < BL->getNumOperands(); OpIdx++) {
  296. MachineOperand Op = BL->getOperand(OpIdx);
  297. if (!Op.isReg())
  298. continue;
  299. if (Op.getReg() == ARM::LR && Op.isDef())
  300. ImpLROpIdx = OpIdx;
  301. if (Op.getReg() == ARM::SP && !Op.isDef())
  302. ImpSPOpIdx = OpIdx;
  303. }
  304. assert(ImpLROpIdx != -1);
  305. assert(ImpSPOpIdx != -1);
  306. int FirstOpIdxToRemove = std::max(ImpLROpIdx, ImpSPOpIdx);
  307. int SecondOpIdxToRemove = std::min(ImpLROpIdx, ImpSPOpIdx);
  308. BL->removeOperand(FirstOpIdxToRemove);
  309. BL->removeOperand(SecondOpIdxToRemove);
  310. // Now copy over the implicit operands from the original IndirectCall
  311. BL->copyImplicitOps(MF, IndirectCall);
  312. MF.moveCallSiteInfo(&IndirectCall, BL);
  313. // Also add the register called in the IndirectCall as being used in the
  314. // called thunk.
  315. BL->addOperand(MachineOperand::CreateReg(Reg, false /*isDef*/, true /*isImp*/,
  316. RegIsKilled /*isKill*/));
  317. // Remove IndirectCallinstruction
  318. MBB.erase(MBBI);
  319. return MBB;
  320. }
  321. bool ARMSLSHardening::hardenIndirectCalls(MachineBasicBlock &MBB) const {
  322. if (!ST->hardenSlsBlr())
  323. return false;
  324. bool Modified = false;
  325. MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
  326. MachineBasicBlock::iterator NextMBBI;
  327. for (; MBBI != E; MBBI = NextMBBI) {
  328. MachineInstr &MI = *MBBI;
  329. NextMBBI = std::next(MBBI);
  330. // Tail calls are both indirect calls and "returns".
  331. // They are also indirect jumps, so should be handled by sls-harden-retbr,
  332. // rather than sls-harden-blr.
  333. if (isIndirectCall(MI) && !MI.isReturn()) {
  334. ConvertIndirectCallToIndirectJump(MBB, MBBI);
  335. Modified = true;
  336. }
  337. }
  338. return Modified;
  339. }
  340. FunctionPass *llvm::createARMSLSHardeningPass() {
  341. return new ARMSLSHardening();
  342. }
  343. namespace {
  344. class ARMIndirectThunks : public MachineFunctionPass {
  345. public:
  346. static char ID;
  347. ARMIndirectThunks() : MachineFunctionPass(ID) {}
  348. StringRef getPassName() const override { return "ARM Indirect Thunks"; }
  349. bool doInitialization(Module &M) override;
  350. bool runOnMachineFunction(MachineFunction &MF) override;
  351. void getAnalysisUsage(AnalysisUsage &AU) const override {
  352. MachineFunctionPass::getAnalysisUsage(AU);
  353. AU.addRequired<MachineModuleInfoWrapperPass>();
  354. AU.addPreserved<MachineModuleInfoWrapperPass>();
  355. }
  356. private:
  357. std::tuple<SLSBLRThunkInserter> TIs;
  358. // FIXME: When LLVM moves to C++17, these can become folds
  359. template <typename... ThunkInserterT>
  360. static void initTIs(Module &M,
  361. std::tuple<ThunkInserterT...> &ThunkInserters) {
  362. (void)std::initializer_list<int>{
  363. (std::get<ThunkInserterT>(ThunkInserters).init(M), 0)...};
  364. }
  365. template <typename... ThunkInserterT>
  366. static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
  367. std::tuple<ThunkInserterT...> &ThunkInserters) {
  368. bool Modified = false;
  369. (void)std::initializer_list<int>{
  370. Modified |= std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF)...};
  371. return Modified;
  372. }
  373. };
  374. } // end anonymous namespace
  375. char ARMIndirectThunks::ID = 0;
  376. FunctionPass *llvm::createARMIndirectThunks() {
  377. return new ARMIndirectThunks();
  378. }
  379. bool ARMIndirectThunks::doInitialization(Module &M) {
  380. initTIs(M, TIs);
  381. return false;
  382. }
  383. bool ARMIndirectThunks::runOnMachineFunction(MachineFunction &MF) {
  384. LLVM_DEBUG(dbgs() << getPassName() << '\n');
  385. auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
  386. return runTIs(MMI, MF, TIs);
  387. }