//===- ComplexDeinterleavingPass.cpp --------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Identification: // This step is responsible for finding the patterns that can be lowered to // complex instructions, and building a graph to represent the complex // structures. Starting from the "Converging Shuffle" (a shuffle that // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the // operands are evaluated and identified as "Composite Nodes" (collections of // instructions that can potentially be lowered to a single complex // instruction). This is performed by checking the real and imaginary components // and tracking the data flow for each component while following the operand // pairs. Validity of each node is expected to be done upon creation, and any // validation errors should halt traversal and prevent further graph // construction. // // Replacement: // This step traverses the graph built up by identification, delegating to the // target to validate and generate the correct intrinsics, and plumbs them // together connecting each end of the new intrinsics graph to the existing // use-def chain. This step is assumed to finish successfully, as all // information is expected to be correct by this point. // // // Internal data structure: // ComplexDeinterleavingGraph: // Keeps references to all the valid CompositeNodes formed as part of the // transformation, and every Instruction contained within said nodes. It also // holds onto a reference to the root Instruction, and the root node that should // replace it. // // ComplexDeinterleavingCompositeNode: // A CompositeNode represents a single transformation point; each node should // transform into a single complex instruction (ignoring vector splitting, which // would generate more instructions per node). They are identified in a // depth-first manner, traversing and identifying the operands of each // instruction in the order they appear in the IR. // Each node maintains a reference to its Real and Imaginary instructions, // as well as any additional instructions that make up the identified operation // (Internal instructions should only have uses within their containing node). // A Node also contains the rotation and operation type that it represents. // Operands contains pointers to other CompositeNodes, acting as the edges in // the graph. ReplacementValue is the transformed Value* that has been emitted // to the IR. // // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and // ReplacementValue fields of that Node are relevant, where the ReplacementValue // should be pre-populated. // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/ComplexDeinterleavingPass.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/Local.h" #include using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "complex-deinterleaving" STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); static cl::opt ComplexDeinterleavingEnabled( "enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden); /// Checks the given mask, and determines whether said mask is interleaving. /// /// To be interleaving, a mask must alternate between `i` and `i + (Length / /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a /// 4x vector interleaving mask would be <0, 2, 1, 3>). static bool isInterleavingMask(ArrayRef Mask); /// Checks the given mask, and determines whether said mask is deinterleaving. /// /// To be deinterleaving, a mask must increment in steps of 2, and either start /// with 0 or 1. /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or /// <1, 3, 5, 7>). static bool isDeinterleavingMask(ArrayRef Mask); namespace { class ComplexDeinterleavingLegacyPass : public FunctionPass { public: static char ID; ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) : FunctionPass(ID), TM(TM) { initializeComplexDeinterleavingLegacyPassPass( *PassRegistry::getPassRegistry()); } StringRef getPassName() const override { return "Complex Deinterleaving Pass"; } bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.setPreservesCFG(); } private: const TargetMachine *TM; }; class ComplexDeinterleavingGraph; struct ComplexDeinterleavingCompositeNode { ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, Instruction *R, Instruction *I) : Operation(Op), Real(R), Imag(I) {} private: friend class ComplexDeinterleavingGraph; using NodePtr = std::shared_ptr; using RawNodePtr = ComplexDeinterleavingCompositeNode *; public: ComplexDeinterleavingOperation Operation; Instruction *Real; Instruction *Imag; // Instructions that should only exist within this node, there should be no // users of these instructions outside the node. An example of these would be // the multiply instructions of a partial multiply operation. SmallVector InternalInstructions; ComplexDeinterleavingRotation Rotation; SmallVector Operands; Value *ReplacementNode = nullptr; void addInstruction(Instruction *I) { InternalInstructions.push_back(I); } void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } bool hasAllInternalUses(SmallPtrSet &AllInstructions); void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { auto PrintValue = [&](Value *V) { if (V) { OS << "\""; V->print(OS, true); OS << "\"\n"; } else OS << "nullptr\n"; }; auto PrintNodeRef = [&](RawNodePtr Ptr) { if (Ptr) OS << Ptr << "\n"; else OS << "nullptr\n"; }; OS << "- CompositeNode: " << this << "\n"; OS << " Real: "; PrintValue(Real); OS << " Imag: "; PrintValue(Imag); OS << " ReplacementNode: "; PrintValue(ReplacementNode); OS << " Operation: " << (int)Operation << "\n"; OS << " Rotation: " << ((int)Rotation * 90) << "\n"; OS << " Operands: \n"; for (const auto &Op : Operands) { OS << " - "; PrintNodeRef(Op); } OS << " InternalInstructions:\n"; for (const auto &I : InternalInstructions) { OS << " - \""; I->print(OS, true); OS << "\"\n"; } } }; class ComplexDeinterleavingGraph { public: using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {} private: const TargetLowering *TL; Instruction *RootValue; NodePtr RootNode; SmallVector CompositeNodes; SmallPtrSet AllInstructions; NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, Instruction *R, Instruction *I) { return std::make_shared(Operation, R, I); } NodePtr submitCompositeNode(NodePtr Node) { CompositeNodes.push_back(Node); AllInstructions.insert(Node->Real); AllInstructions.insert(Node->Imag); for (auto *I : Node->InternalInstructions) AllInstructions.insert(I); return Node; } NodePtr getContainingComposite(Value *R, Value *I) { for (const auto &CN : CompositeNodes) { if (CN->Real == R && CN->Imag == I) return CN; } return nullptr; } /// Identifies a complex partial multiply pattern and its rotation, based on /// the following patterns /// /// 0: r: cr + ar * br /// i: ci + ar * bi /// 90: r: cr - ai * bi /// i: ci + ai * br /// 180: r: cr - ar * br /// i: ci - ar * bi /// 270: r: cr + ai * bi /// i: ci - ai * br NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); /// Identify the other branch of a Partial Mul, taking the CommonOperandI that /// is partially known from identifyPartialMul, filling in the other half of /// the complex pair. NodePtr identifyNodeWithImplicitAdd( Instruction *I, Instruction *J, std::pair &CommonOperandI); /// Identifies a complex add pattern and its rotation, based on the following /// patterns. /// /// 90: r: ar - bi /// i: ai + br /// 270: r: ar + bi /// i: ai - br NodePtr identifyAdd(Instruction *Real, Instruction *Imag); NodePtr identifyNode(Instruction *I, Instruction *J); Value *replaceNode(RawNodePtr Node); public: void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { for (const auto &Node : CompositeNodes) Node->dump(OS); } /// Returns false if the deinterleaving operation should be cancelled for the /// current graph. bool identifyNodes(Instruction *RootI); /// Perform the actual replacement of the underlying instruction graph. /// Returns false if the deinterleaving operation should be cancelled for the /// current graph. void replaceNodes(); }; class ComplexDeinterleaving { public: ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) : TL(tl), TLI(tli) {} bool runOnFunction(Function &F); private: bool evaluateBasicBlock(BasicBlock *B); const TargetLowering *TL = nullptr; const TargetLibraryInfo *TLI = nullptr; }; } // namespace char ComplexDeinterleavingLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, "Complex Deinterleaving", false, false) INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, "Complex Deinterleaving", false, false) PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, FunctionAnalysisManager &AM) { const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); auto &TLI = AM.getResult(F); if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve(); return PA; } FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { return new ComplexDeinterleavingLegacyPass(TM); } bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); auto TLI = getAnalysis().getTLI(F); return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); } bool ComplexDeinterleaving::runOnFunction(Function &F) { if (!ComplexDeinterleavingEnabled) { LLVM_DEBUG( dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); return false; } if (!TL->isComplexDeinterleavingSupported()) { LLVM_DEBUG( dbgs() << "Complex deinterleaving has been disabled, target does " "not support lowering of complex number operations.\n"); return false; } bool Changed = false; for (auto &B : F) Changed |= evaluateBasicBlock(&B); return Changed; } static bool isInterleavingMask(ArrayRef Mask) { // If the size is not even, it's not an interleaving mask if ((Mask.size() & 1)) return false; int HalfNumElements = Mask.size() / 2; for (int Idx = 0; Idx < HalfNumElements; ++Idx) { int MaskIdx = Idx * 2; if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) return false; } return true; } static bool isDeinterleavingMask(ArrayRef Mask) { int Offset = Mask[0]; int HalfNumElements = Mask.size() / 2; for (int Idx = 1; Idx < HalfNumElements; ++Idx) { if (Mask[Idx] != (Idx * 2) + Offset) return false; } return true; } bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { bool Changed = false; SmallVector DeadInstrRoots; for (auto &I : *B) { auto *SVI = dyn_cast(&I); if (!SVI) continue; // Look for a shufflevector that takes separate vectors of the real and // imaginary components and recombines them into a single vector. if (!isInterleavingMask(SVI->getShuffleMask())) continue; ComplexDeinterleavingGraph Graph(TL); if (!Graph.identifyNodes(SVI)) continue; Graph.replaceNodes(); DeadInstrRoots.push_back(SVI); Changed = true; } for (const auto &I : DeadInstrRoots) { if (!I || I->getParent() == nullptr) continue; llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI); } return Changed; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( Instruction *Real, Instruction *Imag, std::pair &PartialMatch) { LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag << "\n"); if (!Real->hasOneUse() || !Imag->hasOneUse()) { LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); return nullptr; } if (Real->getOpcode() != Instruction::FMul || Imag->getOpcode() != Instruction::FMul) { LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n"); return nullptr; } Instruction *R0 = dyn_cast(Real->getOperand(0)); Instruction *R1 = dyn_cast(Real->getOperand(1)); Instruction *I0 = dyn_cast(Imag->getOperand(0)); Instruction *I1 = dyn_cast(Imag->getOperand(1)); if (!R0 || !R1 || !I0 || !I1) { LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); return nullptr; } // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the // rotations and use the operand. unsigned Negs = 0; SmallVector FNegs; if (R0->getOpcode() == Instruction::FNeg || R1->getOpcode() == Instruction::FNeg) { Negs |= 1; if (R0->getOpcode() == Instruction::FNeg) { FNegs.push_back(R0); R0 = dyn_cast(R0->getOperand(0)); } else { FNegs.push_back(R1); R1 = dyn_cast(R1->getOperand(0)); } if (!R0 || !R1) return nullptr; } if (I0->getOpcode() == Instruction::FNeg || I1->getOpcode() == Instruction::FNeg) { Negs |= 2; Negs ^= 1; if (I0->getOpcode() == Instruction::FNeg) { FNegs.push_back(I0); I0 = dyn_cast(I0->getOperand(0)); } else { FNegs.push_back(I1); I1 = dyn_cast(I1->getOperand(0)); } if (!I0 || !I1) return nullptr; } ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; Instruction *CommonOperand; Instruction *UncommonRealOp; Instruction *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; UncommonRealOp = R1; } else if (R1 == I0 || R1 == I1) { CommonOperand = R1; UncommonRealOp = R0; } else { LLVM_DEBUG(dbgs() << " - No equal operand\n"); return nullptr; } UncommonImagOp = (CommonOperand == I0) ? I1 : I0; if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) std::swap(UncommonRealOp, UncommonImagOp); // Between identifyPartialMul and here we need to have found a complete valid // pair from the CommonOperand of each part. if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) PartialMatch.first = CommonOperand; else PartialMatch.second = CommonOperand; if (!PartialMatch.first || !PartialMatch.second) { LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); return nullptr; } NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); if (!CommonNode) { LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); return nullptr; } NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); if (!UncommonNode) { LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); return nullptr; } NodePtr Node = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, Real, Imag); Node->Rotation = Rotation; Node->addOperand(CommonNode); Node->addOperand(UncommonNode); Node->InternalInstructions.append(FNegs); return submitCompositeNode(Node); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag << "\n"); // Determine rotation ComplexDeinterleavingRotation Rotation; if (Real->getOpcode() == Instruction::FAdd && Imag->getOpcode() == Instruction::FAdd) Rotation = ComplexDeinterleavingRotation::Rotation_0; else if (Real->getOpcode() == Instruction::FSub && Imag->getOpcode() == Instruction::FAdd) Rotation = ComplexDeinterleavingRotation::Rotation_90; else if (Real->getOpcode() == Instruction::FSub && Imag->getOpcode() == Instruction::FSub) Rotation = ComplexDeinterleavingRotation::Rotation_180; else if (Real->getOpcode() == Instruction::FAdd && Imag->getOpcode() == Instruction::FSub) Rotation = ComplexDeinterleavingRotation::Rotation_270; else { LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); return nullptr; } if (!Real->getFastMathFlags().allowContract() || !Imag->getFastMathFlags().allowContract()) { LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); return nullptr; } Value *CR = Real->getOperand(0); Instruction *RealMulI = dyn_cast(Real->getOperand(1)); if (!RealMulI) return nullptr; Value *CI = Imag->getOperand(0); Instruction *ImagMulI = dyn_cast(Imag->getOperand(1)); if (!ImagMulI) return nullptr; if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); return nullptr; } Instruction *R0 = dyn_cast(RealMulI->getOperand(0)); Instruction *R1 = dyn_cast(RealMulI->getOperand(1)); Instruction *I0 = dyn_cast(ImagMulI->getOperand(0)); Instruction *I1 = dyn_cast(ImagMulI->getOperand(1)); if (!R0 || !R1 || !I0 || !I1) { LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); return nullptr; } Instruction *CommonOperand; Instruction *UncommonRealOp; Instruction *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; UncommonRealOp = R1; } else if (R1 == I0 || R1 == I1) { CommonOperand = R1; UncommonRealOp = R0; } else { LLVM_DEBUG(dbgs() << " - No equal operand\n"); return nullptr; } UncommonImagOp = (CommonOperand == I0) ? I1 : I0; if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) std::swap(UncommonRealOp, UncommonImagOp); std::pair PartialMatch( (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) ? CommonOperand : nullptr, (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) ? CommonOperand : nullptr); NodePtr CNode = identifyNodeWithImplicitAdd( cast(CR), cast(CI), PartialMatch); if (!CNode) { LLVM_DEBUG(dbgs() << " - No cnode identified\n"); return nullptr; } NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); if (!UncommonRes) { LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); return nullptr; } assert(PartialMatch.first && PartialMatch.second); NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); if (!CommonRes) { LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); return nullptr; } NodePtr Node = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, Real, Imag); Node->addInstruction(RealMulI); Node->addInstruction(ImagMulI); Node->Rotation = Rotation; Node->addOperand(CommonRes); Node->addOperand(UncommonRes); Node->addOperand(CNode); return submitCompositeNode(Node); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); // Determine rotation ComplexDeinterleavingRotation Rotation; if ((Real->getOpcode() == Instruction::FSub && Imag->getOpcode() == Instruction::FAdd) || (Real->getOpcode() == Instruction::Sub && Imag->getOpcode() == Instruction::Add)) Rotation = ComplexDeinterleavingRotation::Rotation_90; else if ((Real->getOpcode() == Instruction::FAdd && Imag->getOpcode() == Instruction::FSub) || (Real->getOpcode() == Instruction::Add && Imag->getOpcode() == Instruction::Sub)) Rotation = ComplexDeinterleavingRotation::Rotation_270; else { LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); return nullptr; } auto *AR = dyn_cast(Real->getOperand(0)); auto *BI = dyn_cast(Real->getOperand(1)); auto *AI = dyn_cast(Imag->getOperand(0)); auto *BR = dyn_cast(Imag->getOperand(1)); if (!AR || !AI || !BR || !BI) { LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); return nullptr; } NodePtr ResA = identifyNode(AR, AI); if (!ResA) { LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); return nullptr; } NodePtr ResB = identifyNode(BR, BI); if (!ResB) { LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); return nullptr; } NodePtr Node = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); Node->Rotation = Rotation; Node->addOperand(ResA); Node->addOperand(ResB); return submitCompositeNode(Node); } static bool isInstructionPairAdd(Instruction *A, Instruction *B) { unsigned OpcA = A->getOpcode(); unsigned OpcB = B->getOpcode(); return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || (OpcA == Instruction::Sub && OpcB == Instruction::Add) || (OpcA == Instruction::Add && OpcB == Instruction::Sub); } static bool isInstructionPairMul(Instruction *A, Instruction *B) { auto Pattern = m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); return match(A, Pattern) && match(B, Pattern); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); if (NodePtr CN = getContainingComposite(Real, Imag)) { LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); return CN; } auto *RealShuffle = dyn_cast(Real); auto *ImagShuffle = dyn_cast(Imag); if (RealShuffle && ImagShuffle) { Value *RealOp1 = RealShuffle->getOperand(1); if (!isa(RealOp1) && !isa(RealOp1)) { LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); return nullptr; } Value *ImagOp1 = ImagShuffle->getOperand(1); if (!isa(ImagOp1) && !isa(ImagOp1)) { LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); return nullptr; } Value *RealOp0 = RealShuffle->getOperand(0); Value *ImagOp0 = ImagShuffle->getOperand(0); if (RealOp0 != ImagOp0) { LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); return nullptr; } ArrayRef RealMask = RealShuffle->getShuffleMask(); ArrayRef ImagMask = ImagShuffle->getShuffleMask(); if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); return nullptr; } if (RealMask[0] != 0 || ImagMask[0] != 1) { LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); return nullptr; } // Type checking, the shuffle type should be a vector type of the same // scalar type, but half the size auto CheckType = [&](ShuffleVectorInst *Shuffle) { Value *Op = Shuffle->getOperand(0); auto *ShuffleTy = cast(Shuffle->getType()); auto *OpTy = cast(Op->getType()); if (OpTy->getScalarType() != ShuffleTy->getScalarType()) return false; if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) return false; return true; }; auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { if (!CheckType(Shuffle)) return false; ArrayRef Mask = Shuffle->getShuffleMask(); int Last = *Mask.rbegin(); Value *Op = Shuffle->getOperand(0); auto *OpTy = cast(Op->getType()); int NumElements = OpTy->getNumElements(); // Ensure that the deinterleaving shuffle only pulls from the first // shuffle operand. return Last < NumElements; }; if (RealShuffle->getType() != ImagShuffle->getType()) { LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); return nullptr; } if (!CheckDeinterleavingShuffle(RealShuffle)) { LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); return nullptr; } if (!CheckDeinterleavingShuffle(ImagShuffle)) { LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); return nullptr; } NodePtr PlaceholderNode = prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle, RealShuffle, ImagShuffle); PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); return submitCompositeNode(PlaceholderNode); } if (RealShuffle || ImagShuffle) return nullptr; auto *VTy = cast(Real->getType()); auto *NewVTy = FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2); if (TL->isComplexDeinterleavingOperationSupported( ComplexDeinterleavingOperation::CMulPartial, NewVTy) && isInstructionPairMul(Real, Imag)) { return identifyPartialMul(Real, Imag); } if (TL->isComplexDeinterleavingOperationSupported( ComplexDeinterleavingOperation::CAdd, NewVTy) && isInstructionPairAdd(Real, Imag)) { return identifyAdd(Real, Imag); } return nullptr; } bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { Instruction *Real; Instruction *Imag; if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) return false; RootValue = RootI; AllInstructions.insert(RootI); RootNode = identifyNode(Real, Imag); LLVM_DEBUG({ Function *F = RootI->getFunction(); BasicBlock *B = RootI->getParent(); dbgs() << "Complex deinterleaving graph for " << F->getName() << "::" << B->getName() << ".\n"; dump(dbgs()); dbgs() << "\n"; }); // Check all instructions have internal uses for (const auto &Node : CompositeNodes) { if (!Node->hasAllInternalUses(AllInstructions)) { LLVM_DEBUG(dbgs() << " - Invalid internal uses\n"); return false; } } return RootNode != nullptr; } Value *ComplexDeinterleavingGraph::replaceNode( ComplexDeinterleavingGraph::RawNodePtr Node) { if (Node->ReplacementNode) return Node->ReplacementNode; Value *Input0 = replaceNode(Node->Operands[0]); Value *Input1 = replaceNode(Node->Operands[1]); Value *Accumulator = Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr; assert(Input0->getType() == Input1->getType() && "Node inputs need to be of the same type"); Node->ReplacementNode = TL->createComplexDeinterleavingIR( Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); assert(Node->ReplacementNode && "Target failed to create Intrinsic call."); NumComplexTransformations += 1; return Node->ReplacementNode; } void ComplexDeinterleavingGraph::replaceNodes() { Value *R = replaceNode(RootNode.get()); assert(R && "Unable to find replacement for RootValue"); RootValue->replaceAllUsesWith(R); } bool ComplexDeinterleavingCompositeNode::hasAllInternalUses( SmallPtrSet &AllInstructions) { if (Operation == ComplexDeinterleavingOperation::Shuffle) return true; for (auto *User : Real->users()) { if (!AllInstructions.contains(cast(User))) return false; } for (auto *User : Imag->users()) { if (!AllInstructions.contains(cast(User))) return false; } for (auto *I : InternalInstructions) { for (auto *User : I->users()) { if (!AllInstructions.contains(cast(User))) return false; } } return true; }