ib_collective.cpp 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317
  1. #include "stdafx.h"
  2. #include "ib_collective.h"
  3. #include "ib_mem.h"
  4. #include "ib_buffers.h"
  5. #include "ib_low.h"
  6. #include "udp_http.h"
  7. #include "udp_address.h"
  8. #include <util/generic/deque.h>
  9. #include <util/system/hp_timer.h>
  10. namespace NNetliba {
  11. const int COL_SERVICE_LEVEL = 2;
  12. const int COL_DATA_SERVICE_LEVEL = 2; // base level
  13. const int COL_DATA_SERVICE_LEVEL_COUNT = 6; // level count
  14. const int MAX_REQS_PER_PEER = 32;
  15. const int MAX_TOTAL_RDMA = 20;
  16. const int SEND_COUNT_TABLE_SIZE = 1 << 12; // must be power of 2
  17. struct TMergeRecord {
  18. struct TTransfer {
  19. int DstRank;
  20. int SL;
  21. int RangeBeg, RangeFin;
  22. int Id;
  23. TTransfer()
  24. : DstRank(-1)
  25. , SL(0)
  26. , RangeBeg(0)
  27. , RangeFin(0)
  28. , Id(0)
  29. {
  30. }
  31. TTransfer(int dstRank, int sl, int rangeBeg, int rangeFin, int id)
  32. : DstRank(dstRank)
  33. , SL(sl)
  34. , RangeBeg(rangeBeg)
  35. , RangeFin(rangeFin)
  36. , Id(id)
  37. {
  38. }
  39. };
  40. struct TInTransfer {
  41. int SrcRank;
  42. int SL;
  43. TInTransfer()
  44. : SrcRank(-1)
  45. , SL(0)
  46. {
  47. }
  48. TInTransfer(int srcRank, int sl)
  49. : SrcRank(srcRank)
  50. , SL(sl)
  51. {
  52. }
  53. };
  54. TVector<TTransfer> OutList;
  55. TVector<TInTransfer> InList;
  56. ui64 RecvMask;
  57. TMergeRecord()
  58. : RecvMask(0)
  59. {
  60. }
  61. };
  62. struct TMergeIteration {
  63. TVector<TMergeRecord> Ops;
  64. void Init(int colSize) {
  65. Ops.resize(colSize);
  66. }
  67. void Transfer(int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin, int id) {
  68. Y_ABORT_UNLESS(id < 64, "recv mask overflow");
  69. Ops[srcRank].OutList.push_back(TMergeRecord::TTransfer(dstRank, sl, rangeBeg, rangeFin, id));
  70. Ops[dstRank].InList.push_back(TMergeRecord::TInTransfer(srcRank, sl));
  71. Ops[dstRank].RecvMask |= ui64(1) << id;
  72. }
  73. };
  74. struct TMergePlan {
  75. TVector<TMergeIteration> Iterations;
  76. TVector<int> RankReceiveCount;
  77. int ColSize;
  78. int MaxRankReceiveCount;
  79. TMergePlan()
  80. : ColSize(0)
  81. , MaxRankReceiveCount(0)
  82. {
  83. }
  84. void Init(int colSize) {
  85. Iterations.resize(0);
  86. RankReceiveCount.resize(0);
  87. RankReceiveCount.resize(colSize, 0);
  88. ColSize = colSize;
  89. }
  90. void Transfer(int iter, int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin) {
  91. while (iter >= Iterations.ysize()) {
  92. TMergeIteration& res = Iterations.emplace_back();
  93. res.Init(ColSize);
  94. }
  95. int id = RankReceiveCount[dstRank]++;
  96. MaxRankReceiveCount = Max(MaxRankReceiveCount, id + 1);
  97. Y_ASSERT(id < 64);
  98. Iterations[iter].Transfer(srcRank, dstRank, sl, rangeBeg, rangeFin, id);
  99. }
  100. };
  101. struct TSRTransfer {
  102. int SrcRank, DstRank;
  103. int RangeBeg, RangeFin;
  104. TSRTransfer() {
  105. Zero(*this);
  106. }
  107. TSRTransfer(int srcRank, int dstRank, int rangeBeg, int rangeFin)
  108. : SrcRank(srcRank)
  109. , DstRank(dstRank)
  110. , RangeBeg(rangeBeg)
  111. , RangeFin(rangeFin)
  112. {
  113. }
  114. };
  115. static int SplitRange(THashMap<int, TVector<TSRTransfer>>* res, int iter, int beg, int fin) {
  116. int mid = (beg + fin + 1) / 2;
  117. if (mid == fin) {
  118. return iter;
  119. }
  120. for (int i = 0; i < fin - mid; ++i) {
  121. (*res)[iter].push_back(TSRTransfer(beg + i, mid + i, beg, mid));
  122. (*res)[iter].push_back(TSRTransfer(mid + i, beg + i, mid, fin));
  123. }
  124. if (fin - mid < mid - beg) {
  125. // [mid - 1] did not receive [mid;fin)
  126. (*res)[iter].push_back(TSRTransfer(mid, mid - 1, mid, fin));
  127. }
  128. int rv1 = SplitRange(res, iter + 1, beg, mid);
  129. int rv2 = SplitRange(res, iter + 1, mid, fin);
  130. return Max(rv1, rv2);
  131. }
  132. static void CreatePow2Merge(TMergePlan* plan, int colSize) {
  133. // finally everybody has full range [0;ColSize)
  134. // construct plan recursively, on each iteration split some range
  135. plan->Init(colSize);
  136. THashMap<int, TVector<TSRTransfer>> allTransfers;
  137. int maxIter = SplitRange(&allTransfers, 0, 0, colSize);
  138. for (int iter = 0; iter < maxIter; ++iter) {
  139. const TVector<TSRTransfer>& arr = allTransfers[maxIter - iter - 1]; // reverse order
  140. for (int i = 0; i < arr.ysize(); ++i) {
  141. const TSRTransfer& sr = arr[i];
  142. plan->Transfer(iter, sr.SrcRank, sr.DstRank, 0, sr.RangeBeg, sr.RangeFin);
  143. }
  144. }
  145. }
  146. struct TCoverInterval {
  147. int Beg, Fin; // [Beg;Fin)
  148. TCoverInterval()
  149. : Beg(0)
  150. , Fin(0)
  151. {
  152. }
  153. TCoverInterval(int b, int f)
  154. : Beg(b)
  155. , Fin(f)
  156. {
  157. }
  158. };
  159. enum EAllToAllMode {
  160. AA_POW2,
  161. AA_CIRCLE,
  162. AA_STAR,
  163. AA_POW2_MERGE,
  164. };
  165. static int AllToAll(TMergePlan* plan, int iter, int sl, EAllToAllMode mode, const TVector<int>& myGroup, TVector<TCoverInterval>* cover) {
  166. TVector<TCoverInterval>& hostCoverage = *cover;
  167. int groupSize = myGroup.ysize();
  168. for (int k = 1; k < groupSize; ++k) {
  169. int h1 = myGroup[k - 1];
  170. int h2 = myGroup[k];
  171. Y_ABORT_UNLESS(hostCoverage[h1].Fin == hostCoverage[h2].Beg, "Invalid host order in CreateGroupMerge()");
  172. }
  173. switch (mode) {
  174. case AA_POW2: {
  175. for (int delta = 1; delta < groupSize; delta *= 2) {
  176. int sz = Min(delta, groupSize - delta);
  177. for (int offset = 0; offset < groupSize; ++offset) {
  178. int srcRank = myGroup[offset];
  179. int dstRank = myGroup[(offset + delta) % groupSize];
  180. int start = offset + 1 - sz;
  181. int finish = offset + 1;
  182. if (start < 0) {
  183. // [start; myGroup.size())
  184. int dataBeg = hostCoverage[myGroup[start + groupSize]].Beg;
  185. int dataFin = hostCoverage[myGroup.back()].Fin;
  186. plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
  187. // [0; finish)
  188. dataBeg = hostCoverage[myGroup[0]].Beg;
  189. dataFin = hostCoverage[myGroup[finish - 1]].Fin;
  190. plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
  191. } else {
  192. // [start;finish)
  193. int dataBeg = hostCoverage[myGroup[start]].Beg;
  194. int dataFin = hostCoverage[myGroup[finish - 1]].Fin;
  195. plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
  196. }
  197. }
  198. ++iter;
  199. }
  200. } break;
  201. case AA_CIRCLE: {
  202. for (int dataDelta = 1; dataDelta < groupSize; ++dataDelta) {
  203. for (int offset = 0; offset < groupSize; ++offset) {
  204. int srcRank = myGroup[offset];
  205. int dstRank = myGroup[(offset + 1) % groupSize];
  206. int dataRank = myGroup[(offset + 1 - dataDelta + groupSize) % groupSize];
  207. int dataBeg = hostCoverage[dataRank].Beg;
  208. int dataFin = hostCoverage[dataRank].Fin;
  209. plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
  210. }
  211. ++iter;
  212. }
  213. } break;
  214. case AA_STAR: {
  215. for (int offset = 0; offset < groupSize; ++offset) {
  216. for (int delta = 1; delta < groupSize; ++delta) {
  217. int srcRank = myGroup[offset];
  218. int dstRank = myGroup[(offset + delta) % groupSize];
  219. int dataRank = myGroup[offset];
  220. int dataBeg = hostCoverage[dataRank].Beg;
  221. int dataFin = hostCoverage[dataRank].Fin;
  222. plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
  223. }
  224. }
  225. ++iter;
  226. } break;
  227. case AA_POW2_MERGE: {
  228. TMergePlan pp;
  229. CreatePow2Merge(&pp, groupSize);
  230. for (int z = 0; z < pp.Iterations.ysize(); ++z) {
  231. const TMergeIteration& mm = pp.Iterations[z];
  232. for (int src = 0; src < mm.Ops.ysize(); ++src) {
  233. const TMergeRecord& mr = mm.Ops[src];
  234. int srcRank = myGroup[src];
  235. for (int i = 0; i < mr.OutList.ysize(); ++i) {
  236. int dstRank = myGroup[mr.OutList[i].DstRank];
  237. plan->Transfer(iter, srcRank, dstRank, sl, 0, 1);
  238. }
  239. }
  240. ++iter;
  241. }
  242. } break;
  243. default:
  244. Y_ASSERT(0);
  245. break;
  246. }
  247. {
  248. TCoverInterval cc(hostCoverage[myGroup[0]].Beg, hostCoverage[myGroup.back()].Fin);
  249. for (int k = 0; k < groupSize; ++k) {
  250. hostCoverage[myGroup[k]] = cc;
  251. }
  252. }
  253. return iter;
  254. }
  255. // fully populated matrix
  256. static void CreateGroupMerge(TMergePlan* plan, EAllToAllMode mode, const TVector<TVector<int>>& hostGroup) {
  257. int hostCount = hostGroup[0].ysize();
  258. int groupTypeCount = hostGroup.ysize();
  259. plan->Init(hostCount);
  260. TVector<int> gcount;
  261. gcount.resize(groupTypeCount, 0);
  262. for (int hostId = 0; hostId < hostCount; ++hostId) {
  263. for (int groupType = 0; groupType < groupTypeCount; ++groupType) {
  264. int val = hostGroup[groupType][hostId];
  265. gcount[groupType] = Max(gcount[groupType], val + 1);
  266. }
  267. }
  268. for (int hostId = 1; hostId < hostCount; ++hostId) {
  269. bool isIncrement = true;
  270. for (int groupType = 0; groupType < groupTypeCount; ++groupType) {
  271. int prev = hostGroup[groupType][hostId - 1];
  272. int cur = hostGroup[groupType][hostId];
  273. if (isIncrement) {
  274. if (cur == prev + 1) {
  275. isIncrement = false;
  276. } else {
  277. Y_ABORT_UNLESS(cur == 0, "ib_hosts, wrapped to non-zero");
  278. Y_ABORT_UNLESS(prev == gcount[groupType] - 1, "ib_hosts, structure is irregular");
  279. isIncrement = true;
  280. }
  281. } else {
  282. Y_ABORT_UNLESS(prev == cur, "ib_hosts, structure is irregular");
  283. }
  284. }
  285. }
  286. TVector<TCoverInterval> hostCoverage;
  287. for (int i = 0; i < hostCount; ++i) {
  288. hostCoverage.push_back(TCoverInterval(i, i + 1));
  289. }
  290. int baseIter = 0;
  291. for (int groupType = hostGroup.ysize() - 1; groupType >= 0; --groupType) {
  292. Y_ASSERT(hostGroup[groupType].ysize() == hostCount);
  293. TVector<TVector<int>> hh;
  294. hh.resize(gcount[groupType]);
  295. for (int rank = 0; rank < hostGroup[groupType].ysize(); ++rank) {
  296. int groupId = hostGroup[groupType][rank];
  297. hh[groupId].push_back(rank);
  298. }
  299. int newIter = 0;
  300. for (int groupId = 0; groupId < hh.ysize(); ++groupId) {
  301. int nn = AllToAll(plan, baseIter, 0, mode, hh[groupId], &hostCoverage); // seems to be fastest
  302. if (newIter == 0) {
  303. newIter = nn;
  304. } else {
  305. Y_ABORT_UNLESS(newIter == nn, "groups should be of same size");
  306. }
  307. }
  308. baseIter = newIter;
  309. }
  310. //printf("%d iterations symmetrical plan\n", baseIter);
  311. }
  312. //////////////////////////////////////////////////////////////////////////
  313. struct TAllDataSync {
  314. enum {
  315. WR_COUNT = 64 * 2
  316. };
  317. int CurrentBuffer;
  318. TIntrusivePtr<TIBMemBlock> MemBlock[2];
  319. TIntrusivePtr<TComplectionQueue> CQ;
  320. TIntrusivePtr<TSharedReceiveQueue> SRQ;
  321. TIntrusivePtr<TIBMemBlock> FakeRecvMem;
  322. size_t DataSize, BufSize;
  323. size_t CurrentOffset, ReadyOffset;
  324. bool WasFlushed;
  325. int ActiveRDMACount;
  326. ui64 FutureRecvMask;
  327. TIntrusivePtr<IReduceOp> ReduceOp;
  328. struct TBlockInfo {
  329. ui64 Addr;
  330. ui32 Key;
  331. };
  332. struct TSend {
  333. TBlockInfo RemoteBlocks[2];
  334. TIntrusivePtr<TRCQueuePair> QP;
  335. size_t SrcOffset;
  336. size_t DstOffset;
  337. size_t Length;
  338. ui32 ImmData;
  339. int DstRank;
  340. union {
  341. struct {
  342. int RangeBeg, RangeFin;
  343. } Gather;
  344. struct {
  345. int SrcIndex, DstIndex;
  346. } Reduce;
  347. };
  348. };
  349. struct TRecv {
  350. TIntrusivePtr<TRCQueuePair> QP;
  351. int SrcRank;
  352. };
  353. struct TReduce {
  354. size_t DstOffset, SrcOffset;
  355. int DstIndex, SrcIndex;
  356. };
  357. struct TIteration {
  358. TVector<TSend> OutList;
  359. TVector<TRecv> InList;
  360. TVector<TReduce> ReduceList;
  361. ui64 RecvMask;
  362. };
  363. TVector<TIteration> Iterations;
  364. public:
  365. void* GetRawData() {
  366. char* myData = (char*)MemBlock[CurrentBuffer]->GetData();
  367. return myData + CurrentOffset;
  368. }
  369. size_t GetRawDataSize() {
  370. return DataSize;
  371. }
  372. void PostRecv() {
  373. SRQ->PostReceive(FakeRecvMem->GetMemRegion(), 0, FakeRecvMem->GetData(), FakeRecvMem->GetSize());
  374. }
  375. void Sync() {
  376. Y_ASSERT(WasFlushed && "Have to call Flush() before data fill & Sync()");
  377. char* myData = (char*)MemBlock[CurrentBuffer]->GetData();
  378. ui64 recvMask = FutureRecvMask;
  379. FutureRecvMask = 0;
  380. int recvDebt = 0;
  381. for (int z = 0; z < Iterations.ysize(); ++z) {
  382. const TIteration& iter = Iterations[z];
  383. for (int k = 0; k < iter.OutList.ysize(); ++k) {
  384. const TSend& ss = iter.OutList[k];
  385. const TBlockInfo& remoteBlk = ss.RemoteBlocks[CurrentBuffer];
  386. ss.QP->PostRDMAWriteImm(remoteBlk.Addr + ss.DstOffset, remoteBlk.Key, ss.ImmData,
  387. MemBlock[CurrentBuffer]->GetMemRegion(), 0, myData + ss.SrcOffset, ss.Length);
  388. ++ActiveRDMACount;
  389. //printf("-> %d, imm %d (%" PRId64 " bytes)\n", ss.DstRank, ss.ImmData, ss.Length);
  390. //printf("send %d\n", ss.SrcOffset);
  391. }
  392. ibv_wc wc;
  393. while ((recvMask & iter.RecvMask) != iter.RecvMask) {
  394. int rv = CQ->Poll(&wc, 1);
  395. if (rv > 0) {
  396. Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "AllGather::Sync fail, status %d", (int)wc.status);
  397. if (wc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
  398. //printf("Got %d\n", wc.imm_data);
  399. ++recvDebt;
  400. ui64 newBit = ui64(1) << wc.imm_data;
  401. if (recvMask & newBit) {
  402. Y_ABORT_UNLESS((FutureRecvMask & newBit) == 0, "data from 2 Sync() ahead is impossible");
  403. FutureRecvMask |= newBit;
  404. } else {
  405. recvMask |= newBit;
  406. }
  407. } else if (wc.opcode == IBV_WC_RDMA_WRITE) {
  408. --ActiveRDMACount;
  409. } else {
  410. Y_ASSERT(0);
  411. }
  412. } else {
  413. if (recvDebt > 0) {
  414. PostRecv();
  415. --recvDebt;
  416. }
  417. }
  418. }
  419. for (int k = 0; k < iter.ReduceList.ysize(); ++k) {
  420. const TReduce& rr = iter.ReduceList[k];
  421. ReduceOp->Reduce(myData + rr.DstOffset, myData + rr.SrcOffset, DataSize);
  422. //printf("Merge %d -> %d (%d bytes)\n", rr.SrcOffset, rr.DstOffset, DataSize);
  423. }
  424. //printf("Iteration %d done\n", z);
  425. }
  426. while (recvDebt > 0) {
  427. PostRecv();
  428. --recvDebt;
  429. }
  430. CurrentOffset = ReadyOffset;
  431. WasFlushed = false;
  432. //printf("new cur offset %g\n", (double)CurrentOffset);
  433. //printf("Sync complete\n");
  434. }
  435. void Flush() {
  436. Y_ASSERT(!WasFlushed);
  437. CurrentBuffer = 1 - CurrentBuffer;
  438. CurrentOffset = 0;
  439. WasFlushed = true;
  440. }
  441. public:
  442. TAllDataSync(size_t bufSize, TPtrArg<TIBMemPool> memPool, TPtrArg<IReduceOp> reduceOp)
  443. : CurrentBuffer(0)
  444. , DataSize(0)
  445. , BufSize(bufSize)
  446. , CurrentOffset(0)
  447. , ReadyOffset(0)
  448. , WasFlushed(false)
  449. , ActiveRDMACount(0)
  450. , FutureRecvMask(0)
  451. , ReduceOp(reduceOp)
  452. {
  453. if (memPool) {
  454. MemBlock[0] = memPool->Alloc(BufSize);
  455. MemBlock[1] = memPool->Alloc(BufSize);
  456. CQ = new TComplectionQueue(memPool->GetIBContext(), WR_COUNT);
  457. SRQ = new TSharedReceiveQueue(memPool->GetIBContext(), WR_COUNT);
  458. FakeRecvMem = memPool->Alloc(4096);
  459. } else {
  460. MemBlock[0] = new TIBMemBlock(BufSize);
  461. MemBlock[1] = new TIBMemBlock(BufSize);
  462. CQ = new TComplectionQueue(nullptr, WR_COUNT);
  463. SRQ = new TSharedReceiveQueue(nullptr, WR_COUNT);
  464. FakeRecvMem = new TIBMemBlock(4096);
  465. }
  466. for (int i = 0; i < WR_COUNT; ++i) {
  467. PostRecv();
  468. }
  469. }
  470. ~TAllDataSync() {
  471. while (ActiveRDMACount > 0) {
  472. ibv_wc wc;
  473. int rv = CQ->Poll(&wc, 1);
  474. if (rv > 0) {
  475. if (wc.opcode == IBV_WC_RDMA_WRITE) {
  476. --ActiveRDMACount;
  477. } else {
  478. Y_ASSERT(0);
  479. }
  480. }
  481. }
  482. }
  483. };
  484. class TAllReduce: public IAllReduce {
  485. TAllDataSync DataSync;
  486. size_t BufSizeMult;
  487. size_t ReadyOffsetMult;
  488. public:
  489. TAllReduce(size_t bufSize, TPtrArg<TIBMemPool> memPool, TPtrArg<IReduceOp> reduceOp)
  490. : DataSync(bufSize, memPool, reduceOp)
  491. , BufSizeMult(0)
  492. , ReadyOffsetMult(0)
  493. {
  494. }
  495. TAllDataSync& GetDataSync() {
  496. return DataSync;
  497. }
  498. void* GetRawData() override {
  499. return DataSync.GetRawData();
  500. }
  501. size_t GetRawDataSize() override {
  502. return DataSync.GetRawDataSize();
  503. }
  504. void Sync() override {
  505. DataSync.Sync();
  506. }
  507. void Flush() override {
  508. DataSync.Flush();
  509. }
  510. bool Resize(size_t dataSize) override {
  511. size_t repSize = (dataSize + 63) & (~63ull);
  512. size_t bufSize = repSize * BufSizeMult;
  513. if (bufSize > DataSync.BufSize) {
  514. return false;
  515. }
  516. for (int z = 0; z < DataSync.Iterations.ysize(); ++z) {
  517. TAllDataSync::TIteration& iter = DataSync.Iterations[z];
  518. for (int i = 0; i < iter.OutList.ysize(); ++i) {
  519. TAllDataSync::TSend& snd = iter.OutList[i];
  520. snd.Length = dataSize;
  521. snd.SrcOffset = snd.Reduce.SrcIndex * repSize;
  522. snd.DstOffset = snd.Reduce.DstIndex * repSize;
  523. }
  524. for (int i = 0; i < iter.ReduceList.ysize(); ++i) {
  525. TAllDataSync::TReduce& red = iter.ReduceList[i];
  526. red.SrcOffset = red.SrcIndex * repSize;
  527. red.DstOffset = red.DstIndex * repSize;
  528. }
  529. }
  530. DataSync.ReadyOffset = ReadyOffsetMult * repSize;
  531. DataSync.DataSize = dataSize;
  532. return true;
  533. }
  534. friend class TIBCollective;
  535. };
  536. class TAllGather: public IAllGather {
  537. TAllDataSync DataSync;
  538. int ColSize;
  539. public:
  540. TAllGather(int colSize, size_t bufSize, TPtrArg<TIBMemPool> memPool)
  541. : DataSync(bufSize, memPool, nullptr)
  542. , ColSize(colSize)
  543. {
  544. }
  545. TAllDataSync& GetDataSync() {
  546. return DataSync;
  547. }
  548. void* GetRawData() override {
  549. return DataSync.GetRawData();
  550. }
  551. size_t GetRawDataSize() override {
  552. return DataSync.GetRawDataSize();
  553. }
  554. void Sync() override {
  555. DataSync.Sync();
  556. }
  557. void Flush() override {
  558. DataSync.Flush();
  559. }
  560. bool Resize(const TVector<size_t>& szPerRank) override {
  561. Y_ABORT_UNLESS(szPerRank.ysize() == ColSize, "Invalid size array");
  562. TVector<size_t> offsets;
  563. offsets.push_back(0);
  564. for (int rank = 0; rank < ColSize; ++rank) {
  565. offsets.push_back(offsets.back() + szPerRank[rank]);
  566. }
  567. size_t dataSize = offsets.back();
  568. if (dataSize > DataSync.BufSize) {
  569. return false;
  570. }
  571. for (int z = 0; z < DataSync.Iterations.ysize(); ++z) {
  572. TAllDataSync::TIteration& iter = DataSync.Iterations[z];
  573. for (int i = 0; i < iter.OutList.ysize(); ++i) {
  574. TAllDataSync::TSend& snd = iter.OutList[i];
  575. int rangeBeg = snd.Gather.RangeBeg;
  576. int rangeFin = snd.Gather.RangeFin;
  577. snd.Length = offsets[rangeFin] - offsets[rangeBeg];
  578. snd.SrcOffset = offsets[rangeBeg];
  579. snd.DstOffset = snd.SrcOffset;
  580. }
  581. }
  582. DataSync.DataSize = dataSize;
  583. return true;
  584. }
  585. };
  586. struct TIBAddr {
  587. int LID, SL;
  588. TIBAddr()
  589. : LID(0)
  590. , SL(0)
  591. {
  592. }
  593. TIBAddr(int lid, int sl)
  594. : LID(lid)
  595. , SL(sl)
  596. {
  597. }
  598. };
  599. inline bool operator==(const TIBAddr& a, const TIBAddr& b) {
  600. return a.LID == b.LID && a.SL == b.SL;
  601. }
  602. inline bool operator<(const TIBAddr& a, const TIBAddr& b) {
  603. if (a.LID == b.LID) {
  604. return a.SL < b.SL;
  605. }
  606. return a.LID < b.LID;
  607. }
  608. struct TIBAddrHash {
  609. int operator()(const TIBAddr& a) const {
  610. return a.LID + a.SL * 4254515;
  611. }
  612. };
  613. class TIBCollective: public IIBCollective {
  614. struct TPendingMessage {
  615. int QPN;
  616. ui64 WorkId;
  617. TPendingMessage() {
  618. Zero(*this);
  619. }
  620. TPendingMessage(int qpn, ui64 wid)
  621. : QPN(qpn)
  622. , WorkId(wid)
  623. {
  624. }
  625. };
  626. struct TBlockInform {
  627. TAllDataSync::TBlockInfo RemoteBlocks[2];
  628. int PSN, QPN;
  629. };
  630. struct TPeerConnection {
  631. TAllDataSync::TBlockInfo RemoteBlocks[2];
  632. TIntrusivePtr<TRCQueuePair> QP;
  633. };
  634. struct TBWTest {
  635. ui64 Addr;
  636. ui32 RKey;
  637. };
  638. TIntrusivePtr<TIBPort> Port;
  639. TIntrusivePtr<TIBMemPool> MemPool;
  640. int ColSize, ColRank;
  641. TVector<int> Hosts; // host LIDs
  642. TVector<TVector<int>> HostGroup;
  643. TVector<TIntrusivePtr<TRCQueuePair>> Peers;
  644. TIntrusivePtr<TComplectionQueue> CQ;
  645. TIntrusivePtr<TIBBufferPool> BP;
  646. ui8 SendCountTable[SEND_COUNT_TABLE_SIZE];
  647. ui8 RDMACountTable[SEND_COUNT_TABLE_SIZE];
  648. TDeque<TPendingMessage> Pending;
  649. TMergePlan MergePlan, ReducePlan;
  650. int QPNTableSizeLog;
  651. void WriteCompleted(const ibv_wc& wc) {
  652. --SendCountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)];
  653. if (wc.opcode == IBV_WC_RDMA_WRITE) {
  654. --RDMACountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)];
  655. }
  656. BP->FreeBuf(wc.wr_id);
  657. }
  658. bool GetMsg(ui64* resWorkId, int* resQPN, TIBMicroPeerTable* tbl) {
  659. if (tbl->NeedParsePending()) {
  660. for (TDeque<TPendingMessage>::iterator z = Pending.begin(); z != Pending.end(); ++z) {
  661. if (!tbl->NeedQPN(z->QPN)) {
  662. continue;
  663. }
  664. *resWorkId = z->WorkId;
  665. *resQPN = z->QPN;
  666. Pending.erase(z);
  667. return true;
  668. }
  669. //printf("Stop parse pending\n");
  670. tbl->StopParsePending();
  671. }
  672. for (;;) {
  673. ibv_wc wc;
  674. int rv = CQ->Poll(&wc, 1);
  675. if (rv > 0) {
  676. Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
  677. if (wc.opcode & IBV_WC_RECV) {
  678. BP->RequestPostRecv();
  679. if (tbl->NeedQPN(wc.qp_num)) {
  680. *resWorkId = wc.wr_id;
  681. *resQPN = wc.qp_num;
  682. return true;
  683. } else {
  684. Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
  685. BP->PostRecv();
  686. }
  687. } else {
  688. WriteCompleted(wc);
  689. }
  690. } else {
  691. return false;
  692. }
  693. }
  694. }
  695. bool ProcessSendCompletion(const ibv_wc& wc) {
  696. Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
  697. if (wc.opcode & IBV_WC_RECV) {
  698. BP->RequestPostRecv();
  699. Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
  700. BP->PostRecv();
  701. } else {
  702. WriteCompleted(wc);
  703. return true;
  704. }
  705. return false;
  706. }
  707. void WaitCompletion(ibv_wc* res) {
  708. ibv_wc& wc = *res;
  709. for (;;) {
  710. int rv = CQ->Poll(&wc, 1);
  711. if (rv > 0 && ProcessSendCompletion(wc)) {
  712. break;
  713. }
  714. }
  715. }
  716. bool TryWaitCompletion() override {
  717. ibv_wc wc;
  718. for (;;) {
  719. int rv = CQ->Poll(&wc, 1);
  720. if (rv > 0) {
  721. if (ProcessSendCompletion(wc)) {
  722. return true;
  723. }
  724. } else {
  725. return false;
  726. }
  727. }
  728. }
  729. void WaitCompletion() override {
  730. ibv_wc wc;
  731. WaitCompletion(&wc);
  732. }
  733. ui64 WaitForMsg(int qpn) {
  734. for (TDeque<TPendingMessage>::iterator z = Pending.begin(); z != Pending.end(); ++z) {
  735. if (z->QPN == qpn) {
  736. ui64 workId = z->WorkId;
  737. Pending.erase(z);
  738. return workId;
  739. }
  740. }
  741. ibv_wc wc;
  742. for (;;) {
  743. int rv = CQ->Poll(&wc, 1);
  744. if (rv > 0) {
  745. Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
  746. if (wc.opcode & IBV_WC_RECV) {
  747. BP->RequestPostRecv();
  748. if ((int)wc.qp_num == qpn) {
  749. return wc.wr_id;
  750. } else {
  751. Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
  752. BP->PostRecv();
  753. }
  754. } else {
  755. WriteCompleted(wc);
  756. }
  757. }
  758. }
  759. }
  760. bool AllocOperationSlot(TPtrArg<TRCQueuePair> qp) {
  761. int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1);
  762. if (SendCountTable[way] >= MAX_REQS_PER_PEER) {
  763. return false;
  764. }
  765. ++SendCountTable[way];
  766. return true;
  767. }
  768. bool AllocRDMAWriteSlot(TPtrArg<TRCQueuePair> qp) {
  769. int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1);
  770. if (SendCountTable[way] >= MAX_REQS_PER_PEER) {
  771. return false;
  772. }
  773. if (RDMACountTable[way] >= MAX_OUTSTANDING_RDMA) {
  774. return false;
  775. }
  776. ++SendCountTable[way];
  777. ++RDMACountTable[way];
  778. return true;
  779. }
  780. bool TryPostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) {
  781. if (AllocOperationSlot(qp)) {
  782. BP->PostSend(qp, data, len);
  783. return true;
  784. }
  785. return false;
  786. }
  787. void PostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) {
  788. while (!TryPostSend(qp, data, len)) {
  789. WaitCompletion();
  790. }
  791. }
  792. int GetRank() override {
  793. return ColRank;
  794. }
  795. int GetSize() override {
  796. return ColSize;
  797. }
  798. int GetGroupTypeCount() override {
  799. return HostGroup.ysize();
  800. }
  801. int GetQPN(int rank) override {
  802. if (rank == ColRank) {
  803. Y_ASSERT(0 && "there is no qpn connected to localhost");
  804. return 0;
  805. }
  806. return Peers[rank]->GetQPN();
  807. }
  808. void Start(const TCollectiveLinkSet& links) override {
  809. Hosts = links.Hosts;
  810. HostGroup = links.HostGroup;
  811. for (int k = 0; k < ColSize; ++k) {
  812. if (k == ColRank) {
  813. continue;
  814. }
  815. const TCollectiveLinkSet::TLinkInfo& lnk = links.Links[k];
  816. ibv_ah_attr peerAddr;
  817. MakeAH(&peerAddr, Port, Hosts[k], COL_SERVICE_LEVEL);
  818. Peers[k]->Init(peerAddr, lnk.QPN, lnk.PSN);
  819. }
  820. //CreatePow2Merge(&MergePlan, ColSize);
  821. //CreatePow2Merge(&ReducePlan, ColSize);
  822. CreateGroupMerge(&MergePlan, AA_STAR, HostGroup);
  823. CreateGroupMerge(&ReducePlan, AA_POW2_MERGE, HostGroup);
  824. }
  825. void CreateDataSyncQPs(
  826. TPtrArg<TComplectionQueue> cq,
  827. TPtrArg<TSharedReceiveQueue> srq,
  828. TPtrArg<TIBMemBlock> memBlock0,
  829. TPtrArg<TIBMemBlock> memBlock1,
  830. const TMergePlan& plan,
  831. THashMap<TIBAddr, TPeerConnection, TIBAddrHash>* res) {
  832. THashMap<TIBAddr, TPeerConnection, TIBAddrHash>& connections = *res;
  833. TIBMemBlock* memBlock[2] = {memBlock0, memBlock1};
  834. // make full peer list
  835. TVector<TIBAddr> peerList;
  836. for (int z = 0; z < plan.Iterations.ysize(); ++z) {
  837. const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
  838. for (int i = 0; i < rr.OutList.ysize(); ++i) {
  839. const TMergeRecord::TTransfer& tr = rr.OutList[i];
  840. peerList.push_back(TIBAddr(tr.DstRank, tr.SL));
  841. }
  842. for (int i = 0; i < rr.InList.ysize(); ++i) {
  843. const TMergeRecord::TInTransfer& tr = rr.InList[i];
  844. peerList.push_back(TIBAddr(tr.SrcRank, tr.SL));
  845. }
  846. }
  847. Sort(peerList.begin(), peerList.end());
  848. peerList.erase(Unique(peerList.begin(), peerList.end()), peerList.end());
  849. // establish QPs and exchange mem block handlers
  850. for (int z = 0; z < peerList.ysize(); ++z) {
  851. const TIBAddr& ibAddr = peerList[z];
  852. int dstRank = ibAddr.LID;
  853. TPeerConnection& dst = connections[ibAddr];
  854. dst.QP = new TRCQueuePair(Port->GetCtx(), cq, srq, TAllDataSync::WR_COUNT);
  855. TBlockInform myBlock;
  856. for (int k = 0; k < 2; ++k) {
  857. myBlock.RemoteBlocks[k].Addr = memBlock[k]->GetAddr();
  858. myBlock.RemoteBlocks[k].Key = memBlock[k]->GetMemRegion()->GetRKey();
  859. }
  860. myBlock.PSN = dst.QP->GetPSN();
  861. myBlock.QPN = dst.QP->GetQPN();
  862. PostSend(Peers[dstRank], &myBlock, sizeof(myBlock));
  863. }
  864. for (int z = 0; z < peerList.ysize(); ++z) {
  865. const TIBAddr& ibAddr = peerList[z];
  866. int dstRank = ibAddr.LID;
  867. int sl = COL_DATA_SERVICE_LEVEL + ClampVal(ibAddr.SL, 0, COL_DATA_SERVICE_LEVEL_COUNT);
  868. TPeerConnection& dst = connections[ibAddr];
  869. ui64 wr_id = WaitForMsg(Peers[dstRank]->GetQPN());
  870. TIBRecvPacketProcess pkt(*BP, wr_id);
  871. const TBlockInform& info = *(TBlockInform*)pkt.GetData();
  872. ibv_ah_attr peerAddr;
  873. MakeAH(&peerAddr, Port, Hosts[dstRank], COL_DATA_SERVICE_LEVEL + sl);
  874. dst.QP->Init(peerAddr, info.QPN, info.PSN);
  875. dst.RemoteBlocks[0] = info.RemoteBlocks[0];
  876. dst.RemoteBlocks[1] = info.RemoteBlocks[1];
  877. }
  878. Fence();
  879. }
  880. IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) override {
  881. const TMergePlan& plan = MergePlan;
  882. Y_ABORT_UNLESS(szPerRank.ysize() == ColSize, "Invalid size array");
  883. size_t totalSize = 0;
  884. for (int i = 0; i < szPerRank.ysize(); ++i) {
  885. totalSize += szPerRank[i];
  886. }
  887. size_t bufSize = 4096;
  888. while (totalSize >= bufSize) {
  889. bufSize *= 2;
  890. }
  891. TAllGather* res = new TAllGather(ColSize, bufSize, MemPool);
  892. TAllDataSync& ds = res->GetDataSync();
  893. THashMap<TIBAddr, TPeerConnection, TIBAddrHash> connections;
  894. CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections);
  895. // build plan
  896. for (int z = 0; z < plan.Iterations.ysize(); ++z) {
  897. const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
  898. if (rr.OutList.empty() && rr.InList.empty()) {
  899. continue;
  900. }
  901. TAllDataSync::TIteration& iter = ds.Iterations.emplace_back();
  902. for (int i = 0; i < rr.OutList.ysize(); ++i) {
  903. const TMergeRecord::TTransfer& tr = rr.OutList[i];
  904. TAllDataSync::TSend& snd = iter.OutList.emplace_back();
  905. TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)];
  906. snd.ImmData = tr.Id;
  907. snd.Gather.RangeBeg = tr.RangeBeg;
  908. snd.Gather.RangeFin = tr.RangeFin;
  909. snd.QP = pc.QP;
  910. snd.RemoteBlocks[0] = pc.RemoteBlocks[0];
  911. snd.RemoteBlocks[1] = pc.RemoteBlocks[1];
  912. snd.DstRank = tr.DstRank;
  913. }
  914. for (int i = 0; i < rr.InList.ysize(); ++i) {
  915. const TMergeRecord::TInTransfer& tr = rr.InList[i];
  916. TAllDataSync::TRecv& rcv = iter.InList.emplace_back();
  917. TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)];
  918. rcv.QP = pc.QP;
  919. rcv.SrcRank = tr.SrcRank;
  920. }
  921. iter.RecvMask = rr.RecvMask;
  922. }
  923. bool rv = res->Resize(szPerRank);
  924. Y_ABORT_UNLESS(rv, "oops");
  925. return res;
  926. }
  927. IAllGather* CreateAllGather(size_t szPerRank) override {
  928. TVector<size_t> arr;
  929. arr.resize(ColSize, szPerRank);
  930. return CreateAllGather(arr);
  931. }
  932. IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) override {
  933. const TMergePlan& plan = ReducePlan;
  934. size_t bufSizeMult = plan.MaxRankReceiveCount + 1;
  935. size_t bufSize = 4096;
  936. {
  937. size_t sz = (dataSize + 64) * bufSizeMult;
  938. while (sz > bufSize) {
  939. bufSize *= 2;
  940. }
  941. }
  942. TAllReduce* res = new TAllReduce(bufSize, MemPool, reduceOp);
  943. TAllDataSync& ds = res->GetDataSync();
  944. THashMap<TIBAddr, TPeerConnection, TIBAddrHash> connections;
  945. CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections);
  946. // build plan
  947. int currentDataOffset = 0;
  948. for (int z = 0; z < plan.Iterations.ysize(); ++z) {
  949. const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
  950. if (rr.OutList.empty() && rr.InList.empty()) {
  951. continue;
  952. }
  953. TAllDataSync::TIteration& iter = ds.Iterations.emplace_back();
  954. for (int i = 0; i < rr.OutList.ysize(); ++i) {
  955. const TMergeRecord::TTransfer& tr = rr.OutList[i];
  956. TAllDataSync::TSend& snd = iter.OutList.emplace_back();
  957. TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)];
  958. snd.ImmData = tr.Id;
  959. snd.Reduce.SrcIndex = currentDataOffset;
  960. snd.Reduce.DstIndex = tr.Id + 1;
  961. snd.QP = pc.QP;
  962. snd.RemoteBlocks[0] = pc.RemoteBlocks[0];
  963. snd.RemoteBlocks[1] = pc.RemoteBlocks[1];
  964. snd.DstRank = tr.DstRank;
  965. }
  966. for (int i = 0; i < rr.InList.ysize(); ++i) {
  967. const TMergeRecord::TInTransfer& tr = rr.InList[i];
  968. TAllDataSync::TRecv& rcv = iter.InList.emplace_back();
  969. TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)];
  970. rcv.QP = pc.QP;
  971. rcv.SrcRank = tr.SrcRank;
  972. }
  973. iter.RecvMask = rr.RecvMask;
  974. TVector<int> inputOffset;
  975. inputOffset.push_back(currentDataOffset);
  976. int newDataOffset = currentDataOffset;
  977. for (int i = 0; i < 64; ++i) {
  978. if (rr.RecvMask & (1ull << i)) {
  979. int offset = i + 1;
  980. inputOffset.push_back(offset);
  981. newDataOffset = Max(offset, newDataOffset);
  982. }
  983. }
  984. for (int i = 0; i < inputOffset.ysize(); ++i) {
  985. if (inputOffset[i] == newDataOffset) {
  986. continue;
  987. }
  988. TAllDataSync::TReduce& red = iter.ReduceList.emplace_back();
  989. red.SrcIndex = inputOffset[i];
  990. red.DstIndex = newDataOffset;
  991. }
  992. currentDataOffset = newDataOffset;
  993. }
  994. res->BufSizeMult = bufSizeMult;
  995. res->ReadyOffsetMult = currentDataOffset;
  996. bool rv = res->Resize(dataSize);
  997. Y_ABORT_UNLESS(rv, "oops");
  998. return res;
  999. }
  1000. void Fence() override {
  1001. const TMergePlan& plan = ReducePlan;
  1002. for (int z = 0; z < plan.Iterations.ysize(); ++z) {
  1003. const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
  1004. for (int i = 0; i < rr.OutList.ysize(); ++i) {
  1005. const TMergeRecord::TTransfer& tr = rr.OutList[i];
  1006. char c;
  1007. PostSend(Peers[tr.DstRank], &c, sizeof(c));
  1008. }
  1009. for (int i = 0; i < rr.InList.ysize(); ++i) {
  1010. const TMergeRecord::TInTransfer& tr = rr.InList[i];
  1011. ui64 wr_id = WaitForMsg(Peers[tr.SrcRank]->GetQPN());
  1012. TIBRecvPacketProcess pkt(*BP, wr_id);
  1013. }
  1014. }
  1015. }
  1016. void RunBWTest(int groupType, int delta, int* targetRank, float* res) override {
  1017. const int BUF_SIZE = 8 * 1024 * 1024;
  1018. TIntrusivePtr<TIBMemBlock> sendMem, recvMem;
  1019. sendMem = MemPool->Alloc(BUF_SIZE);
  1020. recvMem = MemPool->Alloc(BUF_SIZE);
  1021. int myGroup = HostGroup[groupType][ColRank];
  1022. int myGroupPos = 0;
  1023. TVector<int> gg;
  1024. Y_ASSERT(HostGroup[groupType].ysize() == ColSize);
  1025. for (int rank = 0; rank < ColSize; ++rank) {
  1026. if (HostGroup[groupType][rank] == myGroup) {
  1027. if (rank == ColRank) {
  1028. myGroupPos = gg.ysize();
  1029. }
  1030. gg.push_back(rank);
  1031. }
  1032. }
  1033. if (delta >= gg.ysize()) {
  1034. *targetRank = -1;
  1035. *res = 0;
  1036. return;
  1037. }
  1038. int sendRank = gg[(myGroupPos + delta) % gg.ysize()];
  1039. int recvRank = gg[(myGroupPos + gg.ysize() - delta) % gg.ysize()];
  1040. *targetRank = sendRank;
  1041. TIntrusivePtr<TRCQueuePair> sendRC = Peers[sendRank];
  1042. TIntrusivePtr<TRCQueuePair> recvRC = Peers[recvRank];
  1043. {
  1044. TBWTest bw;
  1045. bw.Addr = recvMem->GetAddr();
  1046. bw.RKey = recvMem->GetMemRegion()->GetRKey();
  1047. PostSend(recvRC, &bw, sizeof(bw));
  1048. }
  1049. TBWTest dstMem;
  1050. {
  1051. ui64 wr_id = WaitForMsg(sendRC->GetQPN());
  1052. TIBRecvPacketProcess pkt(*BP, wr_id);
  1053. dstMem = *(TBWTest*)pkt.GetData();
  1054. }
  1055. // run
  1056. TVector<double> score;
  1057. for (int iter = 0; iter < 5; ++iter) {
  1058. while (!AllocRDMAWriteSlot(sendRC)) {
  1059. WaitCompletion();
  1060. Y_ASSERT(0 && "measurements are imprecise");
  1061. }
  1062. NHPTimer::STime t;
  1063. NHPTimer::GetTime(&t);
  1064. sendRC->PostRDMAWrite(dstMem.Addr, dstMem.RKey, sendMem->GetMemRegion(), 0, sendMem->GetData(), BUF_SIZE);
  1065. for (;;) {
  1066. ibv_wc wc;
  1067. WaitCompletion(&wc);
  1068. if (wc.opcode == IBV_WC_RDMA_WRITE) {
  1069. if (wc.qp_num != (ui32)sendRC->GetQPN()) {
  1070. abort();
  1071. }
  1072. break;
  1073. }
  1074. }
  1075. double tPassed = NHPTimer::GetTimePassed(&t);
  1076. double speed = BUF_SIZE / tPassed / 1000000000.0; // G/sec
  1077. score.push_back(speed);
  1078. }
  1079. Sort(score.begin(), score.end());
  1080. // signal completion & wait for signal
  1081. *res = score[score.size() / 2];
  1082. {
  1083. char bb;
  1084. PostSend(sendRC, &bb, sizeof(bb));
  1085. ui64 wr_id = WaitForMsg(recvRC->GetQPN());
  1086. TIBRecvPacketProcess pkt(*BP, wr_id);
  1087. }
  1088. }
  1089. bool TrySendMicro(int dstRank, const void* data, int dataSize) override {
  1090. return TryPostSend(Peers[dstRank], data, dataSize);
  1091. }
  1092. void InitPeerTable(TIBMicroPeerTable* res) override {
  1093. res->Init(QPNTableSizeLog);
  1094. }
  1095. void RdmaWrite(const TVector<TRdmaRequest>& reqs) override {
  1096. TVector<TVector<int>> reqPerRank;
  1097. reqPerRank.resize(ColSize);
  1098. int reqCount = reqs.ysize();
  1099. for (int i = 0; i < reqCount; ++i) {
  1100. reqPerRank[reqs[i].DstRank].push_back(i);
  1101. }
  1102. int inFlight = 0; // IB congestion control sucks :/ so we limit number of simultaneous rdmas
  1103. int startRank = ColRank;
  1104. while (reqCount > 0) {
  1105. if (inFlight < MAX_TOTAL_RDMA) {
  1106. for (int z = 0; z < ColSize; ++z) {
  1107. int dstRank = (startRank + 1 + z) % ColSize;
  1108. if (reqPerRank[dstRank].empty()) {
  1109. continue;
  1110. }
  1111. Y_ASSERT(dstRank != ColRank && "sending self is meaningless");
  1112. TRCQueuePair* qp = Peers[dstRank].Get();
  1113. if (AllocRDMAWriteSlot(qp)) {
  1114. const TRdmaRequest& rr = reqs[reqPerRank[dstRank].back()];
  1115. qp->PostRDMAWrite(rr.RemoteAddr, rr.RemoteKey, rr.LocalAddr, rr.LocalKey, 0, rr.Size);
  1116. reqPerRank[dstRank].pop_back();
  1117. if (++inFlight >= MAX_TOTAL_RDMA) {
  1118. startRank = dstRank;
  1119. break;
  1120. }
  1121. }
  1122. }
  1123. }
  1124. {
  1125. ibv_wc wc;
  1126. WaitCompletion(&wc);
  1127. if (wc.opcode == IBV_WC_RDMA_WRITE) {
  1128. --inFlight;
  1129. --reqCount;
  1130. }
  1131. }
  1132. }
  1133. }
  1134. public:
  1135. TIBCollective(TPtrArg<TIBPort> port, TPtrArg<TIBMemPool> memPool,
  1136. const TCollectiveInit& params,
  1137. TCollectiveLinkSet* resLinks)
  1138. : Port(port)
  1139. , MemPool(memPool)
  1140. , QPNTableSizeLog(0)
  1141. {
  1142. ColSize = params.Size;
  1143. ColRank = params.Rank;
  1144. int maxOutstandingQueries = MAX_REQS_PER_PEER * ColSize + 10;
  1145. CQ = new TComplectionQueue(Port->GetCtx(), maxOutstandingQueries * 2);
  1146. BP = new TIBBufferPool(Port->GetCtx(), maxOutstandingQueries);
  1147. Peers.resize(ColSize);
  1148. resLinks->Links.resize(ColSize);
  1149. TVector<int> qpnArr;
  1150. for (int k = 0; k < ColSize; ++k) {
  1151. if (k == ColRank) {
  1152. continue;
  1153. }
  1154. TRCQueuePair* rc = new TRCQueuePair(Port->GetCtx(), CQ, BP->GetSRQ(), MAX_REQS_PER_PEER);
  1155. Peers[k] = rc;
  1156. TCollectiveLinkSet::TLinkInfo& lnk = resLinks->Links[k];
  1157. lnk.PSN = rc->GetPSN();
  1158. lnk.QPN = rc->GetQPN();
  1159. qpnArr.push_back(lnk.QPN);
  1160. }
  1161. resLinks->Hosts.resize(ColSize);
  1162. resLinks->Hosts[ColRank] = Port->GetLID();
  1163. static_assert(MAX_REQS_PER_PEER < 256, "expect MAX_REQS_PER_PEER < 256"); // sent count will fit into SendCountTable[]
  1164. Zero(SendCountTable);
  1165. Zero(RDMACountTable);
  1166. if (!qpnArr.empty()) {
  1167. for (;;) {
  1168. TVector<ui8> qpnTable;
  1169. int qpnTableSize = 1 << QPNTableSizeLog;
  1170. qpnTable.resize(qpnTableSize, 0);
  1171. bool ok = true;
  1172. for (int i = 0; i < qpnArr.ysize(); ++i) {
  1173. int idx = qpnArr[i] & (qpnTableSize - 1);
  1174. if (++qpnTable[idx] == 2) {
  1175. ok = false;
  1176. break;
  1177. }
  1178. }
  1179. if (ok) {
  1180. break;
  1181. }
  1182. ++QPNTableSizeLog;
  1183. }
  1184. //printf("QPN table, size_log %d\n", QPNTableSizeLog);
  1185. }
  1186. }
  1187. friend class TIBRecvMicro;
  1188. };
  1189. TIBRecvMicro::TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable)
  1190. : IB(*(TIBCollective*)col)
  1191. {
  1192. Y_ASSERT(typeid(IB) == typeid(TIBCollective));
  1193. if (IB.GetMsg(&Id, &QPN, peerTable)) {
  1194. Data = IB.BP->GetBufData(Id);
  1195. } else {
  1196. Data = nullptr;
  1197. }
  1198. }
  1199. TIBRecvMicro::~TIBRecvMicro() {
  1200. if (Data) {
  1201. IB.BP->FreeBuf(Id);
  1202. IB.BP->PostRecv();
  1203. }
  1204. }
  1205. IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks) {
  1206. return new TIBCollective(GetIBDevice(), GetIBMemPool(), params, resLinks);
  1207. }
  1208. }