SimpleRemoteEPC.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
  9. #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
  10. #include "llvm/ExecutionEngine/Orc/EPCGenericMemoryAccess.h"
  11. #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
  12. #include "llvm/Support/FormatVariadic.h"
  13. #define DEBUG_TYPE "orc"
  14. namespace llvm {
  15. namespace orc {
  16. SimpleRemoteEPC::~SimpleRemoteEPC() {
  17. #ifndef NDEBUG
  18. std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
  19. assert(Disconnected && "Destroyed without disconnection");
  20. #endif // NDEBUG
  21. }
  22. Expected<tpctypes::DylibHandle>
  23. SimpleRemoteEPC::loadDylib(const char *DylibPath) {
  24. return DylibMgr->open(DylibPath, 0);
  25. }
  26. Expected<std::vector<tpctypes::LookupResult>>
  27. SimpleRemoteEPC::lookupSymbols(ArrayRef<LookupRequest> Request) {
  28. std::vector<tpctypes::LookupResult> Result;
  29. for (auto &Element : Request) {
  30. if (auto R = DylibMgr->lookup(Element.Handle, Element.Symbols)) {
  31. Result.push_back({});
  32. Result.back().reserve(R->size());
  33. for (auto Addr : *R)
  34. Result.back().push_back(Addr);
  35. } else
  36. return R.takeError();
  37. }
  38. return std::move(Result);
  39. }
  40. Expected<int32_t> SimpleRemoteEPC::runAsMain(ExecutorAddr MainFnAddr,
  41. ArrayRef<std::string> Args) {
  42. int64_t Result = 0;
  43. if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(
  44. RunAsMainAddr, Result, ExecutorAddr(MainFnAddr), Args))
  45. return std::move(Err);
  46. return Result;
  47. }
  48. Expected<int32_t> SimpleRemoteEPC::runAsVoidFunction(ExecutorAddr VoidFnAddr) {
  49. int32_t Result = 0;
  50. if (auto Err = callSPSWrapper<rt::SPSRunAsVoidFunctionSignature>(
  51. RunAsVoidFunctionAddr, Result, ExecutorAddr(VoidFnAddr)))
  52. return std::move(Err);
  53. return Result;
  54. }
  55. Expected<int32_t> SimpleRemoteEPC::runAsIntFunction(ExecutorAddr IntFnAddr,
  56. int Arg) {
  57. int32_t Result = 0;
  58. if (auto Err = callSPSWrapper<rt::SPSRunAsIntFunctionSignature>(
  59. RunAsIntFunctionAddr, Result, ExecutorAddr(IntFnAddr), Arg))
  60. return std::move(Err);
  61. return Result;
  62. }
  63. void SimpleRemoteEPC::callWrapperAsync(ExecutorAddr WrapperFnAddr,
  64. IncomingWFRHandler OnComplete,
  65. ArrayRef<char> ArgBuffer) {
  66. uint64_t SeqNo;
  67. {
  68. std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
  69. SeqNo = getNextSeqNo();
  70. assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
  71. PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
  72. }
  73. if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
  74. WrapperFnAddr, ArgBuffer)) {
  75. IncomingWFRHandler H;
  76. // We just registered OnComplete, but there may be a race between this
  77. // thread returning from sendMessage and handleDisconnect being called from
  78. // the transport's listener thread. If handleDisconnect gets there first
  79. // then it will have failed 'H' for us. If we get there first (or if
  80. // handleDisconnect already ran) then we need to take care of it.
  81. {
  82. std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
  83. auto I = PendingCallWrapperResults.find(SeqNo);
  84. if (I != PendingCallWrapperResults.end()) {
  85. H = std::move(I->second);
  86. PendingCallWrapperResults.erase(I);
  87. }
  88. }
  89. if (H)
  90. H(shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
  91. getExecutionSession().reportError(std::move(Err));
  92. }
  93. }
  94. Error SimpleRemoteEPC::disconnect() {
  95. T->disconnect();
  96. D->shutdown();
  97. std::unique_lock<std::mutex> Lock(SimpleRemoteEPCMutex);
  98. DisconnectCV.wait(Lock, [this] { return Disconnected; });
  99. return std::move(DisconnectErr);
  100. }
  101. Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
  102. SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
  103. ExecutorAddr TagAddr,
  104. SimpleRemoteEPCArgBytesVector ArgBytes) {
  105. LLVM_DEBUG({
  106. dbgs() << "SimpleRemoteEPC::handleMessage: opc = ";
  107. switch (OpC) {
  108. case SimpleRemoteEPCOpcode::Setup:
  109. dbgs() << "Setup";
  110. assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
  111. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
  112. break;
  113. case SimpleRemoteEPCOpcode::Hangup:
  114. dbgs() << "Hangup";
  115. assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
  116. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
  117. break;
  118. case SimpleRemoteEPCOpcode::Result:
  119. dbgs() << "Result";
  120. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
  121. break;
  122. case SimpleRemoteEPCOpcode::CallWrapper:
  123. dbgs() << "CallWrapper";
  124. break;
  125. }
  126. dbgs() << ", seqno = " << SeqNo
  127. << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
  128. << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
  129. << " bytes\n";
  130. });
  131. using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
  132. if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
  133. return make_error<StringError>("Unexpected opcode",
  134. inconvertibleErrorCode());
  135. switch (OpC) {
  136. case SimpleRemoteEPCOpcode::Setup:
  137. if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
  138. return std::move(Err);
  139. break;
  140. case SimpleRemoteEPCOpcode::Hangup:
  141. T->disconnect();
  142. if (auto Err = handleHangup(std::move(ArgBytes)))
  143. return std::move(Err);
  144. return EndSession;
  145. case SimpleRemoteEPCOpcode::Result:
  146. if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
  147. return std::move(Err);
  148. break;
  149. case SimpleRemoteEPCOpcode::CallWrapper:
  150. handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
  151. break;
  152. }
  153. return ContinueSession;
  154. }
  155. void SimpleRemoteEPC::handleDisconnect(Error Err) {
  156. LLVM_DEBUG({
  157. dbgs() << "SimpleRemoteEPC::handleDisconnect: "
  158. << (Err ? "failure" : "success") << "\n";
  159. });
  160. PendingCallWrapperResultsMap TmpPending;
  161. {
  162. std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
  163. std::swap(TmpPending, PendingCallWrapperResults);
  164. }
  165. for (auto &KV : TmpPending)
  166. KV.second(
  167. shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
  168. std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
  169. DisconnectErr = joinErrors(std::move(DisconnectErr), std::move(Err));
  170. Disconnected = true;
  171. DisconnectCV.notify_all();
  172. }
  173. Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
  174. SimpleRemoteEPC::createDefaultMemoryManager(SimpleRemoteEPC &SREPC) {
  175. EPCGenericJITLinkMemoryManager::SymbolAddrs SAs;
  176. if (auto Err = SREPC.getBootstrapSymbols(
  177. {{SAs.Allocator, rt::SimpleExecutorMemoryManagerInstanceName},
  178. {SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName},
  179. {SAs.Finalize, rt::SimpleExecutorMemoryManagerFinalizeWrapperName},
  180. {SAs.Deallocate,
  181. rt::SimpleExecutorMemoryManagerDeallocateWrapperName}}))
  182. return std::move(Err);
  183. return std::make_unique<EPCGenericJITLinkMemoryManager>(SREPC, SAs);
  184. }
  185. Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>
  186. SimpleRemoteEPC::createDefaultMemoryAccess(SimpleRemoteEPC &SREPC) {
  187. return nullptr;
  188. }
  189. Error SimpleRemoteEPC::sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
  190. ExecutorAddr TagAddr,
  191. ArrayRef<char> ArgBytes) {
  192. assert(OpC != SimpleRemoteEPCOpcode::Setup &&
  193. "SimpleRemoteEPC sending Setup message? That's the wrong direction.");
  194. LLVM_DEBUG({
  195. dbgs() << "SimpleRemoteEPC::sendMessage: opc = ";
  196. switch (OpC) {
  197. case SimpleRemoteEPCOpcode::Hangup:
  198. dbgs() << "Hangup";
  199. assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
  200. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
  201. break;
  202. case SimpleRemoteEPCOpcode::Result:
  203. dbgs() << "Result";
  204. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
  205. break;
  206. case SimpleRemoteEPCOpcode::CallWrapper:
  207. dbgs() << "CallWrapper";
  208. break;
  209. default:
  210. llvm_unreachable("Invalid opcode");
  211. }
  212. dbgs() << ", seqno = " << SeqNo
  213. << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
  214. << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
  215. << " bytes\n";
  216. });
  217. auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
  218. LLVM_DEBUG({
  219. if (Err)
  220. dbgs() << " \\--> SimpleRemoteEPC::sendMessage failed\n";
  221. });
  222. return Err;
  223. }
  224. Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
  225. SimpleRemoteEPCArgBytesVector ArgBytes) {
  226. if (SeqNo != 0)
  227. return make_error<StringError>("Setup packet SeqNo not zero",
  228. inconvertibleErrorCode());
  229. if (TagAddr)
  230. return make_error<StringError>("Setup packet TagAddr not zero",
  231. inconvertibleErrorCode());
  232. std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
  233. auto I = PendingCallWrapperResults.find(0);
  234. assert(PendingCallWrapperResults.size() == 1 &&
  235. I != PendingCallWrapperResults.end() &&
  236. "Setup message handler not connectly set up");
  237. auto SetupMsgHandler = std::move(I->second);
  238. PendingCallWrapperResults.erase(I);
  239. auto WFR =
  240. shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
  241. SetupMsgHandler(std::move(WFR));
  242. return Error::success();
  243. }
  244. Error SimpleRemoteEPC::setup(Setup S) {
  245. using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
  246. std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
  247. auto EIF = EIP.get_future();
  248. // Prepare a handler for the setup packet.
  249. PendingCallWrapperResults[0] =
  250. RunInPlace()(
  251. [&](shared::WrapperFunctionResult SetupMsgBytes) {
  252. if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
  253. EIP.set_value(
  254. make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
  255. return;
  256. }
  257. using SPSSerialize =
  258. shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
  259. shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
  260. SimpleRemoteEPCExecutorInfo EI;
  261. if (SPSSerialize::deserialize(IB, EI))
  262. EIP.set_value(EI);
  263. else
  264. EIP.set_value(make_error<StringError>(
  265. "Could not deserialize setup message", inconvertibleErrorCode()));
  266. });
  267. // Start the transport.
  268. if (auto Err = T->start())
  269. return Err;
  270. // Wait for setup packet to arrive.
  271. auto EI = EIF.get();
  272. if (!EI) {
  273. T->disconnect();
  274. return EI.takeError();
  275. }
  276. LLVM_DEBUG({
  277. dbgs() << "SimpleRemoteEPC received setup message:\n"
  278. << " Triple: " << EI->TargetTriple << "\n"
  279. << " Page size: " << EI->PageSize << "\n"
  280. << " Bootstrap symbols:\n";
  281. for (const auto &KV : EI->BootstrapSymbols)
  282. dbgs() << " " << KV.first() << ": "
  283. << formatv("{0:x16}", KV.second.getValue()) << "\n";
  284. });
  285. TargetTriple = Triple(EI->TargetTriple);
  286. PageSize = EI->PageSize;
  287. BootstrapSymbols = std::move(EI->BootstrapSymbols);
  288. if (auto Err = getBootstrapSymbols(
  289. {{JDI.JITDispatchContext, ExecutorSessionObjectName},
  290. {JDI.JITDispatchFunction, DispatchFnName},
  291. {RunAsMainAddr, rt::RunAsMainWrapperName},
  292. {RunAsVoidFunctionAddr, rt::RunAsVoidFunctionWrapperName},
  293. {RunAsIntFunctionAddr, rt::RunAsIntFunctionWrapperName}}))
  294. return Err;
  295. if (auto DM =
  296. EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols(*this))
  297. DylibMgr = std::make_unique<EPCGenericDylibManager>(std::move(*DM));
  298. else
  299. return DM.takeError();
  300. // Set a default CreateMemoryManager if none is specified.
  301. if (!S.CreateMemoryManager)
  302. S.CreateMemoryManager = createDefaultMemoryManager;
  303. if (auto MemMgr = S.CreateMemoryManager(*this)) {
  304. OwnedMemMgr = std::move(*MemMgr);
  305. this->MemMgr = OwnedMemMgr.get();
  306. } else
  307. return MemMgr.takeError();
  308. // Set a default CreateMemoryAccess if none is specified.
  309. if (!S.CreateMemoryAccess)
  310. S.CreateMemoryAccess = createDefaultMemoryAccess;
  311. if (auto MemAccess = S.CreateMemoryAccess(*this)) {
  312. OwnedMemAccess = std::move(*MemAccess);
  313. this->MemAccess = OwnedMemAccess.get();
  314. } else
  315. return MemAccess.takeError();
  316. return Error::success();
  317. }
  318. Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
  319. SimpleRemoteEPCArgBytesVector ArgBytes) {
  320. IncomingWFRHandler SendResult;
  321. if (TagAddr)
  322. return make_error<StringError>("Unexpected TagAddr in result message",
  323. inconvertibleErrorCode());
  324. {
  325. std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
  326. auto I = PendingCallWrapperResults.find(SeqNo);
  327. if (I == PendingCallWrapperResults.end())
  328. return make_error<StringError>("No call for sequence number " +
  329. Twine(SeqNo),
  330. inconvertibleErrorCode());
  331. SendResult = std::move(I->second);
  332. PendingCallWrapperResults.erase(I);
  333. releaseSeqNo(SeqNo);
  334. }
  335. auto WFR =
  336. shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
  337. SendResult(std::move(WFR));
  338. return Error::success();
  339. }
  340. void SimpleRemoteEPC::handleCallWrapper(
  341. uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
  342. SimpleRemoteEPCArgBytesVector ArgBytes) {
  343. assert(ES && "No ExecutionSession attached");
  344. D->dispatch(makeGenericNamedTask(
  345. [this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
  346. ES->runJITDispatchHandler(
  347. [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
  348. if (auto Err =
  349. sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
  350. ExecutorAddr(), {WFR.data(), WFR.size()}))
  351. getExecutionSession().reportError(std::move(Err));
  352. },
  353. TagAddr.getValue(), ArgBytes);
  354. },
  355. "callWrapper task"));
  356. }
  357. Error SimpleRemoteEPC::handleHangup(SimpleRemoteEPCArgBytesVector ArgBytes) {
  358. using namespace llvm::orc::shared;
  359. auto WFR = WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
  360. if (const char *ErrMsg = WFR.getOutOfBandError())
  361. return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
  362. detail::SPSSerializableError Info;
  363. SPSInputBuffer IB(WFR.data(), WFR.size());
  364. if (!SPSArgList<SPSError>::deserialize(IB, Info))
  365. return make_error<StringError>("Could not deserialize hangup info",
  366. inconvertibleErrorCode());
  367. return fromSPSSerializable(std::move(Info));
  368. }
  369. } // end namespace orc
  370. } // end namespace llvm