123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- // Take a module and rewrite:
- // 1. `malloc` -> `polly_mallocManaged`
- // 2. `free` -> `polly_freeManaged`
- // 3. global arrays with initializers -> global arrays that are initialized
- // with a constructor call to
- // `polly_mallocManaged`.
- //
- //===----------------------------------------------------------------------===//
- #include "polly/CodeGen/IRBuilder.h"
- #include "polly/CodeGen/PPCGCodeGeneration.h"
- #include "polly/DependenceInfo.h"
- #include "polly/LinkAllPasses.h"
- #include "polly/Options.h"
- #include "polly/ScopDetection.h"
- #include "llvm/ADT/SmallSet.h"
- #include "llvm/Analysis/CaptureTracking.h"
- #include "llvm/InitializePasses.h"
- #include "llvm/Transforms/Utils/ModuleUtils.h"
- using namespace llvm;
- using namespace polly;
- static cl::opt<bool> RewriteAllocas(
- "polly-acc-rewrite-allocas",
- cl::desc(
- "Ask the managed memory rewriter to also rewrite alloca instructions"),
- cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
- static cl::opt<bool> IgnoreLinkageForGlobals(
- "polly-acc-rewrite-ignore-linkage-for-globals",
- cl::desc(
- "By default, we only rewrite globals with internal linkage. This flag "
- "enables rewriting of globals regardless of linkage"),
- cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
- #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
- namespace {
- static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
- const char *Name = "polly_mallocManaged";
- Function *F = M.getFunction(Name);
- // If F is not available, declare it.
- if (!F) {
- GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
- PollyIRBuilder Builder(M.getContext());
- // TODO: How do I get `size_t`? I assume from DataLayout?
- FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
- {Builder.getInt64Ty()}, false);
- F = Function::Create(Ty, Linkage, Name, &M);
- }
- return F;
- }
- static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
- const char *Name = "polly_freeManaged";
- Function *F = M.getFunction(Name);
- // If F is not available, declare it.
- if (!F) {
- GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
- PollyIRBuilder Builder(M.getContext());
- // TODO: How do I get `size_t`? I assume from DataLayout?
- FunctionType *Ty =
- FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
- F = Function::Create(Ty, Linkage, Name, &M);
- }
- return F;
- }
- // Expand a constant expression `Cur`, which is used at instruction `Parent`
- // at index `index`.
- // Since a constant expression can expand to multiple instructions, store all
- // the expands into a set called `Expands`.
- // Note that this goes inorder on the constant expression tree.
- // A * ((B * D) + C)
- // will be processed with first A, then B * D, then B, then D, and then C.
- // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
- // have something like this:
- // *
- // / \
- // \ /
- // (D)
- //
- // For the purposes of this expansion, we expand the two occurences of D
- // separately. Therefore, we expand the DAG into the tree:
- // *
- // / \
- // D D
- // TODO: We don't _have_to do this, but this is the simplest solution.
- // We can write a solution that keeps track of which constants have been
- // already expanded.
- static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
- Instruction *Parent, int index,
- SmallPtrSet<Instruction *, 4> &Expands) {
- assert(Cur && "invalid constant expression passed");
- Instruction *I = Cur->getAsInstruction();
- assert(I && "unable to convert ConstantExpr to Instruction");
- LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
- << ") in Instruction: (" << *I << ")\n";);
- // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
- // they should mutate `I`.
- Cur = nullptr;
- Expands.insert(I);
- Parent->setOperand(index, I);
- // The things that `Parent` uses (its operands) should be created
- // before `Parent`.
- Builder.SetInsertPoint(Parent);
- Builder.Insert(I);
- for (unsigned i = 0; i < I->getNumOperands(); i++) {
- Value *Op = I->getOperand(i);
- assert(isa<Constant>(Op) && "constant must have a constant operand");
- if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
- expandConstantExpr(CExprOp, Builder, I, i, Expands);
- }
- }
- // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
- // `ConstantExpr`s that are used in the `Inst`.
- // Note that `replaceAllUsesWith` is insufficient for this purpose because it
- // does not rewrite values in `ConstantExpr`s.
- static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
- PollyIRBuilder &Builder) {
- // This contains a set of instructions in which OldVal must be replaced.
- // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
- // from `Inst`s arguments.
- // We need to go through this process because `replaceAllUsesWith` does not
- // actually edit `ConstantExpr`s.
- SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
- // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
- for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
- Value *Operand = Inst->getOperand(i);
- if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
- expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
- }
- // Now visit each instruction and use `replaceUsesOfWith`. We know that
- // will work because `I` cannot have any `ConstantExpr` within it.
- for (Instruction *I : InstsToVisit)
- I->replaceUsesOfWith(OldVal, NewVal);
- }
- // Given a value `Current`, return all Instructions that may contain `Current`
- // in an expression.
- // We need this auxiliary function, because if we have a
- // `Constant` that is a user of `V`, we need to recurse into the
- // `Constant`s uses to gather the root instruciton.
- static void getInstructionUsersOfValue(Value *V,
- SmallVector<Instruction *, 4> &Owners) {
- if (auto *I = dyn_cast<Instruction>(V)) {
- Owners.push_back(I);
- } else {
- // Anything that is a `User` must be a constant or an instruction.
- auto *C = cast<Constant>(V);
- for (Use &CUse : C->uses())
- getInstructionUsersOfValue(CUse.getUser(), Owners);
- }
- }
- static void
- replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
- SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
- // We only want arrays.
- ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType());
- if (!ArrayTy)
- return;
- Type *ElemTy = ArrayTy->getElementType();
- PointerType *ElemPtrTy = ElemTy->getPointerTo();
- // We only wish to replace arrays that are visible in the module they
- // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
- // modules.
- const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
- Array.hasInternalLinkage() ||
- IgnoreLinkageForGlobals;
- if (!OnlyVisibleInsideModule) {
- LLVM_DEBUG(
- dbgs() << "Not rewriting (" << Array
- << ") to managed memory "
- "because it could be visible externally. To force rewrite, "
- "use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
- return;
- }
- if (!Array.hasInitializer() ||
- !isa<ConstantAggregateZero>(Array.getInitializer())) {
- LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
- << ") to managed memory "
- "because it has an initializer which is "
- "not a zeroinitializer.\n");
- return;
- }
- // At this point, we have committed to replacing this array.
- ReplacedGlobals.insert(&Array);
- std::string NewName = Array.getName().str();
- NewName += ".toptr";
- GlobalVariable *ReplacementToArr =
- cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
- ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
- Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
- std::string FnName = Array.getName().str();
- FnName += ".constructor";
- PollyIRBuilder Builder(M.getContext());
- FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
- const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
- Function *F = Function::Create(Ty, Linkage, FnName, &M);
- BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
- Builder.SetInsertPoint(Start);
- const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
- Value *ArraySize = Builder.getInt64(ArraySizeInt);
- ArraySize->setName("array.size");
- Value *AllocatedMemRaw =
- Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
- Value *AllocatedMemTyped =
- Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
- Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
- Builder.CreateRetVoid();
- const int Priority = 0;
- appendToGlobalCtors(M, F, Priority, ReplacementToArr);
- SmallVector<Instruction *, 4> ArrayUserInstructions;
- // Get all instructions that use array. We need to do this weird thing
- // because `Constant`s that contain this array neeed to be expanded into
- // instructions so that we can replace their parameters. `Constant`s cannot
- // be edited easily, so we choose to convert all `Constant`s to
- // `Instruction`s and handle all of the uses of `Array` uniformly.
- for (Use &ArrayUse : Array.uses())
- getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
- for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
- Builder.SetInsertPoint(UserOfArrayInst);
- // <ty>** -> <ty>*
- Value *ArrPtrLoaded =
- Builder.CreateLoad(ElemPtrTy, ReplacementToArr, "arrptr.load");
- // <ty>* -> [ty]*
- Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
- ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
- rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
- }
- }
- // We return all `allocas` that may need to be converted to a call to
- // cudaMallocManaged.
- static void getAllocasToBeManaged(Function &F,
- SmallSet<AllocaInst *, 4> &Allocas) {
- for (BasicBlock &BB : F) {
- for (Instruction &I : BB) {
- auto *Alloca = dyn_cast<AllocaInst>(&I);
- if (!Alloca)
- continue;
- LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
- if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
- /* StoreCaptures */ true)) {
- Allocas.insert(Alloca);
- LLVM_DEBUG(dbgs() << "YES (captured).\n");
- } else {
- LLVM_DEBUG(dbgs() << "NO (not captured).\n");
- }
- }
- }
- }
- static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
- const DataLayout &DL) {
- LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
- Module *M = Alloca->getModule();
- assert(M && "Alloca does not have a module");
- PollyIRBuilder Builder(M->getContext());
- Builder.SetInsertPoint(Alloca);
- Function *MallocManagedFn =
- getOrCreatePollyMallocManaged(*Alloca->getModule());
- const uint64_t Size =
- DL.getTypeAllocSize(Alloca->getType()->getElementType());
- Value *SizeVal = Builder.getInt64(Size);
- Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
- Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
- Function *F = Alloca->getFunction();
- assert(F && "Alloca has invalid function");
- Bitcasted->takeName(Alloca);
- Alloca->replaceAllUsesWith(Bitcasted);
- Alloca->eraseFromParent();
- for (BasicBlock &BB : *F) {
- ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
- if (!Return)
- continue;
- Builder.SetInsertPoint(Return);
- Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
- Builder.CreateCall(FreeManagedFn, {RawManagedMem});
- }
- }
- // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
- //
- // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
- // actually does replace it in `ConstantExpr`. The caveat is that if there is
- // a use that is *outside* a function (say, at global declarations), we fail.
- // So, this is meant to be used on values which we know will only be used
- // within functions.
- //
- // This process works by looking through the uses of `Old`. If it finds a
- // `ConstantExpr`, it recursively looks for the owning instruction.
- // Then, it expands all the `ConstantExpr` to instructions and replaces
- // `Old` with `New` in the expanded instructions.
- static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
- PollyIRBuilder &Builder) {
- SmallVector<Instruction *, 4> UserInstructions;
- // Get all instructions that use array. We need to do this weird thing
- // because `Constant`s that contain this array neeed to be expanded into
- // instructions so that we can replace their parameters. `Constant`s cannot
- // be edited easily, so we choose to convert all `Constant`s to
- // `Instruction`s and handle all of the uses of `Array` uniformly.
- for (Use &ArrayUse : Old->uses())
- getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
- for (Instruction *I : UserInstructions)
- rewriteOldValToNew(I, Old, New, Builder);
- }
- class ManagedMemoryRewritePass : public ModulePass {
- public:
- static char ID;
- GPUArch Architecture;
- GPURuntime Runtime;
- ManagedMemoryRewritePass() : ModulePass(ID) {}
- bool runOnModule(Module &M) override {
- const DataLayout &DL = M.getDataLayout();
- Function *Malloc = M.getFunction("malloc");
- if (Malloc) {
- PollyIRBuilder Builder(M.getContext());
- Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
- assert(PollyMallocManaged && "unable to create polly_mallocManaged");
- replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
- Malloc->eraseFromParent();
- }
- Function *Free = M.getFunction("free");
- if (Free) {
- PollyIRBuilder Builder(M.getContext());
- Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
- assert(PollyFreeManaged && "unable to create polly_freeManaged");
- replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
- Free->eraseFromParent();
- }
- SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
- for (GlobalVariable &Global : M.globals())
- replaceGlobalArray(M, DL, Global, GlobalsToErase);
- for (GlobalVariable *G : GlobalsToErase)
- G->eraseFromParent();
- // Rewrite allocas to cudaMallocs if we are asked to do so.
- if (RewriteAllocas) {
- SmallSet<AllocaInst *, 4> AllocasToBeManaged;
- for (Function &F : M.functions())
- getAllocasToBeManaged(F, AllocasToBeManaged);
- for (AllocaInst *Alloca : AllocasToBeManaged)
- rewriteAllocaAsManagedMemory(Alloca, DL);
- }
- return true;
- }
- };
- } // namespace
- char ManagedMemoryRewritePass::ID = 42;
- Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
- GPURuntime Runtime) {
- ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
- pass->Runtime = Runtime;
- pass->Architecture = Arch;
- return pass;
- }
- INITIALIZE_PASS_BEGIN(
- ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
- "Polly - Rewrite all allocations in heap & data section to managed memory",
- false, false)
- INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
- INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
- INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
- INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
- INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
- INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
- INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
- INITIALIZE_PASS_END(
- ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
- "Polly - Rewrite all allocations in heap & data section to managed memory",
- false, false)
|