//===- ConvergenceUtils.cpp -----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/Analysis/UniformityAnalysis.h" #include "llvm/ADT/GenericUniformityImpl.h" #include "llvm/Analysis/CycleAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" using namespace llvm; template <> bool llvm::GenericUniformityAnalysisImpl::hasDivergentDefs( const Instruction &I) const { return isDivergent((const Value *)&I); } template <> bool llvm::GenericUniformityAnalysisImpl::markDefsDivergent( const Instruction &Instr, bool AllDefsDivergent) { return markDivergent(&Instr); } template <> void llvm::GenericUniformityAnalysisImpl::initialize() { for (auto &I : instructions(F)) { if (TTI->isSourceOfDivergence(&I)) { assert(!I.isTerminator()); markDivergent(I); } else if (TTI->isAlwaysUniform(&I)) { addUniformOverride(I); } } for (auto &Arg : F.args()) { if (TTI->isSourceOfDivergence(&Arg)) { markDivergent(&Arg); } } } template <> void llvm::GenericUniformityAnalysisImpl::pushUsers( const Value *V) { for (const auto *User : V->users()) { const auto *UserInstr = dyn_cast(User); if (!UserInstr) continue; if (isAlwaysUniform(*UserInstr)) continue; if (markDivergent(*UserInstr)) { Worklist.push_back(UserInstr); } } } template <> void llvm::GenericUniformityAnalysisImpl::pushUsers( const Instruction &Instr) { assert(!isAlwaysUniform(Instr)); if (Instr.isTerminator()) return; pushUsers(cast(&Instr)); } template <> bool llvm::GenericUniformityAnalysisImpl::usesValueFromCycle( const Instruction &I, const Cycle &DefCycle) const { if (isAlwaysUniform(I)) return false; for (const Use &U : I.operands()) { if (auto *I = dyn_cast(&U)) { if (DefCycle.contains(I->getParent())) return true; } } return false; } // This ensures explicit instantiation of // GenericUniformityAnalysisImpl::ImplDeleter::operator() template class llvm::GenericUniformityInfo; template struct llvm::GenericUniformityAnalysisImplDeleter< llvm::GenericUniformityAnalysisImpl>; //===----------------------------------------------------------------------===// // UniformityInfoAnalysis and related pass implementations //===----------------------------------------------------------------------===// llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { auto &DT = FAM.getResult(F); auto &TTI = FAM.getResult(F); auto &CI = FAM.getResult(F); return UniformityInfo{F, DT, CI, &TTI}; } AnalysisKey UniformityInfoAnalysis::Key; UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS) : OS(OS) {} PreservedAnalyses UniformityInfoPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { OS << "UniformityInfo for function '" << F.getName() << "':\n"; AM.getResult(F).print(OS); return PreservedAnalyses::all(); } //===----------------------------------------------------------------------===// // UniformityInfoWrapperPass Implementation //===----------------------------------------------------------------------===// char UniformityInfoWrapperPass::ID = 0; UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) { initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry()); } INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo", "Uniform Info Analysis", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo", "Uniform Info Analysis", true, true) void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired(); AU.addRequired(); AU.addRequired(); } bool UniformityInfoWrapperPass::runOnFunction(Function &F) { auto &cycleInfo = getAnalysis().getResult(); auto &domTree = getAnalysis().getDomTree(); auto &targetTransformInfo = getAnalysis().getTTI(F); m_function = &F; m_uniformityInfo = UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo}; return false; } void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const { OS << "UniformityInfo for function '" << m_function->getName() << "':\n"; } void UniformityInfoWrapperPass::releaseMemory() { m_uniformityInfo = UniformityInfo{}; m_function = nullptr; }