SVEIntrinsicOpts.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Performs general IR level optimizations on SVE intrinsics.
  10. //
  11. // This pass performs the following optimizations:
  12. //
  13. // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
  14. // %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
  15. // %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
  16. // ; (%1 can be replaced with a reinterpret of %2)
  17. //
  18. // - optimizes ptest intrinsics where the operands are being needlessly
  19. // converted to and from svbool_t.
  20. //
  21. //===----------------------------------------------------------------------===//
  22. #include "AArch64.h"
  23. #include "Utils/AArch64BaseInfo.h"
  24. #include "llvm/ADT/PostOrderIterator.h"
  25. #include "llvm/ADT/SetVector.h"
  26. #include "llvm/IR/Constants.h"
  27. #include "llvm/IR/Dominators.h"
  28. #include "llvm/IR/IRBuilder.h"
  29. #include "llvm/IR/Instructions.h"
  30. #include "llvm/IR/IntrinsicInst.h"
  31. #include "llvm/IR/IntrinsicsAArch64.h"
  32. #include "llvm/IR/LLVMContext.h"
  33. #include "llvm/IR/PatternMatch.h"
  34. #include "llvm/InitializePasses.h"
  35. #include "llvm/Support/Debug.h"
  36. #include <optional>
  37. using namespace llvm;
  38. using namespace llvm::PatternMatch;
  39. #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
  40. namespace {
  41. struct SVEIntrinsicOpts : public ModulePass {
  42. static char ID; // Pass identification, replacement for typeid
  43. SVEIntrinsicOpts() : ModulePass(ID) {
  44. initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
  45. }
  46. bool runOnModule(Module &M) override;
  47. void getAnalysisUsage(AnalysisUsage &AU) const override;
  48. private:
  49. bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
  50. SmallSetVector<IntrinsicInst *, 4> &PTrues);
  51. bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
  52. bool optimizePredicateStore(Instruction *I);
  53. bool optimizePredicateLoad(Instruction *I);
  54. bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
  55. /// Operates at the function-scope. I.e., optimizations are applied local to
  56. /// the functions themselves.
  57. bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
  58. };
  59. } // end anonymous namespace
  60. void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
  61. AU.addRequired<DominatorTreeWrapperPass>();
  62. AU.setPreservesCFG();
  63. }
  64. char SVEIntrinsicOpts::ID = 0;
  65. static const char *name = "SVE intrinsics optimizations";
  66. INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
  67. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
  68. INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
  69. ModulePass *llvm::createSVEIntrinsicOptsPass() {
  70. return new SVEIntrinsicOpts();
  71. }
  72. /// Checks if a ptrue intrinsic call is promoted. The act of promoting a
  73. /// ptrue will introduce zeroing. For example:
  74. ///
  75. /// %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
  76. /// %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
  77. /// %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
  78. ///
  79. /// %1 is promoted, because it is converted:
  80. ///
  81. /// <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
  82. ///
  83. /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
  84. static bool isPTruePromoted(IntrinsicInst *PTrue) {
  85. // Find all users of this intrinsic that are calls to convert-to-svbool
  86. // reinterpret intrinsics.
  87. SmallVector<IntrinsicInst *, 4> ConvertToUses;
  88. for (User *User : PTrue->users()) {
  89. if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
  90. ConvertToUses.push_back(cast<IntrinsicInst>(User));
  91. }
  92. }
  93. // If no such calls were found, this is ptrue is not promoted.
  94. if (ConvertToUses.empty())
  95. return false;
  96. // Otherwise, try to find users of the convert-to-svbool intrinsics that are
  97. // calls to the convert-from-svbool intrinsic, and would result in some lanes
  98. // being zeroed.
  99. const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
  100. for (IntrinsicInst *ConvertToUse : ConvertToUses) {
  101. for (User *User : ConvertToUse->users()) {
  102. auto *IntrUser = dyn_cast<IntrinsicInst>(User);
  103. if (IntrUser && IntrUser->getIntrinsicID() ==
  104. Intrinsic::aarch64_sve_convert_from_svbool) {
  105. const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
  106. // Would some lanes become zeroed by the conversion?
  107. if (IntrUserVTy->getElementCount().getKnownMinValue() >
  108. PTrueVTy->getElementCount().getKnownMinValue())
  109. // This is a promoted ptrue.
  110. return true;
  111. }
  112. }
  113. }
  114. // If no matching calls were found, this is not a promoted ptrue.
  115. return false;
  116. }
  117. /// Attempts to coalesce ptrues in a basic block.
  118. bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
  119. BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
  120. if (PTrues.size() <= 1)
  121. return false;
  122. // Find the ptrue with the most lanes.
  123. auto *MostEncompassingPTrue = *std::max_element(
  124. PTrues.begin(), PTrues.end(), [](auto *PTrue1, auto *PTrue2) {
  125. auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
  126. auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
  127. return PTrue1VTy->getElementCount().getKnownMinValue() <
  128. PTrue2VTy->getElementCount().getKnownMinValue();
  129. });
  130. // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
  131. // behind only the ptrues to be coalesced.
  132. PTrues.remove(MostEncompassingPTrue);
  133. PTrues.remove_if(isPTruePromoted);
  134. // Hoist MostEncompassingPTrue to the start of the basic block. It is always
  135. // safe to do this, since ptrue intrinsic calls are guaranteed to have no
  136. // predecessors.
  137. MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
  138. LLVMContext &Ctx = BB.getContext();
  139. IRBuilder<> Builder(Ctx);
  140. Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
  141. auto *MostEncompassingPTrueVTy =
  142. cast<VectorType>(MostEncompassingPTrue->getType());
  143. auto *ConvertToSVBool = Builder.CreateIntrinsic(
  144. Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
  145. {MostEncompassingPTrue});
  146. bool ConvertFromCreated = false;
  147. for (auto *PTrue : PTrues) {
  148. auto *PTrueVTy = cast<VectorType>(PTrue->getType());
  149. // Only create the converts if the types are not already the same, otherwise
  150. // just use the most encompassing ptrue.
  151. if (MostEncompassingPTrueVTy != PTrueVTy) {
  152. ConvertFromCreated = true;
  153. Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
  154. auto *ConvertFromSVBool =
  155. Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
  156. {PTrueVTy}, {ConvertToSVBool});
  157. PTrue->replaceAllUsesWith(ConvertFromSVBool);
  158. } else
  159. PTrue->replaceAllUsesWith(MostEncompassingPTrue);
  160. PTrue->eraseFromParent();
  161. }
  162. // We never used the ConvertTo so remove it
  163. if (!ConvertFromCreated)
  164. ConvertToSVBool->eraseFromParent();
  165. return true;
  166. }
  167. /// The goal of this function is to remove redundant calls to the SVE ptrue
  168. /// intrinsic in each basic block within the given functions.
  169. ///
  170. /// SVE ptrues have two representations in LLVM IR:
  171. /// - a logical representation -- an arbitrary-width scalable vector of i1s,
  172. /// i.e. <vscale x N x i1>.
  173. /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
  174. /// scalable vector of i1s, i.e. <vscale x 16 x i1>.
  175. ///
  176. /// The SVE ptrue intrinsic is used to create a logical representation of an SVE
  177. /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
  178. /// P1 creates a logical SVE predicate that is at least as wide as the logical
  179. /// SVE predicate created by P2, then all of the bits that are true in the
  180. /// physical representation of P2 are necessarily also true in the physical
  181. /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
  182. /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
  183. /// convert.{to,from}.svbool.
  184. ///
  185. /// Currently, this pass only coalesces calls to SVE ptrue intrinsics
  186. /// if they match the following conditions:
  187. ///
  188. /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
  189. /// SV_ALL indicates that all bits of the predicate vector are to be set to
  190. /// true. SV_POW2 indicates that all bits of the predicate vector up to the
  191. /// largest power-of-two are to be set to true.
  192. /// - the result of the call to the intrinsic is not promoted to a wider
  193. /// predicate. In this case, keeping the extra ptrue leads to better codegen
  194. /// -- coalescing here would create an irreducible chain of SVE reinterprets
  195. /// via convert.{to,from}.svbool.
  196. ///
  197. /// EXAMPLE:
  198. ///
  199. /// %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
  200. /// ; Logical: <1, 1, 1, 1, 1, 1, 1, 1>
  201. /// ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
  202. /// ...
  203. ///
  204. /// %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
  205. /// ; Logical: <1, 1, 1, 1>
  206. /// ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
  207. /// ...
  208. ///
  209. /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
  210. ///
  211. /// %1 = <vscale x 8 x i1> ptrue(i32 i31)
  212. /// %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
  213. /// %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
  214. ///
  215. bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
  216. SmallSetVector<Function *, 4> &Functions) {
  217. bool Changed = false;
  218. for (auto *F : Functions) {
  219. for (auto &BB : *F) {
  220. SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
  221. SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
  222. // For each basic block, collect the used ptrues and try to coalesce them.
  223. for (Instruction &I : BB) {
  224. if (I.use_empty())
  225. continue;
  226. auto *IntrI = dyn_cast<IntrinsicInst>(&I);
  227. if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
  228. continue;
  229. const auto PTruePattern =
  230. cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
  231. if (PTruePattern == AArch64SVEPredPattern::all)
  232. SVAllPTrues.insert(IntrI);
  233. if (PTruePattern == AArch64SVEPredPattern::pow2)
  234. SVPow2PTrues.insert(IntrI);
  235. }
  236. Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
  237. Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
  238. }
  239. }
  240. return Changed;
  241. }
  242. // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
  243. // scalable stores as late as possible
  244. bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
  245. auto *F = I->getFunction();
  246. auto Attr = F->getFnAttribute(Attribute::VScaleRange);
  247. if (!Attr.isValid())
  248. return false;
  249. unsigned MinVScale = Attr.getVScaleRangeMin();
  250. std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
  251. // The transform needs to know the exact runtime length of scalable vectors
  252. if (!MaxVScale || MinVScale != MaxVScale)
  253. return false;
  254. auto *PredType =
  255. ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
  256. auto *FixedPredType =
  257. FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
  258. // If we have a store..
  259. auto *Store = dyn_cast<StoreInst>(I);
  260. if (!Store || !Store->isSimple())
  261. return false;
  262. // ..that is storing a predicate vector sized worth of bits..
  263. if (Store->getOperand(0)->getType() != FixedPredType)
  264. return false;
  265. // ..where the value stored comes from a vector extract..
  266. auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
  267. if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
  268. return false;
  269. // ..that is extracting from index 0..
  270. if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
  271. return false;
  272. // ..where the value being extract from comes from a bitcast
  273. auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
  274. if (!BitCast)
  275. return false;
  276. // ..and the bitcast is casting from predicate type
  277. if (BitCast->getOperand(0)->getType() != PredType)
  278. return false;
  279. IRBuilder<> Builder(I->getContext());
  280. Builder.SetInsertPoint(I);
  281. auto *PtrBitCast = Builder.CreateBitCast(
  282. Store->getPointerOperand(),
  283. PredType->getPointerTo(Store->getPointerAddressSpace()));
  284. Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
  285. Store->eraseFromParent();
  286. if (IntrI->getNumUses() == 0)
  287. IntrI->eraseFromParent();
  288. if (BitCast->getNumUses() == 0)
  289. BitCast->eraseFromParent();
  290. return true;
  291. }
  292. // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
  293. // scalable loads as late as possible
  294. bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
  295. auto *F = I->getFunction();
  296. auto Attr = F->getFnAttribute(Attribute::VScaleRange);
  297. if (!Attr.isValid())
  298. return false;
  299. unsigned MinVScale = Attr.getVScaleRangeMin();
  300. std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
  301. // The transform needs to know the exact runtime length of scalable vectors
  302. if (!MaxVScale || MinVScale != MaxVScale)
  303. return false;
  304. auto *PredType =
  305. ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
  306. auto *FixedPredType =
  307. FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
  308. // If we have a bitcast..
  309. auto *BitCast = dyn_cast<BitCastInst>(I);
  310. if (!BitCast || BitCast->getType() != PredType)
  311. return false;
  312. // ..whose operand is a vector_insert..
  313. auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
  314. if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
  315. return false;
  316. // ..that is inserting into index zero of an undef vector..
  317. if (!isa<UndefValue>(IntrI->getOperand(0)) ||
  318. !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
  319. return false;
  320. // ..where the value inserted comes from a load..
  321. auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
  322. if (!Load || !Load->isSimple())
  323. return false;
  324. // ..that is loading a predicate vector sized worth of bits..
  325. if (Load->getType() != FixedPredType)
  326. return false;
  327. IRBuilder<> Builder(I->getContext());
  328. Builder.SetInsertPoint(Load);
  329. auto *PtrBitCast = Builder.CreateBitCast(
  330. Load->getPointerOperand(),
  331. PredType->getPointerTo(Load->getPointerAddressSpace()));
  332. auto *LoadPred = Builder.CreateLoad(PredType, PtrBitCast);
  333. BitCast->replaceAllUsesWith(LoadPred);
  334. BitCast->eraseFromParent();
  335. if (IntrI->getNumUses() == 0)
  336. IntrI->eraseFromParent();
  337. if (Load->getNumUses() == 0)
  338. Load->eraseFromParent();
  339. return true;
  340. }
  341. bool SVEIntrinsicOpts::optimizeInstructions(
  342. SmallSetVector<Function *, 4> &Functions) {
  343. bool Changed = false;
  344. for (auto *F : Functions) {
  345. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
  346. // Traverse the DT with an rpo walk so we see defs before uses, allowing
  347. // simplification to be done incrementally.
  348. BasicBlock *Root = DT->getRoot();
  349. ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
  350. for (auto *BB : RPOT) {
  351. for (Instruction &I : make_early_inc_range(*BB)) {
  352. switch (I.getOpcode()) {
  353. case Instruction::Store:
  354. Changed |= optimizePredicateStore(&I);
  355. break;
  356. case Instruction::BitCast:
  357. Changed |= optimizePredicateLoad(&I);
  358. break;
  359. }
  360. }
  361. }
  362. }
  363. return Changed;
  364. }
  365. bool SVEIntrinsicOpts::optimizeFunctions(
  366. SmallSetVector<Function *, 4> &Functions) {
  367. bool Changed = false;
  368. Changed |= optimizePTrueIntrinsicCalls(Functions);
  369. Changed |= optimizeInstructions(Functions);
  370. return Changed;
  371. }
  372. bool SVEIntrinsicOpts::runOnModule(Module &M) {
  373. bool Changed = false;
  374. SmallSetVector<Function *, 4> Functions;
  375. // Check for SVE intrinsic declarations first so that we only iterate over
  376. // relevant functions. Where an appropriate declaration is found, store the
  377. // function(s) where it is used so we can target these only.
  378. for (auto &F : M.getFunctionList()) {
  379. if (!F.isDeclaration())
  380. continue;
  381. switch (F.getIntrinsicID()) {
  382. case Intrinsic::vector_extract:
  383. case Intrinsic::vector_insert:
  384. case Intrinsic::aarch64_sve_ptrue:
  385. for (User *U : F.users())
  386. Functions.insert(cast<Instruction>(U)->getFunction());
  387. break;
  388. default:
  389. break;
  390. }
  391. }
  392. if (!Functions.empty())
  393. Changed |= optimizeFunctions(Functions);
  394. return Changed;
  395. }