OrcRPCTargetProcessControl.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===--- OrcRPCTargetProcessControl.h - Remote target control ---*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // Utilities for interacting with target processes.
  15. //
  16. //===----------------------------------------------------------------------===//
  17. #ifndef LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H
  18. #define LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H
  19. #include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h"
  20. #include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h"
  21. #include "llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h"
  22. #include "llvm/ExecutionEngine/Orc/TargetProcessControl.h"
  23. #include "llvm/Support/MSVCErrorWorkarounds.h"
  24. namespace llvm {
  25. namespace orc {
  26. /// JITLinkMemoryManager implementation for a process connected via an ORC RPC
  27. /// endpoint.
  28. template <typename OrcRPCTPCImplT>
  29. class OrcRPCTPCJITLinkMemoryManager : public jitlink::JITLinkMemoryManager {
  30. private:
  31. struct HostAlloc {
  32. std::unique_ptr<char[]> Mem;
  33. uint64_t Size;
  34. };
  35. struct TargetAlloc {
  36. JITTargetAddress Address = 0;
  37. uint64_t AllocatedSize = 0;
  38. };
  39. using HostAllocMap = DenseMap<int, HostAlloc>;
  40. using TargetAllocMap = DenseMap<int, TargetAlloc>;
  41. public:
  42. class OrcRPCAllocation : public Allocation {
  43. public:
  44. OrcRPCAllocation(OrcRPCTPCJITLinkMemoryManager<OrcRPCTPCImplT> &Parent,
  45. HostAllocMap HostAllocs, TargetAllocMap TargetAllocs)
  46. : Parent(Parent), HostAllocs(std::move(HostAllocs)),
  47. TargetAllocs(std::move(TargetAllocs)) {
  48. assert(HostAllocs.size() == TargetAllocs.size() &&
  49. "HostAllocs size should match TargetAllocs");
  50. }
  51. ~OrcRPCAllocation() override {
  52. assert(TargetAllocs.empty() && "failed to deallocate");
  53. }
  54. MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) override {
  55. auto I = HostAllocs.find(Seg);
  56. assert(I != HostAllocs.end() && "No host allocation for segment");
  57. auto &HA = I->second;
  58. return {HA.Mem.get(), static_cast<size_t>(HA.Size)};
  59. }
  60. JITTargetAddress getTargetMemory(ProtectionFlags Seg) override {
  61. auto I = TargetAllocs.find(Seg);
  62. assert(I != TargetAllocs.end() && "No target allocation for segment");
  63. return I->second.Address;
  64. }
  65. void finalizeAsync(FinalizeContinuation OnFinalize) override {
  66. std::vector<tpctypes::BufferWrite> BufferWrites;
  67. orcrpctpc::ReleaseOrFinalizeMemRequest FMR;
  68. for (auto &KV : HostAllocs) {
  69. assert(TargetAllocs.count(KV.first) &&
  70. "No target allocation for buffer");
  71. auto &HA = KV.second;
  72. auto &TA = TargetAllocs[KV.first];
  73. BufferWrites.push_back({TA.Address, StringRef(HA.Mem.get(), HA.Size)});
  74. FMR.push_back({orcrpctpc::toWireProtectionFlags(
  75. static_cast<sys::Memory::ProtectionFlags>(KV.first)),
  76. TA.Address, TA.AllocatedSize});
  77. }
  78. DEBUG_WITH_TYPE("orc", {
  79. dbgs() << "finalizeAsync " << (void *)this << ":\n";
  80. auto FMRI = FMR.begin();
  81. for (auto &B : BufferWrites) {
  82. auto Prot = FMRI->Prot;
  83. ++FMRI;
  84. dbgs() << " Writing " << formatv("{0:x16}", B.Buffer.size())
  85. << " bytes to " << ((Prot & orcrpctpc::WPF_Read) ? 'R' : '-')
  86. << ((Prot & orcrpctpc::WPF_Write) ? 'W' : '-')
  87. << ((Prot & orcrpctpc::WPF_Exec) ? 'X' : '-')
  88. << " segment: local " << (const void *)B.Buffer.data()
  89. << " -> target " << formatv("{0:x16}", B.Address) << "\n";
  90. }
  91. });
  92. if (auto Err =
  93. Parent.Parent.getMemoryAccess().writeBuffers(BufferWrites)) {
  94. OnFinalize(std::move(Err));
  95. return;
  96. }
  97. DEBUG_WITH_TYPE("orc", dbgs() << " Applying permissions...\n");
  98. if (auto Err =
  99. Parent.getEndpoint().template callAsync<orcrpctpc::FinalizeMem>(
  100. [OF = std::move(OnFinalize)](Error Err2) {
  101. // FIXME: Dispatch to work queue.
  102. std::thread([OF = std::move(OF),
  103. Err3 = std::move(Err2)]() mutable {
  104. DEBUG_WITH_TYPE(
  105. "orc", { dbgs() << " finalizeAsync complete\n"; });
  106. OF(std::move(Err3));
  107. }).detach();
  108. return Error::success();
  109. },
  110. FMR)) {
  111. DEBUG_WITH_TYPE("orc", dbgs() << " failed.\n");
  112. Parent.getEndpoint().abandonPendingResponses();
  113. Parent.reportError(std::move(Err));
  114. }
  115. DEBUG_WITH_TYPE("orc", {
  116. dbgs() << "Leaving finalizeAsync (finalization may continue in "
  117. "background)\n";
  118. });
  119. }
  120. Error deallocate() override {
  121. orcrpctpc::ReleaseOrFinalizeMemRequest RMR;
  122. for (auto &KV : TargetAllocs)
  123. RMR.push_back({orcrpctpc::toWireProtectionFlags(
  124. static_cast<sys::Memory::ProtectionFlags>(KV.first)),
  125. KV.second.Address, KV.second.AllocatedSize});
  126. TargetAllocs.clear();
  127. return Parent.getEndpoint().template callB<orcrpctpc::ReleaseMem>(RMR);
  128. }
  129. private:
  130. OrcRPCTPCJITLinkMemoryManager<OrcRPCTPCImplT> &Parent;
  131. HostAllocMap HostAllocs;
  132. TargetAllocMap TargetAllocs;
  133. };
  134. OrcRPCTPCJITLinkMemoryManager(OrcRPCTPCImplT &Parent) : Parent(Parent) {}
  135. Expected<std::unique_ptr<Allocation>>
  136. allocate(const jitlink::JITLinkDylib *JD,
  137. const SegmentsRequestMap &Request) override {
  138. orcrpctpc::ReserveMemRequest RMR;
  139. HostAllocMap HostAllocs;
  140. for (auto &KV : Request) {
  141. assert(KV.second.getContentSize() <= std::numeric_limits<size_t>::max() &&
  142. "Content size is out-of-range for host");
  143. RMR.push_back({orcrpctpc::toWireProtectionFlags(
  144. static_cast<sys::Memory::ProtectionFlags>(KV.first)),
  145. KV.second.getContentSize() + KV.second.getZeroFillSize(),
  146. KV.second.getAlignment()});
  147. HostAllocs[KV.first] = {
  148. std::make_unique<char[]>(KV.second.getContentSize()),
  149. KV.second.getContentSize()};
  150. }
  151. DEBUG_WITH_TYPE("orc", {
  152. dbgs() << "Orc remote memmgr got request:\n";
  153. for (auto &KV : Request)
  154. dbgs() << " permissions: "
  155. << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-')
  156. << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-')
  157. << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-')
  158. << ", content size: "
  159. << formatv("{0:x16}", KV.second.getContentSize())
  160. << " + zero-fill-size: "
  161. << formatv("{0:x16}", KV.second.getZeroFillSize())
  162. << ", align: " << KV.second.getAlignment() << "\n";
  163. });
  164. // FIXME: LLVM RPC needs to be fixed to support alt
  165. // serialization/deserialization on return types. For now just
  166. // translate from std::map to DenseMap manually.
  167. auto TmpTargetAllocs =
  168. Parent.getEndpoint().template callB<orcrpctpc::ReserveMem>(RMR);
  169. if (!TmpTargetAllocs)
  170. return TmpTargetAllocs.takeError();
  171. if (TmpTargetAllocs->size() != RMR.size())
  172. return make_error<StringError>(
  173. "Number of target allocations does not match request",
  174. inconvertibleErrorCode());
  175. TargetAllocMap TargetAllocs;
  176. for (auto &E : *TmpTargetAllocs)
  177. TargetAllocs[orcrpctpc::fromWireProtectionFlags(E.Prot)] = {
  178. E.Address, E.AllocatedSize};
  179. DEBUG_WITH_TYPE("orc", {
  180. auto HAI = HostAllocs.begin();
  181. for (auto &KV : TargetAllocs)
  182. dbgs() << " permissions: "
  183. << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-')
  184. << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-')
  185. << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-')
  186. << " assigned local " << (void *)HAI->second.Mem.get()
  187. << ", target " << formatv("{0:x16}", KV.second.Address) << "\n";
  188. });
  189. return std::make_unique<OrcRPCAllocation>(*this, std::move(HostAllocs),
  190. std::move(TargetAllocs));
  191. }
  192. private:
  193. void reportError(Error Err) { Parent.reportError(std::move(Err)); }
  194. decltype(std::declval<OrcRPCTPCImplT>().getEndpoint()) getEndpoint() {
  195. return Parent.getEndpoint();
  196. }
  197. OrcRPCTPCImplT &Parent;
  198. };
  199. /// TargetProcessControl::MemoryAccess implementation for a process connected
  200. /// via an ORC RPC endpoint.
  201. template <typename OrcRPCTPCImplT>
  202. class OrcRPCTPCMemoryAccess : public TargetProcessControl::MemoryAccess {
  203. public:
  204. OrcRPCTPCMemoryAccess(OrcRPCTPCImplT &Parent) : Parent(Parent) {}
  205. void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws,
  206. WriteResultFn OnWriteComplete) override {
  207. writeViaRPC<orcrpctpc::WriteUInt8s>(Ws, std::move(OnWriteComplete));
  208. }
  209. void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws,
  210. WriteResultFn OnWriteComplete) override {
  211. writeViaRPC<orcrpctpc::WriteUInt16s>(Ws, std::move(OnWriteComplete));
  212. }
  213. void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws,
  214. WriteResultFn OnWriteComplete) override {
  215. writeViaRPC<orcrpctpc::WriteUInt32s>(Ws, std::move(OnWriteComplete));
  216. }
  217. void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws,
  218. WriteResultFn OnWriteComplete) override {
  219. writeViaRPC<orcrpctpc::WriteUInt64s>(Ws, std::move(OnWriteComplete));
  220. }
  221. void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws,
  222. WriteResultFn OnWriteComplete) override {
  223. writeViaRPC<orcrpctpc::WriteBuffers>(Ws, std::move(OnWriteComplete));
  224. }
  225. private:
  226. template <typename WriteRPCFunction, typename WriteElementT>
  227. void writeViaRPC(ArrayRef<WriteElementT> Ws, WriteResultFn OnWriteComplete) {
  228. if (auto Err = Parent.getEndpoint().template callAsync<WriteRPCFunction>(
  229. [OWC = std::move(OnWriteComplete)](Error Err2) mutable -> Error {
  230. OWC(std::move(Err2));
  231. return Error::success();
  232. },
  233. Ws)) {
  234. Parent.reportError(std::move(Err));
  235. Parent.getEndpoint().abandonPendingResponses();
  236. }
  237. }
  238. OrcRPCTPCImplT &Parent;
  239. };
  240. // TargetProcessControl for a process connected via an ORC RPC Endpoint.
  241. template <typename RPCEndpointT>
  242. class OrcRPCTargetProcessControlBase : public TargetProcessControl {
  243. public:
  244. using ErrorReporter = unique_function<void(Error)>;
  245. using OnCloseConnectionFunction = unique_function<Error(Error)>;
  246. OrcRPCTargetProcessControlBase(std::shared_ptr<SymbolStringPool> SSP,
  247. RPCEndpointT &EP, ErrorReporter ReportError)
  248. : TargetProcessControl(std::move(SSP)),
  249. ReportError(std::move(ReportError)), EP(EP) {}
  250. void reportError(Error Err) { ReportError(std::move(Err)); }
  251. RPCEndpointT &getEndpoint() { return EP; }
  252. Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override {
  253. DEBUG_WITH_TYPE("orc", {
  254. dbgs() << "Loading dylib \"" << (DylibPath ? DylibPath : "") << "\" ";
  255. if (!DylibPath)
  256. dbgs() << "(process symbols)";
  257. dbgs() << "\n";
  258. });
  259. if (!DylibPath)
  260. DylibPath = "";
  261. auto H = EP.template callB<orcrpctpc::LoadDylib>(DylibPath);
  262. DEBUG_WITH_TYPE("orc", {
  263. if (H)
  264. dbgs() << " got handle " << formatv("{0:x16}", *H) << "\n";
  265. else
  266. dbgs() << " error, unable to load\n";
  267. });
  268. return H;
  269. }
  270. Expected<std::vector<tpctypes::LookupResult>>
  271. lookupSymbols(ArrayRef<LookupRequest> Request) override {
  272. std::vector<orcrpctpc::RemoteLookupRequest> RR;
  273. for (auto &E : Request) {
  274. RR.push_back({});
  275. RR.back().first = E.Handle;
  276. for (auto &KV : E.Symbols)
  277. RR.back().second.push_back(
  278. {(*KV.first).str(),
  279. KV.second == SymbolLookupFlags::WeaklyReferencedSymbol});
  280. }
  281. DEBUG_WITH_TYPE("orc", {
  282. dbgs() << "Compound lookup:\n";
  283. for (auto &R : Request) {
  284. dbgs() << " In " << formatv("{0:x16}", R.Handle) << ": {";
  285. bool First = true;
  286. for (auto &KV : R.Symbols) {
  287. dbgs() << (First ? "" : ",") << " " << *KV.first;
  288. First = false;
  289. }
  290. dbgs() << " }\n";
  291. }
  292. });
  293. return EP.template callB<orcrpctpc::LookupSymbols>(RR);
  294. }
  295. Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
  296. ArrayRef<std::string> Args) override {
  297. DEBUG_WITH_TYPE("orc", {
  298. dbgs() << "Running as main: " << formatv("{0:x16}", MainFnAddr)
  299. << ", args = [";
  300. for (unsigned I = 0; I != Args.size(); ++I)
  301. dbgs() << (I ? "," : "") << " \"" << Args[I] << "\"";
  302. dbgs() << "]\n";
  303. });
  304. auto Result = EP.template callB<orcrpctpc::RunMain>(MainFnAddr, Args);
  305. DEBUG_WITH_TYPE("orc", {
  306. dbgs() << " call to " << formatv("{0:x16}", MainFnAddr);
  307. if (Result)
  308. dbgs() << " returned result " << *Result << "\n";
  309. else
  310. dbgs() << " failed\n";
  311. });
  312. return Result;
  313. }
  314. Expected<tpctypes::WrapperFunctionResult>
  315. runWrapper(JITTargetAddress WrapperFnAddr,
  316. ArrayRef<uint8_t> ArgBuffer) override {
  317. DEBUG_WITH_TYPE("orc", {
  318. dbgs() << "Running as wrapper function "
  319. << formatv("{0:x16}", WrapperFnAddr) << " with "
  320. << formatv("{0:x16}", ArgBuffer.size()) << " argument buffer\n";
  321. });
  322. auto Result =
  323. EP.template callB<orcrpctpc::RunWrapper>(WrapperFnAddr, ArgBuffer);
  324. // dbgs() << "Returned from runWrapper...\n";
  325. return Result;
  326. }
  327. Error closeConnection(OnCloseConnectionFunction OnCloseConnection) {
  328. DEBUG_WITH_TYPE("orc", dbgs() << "Closing connection to remote\n");
  329. return EP.template callAsync<orcrpctpc::CloseConnection>(
  330. std::move(OnCloseConnection));
  331. }
  332. Error closeConnectionAndWait() {
  333. std::promise<MSVCPError> P;
  334. auto F = P.get_future();
  335. if (auto Err = closeConnection([&](Error Err2) -> Error {
  336. P.set_value(std::move(Err2));
  337. return Error::success();
  338. })) {
  339. EP.abandonAllPendingResponses();
  340. return joinErrors(std::move(Err), F.get());
  341. }
  342. return F.get();
  343. }
  344. protected:
  345. /// Subclasses must call this during construction to initialize the
  346. /// TargetTriple and PageSize members.
  347. Error initializeORCRPCTPCBase() {
  348. if (auto TripleOrErr = EP.template callB<orcrpctpc::GetTargetTriple>())
  349. TargetTriple = Triple(*TripleOrErr);
  350. else
  351. return TripleOrErr.takeError();
  352. if (auto PageSizeOrErr = EP.template callB<orcrpctpc::GetPageSize>())
  353. PageSize = *PageSizeOrErr;
  354. else
  355. return PageSizeOrErr.takeError();
  356. return Error::success();
  357. }
  358. private:
  359. ErrorReporter ReportError;
  360. RPCEndpointT &EP;
  361. };
  362. } // end namespace orc
  363. } // end namespace llvm
  364. #endif // LLVM_EXECUTIONENGINE_ORC_ORCRPCTARGETPROCESSCONTROL_H
  365. #ifdef __GNUC__
  366. #pragma GCC diagnostic pop
  367. #endif