AssumptionCache.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. //===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file contains a pass that keeps track of @llvm.assume and
  10. // @llvm.experimental.guard intrinsics in the functions of a module.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/Analysis/AssumptionCache.h"
  14. #include "llvm/ADT/STLExtras.h"
  15. #include "llvm/ADT/SmallPtrSet.h"
  16. #include "llvm/ADT/SmallVector.h"
  17. #include "llvm/Analysis/AssumeBundleQueries.h"
  18. #include "llvm/Analysis/TargetTransformInfo.h"
  19. #include "llvm/IR/BasicBlock.h"
  20. #include "llvm/IR/Function.h"
  21. #include "llvm/IR/InstrTypes.h"
  22. #include "llvm/IR/Instruction.h"
  23. #include "llvm/IR/Instructions.h"
  24. #include "llvm/IR/PassManager.h"
  25. #include "llvm/IR/PatternMatch.h"
  26. #include "llvm/InitializePasses.h"
  27. #include "llvm/Pass.h"
  28. #include "llvm/Support/Casting.h"
  29. #include "llvm/Support/CommandLine.h"
  30. #include "llvm/Support/ErrorHandling.h"
  31. #include "llvm/Support/raw_ostream.h"
  32. #include <cassert>
  33. #include <utility>
  34. using namespace llvm;
  35. using namespace llvm::PatternMatch;
  36. static cl::opt<bool>
  37. VerifyAssumptionCache("verify-assumption-cache", cl::Hidden,
  38. cl::desc("Enable verification of assumption cache"),
  39. cl::init(false));
  40. SmallVector<AssumptionCache::ResultElem, 1> &
  41. AssumptionCache::getOrInsertAffectedValues(Value *V) {
  42. // Try using find_as first to avoid creating extra value handles just for the
  43. // purpose of doing the lookup.
  44. auto AVI = AffectedValues.find_as(V);
  45. if (AVI != AffectedValues.end())
  46. return AVI->second;
  47. auto AVIP = AffectedValues.insert(
  48. {AffectedValueCallbackVH(V, this), SmallVector<ResultElem, 1>()});
  49. return AVIP.first->second;
  50. }
  51. static void
  52. findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
  53. SmallVectorImpl<AssumptionCache::ResultElem> &Affected) {
  54. // Note: This code must be kept in-sync with the code in
  55. // computeKnownBitsFromAssume in ValueTracking.
  56. auto AddAffected = [&Affected](Value *V, unsigned Idx =
  57. AssumptionCache::ExprResultIdx) {
  58. if (isa<Argument>(V)) {
  59. Affected.push_back({V, Idx});
  60. } else if (auto *I = dyn_cast<Instruction>(V)) {
  61. Affected.push_back({I, Idx});
  62. // Peek through unary operators to find the source of the condition.
  63. Value *Op;
  64. if (match(I, m_BitCast(m_Value(Op))) ||
  65. match(I, m_PtrToInt(m_Value(Op))) || match(I, m_Not(m_Value(Op)))) {
  66. if (isa<Instruction>(Op) || isa<Argument>(Op))
  67. Affected.push_back({Op, Idx});
  68. }
  69. }
  70. };
  71. for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++) {
  72. if (CI->getOperandBundleAt(Idx).Inputs.size() > ABA_WasOn &&
  73. CI->getOperandBundleAt(Idx).getTagName() != IgnoreBundleTag)
  74. AddAffected(CI->getOperandBundleAt(Idx).Inputs[ABA_WasOn], Idx);
  75. }
  76. Value *Cond = CI->getArgOperand(0), *A, *B;
  77. AddAffected(Cond);
  78. CmpInst::Predicate Pred;
  79. if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
  80. AddAffected(A);
  81. AddAffected(B);
  82. if (Pred == ICmpInst::ICMP_EQ) {
  83. // For equality comparisons, we handle the case of bit inversion.
  84. auto AddAffectedFromEq = [&AddAffected](Value *V) {
  85. Value *A;
  86. if (match(V, m_Not(m_Value(A)))) {
  87. AddAffected(A);
  88. V = A;
  89. }
  90. Value *B;
  91. // (A & B) or (A | B) or (A ^ B).
  92. if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
  93. AddAffected(A);
  94. AddAffected(B);
  95. // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
  96. } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
  97. AddAffected(A);
  98. }
  99. };
  100. AddAffectedFromEq(A);
  101. AddAffectedFromEq(B);
  102. } else if (Pred == ICmpInst::ICMP_NE) {
  103. Value *X, *Y;
  104. // Handle (a & b != 0). If a/b is a power of 2 we can use this
  105. // information.
  106. if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
  107. AddAffected(X);
  108. AddAffected(Y);
  109. }
  110. } else if (Pred == ICmpInst::ICMP_ULT) {
  111. Value *X;
  112. // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
  113. // and recognized by LVI at least.
  114. if (match(A, m_Add(m_Value(X), m_ConstantInt())) &&
  115. match(B, m_ConstantInt()))
  116. AddAffected(X);
  117. }
  118. }
  119. if (TTI) {
  120. const Value *Ptr;
  121. unsigned AS;
  122. std::tie(Ptr, AS) = TTI->getPredicatedAddrSpace(Cond);
  123. if (Ptr)
  124. AddAffected(const_cast<Value *>(Ptr->stripInBoundsOffsets()));
  125. }
  126. }
  127. void AssumptionCache::updateAffectedValues(CondGuardInst *CI) {
  128. SmallVector<AssumptionCache::ResultElem, 16> Affected;
  129. findAffectedValues(CI, TTI, Affected);
  130. for (auto &AV : Affected) {
  131. auto &AVV = getOrInsertAffectedValues(AV.Assume);
  132. if (llvm::none_of(AVV, [&](ResultElem &Elem) {
  133. return Elem.Assume == CI && Elem.Index == AV.Index;
  134. }))
  135. AVV.push_back({CI, AV.Index});
  136. }
  137. }
  138. void AssumptionCache::unregisterAssumption(CondGuardInst *CI) {
  139. SmallVector<AssumptionCache::ResultElem, 16> Affected;
  140. findAffectedValues(CI, TTI, Affected);
  141. for (auto &AV : Affected) {
  142. auto AVI = AffectedValues.find_as(AV.Assume);
  143. if (AVI == AffectedValues.end())
  144. continue;
  145. bool Found = false;
  146. bool HasNonnull = false;
  147. for (ResultElem &Elem : AVI->second) {
  148. if (Elem.Assume == CI) {
  149. Found = true;
  150. Elem.Assume = nullptr;
  151. }
  152. HasNonnull |= !!Elem.Assume;
  153. if (HasNonnull && Found)
  154. break;
  155. }
  156. assert(Found && "already unregistered or incorrect cache state");
  157. if (!HasNonnull)
  158. AffectedValues.erase(AVI);
  159. }
  160. erase_value(AssumeHandles, CI);
  161. }
  162. void AssumptionCache::AffectedValueCallbackVH::deleted() {
  163. AC->AffectedValues.erase(getValPtr());
  164. // 'this' now dangles!
  165. }
  166. void AssumptionCache::transferAffectedValuesInCache(Value *OV, Value *NV) {
  167. auto &NAVV = getOrInsertAffectedValues(NV);
  168. auto AVI = AffectedValues.find(OV);
  169. if (AVI == AffectedValues.end())
  170. return;
  171. for (auto &A : AVI->second)
  172. if (!llvm::is_contained(NAVV, A))
  173. NAVV.push_back(A);
  174. AffectedValues.erase(OV);
  175. }
  176. void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) {
  177. if (!isa<Instruction>(NV) && !isa<Argument>(NV))
  178. return;
  179. // Any assumptions that affected this value now affect the new value.
  180. AC->transferAffectedValuesInCache(getValPtr(), NV);
  181. // 'this' now might dangle! If the AffectedValues map was resized to add an
  182. // entry for NV then this object might have been destroyed in favor of some
  183. // copy in the grown map.
  184. }
  185. void AssumptionCache::scanFunction() {
  186. assert(!Scanned && "Tried to scan the function twice!");
  187. assert(AssumeHandles.empty() && "Already have assumes when scanning!");
  188. // Go through all instructions in all blocks, add all calls to @llvm.assume
  189. // to this cache.
  190. for (BasicBlock &B : F)
  191. for (Instruction &I : B)
  192. if (isa<CondGuardInst>(&I))
  193. AssumeHandles.push_back({&I, ExprResultIdx});
  194. // Mark the scan as complete.
  195. Scanned = true;
  196. // Update affected values.
  197. for (auto &A : AssumeHandles)
  198. updateAffectedValues(cast<CondGuardInst>(A));
  199. }
  200. void AssumptionCache::registerAssumption(CondGuardInst *CI) {
  201. // If we haven't scanned the function yet, just drop this assumption. It will
  202. // be found when we scan later.
  203. if (!Scanned)
  204. return;
  205. AssumeHandles.push_back({CI, ExprResultIdx});
  206. #ifndef NDEBUG
  207. assert(CI->getParent() &&
  208. "Cannot a register CondGuardInst not in a basic block");
  209. assert(&F == CI->getParent()->getParent() &&
  210. "Cannot a register CondGuardInst not in this function");
  211. // We expect the number of assumptions to be small, so in an asserts build
  212. // check that we don't accumulate duplicates and that all assumptions point
  213. // to the same function.
  214. SmallPtrSet<Value *, 16> AssumptionSet;
  215. for (auto &VH : AssumeHandles) {
  216. if (!VH)
  217. continue;
  218. assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
  219. "Cached assumption not inside this function!");
  220. assert(isa<CondGuardInst>(VH) &&
  221. "Cached something other than CondGuardInst!");
  222. assert(AssumptionSet.insert(VH).second &&
  223. "Cache contains multiple copies of a call!");
  224. }
  225. #endif
  226. updateAffectedValues(CI);
  227. }
  228. AssumptionCache AssumptionAnalysis::run(Function &F,
  229. FunctionAnalysisManager &FAM) {
  230. auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
  231. return AssumptionCache(F, &TTI);
  232. }
  233. AnalysisKey AssumptionAnalysis::Key;
  234. PreservedAnalyses AssumptionPrinterPass::run(Function &F,
  235. FunctionAnalysisManager &AM) {
  236. AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
  237. OS << "Cached assumptions for function: " << F.getName() << "\n";
  238. for (auto &VH : AC.assumptions())
  239. if (VH)
  240. OS << " " << *cast<CallInst>(VH)->getArgOperand(0) << "\n";
  241. return PreservedAnalyses::all();
  242. }
  243. void AssumptionCacheTracker::FunctionCallbackVH::deleted() {
  244. auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr()));
  245. if (I != ACT->AssumptionCaches.end())
  246. ACT->AssumptionCaches.erase(I);
  247. // 'this' now dangles!
  248. }
  249. AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
  250. // We probe the function map twice to try and avoid creating a value handle
  251. // around the function in common cases. This makes insertion a bit slower,
  252. // but if we have to insert we're going to scan the whole function so that
  253. // shouldn't matter.
  254. auto I = AssumptionCaches.find_as(&F);
  255. if (I != AssumptionCaches.end())
  256. return *I->second;
  257. auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
  258. auto *TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
  259. // Ok, build a new cache by scanning the function, insert it and the value
  260. // handle into our map, and return the newly populated cache.
  261. auto IP = AssumptionCaches.insert(std::make_pair(
  262. FunctionCallbackVH(&F, this), std::make_unique<AssumptionCache>(F, TTI)));
  263. assert(IP.second && "Scanning function already in the map?");
  264. return *IP.first->second;
  265. }
  266. AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) {
  267. auto I = AssumptionCaches.find_as(&F);
  268. if (I != AssumptionCaches.end())
  269. return I->second.get();
  270. return nullptr;
  271. }
  272. void AssumptionCacheTracker::verifyAnalysis() const {
  273. // FIXME: In the long term the verifier should not be controllable with a
  274. // flag. We should either fix all passes to correctly update the assumption
  275. // cache and enable the verifier unconditionally or somehow arrange for the
  276. // assumption list to be updated automatically by passes.
  277. if (!VerifyAssumptionCache)
  278. return;
  279. SmallPtrSet<const CallInst *, 4> AssumptionSet;
  280. for (const auto &I : AssumptionCaches) {
  281. for (auto &VH : I.second->assumptions())
  282. if (VH)
  283. AssumptionSet.insert(cast<CallInst>(VH));
  284. for (const BasicBlock &B : cast<Function>(*I.first))
  285. for (const Instruction &II : B)
  286. if (match(&II, m_Intrinsic<Intrinsic::assume>()) &&
  287. !AssumptionSet.count(cast<CallInst>(&II)))
  288. report_fatal_error("Assumption in scanned function not in cache");
  289. }
  290. }
  291. AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) {
  292. initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry());
  293. }
  294. AssumptionCacheTracker::~AssumptionCacheTracker() = default;
  295. char AssumptionCacheTracker::ID = 0;
  296. INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker",
  297. "Assumption Cache Tracker", false, true)