AlignmentFromAssumptions.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
  2. // Set Load/Store Alignments From Assumptions
  3. //
  4. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  5. // See https://llvm.org/LICENSE.txt for license information.
  6. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This file implements a ScalarEvolution-based transformation to set
  11. // the alignments of load, stores and memory intrinsics based on the truth
  12. // expressions of assume intrinsics. The primary motivation is to handle
  13. // complex alignment assumptions that apply to vector loads and stores that
  14. // appear after vectorization and unrolling.
  15. //
  16. //===----------------------------------------------------------------------===//
  17. #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
  18. #include "llvm/ADT/SmallPtrSet.h"
  19. #include "llvm/ADT/Statistic.h"
  20. #include "llvm/Analysis/AliasAnalysis.h"
  21. #include "llvm/Analysis/AssumptionCache.h"
  22. #include "llvm/Analysis/GlobalsModRef.h"
  23. #include "llvm/Analysis/LoopInfo.h"
  24. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  25. #include "llvm/Analysis/ValueTracking.h"
  26. #include "llvm/IR/Dominators.h"
  27. #include "llvm/IR/Instruction.h"
  28. #include "llvm/IR/Instructions.h"
  29. #include "llvm/IR/IntrinsicInst.h"
  30. #include "llvm/InitializePasses.h"
  31. #include "llvm/Support/Debug.h"
  32. #include "llvm/Support/raw_ostream.h"
  33. #include "llvm/Transforms/Scalar.h"
  34. #define AA_NAME "alignment-from-assumptions"
  35. #define DEBUG_TYPE AA_NAME
  36. using namespace llvm;
  37. STATISTIC(NumLoadAlignChanged,
  38. "Number of loads changed by alignment assumptions");
  39. STATISTIC(NumStoreAlignChanged,
  40. "Number of stores changed by alignment assumptions");
  41. STATISTIC(NumMemIntAlignChanged,
  42. "Number of memory intrinsics changed by alignment assumptions");
  43. namespace {
  44. struct AlignmentFromAssumptions : public FunctionPass {
  45. static char ID; // Pass identification, replacement for typeid
  46. AlignmentFromAssumptions() : FunctionPass(ID) {
  47. initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
  48. }
  49. bool runOnFunction(Function &F) override;
  50. void getAnalysisUsage(AnalysisUsage &AU) const override {
  51. AU.addRequired<AssumptionCacheTracker>();
  52. AU.addRequired<ScalarEvolutionWrapperPass>();
  53. AU.addRequired<DominatorTreeWrapperPass>();
  54. AU.setPreservesCFG();
  55. AU.addPreserved<AAResultsWrapperPass>();
  56. AU.addPreserved<GlobalsAAWrapperPass>();
  57. AU.addPreserved<LoopInfoWrapperPass>();
  58. AU.addPreserved<DominatorTreeWrapperPass>();
  59. AU.addPreserved<ScalarEvolutionWrapperPass>();
  60. }
  61. AlignmentFromAssumptionsPass Impl;
  62. };
  63. }
  64. char AlignmentFromAssumptions::ID = 0;
  65. static const char aip_name[] = "Alignment from assumptions";
  66. INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
  67. aip_name, false, false)
  68. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  69. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  70. INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
  71. INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
  72. aip_name, false, false)
  73. FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
  74. return new AlignmentFromAssumptions();
  75. }
  76. // Given an expression for the (constant) alignment, AlignSCEV, and an
  77. // expression for the displacement between a pointer and the aligned address,
  78. // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
  79. // to a constant. Using SCEV to compute alignment handles the case where
  80. // DiffSCEV is a recurrence with constant start such that the aligned offset
  81. // is constant. e.g. {16,+,32} % 32 -> 16.
  82. static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV,
  83. const SCEV *AlignSCEV,
  84. ScalarEvolution *SE) {
  85. // DiffUnits = Diff % int64_t(Alignment)
  86. const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
  87. LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "
  88. << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
  89. if (const SCEVConstant *ConstDUSCEV =
  90. dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
  91. int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
  92. // If the displacement is an exact multiple of the alignment, then the
  93. // displaced pointer has the same alignment as the aligned pointer, so
  94. // return the alignment value.
  95. if (!DiffUnits)
  96. return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();
  97. // If the displacement is not an exact multiple, but the remainder is a
  98. // constant, then return this remainder (but only if it is a power of 2).
  99. uint64_t DiffUnitsAbs = std::abs(DiffUnits);
  100. if (isPowerOf2_64(DiffUnitsAbs))
  101. return Align(DiffUnitsAbs);
  102. }
  103. return std::nullopt;
  104. }
  105. // There is an address given by an offset OffSCEV from AASCEV which has an
  106. // alignment AlignSCEV. Use that information, if possible, to compute a new
  107. // alignment for Ptr.
  108. static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
  109. const SCEV *OffSCEV, Value *Ptr,
  110. ScalarEvolution *SE) {
  111. const SCEV *PtrSCEV = SE->getSCEV(Ptr);
  112. // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
  113. // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
  114. // may disagree. Trunc/extend so they agree.
  115. PtrSCEV = SE->getTruncateOrZeroExtend(
  116. PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));
  117. const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
  118. if (isa<SCEVCouldNotCompute>(DiffSCEV))
  119. return Align(1);
  120. // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
  121. // sign-extended OffSCEV to i64, so make sure they agree again.
  122. DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
  123. // What we really want to know is the overall offset to the aligned
  124. // address. This address is displaced by the provided offset.
  125. DiffSCEV = SE->getAddExpr(DiffSCEV, OffSCEV);
  126. LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to "
  127. << *AlignSCEV << " and offset " << *OffSCEV
  128. << " using diff " << *DiffSCEV << "\n");
  129. if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) {
  130. LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n");
  131. return *NewAlignment;
  132. }
  133. if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
  134. // The relative offset to the alignment assumption did not yield a constant,
  135. // but we should try harder: if we assume that a is 32-byte aligned, then in
  136. // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
  137. // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
  138. // As a result, the new alignment will not be a constant, but can still
  139. // be improved over the default (of 4) to 16.
  140. const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
  141. const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
  142. LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
  143. << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
  144. // Now compute the new alignment using the displacement to the value in the
  145. // first iteration, and also the alignment using the per-iteration delta.
  146. // If these are the same, then use that answer. Otherwise, use the smaller
  147. // one, but only if it divides the larger one.
  148. MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
  149. MaybeAlign NewIncAlignment =
  150. getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
  151. LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment)
  152. << "\n");
  153. LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment)
  154. << "\n");
  155. if (!NewAlignment || !NewIncAlignment)
  156. return Align(1);
  157. const Align NewAlign = *NewAlignment;
  158. const Align NewIncAlign = *NewIncAlignment;
  159. if (NewAlign > NewIncAlign) {
  160. LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
  161. << DebugStr(NewIncAlign) << "\n");
  162. return NewIncAlign;
  163. }
  164. if (NewIncAlign > NewAlign) {
  165. LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
  166. << "\n");
  167. return NewAlign;
  168. }
  169. assert(NewIncAlign == NewAlign);
  170. LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
  171. << "\n");
  172. return NewAlign;
  173. }
  174. return Align(1);
  175. }
  176. bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
  177. unsigned Idx,
  178. Value *&AAPtr,
  179. const SCEV *&AlignSCEV,
  180. const SCEV *&OffSCEV) {
  181. Type *Int64Ty = Type::getInt64Ty(I->getContext());
  182. OperandBundleUse AlignOB = I->getOperandBundleAt(Idx);
  183. if (AlignOB.getTagName() != "align")
  184. return false;
  185. assert(AlignOB.Inputs.size() >= 2);
  186. AAPtr = AlignOB.Inputs[0].get();
  187. // TODO: Consider accumulating the offset to the base.
  188. AAPtr = AAPtr->stripPointerCastsSameRepresentation();
  189. AlignSCEV = SE->getSCEV(AlignOB.Inputs[1].get());
  190. AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty);
  191. if (!isa<SCEVConstant>(AlignSCEV))
  192. // Added to suppress a crash because consumer doesn't expect non-constant
  193. // alignments in the assume bundle. TODO: Consider generalizing caller.
  194. return false;
  195. if (AlignOB.Inputs.size() == 3)
  196. OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get());
  197. else
  198. OffSCEV = SE->getZero(Int64Ty);
  199. OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);
  200. return true;
  201. }
  202. bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall,
  203. unsigned Idx) {
  204. Value *AAPtr;
  205. const SCEV *AlignSCEV, *OffSCEV;
  206. if (!extractAlignmentInfo(ACall, Idx, AAPtr, AlignSCEV, OffSCEV))
  207. return false;
  208. // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
  209. // affect other users.
  210. if (isa<ConstantData>(AAPtr))
  211. return false;
  212. const SCEV *AASCEV = SE->getSCEV(AAPtr);
  213. // Apply the assumption to all other users of the specified pointer.
  214. SmallPtrSet<Instruction *, 32> Visited;
  215. SmallVector<Instruction*, 16> WorkList;
  216. for (User *J : AAPtr->users()) {
  217. if (J == ACall)
  218. continue;
  219. if (Instruction *K = dyn_cast<Instruction>(J))
  220. WorkList.push_back(K);
  221. }
  222. while (!WorkList.empty()) {
  223. Instruction *J = WorkList.pop_back_val();
  224. if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
  225. if (!isValidAssumeForContext(ACall, J, DT))
  226. continue;
  227. Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
  228. LI->getPointerOperand(), SE);
  229. if (NewAlignment > LI->getAlign()) {
  230. LI->setAlignment(NewAlignment);
  231. ++NumLoadAlignChanged;
  232. }
  233. } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
  234. if (!isValidAssumeForContext(ACall, J, DT))
  235. continue;
  236. Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
  237. SI->getPointerOperand(), SE);
  238. if (NewAlignment > SI->getAlign()) {
  239. SI->setAlignment(NewAlignment);
  240. ++NumStoreAlignChanged;
  241. }
  242. } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
  243. if (!isValidAssumeForContext(ACall, J, DT))
  244. continue;
  245. Align NewDestAlignment =
  246. getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
  247. LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment)
  248. << "\n";);
  249. if (NewDestAlignment > *MI->getDestAlign()) {
  250. MI->setDestAlignment(NewDestAlignment);
  251. ++NumMemIntAlignChanged;
  252. }
  253. // For memory transfers, there is also a source alignment that
  254. // can be set.
  255. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
  256. Align NewSrcAlignment =
  257. getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
  258. LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment)
  259. << "\n";);
  260. if (NewSrcAlignment > *MTI->getSourceAlign()) {
  261. MTI->setSourceAlignment(NewSrcAlignment);
  262. ++NumMemIntAlignChanged;
  263. }
  264. }
  265. }
  266. // Now that we've updated that use of the pointer, look for other uses of
  267. // the pointer to update.
  268. Visited.insert(J);
  269. for (User *UJ : J->users()) {
  270. Instruction *K = cast<Instruction>(UJ);
  271. if (!Visited.count(K))
  272. WorkList.push_back(K);
  273. }
  274. }
  275. return true;
  276. }
  277. bool AlignmentFromAssumptions::runOnFunction(Function &F) {
  278. if (skipFunction(F))
  279. return false;
  280. auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  281. ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  282. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  283. return Impl.runImpl(F, AC, SE, DT);
  284. }
  285. bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
  286. ScalarEvolution *SE_,
  287. DominatorTree *DT_) {
  288. SE = SE_;
  289. DT = DT_;
  290. bool Changed = false;
  291. for (auto &AssumeVH : AC.assumptions())
  292. if (AssumeVH) {
  293. CallInst *Call = cast<CallInst>(AssumeVH);
  294. for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)
  295. Changed |= processAssumption(Call, Idx);
  296. }
  297. return Changed;
  298. }
  299. PreservedAnalyses
  300. AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {
  301. AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
  302. ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
  303. DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
  304. if (!runImpl(F, AC, &SE, &DT))
  305. return PreservedAnalyses::all();
  306. PreservedAnalyses PA;
  307. PA.preserveSet<CFGAnalyses>();
  308. PA.preserve<ScalarEvolutionAnalysis>();
  309. return PA;
  310. }