ib_collective.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #pragma once
  2. #include <library/cpp/binsaver/bin_saver.h>
  3. namespace NNetliba {
  4. struct TCollectiveInit {
  5. int Size, Rank;
  6. SAVELOAD(Size, Rank);
  7. };
  8. struct TCollectiveLinkSet {
  9. struct TLinkInfo {
  10. int QPN, PSN;
  11. };
  12. TVector<int> Hosts; // host LIDs
  13. TVector<TVector<int>> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch
  14. TVector<TLinkInfo> Links;
  15. SAVELOAD(Hosts, HostGroup, Links);
  16. };
  17. struct IAllDataSync: public TThrRefBase {
  18. virtual void* GetRawData() = 0;
  19. virtual size_t GetRawDataSize() = 0;
  20. virtual void Sync() = 0;
  21. virtual void Flush() = 0;
  22. template <class T>
  23. T* GetData() {
  24. return static_cast<T*>(GetRawData());
  25. }
  26. template <class T>
  27. size_t GetSize() {
  28. return GetRawDataSize() / sizeof(T);
  29. }
  30. };
  31. struct IAllReduce: public IAllDataSync {
  32. virtual bool Resize(size_t dataSize) = 0;
  33. };
  34. struct IAllGather: public IAllDataSync {
  35. virtual bool Resize(const TVector<size_t>& szPerRank) = 0;
  36. };
  37. struct IReduceOp: public TThrRefBase {
  38. virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0;
  39. };
  40. template <class T, class TElem = typename T::TElem>
  41. class TAllReduceOp: public IReduceOp {
  42. T Op;
  43. public:
  44. TAllReduceOp() {
  45. }
  46. TAllReduceOp(T op)
  47. : Op(op)
  48. {
  49. }
  50. void Reduce(void* dst, const void* add, size_t dataSize) const override {
  51. TElem* dstPtr = (TElem*)(dst);
  52. const TElem* addPtr = (const TElem*)(add);
  53. TElem* finPtr = (TElem*)(((char*)dst) + dataSize);
  54. while (dstPtr < finPtr) {
  55. Op(dstPtr, *addPtr);
  56. ++dstPtr;
  57. ++addPtr;
  58. }
  59. }
  60. };
  61. // table of active peers for micro send/recv
  62. class TIBMicroPeerTable {
  63. TVector<ui8> Table; // == 0 means accept mesages from this qpn
  64. int TableSize;
  65. bool ParsePending;
  66. public:
  67. TIBMicroPeerTable()
  68. : ParsePending(true)
  69. {
  70. Init(0);
  71. }
  72. void Init(int tableSizeLog) {
  73. TableSize = 1 << tableSizeLog;
  74. ParsePending = true;
  75. Table.resize(0);
  76. Table.resize(TableSize, 0);
  77. }
  78. bool NeedParsePending() const {
  79. return ParsePending;
  80. }
  81. void StopParsePending() {
  82. ParsePending = false;
  83. }
  84. void StopQPN(int qpn, ui8 mask) {
  85. Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0);
  86. Table[qpn & (TableSize - 1)] |= mask;
  87. }
  88. void StopQPN(int qpn) {
  89. Y_ASSERT(Table[qpn & (TableSize - 1)] == 0);
  90. Table[qpn & (TableSize - 1)] = 0xff;
  91. }
  92. bool NeedQPN(int qpn) const {
  93. return Table[qpn & (TableSize - 1)] != 0xff;
  94. }
  95. };
  96. struct IIBCollective;
  97. class TIBCollective;
  98. class TIBRecvMicro: public TNonCopyable {
  99. TIBCollective& IB;
  100. ui64 Id;
  101. int QPN;
  102. void* Data;
  103. public:
  104. TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable);
  105. ~TIBRecvMicro();
  106. void* GetRawData() const {
  107. return Data;
  108. }
  109. template <class T>
  110. T* GetData() {
  111. return static_cast<T*>(GetRawData());
  112. }
  113. int GetQPN() const {
  114. return QPN;
  115. }
  116. };
  117. struct IIBCollective: public TThrRefBase {
  118. struct TRdmaRequest {
  119. int DstRank;
  120. ui64 RemoteAddr, LocalAddr;
  121. ui32 RemoteKey, LocalKey;
  122. ui64 Size;
  123. };
  124. virtual int GetRank() = 0;
  125. virtual int GetSize() = 0;
  126. virtual int GetGroupTypeCount() = 0;
  127. virtual int GetQPN(int rank) = 0;
  128. virtual bool TryWaitCompletion() = 0;
  129. virtual void WaitCompletion() = 0;
  130. virtual void Start(const TCollectiveLinkSet& links) = 0;
  131. virtual IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) = 0;
  132. virtual IAllGather* CreateAllGather(size_t szPerRank) = 0;
  133. virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) = 0;
  134. virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0;
  135. virtual void Fence() = 0;
  136. virtual void InitPeerTable(TIBMicroPeerTable* res) = 0;
  137. virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0;
  138. virtual void RdmaWrite(const TVector<TRdmaRequest>& reqs) = 0;
  139. };
  140. IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks);
  141. }