BPFAdjustOpt.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Adjust optimization to make the code more kernel verifier friendly.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "BPF.h"
  13. #include "BPFCORE.h"
  14. #include "BPFTargetMachine.h"
  15. #include "llvm/IR/Instruction.h"
  16. #include "llvm/IR/Instructions.h"
  17. #include "llvm/IR/IntrinsicsBPF.h"
  18. #include "llvm/IR/Module.h"
  19. #include "llvm/IR/PatternMatch.h"
  20. #include "llvm/IR/Type.h"
  21. #include "llvm/IR/User.h"
  22. #include "llvm/IR/Value.h"
  23. #include "llvm/Pass.h"
  24. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  25. #define DEBUG_TYPE "bpf-adjust-opt"
  26. using namespace llvm;
  27. using namespace llvm::PatternMatch;
  28. static cl::opt<bool>
  29. DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
  30. cl::desc("BPF: Disable Serializing ICMP insns."),
  31. cl::init(false));
  32. static cl::opt<bool> DisableBPFavoidSpeculation(
  33. "bpf-disable-avoid-speculation", cl::Hidden,
  34. cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
  35. cl::init(false));
  36. namespace {
  37. class BPFAdjustOpt final : public ModulePass {
  38. public:
  39. static char ID;
  40. BPFAdjustOpt() : ModulePass(ID) {}
  41. bool runOnModule(Module &M) override;
  42. };
  43. class BPFAdjustOptImpl {
  44. struct PassThroughInfo {
  45. Instruction *Input;
  46. Instruction *UsedInst;
  47. uint32_t OpIdx;
  48. PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
  49. : Input(I), UsedInst(U), OpIdx(Idx) {}
  50. };
  51. public:
  52. BPFAdjustOptImpl(Module *M) : M(M) {}
  53. bool run();
  54. private:
  55. Module *M;
  56. SmallVector<PassThroughInfo, 16> PassThroughs;
  57. bool adjustICmpToBuiltin();
  58. void adjustBasicBlock(BasicBlock &BB);
  59. bool serializeICMPCrossBB(BasicBlock &BB);
  60. void adjustInst(Instruction &I);
  61. bool serializeICMPInBB(Instruction &I);
  62. bool avoidSpeculation(Instruction &I);
  63. bool insertPassThrough();
  64. };
  65. } // End anonymous namespace
  66. char BPFAdjustOpt::ID = 0;
  67. INITIALIZE_PASS(BPFAdjustOpt, "bpf-adjust-opt", "BPF Adjust Optimization",
  68. false, false)
  69. ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); }
  70. bool BPFAdjustOpt::runOnModule(Module &M) { return BPFAdjustOptImpl(&M).run(); }
  71. bool BPFAdjustOptImpl::run() {
  72. bool Changed = adjustICmpToBuiltin();
  73. for (Function &F : *M)
  74. for (auto &BB : F) {
  75. adjustBasicBlock(BB);
  76. for (auto &I : BB)
  77. adjustInst(I);
  78. }
  79. return insertPassThrough() || Changed;
  80. }
  81. // Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with
  82. // trunc op into mask and cmp") added a transformation to
  83. // convert "(conv)a < power_2_const" to "a & <const>" in certain
  84. // cases and bpf kernel verifier has to handle the resulted code
  85. // conservatively and this may reject otherwise legitimate program.
  86. // Here, we change related icmp code to a builtin which will
  87. // be restored to original icmp code later to prevent that
  88. // InstCombine transformatin.
  89. bool BPFAdjustOptImpl::adjustICmpToBuiltin() {
  90. bool Changed = false;
  91. ICmpInst *ToBeDeleted = nullptr;
  92. for (Function &F : *M)
  93. for (auto &BB : F)
  94. for (auto &I : BB) {
  95. if (ToBeDeleted) {
  96. ToBeDeleted->eraseFromParent();
  97. ToBeDeleted = nullptr;
  98. }
  99. auto *Icmp = dyn_cast<ICmpInst>(&I);
  100. if (!Icmp)
  101. continue;
  102. Value *Op0 = Icmp->getOperand(0);
  103. if (!isa<TruncInst>(Op0))
  104. continue;
  105. auto ConstOp1 = dyn_cast<ConstantInt>(Icmp->getOperand(1));
  106. if (!ConstOp1)
  107. continue;
  108. auto ConstOp1Val = ConstOp1->getValue().getZExtValue();
  109. auto Op = Icmp->getPredicate();
  110. if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_UGE) {
  111. if ((ConstOp1Val - 1) & ConstOp1Val)
  112. continue;
  113. } else if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_UGT) {
  114. if (ConstOp1Val & (ConstOp1Val + 1))
  115. continue;
  116. } else {
  117. continue;
  118. }
  119. Constant *Opcode =
  120. ConstantInt::get(Type::getInt32Ty(BB.getContext()), Op);
  121. Function *Fn = Intrinsic::getDeclaration(
  122. M, Intrinsic::bpf_compare, {Op0->getType(), ConstOp1->getType()});
  123. auto *NewInst = CallInst::Create(Fn, {Opcode, Op0, ConstOp1});
  124. BB.getInstList().insert(I.getIterator(), NewInst);
  125. Icmp->replaceAllUsesWith(NewInst);
  126. Changed = true;
  127. ToBeDeleted = Icmp;
  128. }
  129. return Changed;
  130. }
  131. bool BPFAdjustOptImpl::insertPassThrough() {
  132. for (auto &Info : PassThroughs) {
  133. auto *CI = BPFCoreSharedInfo::insertPassThrough(
  134. M, Info.UsedInst->getParent(), Info.Input, Info.UsedInst);
  135. Info.UsedInst->setOperand(Info.OpIdx, CI);
  136. }
  137. return !PassThroughs.empty();
  138. }
  139. // To avoid combining conditionals in the same basic block by
  140. // instrcombine optimization.
  141. bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
  142. // For:
  143. // comp1 = icmp <opcode> ...;
  144. // comp2 = icmp <opcode> ...;
  145. // ... or comp1 comp2 ...
  146. // changed to:
  147. // comp1 = icmp <opcode> ...;
  148. // comp2 = icmp <opcode> ...;
  149. // new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
  150. // ... or new_comp1 comp2 ...
  151. Value *Op0, *Op1;
  152. // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
  153. if (!match(&I, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
  154. return false;
  155. auto *Icmp1 = dyn_cast<ICmpInst>(Op0);
  156. if (!Icmp1)
  157. return false;
  158. auto *Icmp2 = dyn_cast<ICmpInst>(Op1);
  159. if (!Icmp2)
  160. return false;
  161. Value *Icmp1Op0 = Icmp1->getOperand(0);
  162. Value *Icmp2Op0 = Icmp2->getOperand(0);
  163. if (Icmp1Op0 != Icmp2Op0)
  164. return false;
  165. // Now we got two icmp instructions which feed into
  166. // an "or" instruction.
  167. PassThroughInfo Info(Icmp1, &I, 0);
  168. PassThroughs.push_back(Info);
  169. return true;
  170. }
  171. // To avoid combining conditionals in the same basic block by
  172. // instrcombine optimization.
  173. bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
  174. // For:
  175. // B1:
  176. // comp1 = icmp <opcode> ...;
  177. // if (comp1) goto B2 else B3;
  178. // B2:
  179. // comp2 = icmp <opcode> ...;
  180. // if (comp2) goto B4 else B5;
  181. // B4:
  182. // ...
  183. // changed to:
  184. // B1:
  185. // comp1 = icmp <opcode> ...;
  186. // comp1 = __builtin_bpf_passthrough(seq_num, comp1);
  187. // if (comp1) goto B2 else B3;
  188. // B2:
  189. // comp2 = icmp <opcode> ...;
  190. // if (comp2) goto B4 else B5;
  191. // B4:
  192. // ...
  193. // Check basic predecessors, if two of them (say B1, B2) are using
  194. // icmp instructions to generate conditions and one is the predesessor
  195. // of another (e.g., B1 is the predecessor of B2). Add a passthrough
  196. // barrier after icmp inst of block B1.
  197. BasicBlock *B2 = BB.getSinglePredecessor();
  198. if (!B2)
  199. return false;
  200. BasicBlock *B1 = B2->getSinglePredecessor();
  201. if (!B1)
  202. return false;
  203. Instruction *TI = B2->getTerminator();
  204. auto *BI = dyn_cast<BranchInst>(TI);
  205. if (!BI || !BI->isConditional())
  206. return false;
  207. auto *Cond = dyn_cast<ICmpInst>(BI->getCondition());
  208. if (!Cond || B2->getFirstNonPHI() != Cond)
  209. return false;
  210. Value *B2Op0 = Cond->getOperand(0);
  211. auto Cond2Op = Cond->getPredicate();
  212. TI = B1->getTerminator();
  213. BI = dyn_cast<BranchInst>(TI);
  214. if (!BI || !BI->isConditional())
  215. return false;
  216. Cond = dyn_cast<ICmpInst>(BI->getCondition());
  217. if (!Cond)
  218. return false;
  219. Value *B1Op0 = Cond->getOperand(0);
  220. auto Cond1Op = Cond->getPredicate();
  221. if (B1Op0 != B2Op0)
  222. return false;
  223. if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
  224. if (Cond2Op != ICmpInst::ICMP_SLT && Cond1Op != ICmpInst::ICMP_SLE)
  225. return false;
  226. } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
  227. if (Cond2Op != ICmpInst::ICMP_SGT && Cond1Op != ICmpInst::ICMP_SGE)
  228. return false;
  229. } else {
  230. return false;
  231. }
  232. PassThroughInfo Info(Cond, BI, 0);
  233. PassThroughs.push_back(Info);
  234. return true;
  235. }
  236. // To avoid speculative hoisting certain computations out of
  237. // a basic block.
  238. bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
  239. if (auto *LdInst = dyn_cast<LoadInst>(&I)) {
  240. if (auto *GV = dyn_cast<GlobalVariable>(LdInst->getOperand(0))) {
  241. if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
  242. GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
  243. return false;
  244. }
  245. }
  246. if (!isa<LoadInst>(&I) && !isa<CallInst>(&I))
  247. return false;
  248. // For:
  249. // B1:
  250. // var = ...
  251. // ...
  252. // /* icmp may not be in the same block as var = ... */
  253. // comp1 = icmp <opcode> var, <const>;
  254. // if (comp1) goto B2 else B3;
  255. // B2:
  256. // ... var ...
  257. // change to:
  258. // B1:
  259. // var = ...
  260. // ...
  261. // /* icmp may not be in the same block as var = ... */
  262. // comp1 = icmp <opcode> var, <const>;
  263. // if (comp1) goto B2 else B3;
  264. // B2:
  265. // var = __builtin_bpf_passthrough(seq_num, var);
  266. // ... var ...
  267. bool isCandidate = false;
  268. SmallVector<PassThroughInfo, 4> Candidates;
  269. for (User *U : I.users()) {
  270. Instruction *Inst = dyn_cast<Instruction>(U);
  271. if (!Inst)
  272. continue;
  273. // May cover a little bit more than the
  274. // above pattern.
  275. if (auto *Icmp1 = dyn_cast<ICmpInst>(Inst)) {
  276. Value *Icmp1Op1 = Icmp1->getOperand(1);
  277. if (!isa<Constant>(Icmp1Op1))
  278. return false;
  279. isCandidate = true;
  280. continue;
  281. }
  282. // Ignore the use in the same basic block as the definition.
  283. if (Inst->getParent() == I.getParent())
  284. continue;
  285. // use in a different basic block, If there is a call or
  286. // load/store insn before this instruction in this basic
  287. // block. Most likely it cannot be hoisted out. Skip it.
  288. for (auto &I2 : *Inst->getParent()) {
  289. if (isa<CallInst>(&I2))
  290. return false;
  291. if (isa<LoadInst>(&I2) || isa<StoreInst>(&I2))
  292. return false;
  293. if (&I2 == Inst)
  294. break;
  295. }
  296. // It should be used in a GEP or a simple arithmetic like
  297. // ZEXT/SEXT which is used for GEP.
  298. if (Inst->getOpcode() == Instruction::ZExt ||
  299. Inst->getOpcode() == Instruction::SExt) {
  300. PassThroughInfo Info(&I, Inst, 0);
  301. Candidates.push_back(Info);
  302. } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
  303. // traverse GEP inst to find Use operand index
  304. unsigned i, e;
  305. for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
  306. Value *V = GI->getOperand(i);
  307. if (V == &I)
  308. break;
  309. }
  310. if (i == e)
  311. continue;
  312. PassThroughInfo Info(&I, GI, i);
  313. Candidates.push_back(Info);
  314. }
  315. }
  316. if (!isCandidate || Candidates.empty())
  317. return false;
  318. llvm::append_range(PassThroughs, Candidates);
  319. return true;
  320. }
  321. void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
  322. if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
  323. return;
  324. }
  325. void BPFAdjustOptImpl::adjustInst(Instruction &I) {
  326. if (!DisableBPFserializeICMP && serializeICMPInBB(I))
  327. return;
  328. if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
  329. return;
  330. }
  331. PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
  332. return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
  333. : PreservedAnalyses::all();
  334. }