123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- //===- ConvergenceUtils.cpp -----------------------------------------------===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- #include "llvm/Analysis/UniformityAnalysis.h"
- #include "llvm/ADT/GenericUniformityImpl.h"
- #include "llvm/Analysis/CycleAnalysis.h"
- #include "llvm/Analysis/TargetTransformInfo.h"
- #include "llvm/IR/Constants.h"
- #include "llvm/IR/Dominators.h"
- #include "llvm/IR/InstIterator.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/InitializePasses.h"
- using namespace llvm;
- template <>
- bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
- const Instruction &I) const {
- return isDivergent((const Value *)&I);
- }
- template <>
- bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
- const Instruction &Instr, bool AllDefsDivergent) {
- return markDivergent(&Instr);
- }
- template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
- for (auto &I : instructions(F)) {
- if (TTI->isSourceOfDivergence(&I)) {
- assert(!I.isTerminator());
- markDivergent(I);
- } else if (TTI->isAlwaysUniform(&I)) {
- addUniformOverride(I);
- }
- }
- for (auto &Arg : F.args()) {
- if (TTI->isSourceOfDivergence(&Arg)) {
- markDivergent(&Arg);
- }
- }
- }
- template <>
- void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
- const Value *V) {
- for (const auto *User : V->users()) {
- const auto *UserInstr = dyn_cast<const Instruction>(User);
- if (!UserInstr)
- continue;
- if (isAlwaysUniform(*UserInstr))
- continue;
- if (markDivergent(*UserInstr)) {
- Worklist.push_back(UserInstr);
- }
- }
- }
- template <>
- void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
- const Instruction &Instr) {
- assert(!isAlwaysUniform(Instr));
- if (Instr.isTerminator())
- return;
- pushUsers(cast<Value>(&Instr));
- }
- template <>
- bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
- const Instruction &I, const Cycle &DefCycle) const {
- if (isAlwaysUniform(I))
- return false;
- for (const Use &U : I.operands()) {
- if (auto *I = dyn_cast<Instruction>(&U)) {
- if (DefCycle.contains(I->getParent()))
- return true;
- }
- }
- return false;
- }
- // This ensures explicit instantiation of
- // GenericUniformityAnalysisImpl::ImplDeleter::operator()
- template class llvm::GenericUniformityInfo<SSAContext>;
- template struct llvm::GenericUniformityAnalysisImplDeleter<
- llvm::GenericUniformityAnalysisImpl<SSAContext>>;
- //===----------------------------------------------------------------------===//
- // UniformityInfoAnalysis and related pass implementations
- //===----------------------------------------------------------------------===//
- llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
- FunctionAnalysisManager &FAM) {
- auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
- auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
- auto &CI = FAM.getResult<CycleAnalysis>(F);
- return UniformityInfo{F, DT, CI, &TTI};
- }
- AnalysisKey UniformityInfoAnalysis::Key;
- UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
- : OS(OS) {}
- PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
- FunctionAnalysisManager &AM) {
- OS << "UniformityInfo for function '" << F.getName() << "':\n";
- AM.getResult<UniformityInfoAnalysis>(F).print(OS);
- return PreservedAnalyses::all();
- }
- //===----------------------------------------------------------------------===//
- // UniformityInfoWrapperPass Implementation
- //===----------------------------------------------------------------------===//
- char UniformityInfoWrapperPass::ID = 0;
- UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
- initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
- }
- INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo",
- "Uniform Info Analysis", true, true)
- INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
- INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
- INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo",
- "Uniform Info Analysis", true, true)
- void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.setPreservesAll();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<CycleInfoWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- }
- bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
- auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
- auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &targetTransformInfo =
- getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- m_function = &F;
- m_uniformityInfo =
- UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo};
- return false;
- }
- void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
- OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
- }
- void UniformityInfoWrapperPass::releaseMemory() {
- m_uniformityInfo = UniformityInfo{};
- m_function = nullptr;
- }
|