123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317 |
- #include "stdafx.h"
- #include "ib_collective.h"
- #include "ib_mem.h"
- #include "ib_buffers.h"
- #include "ib_low.h"
- #include "udp_http.h"
- #include "udp_address.h"
- #include <util/generic/deque.h>
- #include <util/system/hp_timer.h>
- namespace NNetliba {
- const int COL_SERVICE_LEVEL = 2;
- const int COL_DATA_SERVICE_LEVEL = 2; // base level
- const int COL_DATA_SERVICE_LEVEL_COUNT = 6; // level count
- const int MAX_REQS_PER_PEER = 32;
- const int MAX_TOTAL_RDMA = 20;
- const int SEND_COUNT_TABLE_SIZE = 1 << 12; // must be power of 2
- struct TMergeRecord {
- struct TTransfer {
- int DstRank;
- int SL;
- int RangeBeg, RangeFin;
- int Id;
- TTransfer()
- : DstRank(-1)
- , SL(0)
- , RangeBeg(0)
- , RangeFin(0)
- , Id(0)
- {
- }
- TTransfer(int dstRank, int sl, int rangeBeg, int rangeFin, int id)
- : DstRank(dstRank)
- , SL(sl)
- , RangeBeg(rangeBeg)
- , RangeFin(rangeFin)
- , Id(id)
- {
- }
- };
- struct TInTransfer {
- int SrcRank;
- int SL;
- TInTransfer()
- : SrcRank(-1)
- , SL(0)
- {
- }
- TInTransfer(int srcRank, int sl)
- : SrcRank(srcRank)
- , SL(sl)
- {
- }
- };
- TVector<TTransfer> OutList;
- TVector<TInTransfer> InList;
- ui64 RecvMask;
- TMergeRecord()
- : RecvMask(0)
- {
- }
- };
- struct TMergeIteration {
- TVector<TMergeRecord> Ops;
- void Init(int colSize) {
- Ops.resize(colSize);
- }
- void Transfer(int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin, int id) {
- Y_ABORT_UNLESS(id < 64, "recv mask overflow");
- Ops[srcRank].OutList.push_back(TMergeRecord::TTransfer(dstRank, sl, rangeBeg, rangeFin, id));
- Ops[dstRank].InList.push_back(TMergeRecord::TInTransfer(srcRank, sl));
- Ops[dstRank].RecvMask |= ui64(1) << id;
- }
- };
- struct TMergePlan {
- TVector<TMergeIteration> Iterations;
- TVector<int> RankReceiveCount;
- int ColSize;
- int MaxRankReceiveCount;
- TMergePlan()
- : ColSize(0)
- , MaxRankReceiveCount(0)
- {
- }
- void Init(int colSize) {
- Iterations.resize(0);
- RankReceiveCount.resize(0);
- RankReceiveCount.resize(colSize, 0);
- ColSize = colSize;
- }
- void Transfer(int iter, int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin) {
- while (iter >= Iterations.ysize()) {
- TMergeIteration& res = Iterations.emplace_back();
- res.Init(ColSize);
- }
- int id = RankReceiveCount[dstRank]++;
- MaxRankReceiveCount = Max(MaxRankReceiveCount, id + 1);
- Y_ASSERT(id < 64);
- Iterations[iter].Transfer(srcRank, dstRank, sl, rangeBeg, rangeFin, id);
- }
- };
- struct TSRTransfer {
- int SrcRank, DstRank;
- int RangeBeg, RangeFin;
- TSRTransfer() {
- Zero(*this);
- }
- TSRTransfer(int srcRank, int dstRank, int rangeBeg, int rangeFin)
- : SrcRank(srcRank)
- , DstRank(dstRank)
- , RangeBeg(rangeBeg)
- , RangeFin(rangeFin)
- {
- }
- };
- static int SplitRange(THashMap<int, TVector<TSRTransfer>>* res, int iter, int beg, int fin) {
- int mid = (beg + fin + 1) / 2;
- if (mid == fin) {
- return iter;
- }
- for (int i = 0; i < fin - mid; ++i) {
- (*res)[iter].push_back(TSRTransfer(beg + i, mid + i, beg, mid));
- (*res)[iter].push_back(TSRTransfer(mid + i, beg + i, mid, fin));
- }
- if (fin - mid < mid - beg) {
- // [mid - 1] did not receive [mid;fin)
- (*res)[iter].push_back(TSRTransfer(mid, mid - 1, mid, fin));
- }
- int rv1 = SplitRange(res, iter + 1, beg, mid);
- int rv2 = SplitRange(res, iter + 1, mid, fin);
- return Max(rv1, rv2);
- }
- static void CreatePow2Merge(TMergePlan* plan, int colSize) {
- // finally everybody has full range [0;ColSize)
- // construct plan recursively, on each iteration split some range
- plan->Init(colSize);
- THashMap<int, TVector<TSRTransfer>> allTransfers;
- int maxIter = SplitRange(&allTransfers, 0, 0, colSize);
- for (int iter = 0; iter < maxIter; ++iter) {
- const TVector<TSRTransfer>& arr = allTransfers[maxIter - iter - 1]; // reverse order
- for (int i = 0; i < arr.ysize(); ++i) {
- const TSRTransfer& sr = arr[i];
- plan->Transfer(iter, sr.SrcRank, sr.DstRank, 0, sr.RangeBeg, sr.RangeFin);
- }
- }
- }
- struct TCoverInterval {
- int Beg, Fin; // [Beg;Fin)
- TCoverInterval()
- : Beg(0)
- , Fin(0)
- {
- }
- TCoverInterval(int b, int f)
- : Beg(b)
- , Fin(f)
- {
- }
- };
- enum EAllToAllMode {
- AA_POW2,
- AA_CIRCLE,
- AA_STAR,
- AA_POW2_MERGE,
- };
- static int AllToAll(TMergePlan* plan, int iter, int sl, EAllToAllMode mode, const TVector<int>& myGroup, TVector<TCoverInterval>* cover) {
- TVector<TCoverInterval>& hostCoverage = *cover;
- int groupSize = myGroup.ysize();
- for (int k = 1; k < groupSize; ++k) {
- int h1 = myGroup[k - 1];
- int h2 = myGroup[k];
- Y_ABORT_UNLESS(hostCoverage[h1].Fin == hostCoverage[h2].Beg, "Invalid host order in CreateGroupMerge()");
- }
- switch (mode) {
- case AA_POW2: {
- for (int delta = 1; delta < groupSize; delta *= 2) {
- int sz = Min(delta, groupSize - delta);
- for (int offset = 0; offset < groupSize; ++offset) {
- int srcRank = myGroup[offset];
- int dstRank = myGroup[(offset + delta) % groupSize];
- int start = offset + 1 - sz;
- int finish = offset + 1;
- if (start < 0) {
- // [start; myGroup.size())
- int dataBeg = hostCoverage[myGroup[start + groupSize]].Beg;
- int dataFin = hostCoverage[myGroup.back()].Fin;
- plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
- // [0; finish)
- dataBeg = hostCoverage[myGroup[0]].Beg;
- dataFin = hostCoverage[myGroup[finish - 1]].Fin;
- plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
- } else {
- // [start;finish)
- int dataBeg = hostCoverage[myGroup[start]].Beg;
- int dataFin = hostCoverage[myGroup[finish - 1]].Fin;
- plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
- }
- }
- ++iter;
- }
- } break;
- case AA_CIRCLE: {
- for (int dataDelta = 1; dataDelta < groupSize; ++dataDelta) {
- for (int offset = 0; offset < groupSize; ++offset) {
- int srcRank = myGroup[offset];
- int dstRank = myGroup[(offset + 1) % groupSize];
- int dataRank = myGroup[(offset + 1 - dataDelta + groupSize) % groupSize];
- int dataBeg = hostCoverage[dataRank].Beg;
- int dataFin = hostCoverage[dataRank].Fin;
- plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
- }
- ++iter;
- }
- } break;
- case AA_STAR: {
- for (int offset = 0; offset < groupSize; ++offset) {
- for (int delta = 1; delta < groupSize; ++delta) {
- int srcRank = myGroup[offset];
- int dstRank = myGroup[(offset + delta) % groupSize];
- int dataRank = myGroup[offset];
- int dataBeg = hostCoverage[dataRank].Beg;
- int dataFin = hostCoverage[dataRank].Fin;
- plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
- }
- }
- ++iter;
- } break;
- case AA_POW2_MERGE: {
- TMergePlan pp;
- CreatePow2Merge(&pp, groupSize);
- for (int z = 0; z < pp.Iterations.ysize(); ++z) {
- const TMergeIteration& mm = pp.Iterations[z];
- for (int src = 0; src < mm.Ops.ysize(); ++src) {
- const TMergeRecord& mr = mm.Ops[src];
- int srcRank = myGroup[src];
- for (int i = 0; i < mr.OutList.ysize(); ++i) {
- int dstRank = myGroup[mr.OutList[i].DstRank];
- plan->Transfer(iter, srcRank, dstRank, sl, 0, 1);
- }
- }
- ++iter;
- }
- } break;
- default:
- Y_ASSERT(0);
- break;
- }
- {
- TCoverInterval cc(hostCoverage[myGroup[0]].Beg, hostCoverage[myGroup.back()].Fin);
- for (int k = 0; k < groupSize; ++k) {
- hostCoverage[myGroup[k]] = cc;
- }
- }
- return iter;
- }
- // fully populated matrix
- static void CreateGroupMerge(TMergePlan* plan, EAllToAllMode mode, const TVector<TVector<int>>& hostGroup) {
- int hostCount = hostGroup[0].ysize();
- int groupTypeCount = hostGroup.ysize();
- plan->Init(hostCount);
- TVector<int> gcount;
- gcount.resize(groupTypeCount, 0);
- for (int hostId = 0; hostId < hostCount; ++hostId) {
- for (int groupType = 0; groupType < groupTypeCount; ++groupType) {
- int val = hostGroup[groupType][hostId];
- gcount[groupType] = Max(gcount[groupType], val + 1);
- }
- }
- for (int hostId = 1; hostId < hostCount; ++hostId) {
- bool isIncrement = true;
- for (int groupType = 0; groupType < groupTypeCount; ++groupType) {
- int prev = hostGroup[groupType][hostId - 1];
- int cur = hostGroup[groupType][hostId];
- if (isIncrement) {
- if (cur == prev + 1) {
- isIncrement = false;
- } else {
- Y_ABORT_UNLESS(cur == 0, "ib_hosts, wrapped to non-zero");
- Y_ABORT_UNLESS(prev == gcount[groupType] - 1, "ib_hosts, structure is irregular");
- isIncrement = true;
- }
- } else {
- Y_ABORT_UNLESS(prev == cur, "ib_hosts, structure is irregular");
- }
- }
- }
- TVector<TCoverInterval> hostCoverage;
- for (int i = 0; i < hostCount; ++i) {
- hostCoverage.push_back(TCoverInterval(i, i + 1));
- }
- int baseIter = 0;
- for (int groupType = hostGroup.ysize() - 1; groupType >= 0; --groupType) {
- Y_ASSERT(hostGroup[groupType].ysize() == hostCount);
- TVector<TVector<int>> hh;
- hh.resize(gcount[groupType]);
- for (int rank = 0; rank < hostGroup[groupType].ysize(); ++rank) {
- int groupId = hostGroup[groupType][rank];
- hh[groupId].push_back(rank);
- }
- int newIter = 0;
- for (int groupId = 0; groupId < hh.ysize(); ++groupId) {
- int nn = AllToAll(plan, baseIter, 0, mode, hh[groupId], &hostCoverage); // seems to be fastest
- if (newIter == 0) {
- newIter = nn;
- } else {
- Y_ABORT_UNLESS(newIter == nn, "groups should be of same size");
- }
- }
- baseIter = newIter;
- }
- //printf("%d iterations symmetrical plan\n", baseIter);
- }
- //////////////////////////////////////////////////////////////////////////
- struct TAllDataSync {
- enum {
- WR_COUNT = 64 * 2
- };
- int CurrentBuffer;
- TIntrusivePtr<TIBMemBlock> MemBlock[2];
- TIntrusivePtr<TComplectionQueue> CQ;
- TIntrusivePtr<TSharedReceiveQueue> SRQ;
- TIntrusivePtr<TIBMemBlock> FakeRecvMem;
- size_t DataSize, BufSize;
- size_t CurrentOffset, ReadyOffset;
- bool WasFlushed;
- int ActiveRDMACount;
- ui64 FutureRecvMask;
- TIntrusivePtr<IReduceOp> ReduceOp;
- struct TBlockInfo {
- ui64 Addr;
- ui32 Key;
- };
- struct TSend {
- TBlockInfo RemoteBlocks[2];
- TIntrusivePtr<TRCQueuePair> QP;
- size_t SrcOffset;
- size_t DstOffset;
- size_t Length;
- ui32 ImmData;
- int DstRank;
- union {
- struct {
- int RangeBeg, RangeFin;
- } Gather;
- struct {
- int SrcIndex, DstIndex;
- } Reduce;
- };
- };
- struct TRecv {
- TIntrusivePtr<TRCQueuePair> QP;
- int SrcRank;
- };
- struct TReduce {
- size_t DstOffset, SrcOffset;
- int DstIndex, SrcIndex;
- };
- struct TIteration {
- TVector<TSend> OutList;
- TVector<TRecv> InList;
- TVector<TReduce> ReduceList;
- ui64 RecvMask;
- };
- TVector<TIteration> Iterations;
- public:
- void* GetRawData() {
- char* myData = (char*)MemBlock[CurrentBuffer]->GetData();
- return myData + CurrentOffset;
- }
- size_t GetRawDataSize() {
- return DataSize;
- }
- void PostRecv() {
- SRQ->PostReceive(FakeRecvMem->GetMemRegion(), 0, FakeRecvMem->GetData(), FakeRecvMem->GetSize());
- }
- void Sync() {
- Y_ASSERT(WasFlushed && "Have to call Flush() before data fill & Sync()");
- char* myData = (char*)MemBlock[CurrentBuffer]->GetData();
- ui64 recvMask = FutureRecvMask;
- FutureRecvMask = 0;
- int recvDebt = 0;
- for (int z = 0; z < Iterations.ysize(); ++z) {
- const TIteration& iter = Iterations[z];
- for (int k = 0; k < iter.OutList.ysize(); ++k) {
- const TSend& ss = iter.OutList[k];
- const TBlockInfo& remoteBlk = ss.RemoteBlocks[CurrentBuffer];
- ss.QP->PostRDMAWriteImm(remoteBlk.Addr + ss.DstOffset, remoteBlk.Key, ss.ImmData,
- MemBlock[CurrentBuffer]->GetMemRegion(), 0, myData + ss.SrcOffset, ss.Length);
- ++ActiveRDMACount;
- //printf("-> %d, imm %d (%" PRId64 " bytes)\n", ss.DstRank, ss.ImmData, ss.Length);
- //printf("send %d\n", ss.SrcOffset);
- }
- ibv_wc wc;
- while ((recvMask & iter.RecvMask) != iter.RecvMask) {
- int rv = CQ->Poll(&wc, 1);
- if (rv > 0) {
- Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "AllGather::Sync fail, status %d", (int)wc.status);
- if (wc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
- //printf("Got %d\n", wc.imm_data);
- ++recvDebt;
- ui64 newBit = ui64(1) << wc.imm_data;
- if (recvMask & newBit) {
- Y_ABORT_UNLESS((FutureRecvMask & newBit) == 0, "data from 2 Sync() ahead is impossible");
- FutureRecvMask |= newBit;
- } else {
- recvMask |= newBit;
- }
- } else if (wc.opcode == IBV_WC_RDMA_WRITE) {
- --ActiveRDMACount;
- } else {
- Y_ASSERT(0);
- }
- } else {
- if (recvDebt > 0) {
- PostRecv();
- --recvDebt;
- }
- }
- }
- for (int k = 0; k < iter.ReduceList.ysize(); ++k) {
- const TReduce& rr = iter.ReduceList[k];
- ReduceOp->Reduce(myData + rr.DstOffset, myData + rr.SrcOffset, DataSize);
- //printf("Merge %d -> %d (%d bytes)\n", rr.SrcOffset, rr.DstOffset, DataSize);
- }
- //printf("Iteration %d done\n", z);
- }
- while (recvDebt > 0) {
- PostRecv();
- --recvDebt;
- }
- CurrentOffset = ReadyOffset;
- WasFlushed = false;
- //printf("new cur offset %g\n", (double)CurrentOffset);
- //printf("Sync complete\n");
- }
- void Flush() {
- Y_ASSERT(!WasFlushed);
- CurrentBuffer = 1 - CurrentBuffer;
- CurrentOffset = 0;
- WasFlushed = true;
- }
- public:
- TAllDataSync(size_t bufSize, TPtrArg<TIBMemPool> memPool, TPtrArg<IReduceOp> reduceOp)
- : CurrentBuffer(0)
- , DataSize(0)
- , BufSize(bufSize)
- , CurrentOffset(0)
- , ReadyOffset(0)
- , WasFlushed(false)
- , ActiveRDMACount(0)
- , FutureRecvMask(0)
- , ReduceOp(reduceOp)
- {
- if (memPool) {
- MemBlock[0] = memPool->Alloc(BufSize);
- MemBlock[1] = memPool->Alloc(BufSize);
- CQ = new TComplectionQueue(memPool->GetIBContext(), WR_COUNT);
- SRQ = new TSharedReceiveQueue(memPool->GetIBContext(), WR_COUNT);
- FakeRecvMem = memPool->Alloc(4096);
- } else {
- MemBlock[0] = new TIBMemBlock(BufSize);
- MemBlock[1] = new TIBMemBlock(BufSize);
- CQ = new TComplectionQueue(nullptr, WR_COUNT);
- SRQ = new TSharedReceiveQueue(nullptr, WR_COUNT);
- FakeRecvMem = new TIBMemBlock(4096);
- }
- for (int i = 0; i < WR_COUNT; ++i) {
- PostRecv();
- }
- }
- ~TAllDataSync() {
- while (ActiveRDMACount > 0) {
- ibv_wc wc;
- int rv = CQ->Poll(&wc, 1);
- if (rv > 0) {
- if (wc.opcode == IBV_WC_RDMA_WRITE) {
- --ActiveRDMACount;
- } else {
- Y_ASSERT(0);
- }
- }
- }
- }
- };
- class TAllReduce: public IAllReduce {
- TAllDataSync DataSync;
- size_t BufSizeMult;
- size_t ReadyOffsetMult;
- public:
- TAllReduce(size_t bufSize, TPtrArg<TIBMemPool> memPool, TPtrArg<IReduceOp> reduceOp)
- : DataSync(bufSize, memPool, reduceOp)
- , BufSizeMult(0)
- , ReadyOffsetMult(0)
- {
- }
- TAllDataSync& GetDataSync() {
- return DataSync;
- }
- void* GetRawData() override {
- return DataSync.GetRawData();
- }
- size_t GetRawDataSize() override {
- return DataSync.GetRawDataSize();
- }
- void Sync() override {
- DataSync.Sync();
- }
- void Flush() override {
- DataSync.Flush();
- }
- bool Resize(size_t dataSize) override {
- size_t repSize = (dataSize + 63) & (~63ull);
- size_t bufSize = repSize * BufSizeMult;
- if (bufSize > DataSync.BufSize) {
- return false;
- }
- for (int z = 0; z < DataSync.Iterations.ysize(); ++z) {
- TAllDataSync::TIteration& iter = DataSync.Iterations[z];
- for (int i = 0; i < iter.OutList.ysize(); ++i) {
- TAllDataSync::TSend& snd = iter.OutList[i];
- snd.Length = dataSize;
- snd.SrcOffset = snd.Reduce.SrcIndex * repSize;
- snd.DstOffset = snd.Reduce.DstIndex * repSize;
- }
- for (int i = 0; i < iter.ReduceList.ysize(); ++i) {
- TAllDataSync::TReduce& red = iter.ReduceList[i];
- red.SrcOffset = red.SrcIndex * repSize;
- red.DstOffset = red.DstIndex * repSize;
- }
- }
- DataSync.ReadyOffset = ReadyOffsetMult * repSize;
- DataSync.DataSize = dataSize;
- return true;
- }
- friend class TIBCollective;
- };
- class TAllGather: public IAllGather {
- TAllDataSync DataSync;
- int ColSize;
- public:
- TAllGather(int colSize, size_t bufSize, TPtrArg<TIBMemPool> memPool)
- : DataSync(bufSize, memPool, nullptr)
- , ColSize(colSize)
- {
- }
- TAllDataSync& GetDataSync() {
- return DataSync;
- }
- void* GetRawData() override {
- return DataSync.GetRawData();
- }
- size_t GetRawDataSize() override {
- return DataSync.GetRawDataSize();
- }
- void Sync() override {
- DataSync.Sync();
- }
- void Flush() override {
- DataSync.Flush();
- }
- bool Resize(const TVector<size_t>& szPerRank) override {
- Y_ABORT_UNLESS(szPerRank.ysize() == ColSize, "Invalid size array");
- TVector<size_t> offsets;
- offsets.push_back(0);
- for (int rank = 0; rank < ColSize; ++rank) {
- offsets.push_back(offsets.back() + szPerRank[rank]);
- }
- size_t dataSize = offsets.back();
- if (dataSize > DataSync.BufSize) {
- return false;
- }
- for (int z = 0; z < DataSync.Iterations.ysize(); ++z) {
- TAllDataSync::TIteration& iter = DataSync.Iterations[z];
- for (int i = 0; i < iter.OutList.ysize(); ++i) {
- TAllDataSync::TSend& snd = iter.OutList[i];
- int rangeBeg = snd.Gather.RangeBeg;
- int rangeFin = snd.Gather.RangeFin;
- snd.Length = offsets[rangeFin] - offsets[rangeBeg];
- snd.SrcOffset = offsets[rangeBeg];
- snd.DstOffset = snd.SrcOffset;
- }
- }
- DataSync.DataSize = dataSize;
- return true;
- }
- };
- struct TIBAddr {
- int LID, SL;
- TIBAddr()
- : LID(0)
- , SL(0)
- {
- }
- TIBAddr(int lid, int sl)
- : LID(lid)
- , SL(sl)
- {
- }
- };
- inline bool operator==(const TIBAddr& a, const TIBAddr& b) {
- return a.LID == b.LID && a.SL == b.SL;
- }
- inline bool operator<(const TIBAddr& a, const TIBAddr& b) {
- if (a.LID == b.LID) {
- return a.SL < b.SL;
- }
- return a.LID < b.LID;
- }
- struct TIBAddrHash {
- int operator()(const TIBAddr& a) const {
- return a.LID + a.SL * 4254515;
- }
- };
- class TIBCollective: public IIBCollective {
- struct TPendingMessage {
- int QPN;
- ui64 WorkId;
- TPendingMessage() {
- Zero(*this);
- }
- TPendingMessage(int qpn, ui64 wid)
- : QPN(qpn)
- , WorkId(wid)
- {
- }
- };
- struct TBlockInform {
- TAllDataSync::TBlockInfo RemoteBlocks[2];
- int PSN, QPN;
- };
- struct TPeerConnection {
- TAllDataSync::TBlockInfo RemoteBlocks[2];
- TIntrusivePtr<TRCQueuePair> QP;
- };
- struct TBWTest {
- ui64 Addr;
- ui32 RKey;
- };
- TIntrusivePtr<TIBPort> Port;
- TIntrusivePtr<TIBMemPool> MemPool;
- int ColSize, ColRank;
- TVector<int> Hosts; // host LIDs
- TVector<TVector<int>> HostGroup;
- TVector<TIntrusivePtr<TRCQueuePair>> Peers;
- TIntrusivePtr<TComplectionQueue> CQ;
- TIntrusivePtr<TIBBufferPool> BP;
- ui8 SendCountTable[SEND_COUNT_TABLE_SIZE];
- ui8 RDMACountTable[SEND_COUNT_TABLE_SIZE];
- TDeque<TPendingMessage> Pending;
- TMergePlan MergePlan, ReducePlan;
- int QPNTableSizeLog;
- void WriteCompleted(const ibv_wc& wc) {
- --SendCountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)];
- if (wc.opcode == IBV_WC_RDMA_WRITE) {
- --RDMACountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)];
- }
- BP->FreeBuf(wc.wr_id);
- }
- bool GetMsg(ui64* resWorkId, int* resQPN, TIBMicroPeerTable* tbl) {
- if (tbl->NeedParsePending()) {
- for (TDeque<TPendingMessage>::iterator z = Pending.begin(); z != Pending.end(); ++z) {
- if (!tbl->NeedQPN(z->QPN)) {
- continue;
- }
- *resWorkId = z->WorkId;
- *resQPN = z->QPN;
- Pending.erase(z);
- return true;
- }
- //printf("Stop parse pending\n");
- tbl->StopParsePending();
- }
- for (;;) {
- ibv_wc wc;
- int rv = CQ->Poll(&wc, 1);
- if (rv > 0) {
- Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
- if (wc.opcode & IBV_WC_RECV) {
- BP->RequestPostRecv();
- if (tbl->NeedQPN(wc.qp_num)) {
- *resWorkId = wc.wr_id;
- *resQPN = wc.qp_num;
- return true;
- } else {
- Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
- BP->PostRecv();
- }
- } else {
- WriteCompleted(wc);
- }
- } else {
- return false;
- }
- }
- }
- bool ProcessSendCompletion(const ibv_wc& wc) {
- Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
- if (wc.opcode & IBV_WC_RECV) {
- BP->RequestPostRecv();
- Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
- BP->PostRecv();
- } else {
- WriteCompleted(wc);
- return true;
- }
- return false;
- }
- void WaitCompletion(ibv_wc* res) {
- ibv_wc& wc = *res;
- for (;;) {
- int rv = CQ->Poll(&wc, 1);
- if (rv > 0 && ProcessSendCompletion(wc)) {
- break;
- }
- }
- }
- bool TryWaitCompletion() override {
- ibv_wc wc;
- for (;;) {
- int rv = CQ->Poll(&wc, 1);
- if (rv > 0) {
- if (ProcessSendCompletion(wc)) {
- return true;
- }
- } else {
- return false;
- }
- }
- }
- void WaitCompletion() override {
- ibv_wc wc;
- WaitCompletion(&wc);
- }
- ui64 WaitForMsg(int qpn) {
- for (TDeque<TPendingMessage>::iterator z = Pending.begin(); z != Pending.end(); ++z) {
- if (z->QPN == qpn) {
- ui64 workId = z->WorkId;
- Pending.erase(z);
- return workId;
- }
- }
- ibv_wc wc;
- for (;;) {
- int rv = CQ->Poll(&wc, 1);
- if (rv > 0) {
- Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
- if (wc.opcode & IBV_WC_RECV) {
- BP->RequestPostRecv();
- if ((int)wc.qp_num == qpn) {
- return wc.wr_id;
- } else {
- Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
- BP->PostRecv();
- }
- } else {
- WriteCompleted(wc);
- }
- }
- }
- }
- bool AllocOperationSlot(TPtrArg<TRCQueuePair> qp) {
- int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1);
- if (SendCountTable[way] >= MAX_REQS_PER_PEER) {
- return false;
- }
- ++SendCountTable[way];
- return true;
- }
- bool AllocRDMAWriteSlot(TPtrArg<TRCQueuePair> qp) {
- int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1);
- if (SendCountTable[way] >= MAX_REQS_PER_PEER) {
- return false;
- }
- if (RDMACountTable[way] >= MAX_OUTSTANDING_RDMA) {
- return false;
- }
- ++SendCountTable[way];
- ++RDMACountTable[way];
- return true;
- }
- bool TryPostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) {
- if (AllocOperationSlot(qp)) {
- BP->PostSend(qp, data, len);
- return true;
- }
- return false;
- }
- void PostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) {
- while (!TryPostSend(qp, data, len)) {
- WaitCompletion();
- }
- }
- int GetRank() override {
- return ColRank;
- }
- int GetSize() override {
- return ColSize;
- }
- int GetGroupTypeCount() override {
- return HostGroup.ysize();
- }
- int GetQPN(int rank) override {
- if (rank == ColRank) {
- Y_ASSERT(0 && "there is no qpn connected to localhost");
- return 0;
- }
- return Peers[rank]->GetQPN();
- }
- void Start(const TCollectiveLinkSet& links) override {
- Hosts = links.Hosts;
- HostGroup = links.HostGroup;
- for (int k = 0; k < ColSize; ++k) {
- if (k == ColRank) {
- continue;
- }
- const TCollectiveLinkSet::TLinkInfo& lnk = links.Links[k];
- ibv_ah_attr peerAddr;
- MakeAH(&peerAddr, Port, Hosts[k], COL_SERVICE_LEVEL);
- Peers[k]->Init(peerAddr, lnk.QPN, lnk.PSN);
- }
- //CreatePow2Merge(&MergePlan, ColSize);
- //CreatePow2Merge(&ReducePlan, ColSize);
- CreateGroupMerge(&MergePlan, AA_STAR, HostGroup);
- CreateGroupMerge(&ReducePlan, AA_POW2_MERGE, HostGroup);
- }
- void CreateDataSyncQPs(
- TPtrArg<TComplectionQueue> cq,
- TPtrArg<TSharedReceiveQueue> srq,
- TPtrArg<TIBMemBlock> memBlock0,
- TPtrArg<TIBMemBlock> memBlock1,
- const TMergePlan& plan,
- THashMap<TIBAddr, TPeerConnection, TIBAddrHash>* res) {
- THashMap<TIBAddr, TPeerConnection, TIBAddrHash>& connections = *res;
- TIBMemBlock* memBlock[2] = {memBlock0, memBlock1};
- // make full peer list
- TVector<TIBAddr> peerList;
- for (int z = 0; z < plan.Iterations.ysize(); ++z) {
- const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
- for (int i = 0; i < rr.OutList.ysize(); ++i) {
- const TMergeRecord::TTransfer& tr = rr.OutList[i];
- peerList.push_back(TIBAddr(tr.DstRank, tr.SL));
- }
- for (int i = 0; i < rr.InList.ysize(); ++i) {
- const TMergeRecord::TInTransfer& tr = rr.InList[i];
- peerList.push_back(TIBAddr(tr.SrcRank, tr.SL));
- }
- }
- Sort(peerList.begin(), peerList.end());
- peerList.erase(Unique(peerList.begin(), peerList.end()), peerList.end());
- // establish QPs and exchange mem block handlers
- for (int z = 0; z < peerList.ysize(); ++z) {
- const TIBAddr& ibAddr = peerList[z];
- int dstRank = ibAddr.LID;
- TPeerConnection& dst = connections[ibAddr];
- dst.QP = new TRCQueuePair(Port->GetCtx(), cq, srq, TAllDataSync::WR_COUNT);
- TBlockInform myBlock;
- for (int k = 0; k < 2; ++k) {
- myBlock.RemoteBlocks[k].Addr = memBlock[k]->GetAddr();
- myBlock.RemoteBlocks[k].Key = memBlock[k]->GetMemRegion()->GetRKey();
- }
- myBlock.PSN = dst.QP->GetPSN();
- myBlock.QPN = dst.QP->GetQPN();
- PostSend(Peers[dstRank], &myBlock, sizeof(myBlock));
- }
- for (int z = 0; z < peerList.ysize(); ++z) {
- const TIBAddr& ibAddr = peerList[z];
- int dstRank = ibAddr.LID;
- int sl = COL_DATA_SERVICE_LEVEL + ClampVal(ibAddr.SL, 0, COL_DATA_SERVICE_LEVEL_COUNT);
- TPeerConnection& dst = connections[ibAddr];
- ui64 wr_id = WaitForMsg(Peers[dstRank]->GetQPN());
- TIBRecvPacketProcess pkt(*BP, wr_id);
- const TBlockInform& info = *(TBlockInform*)pkt.GetData();
- ibv_ah_attr peerAddr;
- MakeAH(&peerAddr, Port, Hosts[dstRank], COL_DATA_SERVICE_LEVEL + sl);
- dst.QP->Init(peerAddr, info.QPN, info.PSN);
- dst.RemoteBlocks[0] = info.RemoteBlocks[0];
- dst.RemoteBlocks[1] = info.RemoteBlocks[1];
- }
- Fence();
- }
- IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) override {
- const TMergePlan& plan = MergePlan;
- Y_ABORT_UNLESS(szPerRank.ysize() == ColSize, "Invalid size array");
- size_t totalSize = 0;
- for (int i = 0; i < szPerRank.ysize(); ++i) {
- totalSize += szPerRank[i];
- }
- size_t bufSize = 4096;
- while (totalSize >= bufSize) {
- bufSize *= 2;
- }
- TAllGather* res = new TAllGather(ColSize, bufSize, MemPool);
- TAllDataSync& ds = res->GetDataSync();
- THashMap<TIBAddr, TPeerConnection, TIBAddrHash> connections;
- CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections);
- // build plan
- for (int z = 0; z < plan.Iterations.ysize(); ++z) {
- const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
- if (rr.OutList.empty() && rr.InList.empty()) {
- continue;
- }
- TAllDataSync::TIteration& iter = ds.Iterations.emplace_back();
- for (int i = 0; i < rr.OutList.ysize(); ++i) {
- const TMergeRecord::TTransfer& tr = rr.OutList[i];
- TAllDataSync::TSend& snd = iter.OutList.emplace_back();
- TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)];
- snd.ImmData = tr.Id;
- snd.Gather.RangeBeg = tr.RangeBeg;
- snd.Gather.RangeFin = tr.RangeFin;
- snd.QP = pc.QP;
- snd.RemoteBlocks[0] = pc.RemoteBlocks[0];
- snd.RemoteBlocks[1] = pc.RemoteBlocks[1];
- snd.DstRank = tr.DstRank;
- }
- for (int i = 0; i < rr.InList.ysize(); ++i) {
- const TMergeRecord::TInTransfer& tr = rr.InList[i];
- TAllDataSync::TRecv& rcv = iter.InList.emplace_back();
- TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)];
- rcv.QP = pc.QP;
- rcv.SrcRank = tr.SrcRank;
- }
- iter.RecvMask = rr.RecvMask;
- }
- bool rv = res->Resize(szPerRank);
- Y_ABORT_UNLESS(rv, "oops");
- return res;
- }
- IAllGather* CreateAllGather(size_t szPerRank) override {
- TVector<size_t> arr;
- arr.resize(ColSize, szPerRank);
- return CreateAllGather(arr);
- }
- IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) override {
- const TMergePlan& plan = ReducePlan;
- size_t bufSizeMult = plan.MaxRankReceiveCount + 1;
- size_t bufSize = 4096;
- {
- size_t sz = (dataSize + 64) * bufSizeMult;
- while (sz > bufSize) {
- bufSize *= 2;
- }
- }
- TAllReduce* res = new TAllReduce(bufSize, MemPool, reduceOp);
- TAllDataSync& ds = res->GetDataSync();
- THashMap<TIBAddr, TPeerConnection, TIBAddrHash> connections;
- CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections);
- // build plan
- int currentDataOffset = 0;
- for (int z = 0; z < plan.Iterations.ysize(); ++z) {
- const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
- if (rr.OutList.empty() && rr.InList.empty()) {
- continue;
- }
- TAllDataSync::TIteration& iter = ds.Iterations.emplace_back();
- for (int i = 0; i < rr.OutList.ysize(); ++i) {
- const TMergeRecord::TTransfer& tr = rr.OutList[i];
- TAllDataSync::TSend& snd = iter.OutList.emplace_back();
- TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)];
- snd.ImmData = tr.Id;
- snd.Reduce.SrcIndex = currentDataOffset;
- snd.Reduce.DstIndex = tr.Id + 1;
- snd.QP = pc.QP;
- snd.RemoteBlocks[0] = pc.RemoteBlocks[0];
- snd.RemoteBlocks[1] = pc.RemoteBlocks[1];
- snd.DstRank = tr.DstRank;
- }
- for (int i = 0; i < rr.InList.ysize(); ++i) {
- const TMergeRecord::TInTransfer& tr = rr.InList[i];
- TAllDataSync::TRecv& rcv = iter.InList.emplace_back();
- TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)];
- rcv.QP = pc.QP;
- rcv.SrcRank = tr.SrcRank;
- }
- iter.RecvMask = rr.RecvMask;
- TVector<int> inputOffset;
- inputOffset.push_back(currentDataOffset);
- int newDataOffset = currentDataOffset;
- for (int i = 0; i < 64; ++i) {
- if (rr.RecvMask & (1ull << i)) {
- int offset = i + 1;
- inputOffset.push_back(offset);
- newDataOffset = Max(offset, newDataOffset);
- }
- }
- for (int i = 0; i < inputOffset.ysize(); ++i) {
- if (inputOffset[i] == newDataOffset) {
- continue;
- }
- TAllDataSync::TReduce& red = iter.ReduceList.emplace_back();
- red.SrcIndex = inputOffset[i];
- red.DstIndex = newDataOffset;
- }
- currentDataOffset = newDataOffset;
- }
- res->BufSizeMult = bufSizeMult;
- res->ReadyOffsetMult = currentDataOffset;
- bool rv = res->Resize(dataSize);
- Y_ABORT_UNLESS(rv, "oops");
- return res;
- }
- void Fence() override {
- const TMergePlan& plan = ReducePlan;
- for (int z = 0; z < plan.Iterations.ysize(); ++z) {
- const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
- for (int i = 0; i < rr.OutList.ysize(); ++i) {
- const TMergeRecord::TTransfer& tr = rr.OutList[i];
- char c;
- PostSend(Peers[tr.DstRank], &c, sizeof(c));
- }
- for (int i = 0; i < rr.InList.ysize(); ++i) {
- const TMergeRecord::TInTransfer& tr = rr.InList[i];
- ui64 wr_id = WaitForMsg(Peers[tr.SrcRank]->GetQPN());
- TIBRecvPacketProcess pkt(*BP, wr_id);
- }
- }
- }
- void RunBWTest(int groupType, int delta, int* targetRank, float* res) override {
- const int BUF_SIZE = 8 * 1024 * 1024;
- TIntrusivePtr<TIBMemBlock> sendMem, recvMem;
- sendMem = MemPool->Alloc(BUF_SIZE);
- recvMem = MemPool->Alloc(BUF_SIZE);
- int myGroup = HostGroup[groupType][ColRank];
- int myGroupPos = 0;
- TVector<int> gg;
- Y_ASSERT(HostGroup[groupType].ysize() == ColSize);
- for (int rank = 0; rank < ColSize; ++rank) {
- if (HostGroup[groupType][rank] == myGroup) {
- if (rank == ColRank) {
- myGroupPos = gg.ysize();
- }
- gg.push_back(rank);
- }
- }
- if (delta >= gg.ysize()) {
- *targetRank = -1;
- *res = 0;
- return;
- }
- int sendRank = gg[(myGroupPos + delta) % gg.ysize()];
- int recvRank = gg[(myGroupPos + gg.ysize() - delta) % gg.ysize()];
- *targetRank = sendRank;
- TIntrusivePtr<TRCQueuePair> sendRC = Peers[sendRank];
- TIntrusivePtr<TRCQueuePair> recvRC = Peers[recvRank];
- {
- TBWTest bw;
- bw.Addr = recvMem->GetAddr();
- bw.RKey = recvMem->GetMemRegion()->GetRKey();
- PostSend(recvRC, &bw, sizeof(bw));
- }
- TBWTest dstMem;
- {
- ui64 wr_id = WaitForMsg(sendRC->GetQPN());
- TIBRecvPacketProcess pkt(*BP, wr_id);
- dstMem = *(TBWTest*)pkt.GetData();
- }
- // run
- TVector<double> score;
- for (int iter = 0; iter < 5; ++iter) {
- while (!AllocRDMAWriteSlot(sendRC)) {
- WaitCompletion();
- Y_ASSERT(0 && "measurements are imprecise");
- }
- NHPTimer::STime t;
- NHPTimer::GetTime(&t);
- sendRC->PostRDMAWrite(dstMem.Addr, dstMem.RKey, sendMem->GetMemRegion(), 0, sendMem->GetData(), BUF_SIZE);
- for (;;) {
- ibv_wc wc;
- WaitCompletion(&wc);
- if (wc.opcode == IBV_WC_RDMA_WRITE) {
- if (wc.qp_num != (ui32)sendRC->GetQPN()) {
- abort();
- }
- break;
- }
- }
- double tPassed = NHPTimer::GetTimePassed(&t);
- double speed = BUF_SIZE / tPassed / 1000000000.0; // G/sec
- score.push_back(speed);
- }
- Sort(score.begin(), score.end());
- // signal completion & wait for signal
- *res = score[score.size() / 2];
- {
- char bb;
- PostSend(sendRC, &bb, sizeof(bb));
- ui64 wr_id = WaitForMsg(recvRC->GetQPN());
- TIBRecvPacketProcess pkt(*BP, wr_id);
- }
- }
- bool TrySendMicro(int dstRank, const void* data, int dataSize) override {
- return TryPostSend(Peers[dstRank], data, dataSize);
- }
- void InitPeerTable(TIBMicroPeerTable* res) override {
- res->Init(QPNTableSizeLog);
- }
- void RdmaWrite(const TVector<TRdmaRequest>& reqs) override {
- TVector<TVector<int>> reqPerRank;
- reqPerRank.resize(ColSize);
- int reqCount = reqs.ysize();
- for (int i = 0; i < reqCount; ++i) {
- reqPerRank[reqs[i].DstRank].push_back(i);
- }
- int inFlight = 0; // IB congestion control sucks :/ so we limit number of simultaneous rdmas
- int startRank = ColRank;
- while (reqCount > 0) {
- if (inFlight < MAX_TOTAL_RDMA) {
- for (int z = 0; z < ColSize; ++z) {
- int dstRank = (startRank + 1 + z) % ColSize;
- if (reqPerRank[dstRank].empty()) {
- continue;
- }
- Y_ASSERT(dstRank != ColRank && "sending self is meaningless");
- TRCQueuePair* qp = Peers[dstRank].Get();
- if (AllocRDMAWriteSlot(qp)) {
- const TRdmaRequest& rr = reqs[reqPerRank[dstRank].back()];
- qp->PostRDMAWrite(rr.RemoteAddr, rr.RemoteKey, rr.LocalAddr, rr.LocalKey, 0, rr.Size);
- reqPerRank[dstRank].pop_back();
- if (++inFlight >= MAX_TOTAL_RDMA) {
- startRank = dstRank;
- break;
- }
- }
- }
- }
- {
- ibv_wc wc;
- WaitCompletion(&wc);
- if (wc.opcode == IBV_WC_RDMA_WRITE) {
- --inFlight;
- --reqCount;
- }
- }
- }
- }
- public:
- TIBCollective(TPtrArg<TIBPort> port, TPtrArg<TIBMemPool> memPool,
- const TCollectiveInit& params,
- TCollectiveLinkSet* resLinks)
- : Port(port)
- , MemPool(memPool)
- , QPNTableSizeLog(0)
- {
- ColSize = params.Size;
- ColRank = params.Rank;
- int maxOutstandingQueries = MAX_REQS_PER_PEER * ColSize + 10;
- CQ = new TComplectionQueue(Port->GetCtx(), maxOutstandingQueries * 2);
- BP = new TIBBufferPool(Port->GetCtx(), maxOutstandingQueries);
- Peers.resize(ColSize);
- resLinks->Links.resize(ColSize);
- TVector<int> qpnArr;
- for (int k = 0; k < ColSize; ++k) {
- if (k == ColRank) {
- continue;
- }
- TRCQueuePair* rc = new TRCQueuePair(Port->GetCtx(), CQ, BP->GetSRQ(), MAX_REQS_PER_PEER);
- Peers[k] = rc;
- TCollectiveLinkSet::TLinkInfo& lnk = resLinks->Links[k];
- lnk.PSN = rc->GetPSN();
- lnk.QPN = rc->GetQPN();
- qpnArr.push_back(lnk.QPN);
- }
- resLinks->Hosts.resize(ColSize);
- resLinks->Hosts[ColRank] = Port->GetLID();
- static_assert(MAX_REQS_PER_PEER < 256, "expect MAX_REQS_PER_PEER < 256"); // sent count will fit into SendCountTable[]
- Zero(SendCountTable);
- Zero(RDMACountTable);
- if (!qpnArr.empty()) {
- for (;;) {
- TVector<ui8> qpnTable;
- int qpnTableSize = 1 << QPNTableSizeLog;
- qpnTable.resize(qpnTableSize, 0);
- bool ok = true;
- for (int i = 0; i < qpnArr.ysize(); ++i) {
- int idx = qpnArr[i] & (qpnTableSize - 1);
- if (++qpnTable[idx] == 2) {
- ok = false;
- break;
- }
- }
- if (ok) {
- break;
- }
- ++QPNTableSizeLog;
- }
- //printf("QPN table, size_log %d\n", QPNTableSizeLog);
- }
- }
- friend class TIBRecvMicro;
- };
- TIBRecvMicro::TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable)
- : IB(*(TIBCollective*)col)
- {
- Y_ASSERT(typeid(IB) == typeid(TIBCollective));
- if (IB.GetMsg(&Id, &QPN, peerTable)) {
- Data = IB.BP->GetBufData(Id);
- } else {
- Data = nullptr;
- }
- }
- TIBRecvMicro::~TIBRecvMicro() {
- if (Data) {
- IB.BP->FreeBuf(Id);
- IB.BP->PostRecv();
- }
- }
- IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks) {
- return new TIBCollective(GetIBDevice(), GetIBMemPool(), params, resLinks);
- }
- }
|