Scalarizer.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. //===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass converts vector operations into scalar operations, in order
  10. // to expose optimization opportunities on the individual scalar operations.
  11. // It is mainly intended for targets that do not have vector units, but it
  12. // may also be useful for revectorizing code to different vector widths.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/Transforms/Scalar/Scalarizer.h"
  16. #include "llvm/ADT/PostOrderIterator.h"
  17. #include "llvm/ADT/SmallVector.h"
  18. #include "llvm/ADT/Twine.h"
  19. #include "llvm/Analysis/VectorUtils.h"
  20. #include "llvm/IR/Argument.h"
  21. #include "llvm/IR/BasicBlock.h"
  22. #include "llvm/IR/Constants.h"
  23. #include "llvm/IR/DataLayout.h"
  24. #include "llvm/IR/DerivedTypes.h"
  25. #include "llvm/IR/Dominators.h"
  26. #include "llvm/IR/Function.h"
  27. #include "llvm/IR/IRBuilder.h"
  28. #include "llvm/IR/InstVisitor.h"
  29. #include "llvm/IR/InstrTypes.h"
  30. #include "llvm/IR/Instruction.h"
  31. #include "llvm/IR/Instructions.h"
  32. #include "llvm/IR/Intrinsics.h"
  33. #include "llvm/IR/LLVMContext.h"
  34. #include "llvm/IR/Module.h"
  35. #include "llvm/IR/Type.h"
  36. #include "llvm/IR/Value.h"
  37. #include "llvm/InitializePasses.h"
  38. #include "llvm/Pass.h"
  39. #include "llvm/Support/Casting.h"
  40. #include "llvm/Support/CommandLine.h"
  41. #include "llvm/Support/MathExtras.h"
  42. #include "llvm/Transforms/Scalar.h"
  43. #include "llvm/Transforms/Utils/Local.h"
  44. #include <cassert>
  45. #include <cstdint>
  46. #include <iterator>
  47. #include <map>
  48. #include <utility>
  49. using namespace llvm;
  50. #define DEBUG_TYPE "scalarizer"
  51. static cl::opt<bool> ScalarizeVariableInsertExtract(
  52. "scalarize-variable-insert-extract", cl::init(true), cl::Hidden,
  53. cl::desc("Allow the scalarizer pass to scalarize "
  54. "insertelement/extractelement with variable index"));
  55. // This is disabled by default because having separate loads and stores
  56. // makes it more likely that the -combiner-alias-analysis limits will be
  57. // reached.
  58. static cl::opt<bool>
  59. ScalarizeLoadStore("scalarize-load-store", cl::init(false), cl::Hidden,
  60. cl::desc("Allow the scalarizer pass to scalarize loads and store"));
  61. namespace {
  62. BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) {
  63. BasicBlock *BB = Itr->getParent();
  64. if (isa<PHINode>(Itr))
  65. Itr = BB->getFirstInsertionPt();
  66. if (Itr != BB->end())
  67. Itr = skipDebugIntrinsics(Itr);
  68. return Itr;
  69. }
  70. // Used to store the scattered form of a vector.
  71. using ValueVector = SmallVector<Value *, 8>;
  72. // Used to map a vector Value to its scattered form. We use std::map
  73. // because we want iterators to persist across insertion and because the
  74. // values are relatively large.
  75. using ScatterMap = std::map<Value *, ValueVector>;
  76. // Lists Instructions that have been replaced with scalar implementations,
  77. // along with a pointer to their scattered forms.
  78. using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
  79. // Provides a very limited vector-like interface for lazily accessing one
  80. // component of a scattered vector or vector pointer.
  81. class Scatterer {
  82. public:
  83. Scatterer() = default;
  84. // Scatter V into Size components. If new instructions are needed,
  85. // insert them before BBI in BB. If Cache is nonnull, use it to cache
  86. // the results.
  87. Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
  88. ValueVector *cachePtr = nullptr);
  89. // Return component I, creating a new Value for it if necessary.
  90. Value *operator[](unsigned I);
  91. // Return the number of components.
  92. unsigned size() const { return Size; }
  93. private:
  94. BasicBlock *BB;
  95. BasicBlock::iterator BBI;
  96. Value *V;
  97. ValueVector *CachePtr;
  98. PointerType *PtrTy;
  99. ValueVector Tmp;
  100. unsigned Size;
  101. };
  102. // FCmpSpliiter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
  103. // called Name that compares X and Y in the same way as FCI.
  104. struct FCmpSplitter {
  105. FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
  106. Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
  107. const Twine &Name) const {
  108. return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
  109. }
  110. FCmpInst &FCI;
  111. };
  112. // ICmpSpliiter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
  113. // called Name that compares X and Y in the same way as ICI.
  114. struct ICmpSplitter {
  115. ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
  116. Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
  117. const Twine &Name) const {
  118. return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
  119. }
  120. ICmpInst &ICI;
  121. };
  122. // UnarySpliiter(UO)(Builder, X, Name) uses Builder to create
  123. // a unary operator like UO called Name with operand X.
  124. struct UnarySplitter {
  125. UnarySplitter(UnaryOperator &uo) : UO(uo) {}
  126. Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const {
  127. return Builder.CreateUnOp(UO.getOpcode(), Op, Name);
  128. }
  129. UnaryOperator &UO;
  130. };
  131. // BinarySpliiter(BO)(Builder, X, Y, Name) uses Builder to create
  132. // a binary operator like BO called Name with operands X and Y.
  133. struct BinarySplitter {
  134. BinarySplitter(BinaryOperator &bo) : BO(bo) {}
  135. Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
  136. const Twine &Name) const {
  137. return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
  138. }
  139. BinaryOperator &BO;
  140. };
  141. // Information about a load or store that we're scalarizing.
  142. struct VectorLayout {
  143. VectorLayout() = default;
  144. // Return the alignment of element I.
  145. Align getElemAlign(unsigned I) {
  146. return commonAlignment(VecAlign, I * ElemSize);
  147. }
  148. // The type of the vector.
  149. VectorType *VecTy = nullptr;
  150. // The type of each element.
  151. Type *ElemTy = nullptr;
  152. // The alignment of the vector.
  153. Align VecAlign;
  154. // The size of each element.
  155. uint64_t ElemSize = 0;
  156. };
  157. class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
  158. public:
  159. ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT)
  160. : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT) {
  161. }
  162. bool visit(Function &F);
  163. // InstVisitor methods. They return true if the instruction was scalarized,
  164. // false if nothing changed.
  165. bool visitInstruction(Instruction &I) { return false; }
  166. bool visitSelectInst(SelectInst &SI);
  167. bool visitICmpInst(ICmpInst &ICI);
  168. bool visitFCmpInst(FCmpInst &FCI);
  169. bool visitUnaryOperator(UnaryOperator &UO);
  170. bool visitBinaryOperator(BinaryOperator &BO);
  171. bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
  172. bool visitCastInst(CastInst &CI);
  173. bool visitBitCastInst(BitCastInst &BCI);
  174. bool visitInsertElementInst(InsertElementInst &IEI);
  175. bool visitExtractElementInst(ExtractElementInst &EEI);
  176. bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
  177. bool visitPHINode(PHINode &PHI);
  178. bool visitLoadInst(LoadInst &LI);
  179. bool visitStoreInst(StoreInst &SI);
  180. bool visitCallInst(CallInst &ICI);
  181. private:
  182. Scatterer scatter(Instruction *Point, Value *V);
  183. void gather(Instruction *Op, const ValueVector &CV);
  184. bool canTransferMetadata(unsigned Kind);
  185. void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
  186. Optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
  187. const DataLayout &DL);
  188. bool finish();
  189. template<typename T> bool splitUnary(Instruction &, const T &);
  190. template<typename T> bool splitBinary(Instruction &, const T &);
  191. bool splitCall(CallInst &CI);
  192. ScatterMap Scattered;
  193. GatherList Gathered;
  194. SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
  195. unsigned ParallelLoopAccessMDKind;
  196. DominatorTree *DT;
  197. };
  198. class ScalarizerLegacyPass : public FunctionPass {
  199. public:
  200. static char ID;
  201. ScalarizerLegacyPass() : FunctionPass(ID) {
  202. initializeScalarizerLegacyPassPass(*PassRegistry::getPassRegistry());
  203. }
  204. bool runOnFunction(Function &F) override;
  205. void getAnalysisUsage(AnalysisUsage& AU) const override {
  206. AU.addRequired<DominatorTreeWrapperPass>();
  207. AU.addPreserved<DominatorTreeWrapperPass>();
  208. }
  209. };
  210. } // end anonymous namespace
  211. char ScalarizerLegacyPass::ID = 0;
  212. INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
  213. "Scalarize vector operations", false, false)
  214. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  215. INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
  216. "Scalarize vector operations", false, false)
  217. Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
  218. ValueVector *cachePtr)
  219. : BB(bb), BBI(bbi), V(v), CachePtr(cachePtr) {
  220. Type *Ty = V->getType();
  221. PtrTy = dyn_cast<PointerType>(Ty);
  222. if (PtrTy)
  223. Ty = PtrTy->getPointerElementType();
  224. Size = cast<FixedVectorType>(Ty)->getNumElements();
  225. if (!CachePtr)
  226. Tmp.resize(Size, nullptr);
  227. else if (CachePtr->empty())
  228. CachePtr->resize(Size, nullptr);
  229. else
  230. assert(Size == CachePtr->size() && "Inconsistent vector sizes");
  231. }
  232. // Return component I, creating a new Value for it if necessary.
  233. Value *Scatterer::operator[](unsigned I) {
  234. ValueVector &CV = (CachePtr ? *CachePtr : Tmp);
  235. // Try to reuse a previous value.
  236. if (CV[I])
  237. return CV[I];
  238. IRBuilder<> Builder(BB, BBI);
  239. if (PtrTy) {
  240. Type *ElTy =
  241. cast<VectorType>(PtrTy->getPointerElementType())->getElementType();
  242. if (!CV[0]) {
  243. Type *NewPtrTy = PointerType::get(ElTy, PtrTy->getAddressSpace());
  244. CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0");
  245. }
  246. if (I != 0)
  247. CV[I] = Builder.CreateConstGEP1_32(ElTy, CV[0], I,
  248. V->getName() + ".i" + Twine(I));
  249. } else {
  250. // Search through a chain of InsertElementInsts looking for element I.
  251. // Record other elements in the cache. The new V is still suitable
  252. // for all uncached indices.
  253. while (true) {
  254. InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
  255. if (!Insert)
  256. break;
  257. ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
  258. if (!Idx)
  259. break;
  260. unsigned J = Idx->getZExtValue();
  261. V = Insert->getOperand(0);
  262. if (I == J) {
  263. CV[J] = Insert->getOperand(1);
  264. return CV[J];
  265. } else if (!CV[J]) {
  266. // Only cache the first entry we find for each index we're not actively
  267. // searching for. This prevents us from going too far up the chain and
  268. // caching incorrect entries.
  269. CV[J] = Insert->getOperand(1);
  270. }
  271. }
  272. CV[I] = Builder.CreateExtractElement(V, Builder.getInt32(I),
  273. V->getName() + ".i" + Twine(I));
  274. }
  275. return CV[I];
  276. }
  277. bool ScalarizerLegacyPass::runOnFunction(Function &F) {
  278. if (skipFunction(F))
  279. return false;
  280. Module &M = *F.getParent();
  281. unsigned ParallelLoopAccessMDKind =
  282. M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
  283. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  284. ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT);
  285. return Impl.visit(F);
  286. }
  287. FunctionPass *llvm::createScalarizerPass() {
  288. return new ScalarizerLegacyPass();
  289. }
  290. bool ScalarizerVisitor::visit(Function &F) {
  291. assert(Gathered.empty() && Scattered.empty());
  292. // To ensure we replace gathered components correctly we need to do an ordered
  293. // traversal of the basic blocks in the function.
  294. ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
  295. for (BasicBlock *BB : RPOT) {
  296. for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
  297. Instruction *I = &*II;
  298. bool Done = InstVisitor::visit(I);
  299. ++II;
  300. if (Done && I->getType()->isVoidTy())
  301. I->eraseFromParent();
  302. }
  303. }
  304. return finish();
  305. }
  306. // Return a scattered form of V that can be accessed by Point. V must be a
  307. // vector or a pointer to a vector.
  308. Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) {
  309. if (Argument *VArg = dyn_cast<Argument>(V)) {
  310. // Put the scattered form of arguments in the entry block,
  311. // so that it can be used everywhere.
  312. Function *F = VArg->getParent();
  313. BasicBlock *BB = &F->getEntryBlock();
  314. return Scatterer(BB, BB->begin(), V, &Scattered[V]);
  315. }
  316. if (Instruction *VOp = dyn_cast<Instruction>(V)) {
  317. // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
  318. // nodes in predecessors. If those predecessors are unreachable from entry,
  319. // then the IR in those blocks could have unexpected properties resulting in
  320. // infinite loops in Scatterer::operator[]. By simply treating values
  321. // originating from instructions in unreachable blocks as undef we do not
  322. // need to analyse them further.
  323. if (!DT->isReachableFromEntry(VOp->getParent()))
  324. return Scatterer(Point->getParent(), Point->getIterator(),
  325. UndefValue::get(V->getType()));
  326. // Put the scattered form of an instruction directly after the
  327. // instruction, skipping over PHI nodes and debug intrinsics.
  328. BasicBlock *BB = VOp->getParent();
  329. return Scatterer(
  330. BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V,
  331. &Scattered[V]);
  332. }
  333. // In the fallback case, just put the scattered before Point and
  334. // keep the result local to Point.
  335. return Scatterer(Point->getParent(), Point->getIterator(), V);
  336. }
  337. // Replace Op with the gathered form of the components in CV. Defer the
  338. // deletion of Op and creation of the gathered form to the end of the pass,
  339. // so that we can avoid creating the gathered form if all uses of Op are
  340. // replaced with uses of CV.
  341. void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) {
  342. transferMetadataAndIRFlags(Op, CV);
  343. // If we already have a scattered form of Op (created from ExtractElements
  344. // of Op itself), replace them with the new form.
  345. ValueVector &SV = Scattered[Op];
  346. if (!SV.empty()) {
  347. for (unsigned I = 0, E = SV.size(); I != E; ++I) {
  348. Value *V = SV[I];
  349. if (V == nullptr || SV[I] == CV[I])
  350. continue;
  351. Instruction *Old = cast<Instruction>(V);
  352. if (isa<Instruction>(CV[I]))
  353. CV[I]->takeName(Old);
  354. Old->replaceAllUsesWith(CV[I]);
  355. PotentiallyDeadInstrs.emplace_back(Old);
  356. }
  357. }
  358. SV = CV;
  359. Gathered.push_back(GatherList::value_type(Op, &SV));
  360. }
  361. // Return true if it is safe to transfer the given metadata tag from
  362. // vector to scalar instructions.
  363. bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
  364. return (Tag == LLVMContext::MD_tbaa
  365. || Tag == LLVMContext::MD_fpmath
  366. || Tag == LLVMContext::MD_tbaa_struct
  367. || Tag == LLVMContext::MD_invariant_load
  368. || Tag == LLVMContext::MD_alias_scope
  369. || Tag == LLVMContext::MD_noalias
  370. || Tag == ParallelLoopAccessMDKind
  371. || Tag == LLVMContext::MD_access_group);
  372. }
  373. // Transfer metadata from Op to the instructions in CV if it is known
  374. // to be safe to do so.
  375. void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
  376. const ValueVector &CV) {
  377. SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
  378. Op->getAllMetadataOtherThanDebugLoc(MDs);
  379. for (unsigned I = 0, E = CV.size(); I != E; ++I) {
  380. if (Instruction *New = dyn_cast<Instruction>(CV[I])) {
  381. for (const auto &MD : MDs)
  382. if (canTransferMetadata(MD.first))
  383. New->setMetadata(MD.first, MD.second);
  384. New->copyIRFlags(Op);
  385. if (Op->getDebugLoc() && !New->getDebugLoc())
  386. New->setDebugLoc(Op->getDebugLoc());
  387. }
  388. }
  389. }
  390. // Try to fill in Layout from Ty, returning true on success. Alignment is
  391. // the alignment of the vector, or None if the ABI default should be used.
  392. Optional<VectorLayout>
  393. ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
  394. const DataLayout &DL) {
  395. VectorLayout Layout;
  396. // Make sure we're dealing with a vector.
  397. Layout.VecTy = dyn_cast<VectorType>(Ty);
  398. if (!Layout.VecTy)
  399. return None;
  400. // Check that we're dealing with full-byte elements.
  401. Layout.ElemTy = Layout.VecTy->getElementType();
  402. if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy))
  403. return None;
  404. Layout.VecAlign = Alignment;
  405. Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy);
  406. return Layout;
  407. }
  408. // Scalarize one-operand instruction I, using Split(Builder, X, Name)
  409. // to create an instruction like I with operand X and name Name.
  410. template<typename Splitter>
  411. bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
  412. VectorType *VT = dyn_cast<VectorType>(I.getType());
  413. if (!VT)
  414. return false;
  415. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  416. IRBuilder<> Builder(&I);
  417. Scatterer Op = scatter(&I, I.getOperand(0));
  418. assert(Op.size() == NumElems && "Mismatched unary operation");
  419. ValueVector Res;
  420. Res.resize(NumElems);
  421. for (unsigned Elem = 0; Elem < NumElems; ++Elem)
  422. Res[Elem] = Split(Builder, Op[Elem], I.getName() + ".i" + Twine(Elem));
  423. gather(&I, Res);
  424. return true;
  425. }
  426. // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
  427. // to create an instruction like I with operands X and Y and name Name.
  428. template<typename Splitter>
  429. bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
  430. VectorType *VT = dyn_cast<VectorType>(I.getType());
  431. if (!VT)
  432. return false;
  433. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  434. IRBuilder<> Builder(&I);
  435. Scatterer VOp0 = scatter(&I, I.getOperand(0));
  436. Scatterer VOp1 = scatter(&I, I.getOperand(1));
  437. assert(VOp0.size() == NumElems && "Mismatched binary operation");
  438. assert(VOp1.size() == NumElems && "Mismatched binary operation");
  439. ValueVector Res;
  440. Res.resize(NumElems);
  441. for (unsigned Elem = 0; Elem < NumElems; ++Elem) {
  442. Value *Op0 = VOp0[Elem];
  443. Value *Op1 = VOp1[Elem];
  444. Res[Elem] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Elem));
  445. }
  446. gather(&I, Res);
  447. return true;
  448. }
  449. static bool isTriviallyScalariable(Intrinsic::ID ID) {
  450. return isTriviallyVectorizable(ID);
  451. }
  452. // All of the current scalarizable intrinsics only have one mangled type.
  453. static Function *getScalarIntrinsicDeclaration(Module *M,
  454. Intrinsic::ID ID,
  455. ArrayRef<Type*> Tys) {
  456. return Intrinsic::getDeclaration(M, ID, Tys);
  457. }
  458. /// If a call to a vector typed intrinsic function, split into a scalar call per
  459. /// element if possible for the intrinsic.
  460. bool ScalarizerVisitor::splitCall(CallInst &CI) {
  461. VectorType *VT = dyn_cast<VectorType>(CI.getType());
  462. if (!VT)
  463. return false;
  464. Function *F = CI.getCalledFunction();
  465. if (!F)
  466. return false;
  467. Intrinsic::ID ID = F->getIntrinsicID();
  468. if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
  469. return false;
  470. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  471. unsigned NumArgs = CI.arg_size();
  472. ValueVector ScalarOperands(NumArgs);
  473. SmallVector<Scatterer, 8> Scattered(NumArgs);
  474. Scattered.resize(NumArgs);
  475. SmallVector<llvm::Type *, 3> Tys;
  476. Tys.push_back(VT->getScalarType());
  477. // Assumes that any vector type has the same number of elements as the return
  478. // vector type, which is true for all current intrinsics.
  479. for (unsigned I = 0; I != NumArgs; ++I) {
  480. Value *OpI = CI.getOperand(I);
  481. if (OpI->getType()->isVectorTy()) {
  482. Scattered[I] = scatter(&CI, OpI);
  483. assert(Scattered[I].size() == NumElems && "mismatched call operands");
  484. } else {
  485. ScalarOperands[I] = OpI;
  486. if (hasVectorInstrinsicOverloadedScalarOpd(ID, I))
  487. Tys.push_back(OpI->getType());
  488. }
  489. }
  490. ValueVector Res(NumElems);
  491. ValueVector ScalarCallOps(NumArgs);
  492. Function *NewIntrin = getScalarIntrinsicDeclaration(F->getParent(), ID, Tys);
  493. IRBuilder<> Builder(&CI);
  494. // Perform actual scalarization, taking care to preserve any scalar operands.
  495. for (unsigned Elem = 0; Elem < NumElems; ++Elem) {
  496. ScalarCallOps.clear();
  497. for (unsigned J = 0; J != NumArgs; ++J) {
  498. if (hasVectorInstrinsicScalarOpd(ID, J))
  499. ScalarCallOps.push_back(ScalarOperands[J]);
  500. else
  501. ScalarCallOps.push_back(Scattered[J][Elem]);
  502. }
  503. Res[Elem] = Builder.CreateCall(NewIntrin, ScalarCallOps,
  504. CI.getName() + ".i" + Twine(Elem));
  505. }
  506. gather(&CI, Res);
  507. return true;
  508. }
  509. bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
  510. VectorType *VT = dyn_cast<VectorType>(SI.getType());
  511. if (!VT)
  512. return false;
  513. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  514. IRBuilder<> Builder(&SI);
  515. Scatterer VOp1 = scatter(&SI, SI.getOperand(1));
  516. Scatterer VOp2 = scatter(&SI, SI.getOperand(2));
  517. assert(VOp1.size() == NumElems && "Mismatched select");
  518. assert(VOp2.size() == NumElems && "Mismatched select");
  519. ValueVector Res;
  520. Res.resize(NumElems);
  521. if (SI.getOperand(0)->getType()->isVectorTy()) {
  522. Scatterer VOp0 = scatter(&SI, SI.getOperand(0));
  523. assert(VOp0.size() == NumElems && "Mismatched select");
  524. for (unsigned I = 0; I < NumElems; ++I) {
  525. Value *Op0 = VOp0[I];
  526. Value *Op1 = VOp1[I];
  527. Value *Op2 = VOp2[I];
  528. Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
  529. SI.getName() + ".i" + Twine(I));
  530. }
  531. } else {
  532. Value *Op0 = SI.getOperand(0);
  533. for (unsigned I = 0; I < NumElems; ++I) {
  534. Value *Op1 = VOp1[I];
  535. Value *Op2 = VOp2[I];
  536. Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
  537. SI.getName() + ".i" + Twine(I));
  538. }
  539. }
  540. gather(&SI, Res);
  541. return true;
  542. }
  543. bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) {
  544. return splitBinary(ICI, ICmpSplitter(ICI));
  545. }
  546. bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) {
  547. return splitBinary(FCI, FCmpSplitter(FCI));
  548. }
  549. bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) {
  550. return splitUnary(UO, UnarySplitter(UO));
  551. }
  552. bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
  553. return splitBinary(BO, BinarySplitter(BO));
  554. }
  555. bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
  556. VectorType *VT = dyn_cast<VectorType>(GEPI.getType());
  557. if (!VT)
  558. return false;
  559. IRBuilder<> Builder(&GEPI);
  560. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  561. unsigned NumIndices = GEPI.getNumIndices();
  562. // The base pointer might be scalar even if it's a vector GEP. In those cases,
  563. // splat the pointer into a vector value, and scatter that vector.
  564. Value *Op0 = GEPI.getOperand(0);
  565. if (!Op0->getType()->isVectorTy())
  566. Op0 = Builder.CreateVectorSplat(NumElems, Op0);
  567. Scatterer Base = scatter(&GEPI, Op0);
  568. SmallVector<Scatterer, 8> Ops;
  569. Ops.resize(NumIndices);
  570. for (unsigned I = 0; I < NumIndices; ++I) {
  571. Value *Op = GEPI.getOperand(I + 1);
  572. // The indices might be scalars even if it's a vector GEP. In those cases,
  573. // splat the scalar into a vector value, and scatter that vector.
  574. if (!Op->getType()->isVectorTy())
  575. Op = Builder.CreateVectorSplat(NumElems, Op);
  576. Ops[I] = scatter(&GEPI, Op);
  577. }
  578. ValueVector Res;
  579. Res.resize(NumElems);
  580. for (unsigned I = 0; I < NumElems; ++I) {
  581. SmallVector<Value *, 8> Indices;
  582. Indices.resize(NumIndices);
  583. for (unsigned J = 0; J < NumIndices; ++J)
  584. Indices[J] = Ops[J][I];
  585. Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), Base[I], Indices,
  586. GEPI.getName() + ".i" + Twine(I));
  587. if (GEPI.isInBounds())
  588. if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
  589. NewGEPI->setIsInBounds();
  590. }
  591. gather(&GEPI, Res);
  592. return true;
  593. }
  594. bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
  595. VectorType *VT = dyn_cast<VectorType>(CI.getDestTy());
  596. if (!VT)
  597. return false;
  598. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  599. IRBuilder<> Builder(&CI);
  600. Scatterer Op0 = scatter(&CI, CI.getOperand(0));
  601. assert(Op0.size() == NumElems && "Mismatched cast");
  602. ValueVector Res;
  603. Res.resize(NumElems);
  604. for (unsigned I = 0; I < NumElems; ++I)
  605. Res[I] = Builder.CreateCast(CI.getOpcode(), Op0[I], VT->getElementType(),
  606. CI.getName() + ".i" + Twine(I));
  607. gather(&CI, Res);
  608. return true;
  609. }
  610. bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
  611. VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy());
  612. VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy());
  613. if (!DstVT || !SrcVT)
  614. return false;
  615. unsigned DstNumElems = cast<FixedVectorType>(DstVT)->getNumElements();
  616. unsigned SrcNumElems = cast<FixedVectorType>(SrcVT)->getNumElements();
  617. IRBuilder<> Builder(&BCI);
  618. Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));
  619. ValueVector Res;
  620. Res.resize(DstNumElems);
  621. if (DstNumElems == SrcNumElems) {
  622. for (unsigned I = 0; I < DstNumElems; ++I)
  623. Res[I] = Builder.CreateBitCast(Op0[I], DstVT->getElementType(),
  624. BCI.getName() + ".i" + Twine(I));
  625. } else if (DstNumElems > SrcNumElems) {
  626. // <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the
  627. // individual elements to the destination.
  628. unsigned FanOut = DstNumElems / SrcNumElems;
  629. auto *MidTy = FixedVectorType::get(DstVT->getElementType(), FanOut);
  630. unsigned ResI = 0;
  631. for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
  632. Value *V = Op0[Op0I];
  633. Instruction *VI;
  634. // Look through any existing bitcasts before converting to <N x t2>.
  635. // In the best case, the resulting conversion might be a no-op.
  636. while ((VI = dyn_cast<Instruction>(V)) &&
  637. VI->getOpcode() == Instruction::BitCast)
  638. V = VI->getOperand(0);
  639. V = Builder.CreateBitCast(V, MidTy, V->getName() + ".cast");
  640. Scatterer Mid = scatter(&BCI, V);
  641. for (unsigned MidI = 0; MidI < FanOut; ++MidI)
  642. Res[ResI++] = Mid[MidI];
  643. }
  644. } else {
  645. // <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2.
  646. unsigned FanIn = SrcNumElems / DstNumElems;
  647. auto *MidTy = FixedVectorType::get(SrcVT->getElementType(), FanIn);
  648. unsigned Op0I = 0;
  649. for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
  650. Value *V = PoisonValue::get(MidTy);
  651. for (unsigned MidI = 0; MidI < FanIn; ++MidI)
  652. V = Builder.CreateInsertElement(V, Op0[Op0I++], Builder.getInt32(MidI),
  653. BCI.getName() + ".i" + Twine(ResI)
  654. + ".upto" + Twine(MidI));
  655. Res[ResI] = Builder.CreateBitCast(V, DstVT->getElementType(),
  656. BCI.getName() + ".i" + Twine(ResI));
  657. }
  658. }
  659. gather(&BCI, Res);
  660. return true;
  661. }
  662. bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
  663. VectorType *VT = dyn_cast<VectorType>(IEI.getType());
  664. if (!VT)
  665. return false;
  666. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  667. IRBuilder<> Builder(&IEI);
  668. Scatterer Op0 = scatter(&IEI, IEI.getOperand(0));
  669. Value *NewElt = IEI.getOperand(1);
  670. Value *InsIdx = IEI.getOperand(2);
  671. ValueVector Res;
  672. Res.resize(NumElems);
  673. if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
  674. for (unsigned I = 0; I < NumElems; ++I)
  675. Res[I] = CI->getValue().getZExtValue() == I ? NewElt : Op0[I];
  676. } else {
  677. if (!ScalarizeVariableInsertExtract)
  678. return false;
  679. for (unsigned I = 0; I < NumElems; ++I) {
  680. Value *ShouldReplace =
  681. Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I),
  682. InsIdx->getName() + ".is." + Twine(I));
  683. Value *OldElt = Op0[I];
  684. Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
  685. IEI.getName() + ".i" + Twine(I));
  686. }
  687. }
  688. gather(&IEI, Res);
  689. return true;
  690. }
  691. bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
  692. VectorType *VT = dyn_cast<VectorType>(EEI.getOperand(0)->getType());
  693. if (!VT)
  694. return false;
  695. unsigned NumSrcElems = cast<FixedVectorType>(VT)->getNumElements();
  696. IRBuilder<> Builder(&EEI);
  697. Scatterer Op0 = scatter(&EEI, EEI.getOperand(0));
  698. Value *ExtIdx = EEI.getOperand(1);
  699. if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
  700. Value *Res = Op0[CI->getValue().getZExtValue()];
  701. gather(&EEI, {Res});
  702. return true;
  703. }
  704. if (!ScalarizeVariableInsertExtract)
  705. return false;
  706. Value *Res = UndefValue::get(VT->getElementType());
  707. for (unsigned I = 0; I < NumSrcElems; ++I) {
  708. Value *ShouldExtract =
  709. Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I),
  710. ExtIdx->getName() + ".is." + Twine(I));
  711. Value *Elt = Op0[I];
  712. Res = Builder.CreateSelect(ShouldExtract, Elt, Res,
  713. EEI.getName() + ".upto" + Twine(I));
  714. }
  715. gather(&EEI, {Res});
  716. return true;
  717. }
  718. bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
  719. VectorType *VT = dyn_cast<VectorType>(SVI.getType());
  720. if (!VT)
  721. return false;
  722. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  723. Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));
  724. Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));
  725. ValueVector Res;
  726. Res.resize(NumElems);
  727. for (unsigned I = 0; I < NumElems; ++I) {
  728. int Selector = SVI.getMaskValue(I);
  729. if (Selector < 0)
  730. Res[I] = UndefValue::get(VT->getElementType());
  731. else if (unsigned(Selector) < Op0.size())
  732. Res[I] = Op0[Selector];
  733. else
  734. Res[I] = Op1[Selector - Op0.size()];
  735. }
  736. gather(&SVI, Res);
  737. return true;
  738. }
  739. bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
  740. VectorType *VT = dyn_cast<VectorType>(PHI.getType());
  741. if (!VT)
  742. return false;
  743. unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
  744. IRBuilder<> Builder(&PHI);
  745. ValueVector Res;
  746. Res.resize(NumElems);
  747. unsigned NumOps = PHI.getNumOperands();
  748. for (unsigned I = 0; I < NumElems; ++I)
  749. Res[I] = Builder.CreatePHI(VT->getElementType(), NumOps,
  750. PHI.getName() + ".i" + Twine(I));
  751. for (unsigned I = 0; I < NumOps; ++I) {
  752. Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I));
  753. BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
  754. for (unsigned J = 0; J < NumElems; ++J)
  755. cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
  756. }
  757. gather(&PHI, Res);
  758. return true;
  759. }
  760. bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
  761. if (!ScalarizeLoadStore)
  762. return false;
  763. if (!LI.isSimple())
  764. return false;
  765. Optional<VectorLayout> Layout = getVectorLayout(
  766. LI.getType(), LI.getAlign(), LI.getModule()->getDataLayout());
  767. if (!Layout)
  768. return false;
  769. unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
  770. IRBuilder<> Builder(&LI);
  771. Scatterer Ptr = scatter(&LI, LI.getPointerOperand());
  772. ValueVector Res;
  773. Res.resize(NumElems);
  774. for (unsigned I = 0; I < NumElems; ++I)
  775. Res[I] = Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[I],
  776. Align(Layout->getElemAlign(I)),
  777. LI.getName() + ".i" + Twine(I));
  778. gather(&LI, Res);
  779. return true;
  780. }
  781. bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
  782. if (!ScalarizeLoadStore)
  783. return false;
  784. if (!SI.isSimple())
  785. return false;
  786. Value *FullValue = SI.getValueOperand();
  787. Optional<VectorLayout> Layout = getVectorLayout(
  788. FullValue->getType(), SI.getAlign(), SI.getModule()->getDataLayout());
  789. if (!Layout)
  790. return false;
  791. unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
  792. IRBuilder<> Builder(&SI);
  793. Scatterer VPtr = scatter(&SI, SI.getPointerOperand());
  794. Scatterer VVal = scatter(&SI, FullValue);
  795. ValueVector Stores;
  796. Stores.resize(NumElems);
  797. for (unsigned I = 0; I < NumElems; ++I) {
  798. Value *Val = VVal[I];
  799. Value *Ptr = VPtr[I];
  800. Stores[I] = Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(I));
  801. }
  802. transferMetadataAndIRFlags(&SI, Stores);
  803. return true;
  804. }
  805. bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
  806. return splitCall(CI);
  807. }
  808. // Delete the instructions that we scalarized. If a full vector result
  809. // is still needed, recreate it using InsertElements.
  810. bool ScalarizerVisitor::finish() {
  811. // The presence of data in Gathered or Scattered indicates changes
  812. // made to the Function.
  813. if (Gathered.empty() && Scattered.empty())
  814. return false;
  815. for (const auto &GMI : Gathered) {
  816. Instruction *Op = GMI.first;
  817. ValueVector &CV = *GMI.second;
  818. if (!Op->use_empty()) {
  819. // The value is still needed, so recreate it using a series of
  820. // InsertElements.
  821. Value *Res = PoisonValue::get(Op->getType());
  822. if (auto *Ty = dyn_cast<VectorType>(Op->getType())) {
  823. BasicBlock *BB = Op->getParent();
  824. unsigned Count = cast<FixedVectorType>(Ty)->getNumElements();
  825. IRBuilder<> Builder(Op);
  826. if (isa<PHINode>(Op))
  827. Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
  828. for (unsigned I = 0; I < Count; ++I)
  829. Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I),
  830. Op->getName() + ".upto" + Twine(I));
  831. Res->takeName(Op);
  832. } else {
  833. assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
  834. Res = CV[0];
  835. if (Op == Res)
  836. continue;
  837. }
  838. Op->replaceAllUsesWith(Res);
  839. }
  840. PotentiallyDeadInstrs.emplace_back(Op);
  841. }
  842. Gathered.clear();
  843. Scattered.clear();
  844. RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
  845. return true;
  846. }
  847. PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
  848. Module &M = *F.getParent();
  849. unsigned ParallelLoopAccessMDKind =
  850. M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
  851. DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
  852. ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT);
  853. bool Changed = Impl.visit(F);
  854. PreservedAnalyses PA;
  855. PA.preserve<DominatorTreeAnalysis>();
  856. return Changed ? PA : PreservedAnalyses::all();
  857. }