ComplexDeinterleavingPass.cpp 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889
  1. //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Identification:
  10. // This step is responsible for finding the patterns that can be lowered to
  11. // complex instructions, and building a graph to represent the complex
  12. // structures. Starting from the "Converging Shuffle" (a shuffle that
  13. // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
  14. // operands are evaluated and identified as "Composite Nodes" (collections of
  15. // instructions that can potentially be lowered to a single complex
  16. // instruction). This is performed by checking the real and imaginary components
  17. // and tracking the data flow for each component while following the operand
  18. // pairs. Validity of each node is expected to be done upon creation, and any
  19. // validation errors should halt traversal and prevent further graph
  20. // construction.
  21. //
  22. // Replacement:
  23. // This step traverses the graph built up by identification, delegating to the
  24. // target to validate and generate the correct intrinsics, and plumbs them
  25. // together connecting each end of the new intrinsics graph to the existing
  26. // use-def chain. This step is assumed to finish successfully, as all
  27. // information is expected to be correct by this point.
  28. //
  29. //
  30. // Internal data structure:
  31. // ComplexDeinterleavingGraph:
  32. // Keeps references to all the valid CompositeNodes formed as part of the
  33. // transformation, and every Instruction contained within said nodes. It also
  34. // holds onto a reference to the root Instruction, and the root node that should
  35. // replace it.
  36. //
  37. // ComplexDeinterleavingCompositeNode:
  38. // A CompositeNode represents a single transformation point; each node should
  39. // transform into a single complex instruction (ignoring vector splitting, which
  40. // would generate more instructions per node). They are identified in a
  41. // depth-first manner, traversing and identifying the operands of each
  42. // instruction in the order they appear in the IR.
  43. // Each node maintains a reference to its Real and Imaginary instructions,
  44. // as well as any additional instructions that make up the identified operation
  45. // (Internal instructions should only have uses within their containing node).
  46. // A Node also contains the rotation and operation type that it represents.
  47. // Operands contains pointers to other CompositeNodes, acting as the edges in
  48. // the graph. ReplacementValue is the transformed Value* that has been emitted
  49. // to the IR.
  50. //
  51. // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
  52. // ReplacementValue fields of that Node are relevant, where the ReplacementValue
  53. // should be pre-populated.
  54. //
  55. //===----------------------------------------------------------------------===//
  56. #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
  57. #include "llvm/ADT/Statistic.h"
  58. #include "llvm/Analysis/TargetLibraryInfo.h"
  59. #include "llvm/Analysis/TargetTransformInfo.h"
  60. #include "llvm/CodeGen/TargetLowering.h"
  61. #include "llvm/CodeGen/TargetPassConfig.h"
  62. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  63. #include "llvm/IR/IRBuilder.h"
  64. #include "llvm/InitializePasses.h"
  65. #include "llvm/Target/TargetMachine.h"
  66. #include "llvm/Transforms/Utils/Local.h"
  67. #include <algorithm>
  68. using namespace llvm;
  69. using namespace PatternMatch;
  70. #define DEBUG_TYPE "complex-deinterleaving"
  71. STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
  72. static cl::opt<bool> ComplexDeinterleavingEnabled(
  73. "enable-complex-deinterleaving",
  74. cl::desc("Enable generation of complex instructions"), cl::init(true),
  75. cl::Hidden);
  76. /// Checks the given mask, and determines whether said mask is interleaving.
  77. ///
  78. /// To be interleaving, a mask must alternate between `i` and `i + (Length /
  79. /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
  80. /// 4x vector interleaving mask would be <0, 2, 1, 3>).
  81. static bool isInterleavingMask(ArrayRef<int> Mask);
  82. /// Checks the given mask, and determines whether said mask is deinterleaving.
  83. ///
  84. /// To be deinterleaving, a mask must increment in steps of 2, and either start
  85. /// with 0 or 1.
  86. /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
  87. /// <1, 3, 5, 7>).
  88. static bool isDeinterleavingMask(ArrayRef<int> Mask);
  89. namespace {
  90. class ComplexDeinterleavingLegacyPass : public FunctionPass {
  91. public:
  92. static char ID;
  93. ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
  94. : FunctionPass(ID), TM(TM) {
  95. initializeComplexDeinterleavingLegacyPassPass(
  96. *PassRegistry::getPassRegistry());
  97. }
  98. StringRef getPassName() const override {
  99. return "Complex Deinterleaving Pass";
  100. }
  101. bool runOnFunction(Function &F) override;
  102. void getAnalysisUsage(AnalysisUsage &AU) const override {
  103. AU.addRequired<TargetLibraryInfoWrapperPass>();
  104. AU.setPreservesCFG();
  105. }
  106. private:
  107. const TargetMachine *TM;
  108. };
  109. class ComplexDeinterleavingGraph;
  110. struct ComplexDeinterleavingCompositeNode {
  111. ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
  112. Instruction *R, Instruction *I)
  113. : Operation(Op), Real(R), Imag(I) {}
  114. private:
  115. friend class ComplexDeinterleavingGraph;
  116. using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
  117. using RawNodePtr = ComplexDeinterleavingCompositeNode *;
  118. public:
  119. ComplexDeinterleavingOperation Operation;
  120. Instruction *Real;
  121. Instruction *Imag;
  122. // Instructions that should only exist within this node, there should be no
  123. // users of these instructions outside the node. An example of these would be
  124. // the multiply instructions of a partial multiply operation.
  125. SmallVector<Instruction *> InternalInstructions;
  126. ComplexDeinterleavingRotation Rotation;
  127. SmallVector<RawNodePtr> Operands;
  128. Value *ReplacementNode = nullptr;
  129. void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
  130. void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
  131. bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
  132. void dump() { dump(dbgs()); }
  133. void dump(raw_ostream &OS) {
  134. auto PrintValue = [&](Value *V) {
  135. if (V) {
  136. OS << "\"";
  137. V->print(OS, true);
  138. OS << "\"\n";
  139. } else
  140. OS << "nullptr\n";
  141. };
  142. auto PrintNodeRef = [&](RawNodePtr Ptr) {
  143. if (Ptr)
  144. OS << Ptr << "\n";
  145. else
  146. OS << "nullptr\n";
  147. };
  148. OS << "- CompositeNode: " << this << "\n";
  149. OS << " Real: ";
  150. PrintValue(Real);
  151. OS << " Imag: ";
  152. PrintValue(Imag);
  153. OS << " ReplacementNode: ";
  154. PrintValue(ReplacementNode);
  155. OS << " Operation: " << (int)Operation << "\n";
  156. OS << " Rotation: " << ((int)Rotation * 90) << "\n";
  157. OS << " Operands: \n";
  158. for (const auto &Op : Operands) {
  159. OS << " - ";
  160. PrintNodeRef(Op);
  161. }
  162. OS << " InternalInstructions:\n";
  163. for (const auto &I : InternalInstructions) {
  164. OS << " - \"";
  165. I->print(OS, true);
  166. OS << "\"\n";
  167. }
  168. }
  169. };
  170. class ComplexDeinterleavingGraph {
  171. public:
  172. using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
  173. using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
  174. explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
  175. private:
  176. const TargetLowering *TL;
  177. Instruction *RootValue;
  178. NodePtr RootNode;
  179. SmallVector<NodePtr> CompositeNodes;
  180. SmallPtrSet<Instruction *, 16> AllInstructions;
  181. NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
  182. Instruction *R, Instruction *I) {
  183. return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
  184. I);
  185. }
  186. NodePtr submitCompositeNode(NodePtr Node) {
  187. CompositeNodes.push_back(Node);
  188. AllInstructions.insert(Node->Real);
  189. AllInstructions.insert(Node->Imag);
  190. for (auto *I : Node->InternalInstructions)
  191. AllInstructions.insert(I);
  192. return Node;
  193. }
  194. NodePtr getContainingComposite(Value *R, Value *I) {
  195. for (const auto &CN : CompositeNodes) {
  196. if (CN->Real == R && CN->Imag == I)
  197. return CN;
  198. }
  199. return nullptr;
  200. }
  201. /// Identifies a complex partial multiply pattern and its rotation, based on
  202. /// the following patterns
  203. ///
  204. /// 0: r: cr + ar * br
  205. /// i: ci + ar * bi
  206. /// 90: r: cr - ai * bi
  207. /// i: ci + ai * br
  208. /// 180: r: cr - ar * br
  209. /// i: ci - ar * bi
  210. /// 270: r: cr + ai * bi
  211. /// i: ci - ai * br
  212. NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
  213. /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
  214. /// is partially known from identifyPartialMul, filling in the other half of
  215. /// the complex pair.
  216. NodePtr identifyNodeWithImplicitAdd(
  217. Instruction *I, Instruction *J,
  218. std::pair<Instruction *, Instruction *> &CommonOperandI);
  219. /// Identifies a complex add pattern and its rotation, based on the following
  220. /// patterns.
  221. ///
  222. /// 90: r: ar - bi
  223. /// i: ai + br
  224. /// 270: r: ar + bi
  225. /// i: ai - br
  226. NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
  227. NodePtr identifyNode(Instruction *I, Instruction *J);
  228. Value *replaceNode(RawNodePtr Node);
  229. public:
  230. void dump() { dump(dbgs()); }
  231. void dump(raw_ostream &OS) {
  232. for (const auto &Node : CompositeNodes)
  233. Node->dump(OS);
  234. }
  235. /// Returns false if the deinterleaving operation should be cancelled for the
  236. /// current graph.
  237. bool identifyNodes(Instruction *RootI);
  238. /// Perform the actual replacement of the underlying instruction graph.
  239. /// Returns false if the deinterleaving operation should be cancelled for the
  240. /// current graph.
  241. void replaceNodes();
  242. };
  243. class ComplexDeinterleaving {
  244. public:
  245. ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
  246. : TL(tl), TLI(tli) {}
  247. bool runOnFunction(Function &F);
  248. private:
  249. bool evaluateBasicBlock(BasicBlock *B);
  250. const TargetLowering *TL = nullptr;
  251. const TargetLibraryInfo *TLI = nullptr;
  252. };
  253. } // namespace
  254. char ComplexDeinterleavingLegacyPass::ID = 0;
  255. INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
  256. "Complex Deinterleaving", false, false)
  257. INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
  258. "Complex Deinterleaving", false, false)
  259. PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
  260. FunctionAnalysisManager &AM) {
  261. const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
  262. auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
  263. if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
  264. return PreservedAnalyses::all();
  265. PreservedAnalyses PA;
  266. PA.preserve<FunctionAnalysisManagerModuleProxy>();
  267. return PA;
  268. }
  269. FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
  270. return new ComplexDeinterleavingLegacyPass(TM);
  271. }
  272. bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
  273. const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
  274. auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  275. return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
  276. }
  277. bool ComplexDeinterleaving::runOnFunction(Function &F) {
  278. if (!ComplexDeinterleavingEnabled) {
  279. LLVM_DEBUG(
  280. dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
  281. return false;
  282. }
  283. if (!TL->isComplexDeinterleavingSupported()) {
  284. LLVM_DEBUG(
  285. dbgs() << "Complex deinterleaving has been disabled, target does "
  286. "not support lowering of complex number operations.\n");
  287. return false;
  288. }
  289. bool Changed = false;
  290. for (auto &B : F)
  291. Changed |= evaluateBasicBlock(&B);
  292. return Changed;
  293. }
  294. static bool isInterleavingMask(ArrayRef<int> Mask) {
  295. // If the size is not even, it's not an interleaving mask
  296. if ((Mask.size() & 1))
  297. return false;
  298. int HalfNumElements = Mask.size() / 2;
  299. for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
  300. int MaskIdx = Idx * 2;
  301. if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
  302. return false;
  303. }
  304. return true;
  305. }
  306. static bool isDeinterleavingMask(ArrayRef<int> Mask) {
  307. int Offset = Mask[0];
  308. int HalfNumElements = Mask.size() / 2;
  309. for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
  310. if (Mask[Idx] != (Idx * 2) + Offset)
  311. return false;
  312. }
  313. return true;
  314. }
  315. bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
  316. bool Changed = false;
  317. SmallVector<Instruction *> DeadInstrRoots;
  318. for (auto &I : *B) {
  319. auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
  320. if (!SVI)
  321. continue;
  322. // Look for a shufflevector that takes separate vectors of the real and
  323. // imaginary components and recombines them into a single vector.
  324. if (!isInterleavingMask(SVI->getShuffleMask()))
  325. continue;
  326. ComplexDeinterleavingGraph Graph(TL);
  327. if (!Graph.identifyNodes(SVI))
  328. continue;
  329. Graph.replaceNodes();
  330. DeadInstrRoots.push_back(SVI);
  331. Changed = true;
  332. }
  333. for (const auto &I : DeadInstrRoots) {
  334. if (!I || I->getParent() == nullptr)
  335. continue;
  336. llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
  337. }
  338. return Changed;
  339. }
  340. ComplexDeinterleavingGraph::NodePtr
  341. ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
  342. Instruction *Real, Instruction *Imag,
  343. std::pair<Instruction *, Instruction *> &PartialMatch) {
  344. LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
  345. << "\n");
  346. if (!Real->hasOneUse() || !Imag->hasOneUse()) {
  347. LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
  348. return nullptr;
  349. }
  350. if (Real->getOpcode() != Instruction::FMul ||
  351. Imag->getOpcode() != Instruction::FMul) {
  352. LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
  353. return nullptr;
  354. }
  355. Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
  356. Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
  357. Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
  358. Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
  359. if (!R0 || !R1 || !I0 || !I1) {
  360. LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
  361. return nullptr;
  362. }
  363. // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
  364. // rotations and use the operand.
  365. unsigned Negs = 0;
  366. SmallVector<Instruction *> FNegs;
  367. if (R0->getOpcode() == Instruction::FNeg ||
  368. R1->getOpcode() == Instruction::FNeg) {
  369. Negs |= 1;
  370. if (R0->getOpcode() == Instruction::FNeg) {
  371. FNegs.push_back(R0);
  372. R0 = dyn_cast<Instruction>(R0->getOperand(0));
  373. } else {
  374. FNegs.push_back(R1);
  375. R1 = dyn_cast<Instruction>(R1->getOperand(0));
  376. }
  377. if (!R0 || !R1)
  378. return nullptr;
  379. }
  380. if (I0->getOpcode() == Instruction::FNeg ||
  381. I1->getOpcode() == Instruction::FNeg) {
  382. Negs |= 2;
  383. Negs ^= 1;
  384. if (I0->getOpcode() == Instruction::FNeg) {
  385. FNegs.push_back(I0);
  386. I0 = dyn_cast<Instruction>(I0->getOperand(0));
  387. } else {
  388. FNegs.push_back(I1);
  389. I1 = dyn_cast<Instruction>(I1->getOperand(0));
  390. }
  391. if (!I0 || !I1)
  392. return nullptr;
  393. }
  394. ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
  395. Instruction *CommonOperand;
  396. Instruction *UncommonRealOp;
  397. Instruction *UncommonImagOp;
  398. if (R0 == I0 || R0 == I1) {
  399. CommonOperand = R0;
  400. UncommonRealOp = R1;
  401. } else if (R1 == I0 || R1 == I1) {
  402. CommonOperand = R1;
  403. UncommonRealOp = R0;
  404. } else {
  405. LLVM_DEBUG(dbgs() << " - No equal operand\n");
  406. return nullptr;
  407. }
  408. UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
  409. if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
  410. Rotation == ComplexDeinterleavingRotation::Rotation_270)
  411. std::swap(UncommonRealOp, UncommonImagOp);
  412. // Between identifyPartialMul and here we need to have found a complete valid
  413. // pair from the CommonOperand of each part.
  414. if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
  415. Rotation == ComplexDeinterleavingRotation::Rotation_180)
  416. PartialMatch.first = CommonOperand;
  417. else
  418. PartialMatch.second = CommonOperand;
  419. if (!PartialMatch.first || !PartialMatch.second) {
  420. LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
  421. return nullptr;
  422. }
  423. NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
  424. if (!CommonNode) {
  425. LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
  426. return nullptr;
  427. }
  428. NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
  429. if (!UncommonNode) {
  430. LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
  431. return nullptr;
  432. }
  433. NodePtr Node = prepareCompositeNode(
  434. ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
  435. Node->Rotation = Rotation;
  436. Node->addOperand(CommonNode);
  437. Node->addOperand(UncommonNode);
  438. Node->InternalInstructions.append(FNegs);
  439. return submitCompositeNode(Node);
  440. }
  441. ComplexDeinterleavingGraph::NodePtr
  442. ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
  443. Instruction *Imag) {
  444. LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
  445. << "\n");
  446. // Determine rotation
  447. ComplexDeinterleavingRotation Rotation;
  448. if (Real->getOpcode() == Instruction::FAdd &&
  449. Imag->getOpcode() == Instruction::FAdd)
  450. Rotation = ComplexDeinterleavingRotation::Rotation_0;
  451. else if (Real->getOpcode() == Instruction::FSub &&
  452. Imag->getOpcode() == Instruction::FAdd)
  453. Rotation = ComplexDeinterleavingRotation::Rotation_90;
  454. else if (Real->getOpcode() == Instruction::FSub &&
  455. Imag->getOpcode() == Instruction::FSub)
  456. Rotation = ComplexDeinterleavingRotation::Rotation_180;
  457. else if (Real->getOpcode() == Instruction::FAdd &&
  458. Imag->getOpcode() == Instruction::FSub)
  459. Rotation = ComplexDeinterleavingRotation::Rotation_270;
  460. else {
  461. LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
  462. return nullptr;
  463. }
  464. if (!Real->getFastMathFlags().allowContract() ||
  465. !Imag->getFastMathFlags().allowContract()) {
  466. LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
  467. return nullptr;
  468. }
  469. Value *CR = Real->getOperand(0);
  470. Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
  471. if (!RealMulI)
  472. return nullptr;
  473. Value *CI = Imag->getOperand(0);
  474. Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
  475. if (!ImagMulI)
  476. return nullptr;
  477. if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
  478. LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
  479. return nullptr;
  480. }
  481. Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
  482. Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
  483. Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
  484. Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
  485. if (!R0 || !R1 || !I0 || !I1) {
  486. LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
  487. return nullptr;
  488. }
  489. Instruction *CommonOperand;
  490. Instruction *UncommonRealOp;
  491. Instruction *UncommonImagOp;
  492. if (R0 == I0 || R0 == I1) {
  493. CommonOperand = R0;
  494. UncommonRealOp = R1;
  495. } else if (R1 == I0 || R1 == I1) {
  496. CommonOperand = R1;
  497. UncommonRealOp = R0;
  498. } else {
  499. LLVM_DEBUG(dbgs() << " - No equal operand\n");
  500. return nullptr;
  501. }
  502. UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
  503. if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
  504. Rotation == ComplexDeinterleavingRotation::Rotation_270)
  505. std::swap(UncommonRealOp, UncommonImagOp);
  506. std::pair<Instruction *, Instruction *> PartialMatch(
  507. (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
  508. Rotation == ComplexDeinterleavingRotation::Rotation_180)
  509. ? CommonOperand
  510. : nullptr,
  511. (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
  512. Rotation == ComplexDeinterleavingRotation::Rotation_270)
  513. ? CommonOperand
  514. : nullptr);
  515. NodePtr CNode = identifyNodeWithImplicitAdd(
  516. cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch);
  517. if (!CNode) {
  518. LLVM_DEBUG(dbgs() << " - No cnode identified\n");
  519. return nullptr;
  520. }
  521. NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
  522. if (!UncommonRes) {
  523. LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
  524. return nullptr;
  525. }
  526. assert(PartialMatch.first && PartialMatch.second);
  527. NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
  528. if (!CommonRes) {
  529. LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
  530. return nullptr;
  531. }
  532. NodePtr Node = prepareCompositeNode(
  533. ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
  534. Node->addInstruction(RealMulI);
  535. Node->addInstruction(ImagMulI);
  536. Node->Rotation = Rotation;
  537. Node->addOperand(CommonRes);
  538. Node->addOperand(UncommonRes);
  539. Node->addOperand(CNode);
  540. return submitCompositeNode(Node);
  541. }
  542. ComplexDeinterleavingGraph::NodePtr
  543. ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
  544. LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
  545. // Determine rotation
  546. ComplexDeinterleavingRotation Rotation;
  547. if ((Real->getOpcode() == Instruction::FSub &&
  548. Imag->getOpcode() == Instruction::FAdd) ||
  549. (Real->getOpcode() == Instruction::Sub &&
  550. Imag->getOpcode() == Instruction::Add))
  551. Rotation = ComplexDeinterleavingRotation::Rotation_90;
  552. else if ((Real->getOpcode() == Instruction::FAdd &&
  553. Imag->getOpcode() == Instruction::FSub) ||
  554. (Real->getOpcode() == Instruction::Add &&
  555. Imag->getOpcode() == Instruction::Sub))
  556. Rotation = ComplexDeinterleavingRotation::Rotation_270;
  557. else {
  558. LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
  559. return nullptr;
  560. }
  561. auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
  562. auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
  563. auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
  564. auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
  565. if (!AR || !AI || !BR || !BI) {
  566. LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
  567. return nullptr;
  568. }
  569. NodePtr ResA = identifyNode(AR, AI);
  570. if (!ResA) {
  571. LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
  572. return nullptr;
  573. }
  574. NodePtr ResB = identifyNode(BR, BI);
  575. if (!ResB) {
  576. LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
  577. return nullptr;
  578. }
  579. NodePtr Node =
  580. prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
  581. Node->Rotation = Rotation;
  582. Node->addOperand(ResA);
  583. Node->addOperand(ResB);
  584. return submitCompositeNode(Node);
  585. }
  586. static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
  587. unsigned OpcA = A->getOpcode();
  588. unsigned OpcB = B->getOpcode();
  589. return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
  590. (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
  591. (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
  592. (OpcA == Instruction::Add && OpcB == Instruction::Sub);
  593. }
  594. static bool isInstructionPairMul(Instruction *A, Instruction *B) {
  595. auto Pattern =
  596. m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
  597. return match(A, Pattern) && match(B, Pattern);
  598. }
  599. ComplexDeinterleavingGraph::NodePtr
  600. ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
  601. LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
  602. if (NodePtr CN = getContainingComposite(Real, Imag)) {
  603. LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
  604. return CN;
  605. }
  606. auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
  607. auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
  608. if (RealShuffle && ImagShuffle) {
  609. Value *RealOp1 = RealShuffle->getOperand(1);
  610. if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
  611. LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
  612. return nullptr;
  613. }
  614. Value *ImagOp1 = ImagShuffle->getOperand(1);
  615. if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
  616. LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
  617. return nullptr;
  618. }
  619. Value *RealOp0 = RealShuffle->getOperand(0);
  620. Value *ImagOp0 = ImagShuffle->getOperand(0);
  621. if (RealOp0 != ImagOp0) {
  622. LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
  623. return nullptr;
  624. }
  625. ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
  626. ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
  627. if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
  628. LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
  629. return nullptr;
  630. }
  631. if (RealMask[0] != 0 || ImagMask[0] != 1) {
  632. LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
  633. return nullptr;
  634. }
  635. // Type checking, the shuffle type should be a vector type of the same
  636. // scalar type, but half the size
  637. auto CheckType = [&](ShuffleVectorInst *Shuffle) {
  638. Value *Op = Shuffle->getOperand(0);
  639. auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
  640. auto *OpTy = cast<FixedVectorType>(Op->getType());
  641. if (OpTy->getScalarType() != ShuffleTy->getScalarType())
  642. return false;
  643. if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
  644. return false;
  645. return true;
  646. };
  647. auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
  648. if (!CheckType(Shuffle))
  649. return false;
  650. ArrayRef<int> Mask = Shuffle->getShuffleMask();
  651. int Last = *Mask.rbegin();
  652. Value *Op = Shuffle->getOperand(0);
  653. auto *OpTy = cast<FixedVectorType>(Op->getType());
  654. int NumElements = OpTy->getNumElements();
  655. // Ensure that the deinterleaving shuffle only pulls from the first
  656. // shuffle operand.
  657. return Last < NumElements;
  658. };
  659. if (RealShuffle->getType() != ImagShuffle->getType()) {
  660. LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
  661. return nullptr;
  662. }
  663. if (!CheckDeinterleavingShuffle(RealShuffle)) {
  664. LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
  665. return nullptr;
  666. }
  667. if (!CheckDeinterleavingShuffle(ImagShuffle)) {
  668. LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
  669. return nullptr;
  670. }
  671. NodePtr PlaceholderNode =
  672. prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
  673. RealShuffle, ImagShuffle);
  674. PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
  675. return submitCompositeNode(PlaceholderNode);
  676. }
  677. if (RealShuffle || ImagShuffle)
  678. return nullptr;
  679. auto *VTy = cast<FixedVectorType>(Real->getType());
  680. auto *NewVTy =
  681. FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2);
  682. if (TL->isComplexDeinterleavingOperationSupported(
  683. ComplexDeinterleavingOperation::CMulPartial, NewVTy) &&
  684. isInstructionPairMul(Real, Imag)) {
  685. return identifyPartialMul(Real, Imag);
  686. }
  687. if (TL->isComplexDeinterleavingOperationSupported(
  688. ComplexDeinterleavingOperation::CAdd, NewVTy) &&
  689. isInstructionPairAdd(Real, Imag)) {
  690. return identifyAdd(Real, Imag);
  691. }
  692. return nullptr;
  693. }
  694. bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
  695. Instruction *Real;
  696. Instruction *Imag;
  697. if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
  698. return false;
  699. RootValue = RootI;
  700. AllInstructions.insert(RootI);
  701. RootNode = identifyNode(Real, Imag);
  702. LLVM_DEBUG({
  703. Function *F = RootI->getFunction();
  704. BasicBlock *B = RootI->getParent();
  705. dbgs() << "Complex deinterleaving graph for " << F->getName()
  706. << "::" << B->getName() << ".\n";
  707. dump(dbgs());
  708. dbgs() << "\n";
  709. });
  710. // Check all instructions have internal uses
  711. for (const auto &Node : CompositeNodes) {
  712. if (!Node->hasAllInternalUses(AllInstructions)) {
  713. LLVM_DEBUG(dbgs() << " - Invalid internal uses\n");
  714. return false;
  715. }
  716. }
  717. return RootNode != nullptr;
  718. }
  719. Value *ComplexDeinterleavingGraph::replaceNode(
  720. ComplexDeinterleavingGraph::RawNodePtr Node) {
  721. if (Node->ReplacementNode)
  722. return Node->ReplacementNode;
  723. Value *Input0 = replaceNode(Node->Operands[0]);
  724. Value *Input1 = replaceNode(Node->Operands[1]);
  725. Value *Accumulator =
  726. Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;
  727. assert(Input0->getType() == Input1->getType() &&
  728. "Node inputs need to be of the same type");
  729. Node->ReplacementNode = TL->createComplexDeinterleavingIR(
  730. Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
  731. assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
  732. NumComplexTransformations += 1;
  733. return Node->ReplacementNode;
  734. }
  735. void ComplexDeinterleavingGraph::replaceNodes() {
  736. Value *R = replaceNode(RootNode.get());
  737. assert(R && "Unable to find replacement for RootValue");
  738. RootValue->replaceAllUsesWith(R);
  739. }
  740. bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
  741. SmallPtrSet<Instruction *, 16> &AllInstructions) {
  742. if (Operation == ComplexDeinterleavingOperation::Shuffle)
  743. return true;
  744. for (auto *User : Real->users()) {
  745. if (!AllInstructions.contains(cast<Instruction>(User)))
  746. return false;
  747. }
  748. for (auto *User : Imag->users()) {
  749. if (!AllInstructions.contains(cast<Instruction>(User)))
  750. return false;
  751. }
  752. for (auto *I : InternalInstructions) {
  753. for (auto *User : I->users()) {
  754. if (!AllInstructions.contains(cast<Instruction>(User)))
  755. return false;
  756. }
  757. }
  758. return true;
  759. }