WebAssemblyFixFunctionBitcasts.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. ///
  9. /// \file
  10. /// Fix bitcasted functions.
  11. ///
  12. /// WebAssembly requires caller and callee signatures to match, however in LLVM,
  13. /// some amount of slop is vaguely permitted. Detect mismatch by looking for
  14. /// bitcasts of functions and rewrite them to use wrapper functions instead.
  15. ///
  16. /// This doesn't catch all cases, such as when a function's address is taken in
  17. /// one place and casted in another, but it works for many common cases.
  18. ///
  19. /// Note that LLVM already optimizes away function bitcasts in common cases by
  20. /// dropping arguments as needed, so this pass only ends up getting used in less
  21. /// common cases.
  22. ///
  23. //===----------------------------------------------------------------------===//
  24. #include "WebAssembly.h"
  25. #include "llvm/IR/Constants.h"
  26. #include "llvm/IR/Instructions.h"
  27. #include "llvm/IR/Module.h"
  28. #include "llvm/IR/Operator.h"
  29. #include "llvm/Pass.h"
  30. #include "llvm/Support/Debug.h"
  31. #include "llvm/Support/raw_ostream.h"
  32. using namespace llvm;
  33. #define DEBUG_TYPE "wasm-fix-function-bitcasts"
  34. namespace {
  35. class FixFunctionBitcasts final : public ModulePass {
  36. StringRef getPassName() const override {
  37. return "WebAssembly Fix Function Bitcasts";
  38. }
  39. void getAnalysisUsage(AnalysisUsage &AU) const override {
  40. AU.setPreservesCFG();
  41. ModulePass::getAnalysisUsage(AU);
  42. }
  43. bool runOnModule(Module &M) override;
  44. public:
  45. static char ID;
  46. FixFunctionBitcasts() : ModulePass(ID) {}
  47. };
  48. } // End anonymous namespace
  49. char FixFunctionBitcasts::ID = 0;
  50. INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
  51. "Fix mismatching bitcasts for WebAssembly", false, false)
  52. ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
  53. return new FixFunctionBitcasts();
  54. }
  55. // Recursively descend the def-use lists from V to find non-bitcast users of
  56. // bitcasts of V.
  57. static void findUses(Value *V, Function &F,
  58. SmallVectorImpl<std::pair<CallBase *, Function *>> &Uses) {
  59. for (User *U : V->users()) {
  60. if (auto *BC = dyn_cast<BitCastOperator>(U))
  61. findUses(BC, F, Uses);
  62. else if (auto *A = dyn_cast<GlobalAlias>(U))
  63. findUses(A, F, Uses);
  64. else if (auto *CB = dyn_cast<CallBase>(U)) {
  65. Value *Callee = CB->getCalledOperand();
  66. if (Callee != V)
  67. // Skip calls where the function isn't the callee
  68. continue;
  69. if (CB->getFunctionType() == F.getValueType())
  70. // Skip uses that are immediately called
  71. continue;
  72. Uses.push_back(std::make_pair(CB, &F));
  73. }
  74. }
  75. }
  76. // Create a wrapper function with type Ty that calls F (which may have a
  77. // different type). Attempt to support common bitcasted function idioms:
  78. // - Call with more arguments than needed: arguments are dropped
  79. // - Call with fewer arguments than needed: arguments are filled in with undef
  80. // - Return value is not needed: drop it
  81. // - Return value needed but not present: supply an undef
  82. //
  83. // If the all the argument types of trivially castable to one another (i.e.
  84. // I32 vs pointer type) then we don't create a wrapper at all (return nullptr
  85. // instead).
  86. //
  87. // If there is a type mismatch that we know would result in an invalid wasm
  88. // module then generate wrapper that contains unreachable (i.e. abort at
  89. // runtime). Such programs are deep into undefined behaviour territory,
  90. // but we choose to fail at runtime rather than generate and invalid module
  91. // or fail at compiler time. The reason we delay the error is that we want
  92. // to support the CMake which expects to be able to compile and link programs
  93. // that refer to functions with entirely incorrect signatures (this is how
  94. // CMake detects the existence of a function in a toolchain).
  95. //
  96. // For bitcasts that involve struct types we don't know at this stage if they
  97. // would be equivalent at the wasm level and so we can't know if we need to
  98. // generate a wrapper.
  99. static Function *createWrapper(Function *F, FunctionType *Ty) {
  100. Module *M = F->getParent();
  101. Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage,
  102. F->getName() + "_bitcast", M);
  103. BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
  104. const DataLayout &DL = BB->getModule()->getDataLayout();
  105. // Determine what arguments to pass.
  106. SmallVector<Value *, 4> Args;
  107. Function::arg_iterator AI = Wrapper->arg_begin();
  108. Function::arg_iterator AE = Wrapper->arg_end();
  109. FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
  110. FunctionType::param_iterator PE = F->getFunctionType()->param_end();
  111. bool TypeMismatch = false;
  112. bool WrapperNeeded = false;
  113. Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
  114. Type *RtnType = Ty->getReturnType();
  115. if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
  116. (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
  117. (ExpectedRtnType != RtnType))
  118. WrapperNeeded = true;
  119. for (; AI != AE && PI != PE; ++AI, ++PI) {
  120. Type *ArgType = AI->getType();
  121. Type *ParamType = *PI;
  122. if (ArgType == ParamType) {
  123. Args.push_back(&*AI);
  124. } else {
  125. if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
  126. Instruction *PtrCast =
  127. CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
  128. PtrCast->insertInto(BB, BB->end());
  129. Args.push_back(PtrCast);
  130. } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
  131. LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
  132. << F->getName() << "\n");
  133. WrapperNeeded = false;
  134. } else {
  135. LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
  136. << F->getName() << "\n");
  137. LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
  138. << *ParamType << " Got: " << *ArgType << "\n");
  139. TypeMismatch = true;
  140. break;
  141. }
  142. }
  143. }
  144. if (WrapperNeeded && !TypeMismatch) {
  145. for (; PI != PE; ++PI)
  146. Args.push_back(UndefValue::get(*PI));
  147. if (F->isVarArg())
  148. for (; AI != AE; ++AI)
  149. Args.push_back(&*AI);
  150. CallInst *Call = CallInst::Create(F, Args, "", BB);
  151. Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
  152. Type *RtnType = Ty->getReturnType();
  153. // Determine what value to return.
  154. if (RtnType->isVoidTy()) {
  155. ReturnInst::Create(M->getContext(), BB);
  156. } else if (ExpectedRtnType->isVoidTy()) {
  157. LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
  158. ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
  159. } else if (RtnType == ExpectedRtnType) {
  160. ReturnInst::Create(M->getContext(), Call, BB);
  161. } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
  162. DL)) {
  163. Instruction *Cast =
  164. CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
  165. Cast->insertInto(BB, BB->end());
  166. ReturnInst::Create(M->getContext(), Cast, BB);
  167. } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
  168. LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
  169. << F->getName() << "\n");
  170. WrapperNeeded = false;
  171. } else {
  172. LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
  173. << F->getName() << "\n");
  174. LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
  175. << " Got: " << *RtnType << "\n");
  176. TypeMismatch = true;
  177. }
  178. }
  179. if (TypeMismatch) {
  180. // Create a new wrapper that simply contains `unreachable`.
  181. Wrapper->eraseFromParent();
  182. Wrapper = Function::Create(Ty, Function::PrivateLinkage,
  183. F->getName() + "_bitcast_invalid", M);
  184. BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
  185. new UnreachableInst(M->getContext(), BB);
  186. Wrapper->setName(F->getName() + "_bitcast_invalid");
  187. } else if (!WrapperNeeded) {
  188. LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
  189. << "\n");
  190. Wrapper->eraseFromParent();
  191. return nullptr;
  192. }
  193. LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
  194. return Wrapper;
  195. }
  196. // Test whether a main function with type FuncTy should be rewritten to have
  197. // type MainTy.
  198. static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
  199. // Only fix the main function if it's the standard zero-arg form. That way,
  200. // the standard cases will work as expected, and users will see signature
  201. // mismatches from the linker for non-standard cases.
  202. return FuncTy->getReturnType() == MainTy->getReturnType() &&
  203. FuncTy->getNumParams() == 0 &&
  204. !FuncTy->isVarArg();
  205. }
  206. bool FixFunctionBitcasts::runOnModule(Module &M) {
  207. LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
  208. Function *Main = nullptr;
  209. CallInst *CallMain = nullptr;
  210. SmallVector<std::pair<CallBase *, Function *>, 0> Uses;
  211. // Collect all the places that need wrappers.
  212. for (Function &F : M) {
  213. // Skip to fix when the function is swiftcc because swiftcc allows
  214. // bitcast type difference for swiftself and swifterror.
  215. if (F.getCallingConv() == CallingConv::Swift)
  216. continue;
  217. findUses(&F, F, Uses);
  218. // If we have a "main" function, and its type isn't
  219. // "int main(int argc, char *argv[])", create an artificial call with it
  220. // bitcasted to that type so that we generate a wrapper for it, so that
  221. // the C runtime can call it.
  222. if (F.getName() == "main") {
  223. Main = &F;
  224. LLVMContext &C = M.getContext();
  225. Type *MainArgTys[] = {Type::getInt32Ty(C),
  226. PointerType::get(Type::getInt8PtrTy(C), 0)};
  227. FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
  228. /*isVarArg=*/false);
  229. if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
  230. LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
  231. << *F.getFunctionType() << "\n");
  232. Value *Args[] = {UndefValue::get(MainArgTys[0]),
  233. UndefValue::get(MainArgTys[1])};
  234. Value *Casted =
  235. ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
  236. CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
  237. Uses.push_back(std::make_pair(CallMain, &F));
  238. }
  239. }
  240. }
  241. DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
  242. for (auto &UseFunc : Uses) {
  243. CallBase *CB = UseFunc.first;
  244. Function *F = UseFunc.second;
  245. FunctionType *Ty = CB->getFunctionType();
  246. auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
  247. if (Pair.second)
  248. Pair.first->second = createWrapper(F, Ty);
  249. Function *Wrapper = Pair.first->second;
  250. if (!Wrapper)
  251. continue;
  252. CB->setCalledOperand(Wrapper);
  253. }
  254. // If we created a wrapper for main, rename the wrapper so that it's the
  255. // one that gets called from startup.
  256. if (CallMain) {
  257. Main->setName("__original_main");
  258. auto *MainWrapper =
  259. cast<Function>(CallMain->getCalledOperand()->stripPointerCasts());
  260. delete CallMain;
  261. if (Main->isDeclaration()) {
  262. // The wrapper is not needed in this case as we don't need to export
  263. // it to anyone else.
  264. MainWrapper->eraseFromParent();
  265. } else {
  266. // Otherwise give the wrapper the same linkage as the original main
  267. // function, so that it can be called from the same places.
  268. MainWrapper->setName("main");
  269. MainWrapper->setLinkage(Main->getLinkage());
  270. MainWrapper->setVisibility(Main->getVisibility());
  271. }
  272. }
  273. return true;
  274. }