ManagedMemoryRewrite.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Take a module and rewrite:
  10. // 1. `malloc` -> `polly_mallocManaged`
  11. // 2. `free` -> `polly_freeManaged`
  12. // 3. global arrays with initializers -> global arrays that are initialized
  13. // with a constructor call to
  14. // `polly_mallocManaged`.
  15. //
  16. //===----------------------------------------------------------------------===//
  17. #include "polly/CodeGen/IRBuilder.h"
  18. #include "polly/CodeGen/PPCGCodeGeneration.h"
  19. #include "polly/DependenceInfo.h"
  20. #include "polly/LinkAllPasses.h"
  21. #include "polly/Options.h"
  22. #include "polly/ScopDetection.h"
  23. #include "llvm/ADT/SmallSet.h"
  24. #include "llvm/Analysis/CaptureTracking.h"
  25. #include "llvm/InitializePasses.h"
  26. #include "llvm/Transforms/Utils/ModuleUtils.h"
  27. using namespace llvm;
  28. using namespace polly;
  29. static cl::opt<bool> RewriteAllocas(
  30. "polly-acc-rewrite-allocas",
  31. cl::desc(
  32. "Ask the managed memory rewriter to also rewrite alloca instructions"),
  33. cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
  34. static cl::opt<bool> IgnoreLinkageForGlobals(
  35. "polly-acc-rewrite-ignore-linkage-for-globals",
  36. cl::desc(
  37. "By default, we only rewrite globals with internal linkage. This flag "
  38. "enables rewriting of globals regardless of linkage"),
  39. cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
  40. #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
  41. namespace {
  42. static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
  43. const char *Name = "polly_mallocManaged";
  44. Function *F = M.getFunction(Name);
  45. // If F is not available, declare it.
  46. if (!F) {
  47. GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
  48. PollyIRBuilder Builder(M.getContext());
  49. // TODO: How do I get `size_t`? I assume from DataLayout?
  50. FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
  51. {Builder.getInt64Ty()}, false);
  52. F = Function::Create(Ty, Linkage, Name, &M);
  53. }
  54. return F;
  55. }
  56. static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
  57. const char *Name = "polly_freeManaged";
  58. Function *F = M.getFunction(Name);
  59. // If F is not available, declare it.
  60. if (!F) {
  61. GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
  62. PollyIRBuilder Builder(M.getContext());
  63. // TODO: How do I get `size_t`? I assume from DataLayout?
  64. FunctionType *Ty =
  65. FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
  66. F = Function::Create(Ty, Linkage, Name, &M);
  67. }
  68. return F;
  69. }
  70. // Expand a constant expression `Cur`, which is used at instruction `Parent`
  71. // at index `index`.
  72. // Since a constant expression can expand to multiple instructions, store all
  73. // the expands into a set called `Expands`.
  74. // Note that this goes inorder on the constant expression tree.
  75. // A * ((B * D) + C)
  76. // will be processed with first A, then B * D, then B, then D, and then C.
  77. // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
  78. // have something like this:
  79. // *
  80. // / \
  81. // \ /
  82. // (D)
  83. //
  84. // For the purposes of this expansion, we expand the two occurences of D
  85. // separately. Therefore, we expand the DAG into the tree:
  86. // *
  87. // / \
  88. // D D
  89. // TODO: We don't _have_to do this, but this is the simplest solution.
  90. // We can write a solution that keeps track of which constants have been
  91. // already expanded.
  92. static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
  93. Instruction *Parent, int index,
  94. SmallPtrSet<Instruction *, 4> &Expands) {
  95. assert(Cur && "invalid constant expression passed");
  96. Instruction *I = Cur->getAsInstruction();
  97. assert(I && "unable to convert ConstantExpr to Instruction");
  98. LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
  99. << ") in Instruction: (" << *I << ")\n";);
  100. // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
  101. // they should mutate `I`.
  102. Cur = nullptr;
  103. Expands.insert(I);
  104. Parent->setOperand(index, I);
  105. // The things that `Parent` uses (its operands) should be created
  106. // before `Parent`.
  107. Builder.SetInsertPoint(Parent);
  108. Builder.Insert(I);
  109. for (unsigned i = 0; i < I->getNumOperands(); i++) {
  110. Value *Op = I->getOperand(i);
  111. assert(isa<Constant>(Op) && "constant must have a constant operand");
  112. if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
  113. expandConstantExpr(CExprOp, Builder, I, i, Expands);
  114. }
  115. }
  116. // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
  117. // `ConstantExpr`s that are used in the `Inst`.
  118. // Note that `replaceAllUsesWith` is insufficient for this purpose because it
  119. // does not rewrite values in `ConstantExpr`s.
  120. static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
  121. PollyIRBuilder &Builder) {
  122. // This contains a set of instructions in which OldVal must be replaced.
  123. // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
  124. // from `Inst`s arguments.
  125. // We need to go through this process because `replaceAllUsesWith` does not
  126. // actually edit `ConstantExpr`s.
  127. SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
  128. // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
  129. for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
  130. Value *Operand = Inst->getOperand(i);
  131. if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
  132. expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
  133. }
  134. // Now visit each instruction and use `replaceUsesOfWith`. We know that
  135. // will work because `I` cannot have any `ConstantExpr` within it.
  136. for (Instruction *I : InstsToVisit)
  137. I->replaceUsesOfWith(OldVal, NewVal);
  138. }
  139. // Given a value `Current`, return all Instructions that may contain `Current`
  140. // in an expression.
  141. // We need this auxiliary function, because if we have a
  142. // `Constant` that is a user of `V`, we need to recurse into the
  143. // `Constant`s uses to gather the root instruciton.
  144. static void getInstructionUsersOfValue(Value *V,
  145. SmallVector<Instruction *, 4> &Owners) {
  146. if (auto *I = dyn_cast<Instruction>(V)) {
  147. Owners.push_back(I);
  148. } else {
  149. // Anything that is a `User` must be a constant or an instruction.
  150. auto *C = cast<Constant>(V);
  151. for (Use &CUse : C->uses())
  152. getInstructionUsersOfValue(CUse.getUser(), Owners);
  153. }
  154. }
  155. static void
  156. replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
  157. SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
  158. // We only want arrays.
  159. ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType());
  160. if (!ArrayTy)
  161. return;
  162. Type *ElemTy = ArrayTy->getElementType();
  163. PointerType *ElemPtrTy = ElemTy->getPointerTo();
  164. // We only wish to replace arrays that are visible in the module they
  165. // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
  166. // modules.
  167. const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
  168. Array.hasInternalLinkage() ||
  169. IgnoreLinkageForGlobals;
  170. if (!OnlyVisibleInsideModule) {
  171. LLVM_DEBUG(
  172. dbgs() << "Not rewriting (" << Array
  173. << ") to managed memory "
  174. "because it could be visible externally. To force rewrite, "
  175. "use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
  176. return;
  177. }
  178. if (!Array.hasInitializer() ||
  179. !isa<ConstantAggregateZero>(Array.getInitializer())) {
  180. LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
  181. << ") to managed memory "
  182. "because it has an initializer which is "
  183. "not a zeroinitializer.\n");
  184. return;
  185. }
  186. // At this point, we have committed to replacing this array.
  187. ReplacedGlobals.insert(&Array);
  188. std::string NewName = Array.getName().str();
  189. NewName += ".toptr";
  190. GlobalVariable *ReplacementToArr =
  191. cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
  192. ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
  193. Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
  194. std::string FnName = Array.getName().str();
  195. FnName += ".constructor";
  196. PollyIRBuilder Builder(M.getContext());
  197. FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
  198. const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
  199. Function *F = Function::Create(Ty, Linkage, FnName, &M);
  200. BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
  201. Builder.SetInsertPoint(Start);
  202. const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
  203. Value *ArraySize = Builder.getInt64(ArraySizeInt);
  204. ArraySize->setName("array.size");
  205. Value *AllocatedMemRaw =
  206. Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
  207. Value *AllocatedMemTyped =
  208. Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
  209. Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
  210. Builder.CreateRetVoid();
  211. const int Priority = 0;
  212. appendToGlobalCtors(M, F, Priority, ReplacementToArr);
  213. SmallVector<Instruction *, 4> ArrayUserInstructions;
  214. // Get all instructions that use array. We need to do this weird thing
  215. // because `Constant`s that contain this array neeed to be expanded into
  216. // instructions so that we can replace their parameters. `Constant`s cannot
  217. // be edited easily, so we choose to convert all `Constant`s to
  218. // `Instruction`s and handle all of the uses of `Array` uniformly.
  219. for (Use &ArrayUse : Array.uses())
  220. getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
  221. for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
  222. Builder.SetInsertPoint(UserOfArrayInst);
  223. // <ty>** -> <ty>*
  224. Value *ArrPtrLoaded =
  225. Builder.CreateLoad(ElemPtrTy, ReplacementToArr, "arrptr.load");
  226. // <ty>* -> [ty]*
  227. Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
  228. ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
  229. rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
  230. }
  231. }
  232. // We return all `allocas` that may need to be converted to a call to
  233. // cudaMallocManaged.
  234. static void getAllocasToBeManaged(Function &F,
  235. SmallSet<AllocaInst *, 4> &Allocas) {
  236. for (BasicBlock &BB : F) {
  237. for (Instruction &I : BB) {
  238. auto *Alloca = dyn_cast<AllocaInst>(&I);
  239. if (!Alloca)
  240. continue;
  241. LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
  242. if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
  243. /* StoreCaptures */ true)) {
  244. Allocas.insert(Alloca);
  245. LLVM_DEBUG(dbgs() << "YES (captured).\n");
  246. } else {
  247. LLVM_DEBUG(dbgs() << "NO (not captured).\n");
  248. }
  249. }
  250. }
  251. }
  252. static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
  253. const DataLayout &DL) {
  254. LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
  255. Module *M = Alloca->getModule();
  256. assert(M && "Alloca does not have a module");
  257. PollyIRBuilder Builder(M->getContext());
  258. Builder.SetInsertPoint(Alloca);
  259. Function *MallocManagedFn =
  260. getOrCreatePollyMallocManaged(*Alloca->getModule());
  261. const uint64_t Size =
  262. DL.getTypeAllocSize(Alloca->getType()->getElementType());
  263. Value *SizeVal = Builder.getInt64(Size);
  264. Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
  265. Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
  266. Function *F = Alloca->getFunction();
  267. assert(F && "Alloca has invalid function");
  268. Bitcasted->takeName(Alloca);
  269. Alloca->replaceAllUsesWith(Bitcasted);
  270. Alloca->eraseFromParent();
  271. for (BasicBlock &BB : *F) {
  272. ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
  273. if (!Return)
  274. continue;
  275. Builder.SetInsertPoint(Return);
  276. Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
  277. Builder.CreateCall(FreeManagedFn, {RawManagedMem});
  278. }
  279. }
  280. // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
  281. //
  282. // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
  283. // actually does replace it in `ConstantExpr`. The caveat is that if there is
  284. // a use that is *outside* a function (say, at global declarations), we fail.
  285. // So, this is meant to be used on values which we know will only be used
  286. // within functions.
  287. //
  288. // This process works by looking through the uses of `Old`. If it finds a
  289. // `ConstantExpr`, it recursively looks for the owning instruction.
  290. // Then, it expands all the `ConstantExpr` to instructions and replaces
  291. // `Old` with `New` in the expanded instructions.
  292. static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
  293. PollyIRBuilder &Builder) {
  294. SmallVector<Instruction *, 4> UserInstructions;
  295. // Get all instructions that use array. We need to do this weird thing
  296. // because `Constant`s that contain this array neeed to be expanded into
  297. // instructions so that we can replace their parameters. `Constant`s cannot
  298. // be edited easily, so we choose to convert all `Constant`s to
  299. // `Instruction`s and handle all of the uses of `Array` uniformly.
  300. for (Use &ArrayUse : Old->uses())
  301. getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
  302. for (Instruction *I : UserInstructions)
  303. rewriteOldValToNew(I, Old, New, Builder);
  304. }
  305. class ManagedMemoryRewritePass : public ModulePass {
  306. public:
  307. static char ID;
  308. GPUArch Architecture;
  309. GPURuntime Runtime;
  310. ManagedMemoryRewritePass() : ModulePass(ID) {}
  311. bool runOnModule(Module &M) override {
  312. const DataLayout &DL = M.getDataLayout();
  313. Function *Malloc = M.getFunction("malloc");
  314. if (Malloc) {
  315. PollyIRBuilder Builder(M.getContext());
  316. Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
  317. assert(PollyMallocManaged && "unable to create polly_mallocManaged");
  318. replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
  319. Malloc->eraseFromParent();
  320. }
  321. Function *Free = M.getFunction("free");
  322. if (Free) {
  323. PollyIRBuilder Builder(M.getContext());
  324. Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
  325. assert(PollyFreeManaged && "unable to create polly_freeManaged");
  326. replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
  327. Free->eraseFromParent();
  328. }
  329. SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
  330. for (GlobalVariable &Global : M.globals())
  331. replaceGlobalArray(M, DL, Global, GlobalsToErase);
  332. for (GlobalVariable *G : GlobalsToErase)
  333. G->eraseFromParent();
  334. // Rewrite allocas to cudaMallocs if we are asked to do so.
  335. if (RewriteAllocas) {
  336. SmallSet<AllocaInst *, 4> AllocasToBeManaged;
  337. for (Function &F : M.functions())
  338. getAllocasToBeManaged(F, AllocasToBeManaged);
  339. for (AllocaInst *Alloca : AllocasToBeManaged)
  340. rewriteAllocaAsManagedMemory(Alloca, DL);
  341. }
  342. return true;
  343. }
  344. };
  345. } // namespace
  346. char ManagedMemoryRewritePass::ID = 42;
  347. Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
  348. GPURuntime Runtime) {
  349. ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
  350. pass->Runtime = Runtime;
  351. pass->Architecture = Arch;
  352. return pass;
  353. }
  354. INITIALIZE_PASS_BEGIN(
  355. ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
  356. "Polly - Rewrite all allocations in heap & data section to managed memory",
  357. false, false)
  358. INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
  359. INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
  360. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
  361. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
  362. INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
  363. INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
  364. INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
  365. INITIALIZE_PASS_END(
  366. ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
  367. "Polly - Rewrite all allocations in heap & data section to managed memory",
  368. false, false)