MVETailPredication.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file
  10. /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
  11. /// branches to help accelerate DSP applications. These two extensions,
  12. /// combined with a new form of predication called tail-predication, can be used
  13. /// to provide implicit vector predication within a low-overhead loop.
  14. /// This is implicit because the predicate of active/inactive lanes is
  15. /// calculated by hardware, and thus does not need to be explicitly passed
  16. /// to vector instructions. The instructions responsible for this are the
  17. /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the
  18. /// the total number of data elements processed by the loop. The loop-end
  19. /// LETP instruction is responsible for decrementing and setting the remaining
  20. /// elements to be processed and generating the mask of active lanes.
  21. ///
  22. /// The HardwareLoops pass inserts intrinsics identifying loops that the
  23. /// backend will attempt to convert into a low-overhead loop. The vectorizer is
  24. /// responsible for generating a vectorized loop in which the lanes are
  25. /// predicated upon an get.active.lane.mask intrinsic. This pass looks at these
  26. /// get.active.lane.mask intrinsic and attempts to convert them to VCTP
  27. /// instructions. This will be picked up by the ARM Low-overhead loop pass later
  28. /// in the backend, which performs the final transformation to a DLSTP or WLSTP
  29. /// tail-predicated loop.
  30. //
  31. //===----------------------------------------------------------------------===//
  32. #include "ARM.h"
  33. #include "ARMSubtarget.h"
  34. #include "ARMTargetTransformInfo.h"
  35. #include "llvm/Analysis/LoopInfo.h"
  36. #include "llvm/Analysis/LoopPass.h"
  37. #include "llvm/Analysis/ScalarEvolution.h"
  38. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  39. #include "llvm/Analysis/TargetLibraryInfo.h"
  40. #include "llvm/Analysis/TargetTransformInfo.h"
  41. #include "llvm/CodeGen/TargetPassConfig.h"
  42. #include "llvm/IR/IRBuilder.h"
  43. #include "llvm/IR/Instructions.h"
  44. #include "llvm/IR/IntrinsicsARM.h"
  45. #include "llvm/IR/PatternMatch.h"
  46. #include "llvm/InitializePasses.h"
  47. #include "llvm/Support/Debug.h"
  48. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  49. #include "llvm/Transforms/Utils/Local.h"
  50. #include "llvm/Transforms/Utils/LoopUtils.h"
  51. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  52. using namespace llvm;
  53. #define DEBUG_TYPE "mve-tail-predication"
  54. #define DESC "Transform predicated vector loops to use MVE tail predication"
  55. cl::opt<TailPredication::Mode> EnableTailPredication(
  56. "tail-predication", cl::desc("MVE tail-predication pass options"),
  57. cl::init(TailPredication::Enabled),
  58. cl::values(clEnumValN(TailPredication::Disabled, "disabled",
  59. "Don't tail-predicate loops"),
  60. clEnumValN(TailPredication::EnabledNoReductions,
  61. "enabled-no-reductions",
  62. "Enable tail-predication, but not for reduction loops"),
  63. clEnumValN(TailPredication::Enabled,
  64. "enabled",
  65. "Enable tail-predication, including reduction loops"),
  66. clEnumValN(TailPredication::ForceEnabledNoReductions,
  67. "force-enabled-no-reductions",
  68. "Enable tail-predication, but not for reduction loops, "
  69. "and force this which might be unsafe"),
  70. clEnumValN(TailPredication::ForceEnabled,
  71. "force-enabled",
  72. "Enable tail-predication, including reduction loops, "
  73. "and force this which might be unsafe")));
  74. namespace {
  75. class MVETailPredication : public LoopPass {
  76. SmallVector<IntrinsicInst*, 4> MaskedInsts;
  77. Loop *L = nullptr;
  78. ScalarEvolution *SE = nullptr;
  79. TargetTransformInfo *TTI = nullptr;
  80. const ARMSubtarget *ST = nullptr;
  81. public:
  82. static char ID;
  83. MVETailPredication() : LoopPass(ID) { }
  84. void getAnalysisUsage(AnalysisUsage &AU) const override {
  85. AU.addRequired<ScalarEvolutionWrapperPass>();
  86. AU.addRequired<LoopInfoWrapperPass>();
  87. AU.addRequired<TargetPassConfig>();
  88. AU.addRequired<TargetTransformInfoWrapperPass>();
  89. AU.addPreserved<LoopInfoWrapperPass>();
  90. AU.setPreservesCFG();
  91. }
  92. bool runOnLoop(Loop *L, LPPassManager&) override;
  93. private:
  94. /// Perform the relevant checks on the loop and convert active lane masks if
  95. /// possible.
  96. bool TryConvertActiveLaneMask(Value *TripCount);
  97. /// Perform several checks on the arguments of @llvm.get.active.lane.mask
  98. /// intrinsic. E.g., check that the loop induction variable and the element
  99. /// count are of the form we expect, and also perform overflow checks for
  100. /// the new expressions that are created.
  101. bool IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount);
  102. /// Insert the intrinsic to represent the effect of tail predication.
  103. void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *TripCount);
  104. /// Rematerialize the iteration count in exit blocks, which enables
  105. /// ARMLowOverheadLoops to better optimise away loop update statements inside
  106. /// hardware-loops.
  107. void RematerializeIterCount();
  108. };
  109. } // end namespace
  110. bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
  111. if (skipLoop(L) || !EnableTailPredication)
  112. return false;
  113. MaskedInsts.clear();
  114. Function &F = *L->getHeader()->getParent();
  115. auto &TPC = getAnalysis<TargetPassConfig>();
  116. auto &TM = TPC.getTM<TargetMachine>();
  117. ST = &TM.getSubtarget<ARMSubtarget>(F);
  118. TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  119. SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  120. this->L = L;
  121. // The MVE and LOB extensions are combined to enable tail-predication, but
  122. // there's nothing preventing us from generating VCTP instructions for v8.1m.
  123. if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
  124. LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
  125. return false;
  126. }
  127. BasicBlock *Preheader = L->getLoopPreheader();
  128. if (!Preheader)
  129. return false;
  130. auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
  131. for (auto &I : *BB) {
  132. auto *Call = dyn_cast<IntrinsicInst>(&I);
  133. if (!Call)
  134. continue;
  135. Intrinsic::ID ID = Call->getIntrinsicID();
  136. if (ID == Intrinsic::start_loop_iterations ||
  137. ID == Intrinsic::test_start_loop_iterations)
  138. return cast<IntrinsicInst>(&I);
  139. }
  140. return nullptr;
  141. };
  142. // Look for the hardware loop intrinsic that sets the iteration count.
  143. IntrinsicInst *Setup = FindLoopIterations(Preheader);
  144. // The test.set iteration could live in the pre-preheader.
  145. if (!Setup) {
  146. if (!Preheader->getSinglePredecessor())
  147. return false;
  148. Setup = FindLoopIterations(Preheader->getSinglePredecessor());
  149. if (!Setup)
  150. return false;
  151. }
  152. LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n");
  153. bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0));
  154. return Changed;
  155. }
  156. // The active lane intrinsic has this form:
  157. //
  158. // @llvm.get.active.lane.mask(IV, TC)
  159. //
  160. // Here we perform checks that this intrinsic behaves as expected,
  161. // which means:
  162. //
  163. // 1) Check that the TripCount (TC) belongs to this loop (originally).
  164. // 2) The element count (TC) needs to be sufficiently large that the decrement
  165. // of element counter doesn't overflow, which means that we need to prove:
  166. // ceil(ElementCount / VectorWidth) >= TripCount
  167. // by rounding up ElementCount up:
  168. // ((ElementCount + (VectorWidth - 1)) / VectorWidth
  169. // and evaluate if expression isKnownNonNegative:
  170. // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount
  171. // 3) The IV must be an induction phi with an increment equal to the
  172. // vector width.
  173. bool MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
  174. Value *TripCount) {
  175. bool ForceTailPredication =
  176. EnableTailPredication == TailPredication::ForceEnabledNoReductions ||
  177. EnableTailPredication == TailPredication::ForceEnabled;
  178. Value *ElemCount = ActiveLaneMask->getOperand(1);
  179. bool Changed = false;
  180. if (!L->makeLoopInvariant(ElemCount, Changed))
  181. return false;
  182. auto *EC= SE->getSCEV(ElemCount);
  183. auto *TC = SE->getSCEV(TripCount);
  184. int VectorWidth =
  185. cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
  186. if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
  187. VectorWidth != 16)
  188. return false;
  189. ConstantInt *ConstElemCount = nullptr;
  190. // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to
  191. // this loop. The scalar tripcount corresponds the number of elements
  192. // processed by the loop, so we will refer to that from this point on.
  193. if (!SE->isLoopInvariant(EC, L)) {
  194. LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n");
  195. return false;
  196. }
  197. if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
  198. ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
  199. if (!TC) {
  200. LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in "
  201. "set.loop.iterations\n");
  202. return false;
  203. }
  204. // Calculate 2 tripcount values and check that they are consistent with
  205. // each other. The TripCount for a predicated vector loop body is
  206. // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we
  207. // work it out here.
  208. uint64_t TC1 = TC->getZExtValue();
  209. uint64_t TC2 =
  210. (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth;
  211. // If the tripcount values are inconsistent, we can't insert the VCTP and
  212. // trigger tail-predication; keep the intrinsic as a get.active.lane.mask
  213. // and legalize this.
  214. if (TC1 != TC2) {
  215. LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: "
  216. << TC1 << " from set.loop.iterations, and "
  217. << TC2 << " from get.active.lane.mask\n");
  218. return false;
  219. }
  220. } else if (!ForceTailPredication) {
  221. // 2) We need to prove that the sub expression that we create in the
  222. // tail-predicated loop body, which calculates the remaining elements to be
  223. // processed, is non-negative, i.e. it doesn't overflow:
  224. //
  225. // ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0
  226. //
  227. // This is true if:
  228. //
  229. // TripCount == (ElementCount + VectorWidth - 1) / VectorWidth
  230. //
  231. // which what we will be using here.
  232. //
  233. auto *VW = SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth));
  234. // ElementCount + (VW-1):
  235. auto *ECPlusVWMinus1 = SE->getAddExpr(EC,
  236. SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1)));
  237. // Ceil = ElementCount + (VW-1) / VW
  238. auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
  239. // Prevent unused variable warnings with TC
  240. (void)TC;
  241. LLVM_DEBUG(
  242. dbgs() << "ARM TP: Analysing overflow behaviour for:\n";
  243. dbgs() << "ARM TP: - TripCount = "; TC->dump();
  244. dbgs() << "ARM TP: - ElemCount = "; EC->dump();
  245. dbgs() << "ARM TP: - VecWidth = " << VectorWidth << "\n";
  246. dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = "; Ceil->dump();
  247. );
  248. // As an example, almost all the tripcount expressions (produced by the
  249. // vectoriser) look like this:
  250. //
  251. // TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw>) /u 4)
  252. //
  253. // and "ElementCount + (VW-1) / VW":
  254. //
  255. // Ceil = ((3 + %N) /u 4)
  256. //
  257. // Check for equality of TC and Ceil by calculating SCEV expression
  258. // TC - Ceil and test it for zero.
  259. //
  260. const SCEV *Sub =
  261. SE->getMinusSCEV(SE->getBackedgeTakenCount(L),
  262. SE->getUDivExpr(SE->getAddExpr(SE->getMulExpr(Ceil, VW),
  263. SE->getNegativeSCEV(VW)),
  264. VW));
  265. // Use context sensitive facts about the path to the loop to refine. This
  266. // comes up as the backedge taken count can incorporate context sensitive
  267. // reasoning, and our RHS just above doesn't.
  268. Sub = SE->applyLoopGuards(Sub, L);
  269. if (!Sub->isZero()) {
  270. LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n");
  271. return false;
  272. }
  273. }
  274. // 3) Find out if IV is an induction phi. Note that we can't use Loop
  275. // helpers here to get the induction variable, because the hardware loop is
  276. // no longer in loopsimplify form, and also the hwloop intrinsic uses a
  277. // different counter. Using SCEV, we check that the induction is of the
  278. // form i = i + 4, where the increment must be equal to the VectorWidth.
  279. auto *IV = ActiveLaneMask->getOperand(0);
  280. auto *IVExpr = SE->getSCEV(IV);
  281. auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
  282. if (!AddExpr) {
  283. LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump());
  284. return false;
  285. }
  286. // Check that this AddRec is associated with this loop.
  287. if (AddExpr->getLoop() != L) {
  288. LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n");
  289. return false;
  290. }
  291. auto *Base = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
  292. if (!Base || !Base->isZero()) {
  293. LLVM_DEBUG(dbgs() << "ARM TP: induction base is not 0\n");
  294. return false;
  295. }
  296. auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
  297. if (!Step) {
  298. LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: ";
  299. AddExpr->getOperand(1)->dump());
  300. return false;
  301. }
  302. auto StepValue = Step->getValue()->getSExtValue();
  303. if (VectorWidth == StepValue)
  304. return true;
  305. LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue
  306. << " doesn't match vector width " << VectorWidth << "\n");
  307. return false;
  308. }
  309. void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
  310. Value *TripCount) {
  311. IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
  312. Module *M = L->getHeader()->getModule();
  313. Type *Ty = IntegerType::get(M->getContext(), 32);
  314. unsigned VectorWidth =
  315. cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
  316. // Insert a phi to count the number of elements processed by the loop.
  317. Builder.SetInsertPoint(L->getHeader()->getFirstNonPHI());
  318. PHINode *Processed = Builder.CreatePHI(Ty, 2);
  319. Processed->addIncoming(ActiveLaneMask->getOperand(1), L->getLoopPreheader());
  320. // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and
  321. // thus represent the effect of tail predication.
  322. Builder.SetInsertPoint(ActiveLaneMask);
  323. ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth);
  324. Intrinsic::ID VCTPID;
  325. switch (VectorWidth) {
  326. default:
  327. llvm_unreachable("unexpected number of lanes");
  328. case 2: VCTPID = Intrinsic::arm_mve_vctp64; break;
  329. case 4: VCTPID = Intrinsic::arm_mve_vctp32; break;
  330. case 8: VCTPID = Intrinsic::arm_mve_vctp16; break;
  331. case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
  332. }
  333. Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
  334. Value *VCTPCall = Builder.CreateCall(VCTP, Processed);
  335. ActiveLaneMask->replaceAllUsesWith(VCTPCall);
  336. // Add the incoming value to the new phi.
  337. // TODO: This add likely already exists in the loop.
  338. Value *Remaining = Builder.CreateSub(Processed, Factor);
  339. Processed->addIncoming(Remaining, L->getLoopLatch());
  340. LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
  341. << *Processed << "\n"
  342. << "ARM TP: Inserted VCTP: " << *VCTPCall << "\n");
  343. }
  344. bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) {
  345. SmallVector<IntrinsicInst *, 4> ActiveLaneMasks;
  346. for (auto *BB : L->getBlocks())
  347. for (auto &I : *BB)
  348. if (auto *Int = dyn_cast<IntrinsicInst>(&I))
  349. if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
  350. ActiveLaneMasks.push_back(Int);
  351. if (ActiveLaneMasks.empty())
  352. return false;
  353. LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
  354. for (auto *ActiveLaneMask : ActiveLaneMasks) {
  355. LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: "
  356. << *ActiveLaneMask << "\n");
  357. if (!IsSafeActiveMask(ActiveLaneMask, TripCount)) {
  358. LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n");
  359. return false;
  360. }
  361. LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP.\n");
  362. InsertVCTPIntrinsic(ActiveLaneMask, TripCount);
  363. }
  364. // Remove dead instructions and now dead phis.
  365. for (auto *II : ActiveLaneMasks)
  366. RecursivelyDeleteTriviallyDeadInstructions(II);
  367. for (auto *I : L->blocks())
  368. DeleteDeadPHIs(I);
  369. return true;
  370. }
  371. Pass *llvm::createMVETailPredicationPass() {
  372. return new MVETailPredication();
  373. }
  374. char MVETailPredication::ID = 0;
  375. INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
  376. INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)