ReplaceWithVeclib.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. //=== ReplaceWithVeclib.cpp - Replace vector instrinsics with veclib calls ===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
  10. // with vector operands) with matching calls to functions from a vector
  11. // library (e.g., libmvec, SVML) according to TargetLibraryInfo.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/CodeGen/ReplaceWithVeclib.h"
  15. #include "llvm/ADT/STLExtras.h"
  16. #include "llvm/ADT/Statistic.h"
  17. #include "llvm/Analysis/DemandedBits.h"
  18. #include "llvm/Analysis/GlobalsModRef.h"
  19. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  20. #include "llvm/Analysis/TargetLibraryInfo.h"
  21. #include "llvm/Analysis/VectorUtils.h"
  22. #include "llvm/CodeGen/Passes.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/InstIterator.h"
  25. #include "llvm/IR/IntrinsicInst.h"
  26. #include "llvm/Transforms/Utils/ModuleUtils.h"
  27. using namespace llvm;
  28. #define DEBUG_TYPE "replace-with-veclib"
  29. STATISTIC(NumCallsReplaced,
  30. "Number of calls to intrinsics that have been replaced.");
  31. STATISTIC(NumTLIFuncDeclAdded,
  32. "Number of vector library function declarations added.");
  33. STATISTIC(NumFuncUsedAdded,
  34. "Number of functions added to `llvm.compiler.used`");
  35. static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
  36. Module *M = CI.getModule();
  37. Function *OldFunc = CI.getCalledFunction();
  38. // Check if the vector library function is already declared in this module,
  39. // otherwise insert it.
  40. Function *TLIFunc = M->getFunction(TLIName);
  41. if (!TLIFunc) {
  42. TLIFunc = Function::Create(OldFunc->getFunctionType(),
  43. Function::ExternalLinkage, TLIName, *M);
  44. TLIFunc->copyAttributesFrom(OldFunc);
  45. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
  46. << TLIName << "` of type `" << *(TLIFunc->getType())
  47. << "` to module.\n");
  48. ++NumTLIFuncDeclAdded;
  49. // Add the freshly created function to llvm.compiler.used,
  50. // similar to as it is done in InjectTLIMappings
  51. appendToCompilerUsed(*M, {TLIFunc});
  52. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
  53. << "` to `@llvm.compiler.used`.\n");
  54. ++NumFuncUsedAdded;
  55. }
  56. // Replace the call to the vector intrinsic with a call
  57. // to the corresponding function from the vector library.
  58. IRBuilder<> IRBuilder(&CI);
  59. SmallVector<Value *> Args(CI.args());
  60. // Preserve the operand bundles.
  61. SmallVector<OperandBundleDef, 1> OpBundles;
  62. CI.getOperandBundlesAsDefs(OpBundles);
  63. CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
  64. assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
  65. "Expecting function types to be identical");
  66. CI.replaceAllUsesWith(Replacement);
  67. if (isa<FPMathOperator>(Replacement)) {
  68. // Preserve fast math flags for FP math.
  69. Replacement->copyFastMathFlags(&CI);
  70. }
  71. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
  72. << OldFunc->getName() << "` with call to `" << TLIName
  73. << "`.\n");
  74. ++NumCallsReplaced;
  75. return true;
  76. }
  77. static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
  78. CallInst &CI) {
  79. if (!CI.getCalledFunction()) {
  80. return false;
  81. }
  82. auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
  83. if (IntrinsicID == Intrinsic::not_intrinsic) {
  84. // Replacement is only performed for intrinsic functions
  85. return false;
  86. }
  87. // Convert vector arguments to scalar type and check that
  88. // all vector operands have identical vector width.
  89. ElementCount VF = ElementCount::getFixed(0);
  90. SmallVector<Type *> ScalarTypes;
  91. for (auto Arg : enumerate(CI.args())) {
  92. auto *ArgType = Arg.value()->getType();
  93. // Vector calls to intrinsics can still have
  94. // scalar operands for specific arguments.
  95. if (hasVectorInstrinsicScalarOpd(IntrinsicID, Arg.index())) {
  96. ScalarTypes.push_back(ArgType);
  97. } else {
  98. // The argument in this place should be a vector if
  99. // this is a call to a vector intrinsic.
  100. auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
  101. if (!VectorArgTy) {
  102. // The argument is not a vector, do not perform
  103. // the replacement.
  104. return false;
  105. }
  106. ElementCount NumElements = VectorArgTy->getElementCount();
  107. if (NumElements.isScalable()) {
  108. // The current implementation does not support
  109. // scalable vectors.
  110. return false;
  111. }
  112. if (VF.isNonZero() && VF != NumElements) {
  113. // The different arguments differ in vector size.
  114. return false;
  115. } else {
  116. VF = NumElements;
  117. }
  118. ScalarTypes.push_back(VectorArgTy->getElementType());
  119. }
  120. }
  121. // Try to reconstruct the name for the scalar version of this
  122. // intrinsic using the intrinsic ID and the argument types
  123. // converted to scalar above.
  124. std::string ScalarName;
  125. if (Intrinsic::isOverloaded(IntrinsicID)) {
  126. ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
  127. } else {
  128. ScalarName = Intrinsic::getName(IntrinsicID).str();
  129. }
  130. if (!TLI.isFunctionVectorizable(ScalarName)) {
  131. // The TargetLibraryInfo does not contain a vectorized version of
  132. // the scalar function.
  133. return false;
  134. }
  135. // Try to find the mapping for the scalar version of this intrinsic
  136. // and the exact vector width of the call operands in the
  137. // TargetLibraryInfo.
  138. const std::string TLIName =
  139. std::string(TLI.getVectorizedFunction(ScalarName, VF));
  140. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
  141. << ScalarName << "` and vector width " << VF << ".\n");
  142. if (!TLIName.empty()) {
  143. // Found the correct mapping in the TargetLibraryInfo,
  144. // replace the call to the intrinsic with a call to
  145. // the vector library function.
  146. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
  147. << "`.\n");
  148. return replaceWithTLIFunction(CI, TLIName);
  149. }
  150. return false;
  151. }
  152. static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
  153. bool Changed = false;
  154. SmallVector<CallInst *> ReplacedCalls;
  155. for (auto &I : instructions(F)) {
  156. if (auto *CI = dyn_cast<CallInst>(&I)) {
  157. if (replaceWithCallToVeclib(TLI, *CI)) {
  158. ReplacedCalls.push_back(CI);
  159. Changed = true;
  160. }
  161. }
  162. }
  163. // Erase the calls to the intrinsics that have been replaced
  164. // with calls to the vector library.
  165. for (auto *CI : ReplacedCalls) {
  166. CI->eraseFromParent();
  167. }
  168. return Changed;
  169. }
  170. ////////////////////////////////////////////////////////////////////////////////
  171. // New pass manager implementation.
  172. ////////////////////////////////////////////////////////////////////////////////
  173. PreservedAnalyses ReplaceWithVeclib::run(Function &F,
  174. FunctionAnalysisManager &AM) {
  175. const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
  176. auto Changed = runImpl(TLI, F);
  177. if (Changed) {
  178. PreservedAnalyses PA;
  179. PA.preserveSet<CFGAnalyses>();
  180. PA.preserve<TargetLibraryAnalysis>();
  181. PA.preserve<ScalarEvolutionAnalysis>();
  182. PA.preserve<LoopAccessAnalysis>();
  183. PA.preserve<DemandedBitsAnalysis>();
  184. PA.preserve<OptimizationRemarkEmitterAnalysis>();
  185. return PA;
  186. } else {
  187. // The pass did not replace any calls, hence it preserves all analyses.
  188. return PreservedAnalyses::all();
  189. }
  190. }
  191. ////////////////////////////////////////////////////////////////////////////////
  192. // Legacy PM Implementation.
  193. ////////////////////////////////////////////////////////////////////////////////
  194. bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
  195. const TargetLibraryInfo &TLI =
  196. getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  197. return runImpl(TLI, F);
  198. }
  199. void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
  200. AU.setPreservesCFG();
  201. AU.addRequired<TargetLibraryInfoWrapperPass>();
  202. AU.addPreserved<TargetLibraryInfoWrapperPass>();
  203. AU.addPreserved<ScalarEvolutionWrapperPass>();
  204. AU.addPreserved<AAResultsWrapperPass>();
  205. AU.addPreserved<LoopAccessLegacyAnalysis>();
  206. AU.addPreserved<DemandedBitsWrapperPass>();
  207. AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
  208. AU.addPreserved<GlobalsAAWrapperPass>();
  209. }
  210. ////////////////////////////////////////////////////////////////////////////////
  211. // Legacy Pass manager initialization
  212. ////////////////////////////////////////////////////////////////////////////////
  213. char ReplaceWithVeclibLegacy::ID = 0;
  214. INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
  215. "Replace intrinsics with calls to vector library", false,
  216. false)
  217. INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
  218. INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
  219. "Replace intrinsics with calls to vector library", false,
  220. false)
  221. FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
  222. return new ReplaceWithVeclibLegacy();
  223. }