NVPTXLowerAggrCopies.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // \file
  10. // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
  11. // the size is large or is not a compile-time constant.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "NVPTXLowerAggrCopies.h"
  15. #include "llvm/Analysis/TargetTransformInfo.h"
  16. #include "llvm/CodeGen/StackProtector.h"
  17. #include "llvm/IR/Constants.h"
  18. #include "llvm/IR/DataLayout.h"
  19. #include "llvm/IR/Function.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/IR/Instructions.h"
  22. #include "llvm/IR/IntrinsicInst.h"
  23. #include "llvm/IR/Intrinsics.h"
  24. #include "llvm/IR/LLVMContext.h"
  25. #include "llvm/IR/Module.h"
  26. #include "llvm/Support/Debug.h"
  27. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  28. #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
  29. #define DEBUG_TYPE "nvptx"
  30. using namespace llvm;
  31. namespace {
  32. // actual analysis class, which is a functionpass
  33. struct NVPTXLowerAggrCopies : public FunctionPass {
  34. static char ID;
  35. NVPTXLowerAggrCopies() : FunctionPass(ID) {}
  36. void getAnalysisUsage(AnalysisUsage &AU) const override {
  37. AU.addPreserved<StackProtector>();
  38. AU.addRequired<TargetTransformInfoWrapperPass>();
  39. }
  40. bool runOnFunction(Function &F) override;
  41. static const unsigned MaxAggrCopySize = 128;
  42. StringRef getPassName() const override {
  43. return "Lower aggregate copies/intrinsics into loops";
  44. }
  45. };
  46. char NVPTXLowerAggrCopies::ID = 0;
  47. bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
  48. SmallVector<LoadInst *, 4> AggrLoads;
  49. SmallVector<MemIntrinsic *, 4> MemCalls;
  50. const DataLayout &DL = F.getParent()->getDataLayout();
  51. LLVMContext &Context = F.getParent()->getContext();
  52. const TargetTransformInfo &TTI =
  53. getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  54. // Collect all aggregate loads and mem* calls.
  55. for (BasicBlock &BB : F) {
  56. for (Instruction &I : BB) {
  57. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
  58. if (!LI->hasOneUse())
  59. continue;
  60. if (DL.getTypeStoreSize(LI->getType()) < MaxAggrCopySize)
  61. continue;
  62. if (StoreInst *SI = dyn_cast<StoreInst>(LI->user_back())) {
  63. if (SI->getOperand(0) != LI)
  64. continue;
  65. AggrLoads.push_back(LI);
  66. }
  67. } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(&I)) {
  68. // Convert intrinsic calls with variable size or with constant size
  69. // larger than the MaxAggrCopySize threshold.
  70. if (ConstantInt *LenCI = dyn_cast<ConstantInt>(IntrCall->getLength())) {
  71. if (LenCI->getZExtValue() >= MaxAggrCopySize) {
  72. MemCalls.push_back(IntrCall);
  73. }
  74. } else {
  75. MemCalls.push_back(IntrCall);
  76. }
  77. }
  78. }
  79. }
  80. if (AggrLoads.size() == 0 && MemCalls.size() == 0) {
  81. return false;
  82. }
  83. //
  84. // Do the transformation of an aggr load/copy/set to a loop
  85. //
  86. for (LoadInst *LI : AggrLoads) {
  87. auto *SI = cast<StoreInst>(*LI->user_begin());
  88. Value *SrcAddr = LI->getOperand(0);
  89. Value *DstAddr = SI->getOperand(1);
  90. unsigned NumLoads = DL.getTypeStoreSize(LI->getType());
  91. ConstantInt *CopyLen =
  92. ConstantInt::get(Type::getInt32Ty(Context), NumLoads);
  93. createMemCpyLoopKnownSize(/* ConvertedInst */ SI,
  94. /* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr,
  95. /* CopyLen */ CopyLen,
  96. /* SrcAlign */ LI->getAlign(),
  97. /* DestAlign */ SI->getAlign(),
  98. /* SrcIsVolatile */ LI->isVolatile(),
  99. /* DstIsVolatile */ SI->isVolatile(), TTI);
  100. SI->eraseFromParent();
  101. LI->eraseFromParent();
  102. }
  103. // Transform mem* intrinsic calls.
  104. for (MemIntrinsic *MemCall : MemCalls) {
  105. if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(MemCall)) {
  106. expandMemCpyAsLoop(Memcpy, TTI);
  107. } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) {
  108. expandMemMoveAsLoop(Memmove);
  109. } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) {
  110. expandMemSetAsLoop(Memset);
  111. }
  112. MemCall->eraseFromParent();
  113. }
  114. return true;
  115. }
  116. } // namespace
  117. namespace llvm {
  118. void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
  119. }
  120. INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies",
  121. "Lower aggregate copies, and llvm.mem* intrinsics into loops",
  122. false, false)
  123. FunctionPass *llvm::createLowerAggrCopies() {
  124. return new NVPTXLowerAggrCopies();
  125. }