SimpleRemoteEPCServer.cpp 10 KB


  1. //===------- SimpleEPCServer.cpp - EPC over simple abstract channel -------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "llvm/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.h"
  9. #include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h"
  10. #include "llvm/Support/FormatVariadic.h"
  11. #include "llvm/Support/Host.h"
  12. #include "llvm/Support/Process.h"
  13. #include "OrcRTBootstrap.h"
  14. #define DEBUG_TYPE "orc"
  15. using namespace llvm::orc::shared;
  16. namespace llvm {
  17. namespace orc {
  18. ExecutorBootstrapService::~ExecutorBootstrapService() {}
  19. SimpleRemoteEPCServer::Dispatcher::~Dispatcher() {}
  20. #if LLVM_ENABLE_THREADS
  21. void SimpleRemoteEPCServer::ThreadDispatcher::dispatch(
  22. unique_function<void()> Work) {
  23. {
  24. std::lock_guard<std::mutex> Lock(DispatchMutex);
  25. if (!Running)
  26. return;
  27. ++Outstanding;
  28. }
  29. std::thread([this, Work = std::move(Work)]() mutable {
  30. Work();
  31. std::lock_guard<std::mutex> Lock(DispatchMutex);
  32. --Outstanding;
  33. OutstandingCV.notify_all();
  34. }).detach();
  35. }
  36. void SimpleRemoteEPCServer::ThreadDispatcher::shutdown() {
  37. std::unique_lock<std::mutex> Lock(DispatchMutex);
  38. Running = false;
  39. OutstandingCV.wait(Lock, [this]() { return Outstanding == 0; });
  40. }
  41. #endif
  42. StringMap<ExecutorAddr> SimpleRemoteEPCServer::defaultBootstrapSymbols() {
  43. StringMap<ExecutorAddr> DBS;
  44. rt_bootstrap::addTo(DBS);
  45. return DBS;
  46. }
  47. Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
  48. SimpleRemoteEPCServer::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
  49. ExecutorAddr TagAddr,
  50. SimpleRemoteEPCArgBytesVector ArgBytes) {
  51. LLVM_DEBUG({
  52. dbgs() << "SimpleRemoteEPCServer::handleMessage: opc = ";
  53. switch (OpC) {
  54. case SimpleRemoteEPCOpcode::Setup:
  55. dbgs() << "Setup";
  56. assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
  57. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
  58. break;
  59. case SimpleRemoteEPCOpcode::Hangup:
  60. dbgs() << "Hangup";
  61. assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
  62. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
  63. break;
  64. case SimpleRemoteEPCOpcode::Result:
  65. dbgs() << "Result";
  66. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
  67. break;
  68. case SimpleRemoteEPCOpcode::CallWrapper:
  69. dbgs() << "CallWrapper";
  70. break;
  71. }
  72. dbgs() << ", seqno = " << SeqNo
  73. << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
  74. << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
  75. << " bytes\n";
  76. });
  77. using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
  78. if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
  79. return make_error<StringError>("Unexpected opcode",
  80. inconvertibleErrorCode());
  81. // TODO: Clean detach message?
  82. switch (OpC) {
  83. case SimpleRemoteEPCOpcode::Setup:
  84. return make_error<StringError>("Unexpected Setup opcode",
  85. inconvertibleErrorCode());
  86. case SimpleRemoteEPCOpcode::Hangup:
  87. return SimpleRemoteEPCTransportClient::EndSession;
  88. case SimpleRemoteEPCOpcode::Result:
  89. if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
  90. return std::move(Err);
  91. break;
  92. case SimpleRemoteEPCOpcode::CallWrapper:
  93. handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
  94. break;
  95. }
  96. return ContinueSession;
  97. }
  98. Error SimpleRemoteEPCServer::waitForDisconnect() {
  99. std::unique_lock<std::mutex> Lock(ServerStateMutex);
  100. ShutdownCV.wait(Lock, [this]() { return RunState == ServerShutDown; });
  101. return std::move(ShutdownErr);
  102. }
  103. void SimpleRemoteEPCServer::handleDisconnect(Error Err) {
  104. PendingJITDispatchResultsMap TmpPending;
  105. {
  106. std::lock_guard<std::mutex> Lock(ServerStateMutex);
  107. std::swap(TmpPending, PendingJITDispatchResults);
  108. RunState = ServerShuttingDown;
  109. }
  110. // Send out-of-band errors to any waiting threads.
  111. for (auto &KV : TmpPending)
  112. KV.second->set_value(
  113. shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
  114. // Wait for dispatcher to clear.
  115. D->shutdown();
  116. // Shut down services.
  117. while (!Services.empty()) {
  118. ShutdownErr =
  119. joinErrors(std::move(ShutdownErr), Services.back()->shutdown());
  120. Services.pop_back();
  121. }
  122. std::lock_guard<std::mutex> Lock(ServerStateMutex);
  123. ShutdownErr = joinErrors(std::move(ShutdownErr), std::move(Err));
  124. RunState = ServerShutDown;
  125. ShutdownCV.notify_all();
  126. }
  127. Error SimpleRemoteEPCServer::sendMessage(SimpleRemoteEPCOpcode OpC,
  128. uint64_t SeqNo, ExecutorAddr TagAddr,
  129. ArrayRef<char> ArgBytes) {
  130. LLVM_DEBUG({
  131. dbgs() << "SimpleRemoteEPCServer::sendMessage: opc = ";
  132. switch (OpC) {
  133. case SimpleRemoteEPCOpcode::Setup:
  134. dbgs() << "Setup";
  135. assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
  136. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
  137. break;
  138. case SimpleRemoteEPCOpcode::Hangup:
  139. dbgs() << "Hangup";
  140. assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
  141. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
  142. break;
  143. case SimpleRemoteEPCOpcode::Result:
  144. dbgs() << "Result";
  145. assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
  146. break;
  147. case SimpleRemoteEPCOpcode::CallWrapper:
  148. dbgs() << "CallWrapper";
  149. break;
  150. }
  151. dbgs() << ", seqno = " << SeqNo
  152. << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
  153. << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
  154. << " bytes\n";
  155. });
  156. auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
  157. LLVM_DEBUG({
  158. if (Err)
  159. dbgs() << " \\--> SimpleRemoteEPC::sendMessage failed\n";
  160. });
  161. return Err;
  162. }
  163. Error SimpleRemoteEPCServer::sendSetupMessage(
  164. StringMap<ExecutorAddr> BootstrapSymbols) {
  165. using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
  166. std::vector<char> SetupPacket;
  167. SimpleRemoteEPCExecutorInfo EI;
  168. EI.TargetTriple = sys::getProcessTriple();
  169. if (auto PageSize = sys::Process::getPageSize())
  170. EI.PageSize = *PageSize;
  171. else
  172. return PageSize.takeError();
  173. EI.BootstrapSymbols = std::move(BootstrapSymbols);
  174. assert(!EI.BootstrapSymbols.count(ExecutorSessionObjectName) &&
  175. "Dispatch context name should not be set");
  176. assert(!EI.BootstrapSymbols.count(DispatchFnName) &&
  177. "Dispatch function name should not be set");
  178. EI.BootstrapSymbols[ExecutorSessionObjectName] = ExecutorAddr::fromPtr(this);
  179. EI.BootstrapSymbols[DispatchFnName] = ExecutorAddr::fromPtr(jitDispatchEntry);
  180. using SPSSerialize =
  181. shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
  182. auto SetupPacketBytes =
  183. shared::WrapperFunctionResult::allocate(SPSSerialize::size(EI));
  184. shared::SPSOutputBuffer OB(SetupPacketBytes.data(), SetupPacketBytes.size());
  185. if (!SPSSerialize::serialize(OB, EI))
  186. return make_error<StringError>("Could not send setup packet",
  187. inconvertibleErrorCode());
  188. return sendMessage(SimpleRemoteEPCOpcode::Setup, 0, ExecutorAddr(),
  189. {SetupPacketBytes.data(), SetupPacketBytes.size()});
  190. }
  191. Error SimpleRemoteEPCServer::handleResult(
  192. uint64_t SeqNo, ExecutorAddr TagAddr,
  193. SimpleRemoteEPCArgBytesVector ArgBytes) {
  194. std::promise<shared::WrapperFunctionResult> *P = nullptr;
  195. {
  196. std::lock_guard<std::mutex> Lock(ServerStateMutex);
  197. auto I = PendingJITDispatchResults.find(SeqNo);
  198. if (I == PendingJITDispatchResults.end())
  199. return make_error<StringError>("No call for sequence number " +
  200. Twine(SeqNo),
  201. inconvertibleErrorCode());
  202. P = I->second;
  203. PendingJITDispatchResults.erase(I);
  204. releaseSeqNo(SeqNo);
  205. }
  206. auto R = shared::WrapperFunctionResult::allocate(ArgBytes.size());
  207. memcpy(R.data(), ArgBytes.data(), ArgBytes.size());
  208. P->set_value(std::move(R));
  209. return Error::success();
  210. }
  211. void SimpleRemoteEPCServer::handleCallWrapper(
  212. uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
  213. SimpleRemoteEPCArgBytesVector ArgBytes) {
  214. D->dispatch([this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
  215. using WrapperFnTy =
  216. shared::CWrapperFunctionResult (*)(const char *, size_t);
  217. auto *Fn = TagAddr.toPtr<WrapperFnTy>();
  218. shared::WrapperFunctionResult ResultBytes(
  219. Fn(ArgBytes.data(), ArgBytes.size()));
  220. if (auto Err = sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
  221. ExecutorAddr(),
  222. {ResultBytes.data(), ResultBytes.size()}))
  223. ReportError(std::move(Err));
  224. });
  225. }
  226. shared::WrapperFunctionResult
  227. SimpleRemoteEPCServer::doJITDispatch(const void *FnTag, const char *ArgData,
  228. size_t ArgSize) {
  229. uint64_t SeqNo;
  230. std::promise<shared::WrapperFunctionResult> ResultP;
  231. auto ResultF = ResultP.get_future();
  232. {
  233. std::lock_guard<std::mutex> Lock(ServerStateMutex);
  234. if (RunState != ServerRunning)
  235. return shared::WrapperFunctionResult::createOutOfBandError(
  236. "jit_dispatch not available (EPC server shut down)");
  237. SeqNo = getNextSeqNo();
  238. assert(!PendingJITDispatchResults.count(SeqNo) && "SeqNo already in use");
  239. PendingJITDispatchResults[SeqNo] = &ResultP;
  240. }
  241. if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
  242. ExecutorAddr::fromPtr(FnTag), {ArgData, ArgSize}))
  243. ReportError(std::move(Err));
  244. return ResultF.get();
  245. }
  246. shared::CWrapperFunctionResult
  247. SimpleRemoteEPCServer::jitDispatchEntry(void *DispatchCtx, const void *FnTag,
  248. const char *ArgData, size_t ArgSize) {
  249. return reinterpret_cast<SimpleRemoteEPCServer *>(DispatchCtx)
  250. ->doJITDispatch(FnTag, ArgData, ArgSize)
  251. .release();
  252. }
  253. } // end namespace orc
  254. } // end namespace llvm