ReplaceWithVeclib.cpp 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
  10. // with vector operands) with matching calls to functions from a vector
  11. // library (e.g., libmvec, SVML) according to TargetLibraryInfo.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/CodeGen/ReplaceWithVeclib.h"
  15. #include "llvm/ADT/STLExtras.h"
  16. #include "llvm/ADT/Statistic.h"
  17. #include "llvm/Analysis/DemandedBits.h"
  18. #include "llvm/Analysis/GlobalsModRef.h"
  19. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  20. #include "llvm/Analysis/TargetLibraryInfo.h"
  21. #include "llvm/Analysis/VectorUtils.h"
  22. #include "llvm/CodeGen/Passes.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/InstIterator.h"
  25. #include "llvm/Transforms/Utils/ModuleUtils.h"
  26. using namespace llvm;
  27. #define DEBUG_TYPE "replace-with-veclib"
  28. STATISTIC(NumCallsReplaced,
  29. "Number of calls to intrinsics that have been replaced.");
  30. STATISTIC(NumTLIFuncDeclAdded,
  31. "Number of vector library function declarations added.");
  32. STATISTIC(NumFuncUsedAdded,
  33. "Number of functions added to `llvm.compiler.used`");
  34. static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
  35. Module *M = CI.getModule();
  36. Function *OldFunc = CI.getCalledFunction();
  37. // Check if the vector library function is already declared in this module,
  38. // otherwise insert it.
  39. Function *TLIFunc = M->getFunction(TLIName);
  40. if (!TLIFunc) {
  41. TLIFunc = Function::Create(OldFunc->getFunctionType(),
  42. Function::ExternalLinkage, TLIName, *M);
  43. TLIFunc->copyAttributesFrom(OldFunc);
  44. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
  45. << TLIName << "` of type `" << *(TLIFunc->getType())
  46. << "` to module.\n");
  47. ++NumTLIFuncDeclAdded;
  48. // Add the freshly created function to llvm.compiler.used,
  49. // similar to as it is done in InjectTLIMappings
  50. appendToCompilerUsed(*M, {TLIFunc});
  51. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
  52. << "` to `@llvm.compiler.used`.\n");
  53. ++NumFuncUsedAdded;
  54. }
  55. // Replace the call to the vector intrinsic with a call
  56. // to the corresponding function from the vector library.
  57. IRBuilder<> IRBuilder(&CI);
  58. SmallVector<Value *> Args(CI.args());
  59. // Preserve the operand bundles.
  60. SmallVector<OperandBundleDef, 1> OpBundles;
  61. CI.getOperandBundlesAsDefs(OpBundles);
  62. CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
  63. assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
  64. "Expecting function types to be identical");
  65. CI.replaceAllUsesWith(Replacement);
  66. if (isa<FPMathOperator>(Replacement)) {
  67. // Preserve fast math flags for FP math.
  68. Replacement->copyFastMathFlags(&CI);
  69. }
  70. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
  71. << OldFunc->getName() << "` with call to `" << TLIName
  72. << "`.\n");
  73. ++NumCallsReplaced;
  74. return true;
  75. }
  76. static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
  77. CallInst &CI) {
  78. if (!CI.getCalledFunction()) {
  79. return false;
  80. }
  81. auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
  82. if (IntrinsicID == Intrinsic::not_intrinsic) {
  83. // Replacement is only performed for intrinsic functions
  84. return false;
  85. }
  86. // Convert vector arguments to scalar type and check that
  87. // all vector operands have identical vector width.
  88. ElementCount VF = ElementCount::getFixed(0);
  89. SmallVector<Type *> ScalarTypes;
  90. for (auto Arg : enumerate(CI.args())) {
  91. auto *ArgType = Arg.value()->getType();
  92. // Vector calls to intrinsics can still have
  93. // scalar operands for specific arguments.
  94. if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
  95. ScalarTypes.push_back(ArgType);
  96. } else {
  97. // The argument in this place should be a vector if
  98. // this is a call to a vector intrinsic.
  99. auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
  100. if (!VectorArgTy) {
  101. // The argument is not a vector, do not perform
  102. // the replacement.
  103. return false;
  104. }
  105. ElementCount NumElements = VectorArgTy->getElementCount();
  106. if (NumElements.isScalable()) {
  107. // The current implementation does not support
  108. // scalable vectors.
  109. return false;
  110. }
  111. if (VF.isNonZero() && VF != NumElements) {
  112. // The different arguments differ in vector size.
  113. return false;
  114. } else {
  115. VF = NumElements;
  116. }
  117. ScalarTypes.push_back(VectorArgTy->getElementType());
  118. }
  119. }
  120. // Try to reconstruct the name for the scalar version of this
  121. // intrinsic using the intrinsic ID and the argument types
  122. // converted to scalar above.
  123. std::string ScalarName;
  124. if (Intrinsic::isOverloaded(IntrinsicID)) {
  125. ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
  126. } else {
  127. ScalarName = Intrinsic::getName(IntrinsicID).str();
  128. }
  129. if (!TLI.isFunctionVectorizable(ScalarName)) {
  130. // The TargetLibraryInfo does not contain a vectorized version of
  131. // the scalar function.
  132. return false;
  133. }
  134. // Try to find the mapping for the scalar version of this intrinsic
  135. // and the exact vector width of the call operands in the
  136. // TargetLibraryInfo.
  137. const std::string TLIName =
  138. std::string(TLI.getVectorizedFunction(ScalarName, VF));
  139. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
  140. << ScalarName << "` and vector width " << VF << ".\n");
  141. if (!TLIName.empty()) {
  142. // Found the correct mapping in the TargetLibraryInfo,
  143. // replace the call to the intrinsic with a call to
  144. // the vector library function.
  145. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
  146. << "`.\n");
  147. return replaceWithTLIFunction(CI, TLIName);
  148. }
  149. return false;
  150. }
  151. static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
  152. bool Changed = false;
  153. SmallVector<CallInst *> ReplacedCalls;
  154. for (auto &I : instructions(F)) {
  155. if (auto *CI = dyn_cast<CallInst>(&I)) {
  156. if (replaceWithCallToVeclib(TLI, *CI)) {
  157. ReplacedCalls.push_back(CI);
  158. Changed = true;
  159. }
  160. }
  161. }
  162. // Erase the calls to the intrinsics that have been replaced
  163. // with calls to the vector library.
  164. for (auto *CI : ReplacedCalls) {
  165. CI->eraseFromParent();
  166. }
  167. return Changed;
  168. }
  169. ////////////////////////////////////////////////////////////////////////////////
  170. // New pass manager implementation.
  171. ////////////////////////////////////////////////////////////////////////////////
  172. PreservedAnalyses ReplaceWithVeclib::run(Function &F,
  173. FunctionAnalysisManager &AM) {
  174. const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
  175. auto Changed = runImpl(TLI, F);
  176. if (Changed) {
  177. PreservedAnalyses PA;
  178. PA.preserveSet<CFGAnalyses>();
  179. PA.preserve<TargetLibraryAnalysis>();
  180. PA.preserve<ScalarEvolutionAnalysis>();
  181. PA.preserve<LoopAccessAnalysis>();
  182. PA.preserve<DemandedBitsAnalysis>();
  183. PA.preserve<OptimizationRemarkEmitterAnalysis>();
  184. return PA;
  185. } else {
  186. // The pass did not replace any calls, hence it preserves all analyses.
  187. return PreservedAnalyses::all();
  188. }
  189. }
  190. ////////////////////////////////////////////////////////////////////////////////
  191. // Legacy PM Implementation.
  192. ////////////////////////////////////////////////////////////////////////////////
  193. bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
  194. const TargetLibraryInfo &TLI =
  195. getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  196. return runImpl(TLI, F);
  197. }
  198. void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
  199. AU.setPreservesCFG();
  200. AU.addRequired<TargetLibraryInfoWrapperPass>();
  201. AU.addPreserved<TargetLibraryInfoWrapperPass>();
  202. AU.addPreserved<ScalarEvolutionWrapperPass>();
  203. AU.addPreserved<AAResultsWrapperPass>();
  204. AU.addPreserved<LoopAccessLegacyAnalysis>();
  205. AU.addPreserved<DemandedBitsWrapperPass>();
  206. AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
  207. AU.addPreserved<GlobalsAAWrapperPass>();
  208. }
  209. ////////////////////////////////////////////////////////////////////////////////
  210. // Legacy Pass manager initialization
  211. ////////////////////////////////////////////////////////////////////////////////
  212. char ReplaceWithVeclibLegacy::ID = 0;
  213. INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
  214. "Replace intrinsics with calls to vector library", false,
  215. false)
  216. INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
  217. INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
  218. "Replace intrinsics with calls to vector library", false,
  219. false)
  220. FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
  221. return new ReplaceWithVeclibLegacy();
  222. }