#include "stdafx.h" #include "ib_collective.h" #include "ib_mem.h" #include "ib_buffers.h" #include "ib_low.h" #include "udp_http.h" #include "udp_address.h" #include #include namespace NNetliba { const int COL_SERVICE_LEVEL = 2; const int COL_DATA_SERVICE_LEVEL = 2; // base level const int COL_DATA_SERVICE_LEVEL_COUNT = 6; // level count const int MAX_REQS_PER_PEER = 32; const int MAX_TOTAL_RDMA = 20; const int SEND_COUNT_TABLE_SIZE = 1 << 12; // must be power of 2 struct TMergeRecord { struct TTransfer { int DstRank; int SL; int RangeBeg, RangeFin; int Id; TTransfer() : DstRank(-1) , SL(0) , RangeBeg(0) , RangeFin(0) , Id(0) { } TTransfer(int dstRank, int sl, int rangeBeg, int rangeFin, int id) : DstRank(dstRank) , SL(sl) , RangeBeg(rangeBeg) , RangeFin(rangeFin) , Id(id) { } }; struct TInTransfer { int SrcRank; int SL; TInTransfer() : SrcRank(-1) , SL(0) { } TInTransfer(int srcRank, int sl) : SrcRank(srcRank) , SL(sl) { } }; TVector OutList; TVector InList; ui64 RecvMask; TMergeRecord() : RecvMask(0) { } }; struct TMergeIteration { TVector Ops; void Init(int colSize) { Ops.resize(colSize); } void Transfer(int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin, int id) { Y_ABORT_UNLESS(id < 64, "recv mask overflow"); Ops[srcRank].OutList.push_back(TMergeRecord::TTransfer(dstRank, sl, rangeBeg, rangeFin, id)); Ops[dstRank].InList.push_back(TMergeRecord::TInTransfer(srcRank, sl)); Ops[dstRank].RecvMask |= ui64(1) << id; } }; struct TMergePlan { TVector Iterations; TVector RankReceiveCount; int ColSize; int MaxRankReceiveCount; TMergePlan() : ColSize(0) , MaxRankReceiveCount(0) { } void Init(int colSize) { Iterations.resize(0); RankReceiveCount.resize(0); RankReceiveCount.resize(colSize, 0); ColSize = colSize; } void Transfer(int iter, int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin) { while (iter >= Iterations.ysize()) { TMergeIteration& res = Iterations.emplace_back(); res.Init(ColSize); } int id = RankReceiveCount[dstRank]++; MaxRankReceiveCount = Max(MaxRankReceiveCount, id + 1); Y_ASSERT(id < 64); Iterations[iter].Transfer(srcRank, dstRank, sl, rangeBeg, rangeFin, id); } }; struct TSRTransfer { int SrcRank, DstRank; int RangeBeg, RangeFin; TSRTransfer() { Zero(*this); } TSRTransfer(int srcRank, int dstRank, int rangeBeg, int rangeFin) : SrcRank(srcRank) , DstRank(dstRank) , RangeBeg(rangeBeg) , RangeFin(rangeFin) { } }; static int SplitRange(THashMap>* res, int iter, int beg, int fin) { int mid = (beg + fin + 1) / 2; if (mid == fin) { return iter; } for (int i = 0; i < fin - mid; ++i) { (*res)[iter].push_back(TSRTransfer(beg + i, mid + i, beg, mid)); (*res)[iter].push_back(TSRTransfer(mid + i, beg + i, mid, fin)); } if (fin - mid < mid - beg) { // [mid - 1] did not receive [mid;fin) (*res)[iter].push_back(TSRTransfer(mid, mid - 1, mid, fin)); } int rv1 = SplitRange(res, iter + 1, beg, mid); int rv2 = SplitRange(res, iter + 1, mid, fin); return Max(rv1, rv2); } static void CreatePow2Merge(TMergePlan* plan, int colSize) { // finally everybody has full range [0;ColSize) // construct plan recursively, on each iteration split some range plan->Init(colSize); THashMap> allTransfers; int maxIter = SplitRange(&allTransfers, 0, 0, colSize); for (int iter = 0; iter < maxIter; ++iter) { const TVector& arr = allTransfers[maxIter - iter - 1]; // reverse order for (int i = 0; i < arr.ysize(); ++i) { const TSRTransfer& sr = arr[i]; plan->Transfer(iter, sr.SrcRank, sr.DstRank, 0, sr.RangeBeg, sr.RangeFin); } } } struct TCoverInterval { int Beg, Fin; // [Beg;Fin) TCoverInterval() : Beg(0) , Fin(0) { } TCoverInterval(int b, int f) : Beg(b) , Fin(f) { } }; enum EAllToAllMode { AA_POW2, AA_CIRCLE, AA_STAR, AA_POW2_MERGE, }; static int AllToAll(TMergePlan* plan, int iter, int sl, EAllToAllMode mode, const TVector& myGroup, TVector* cover) { TVector& hostCoverage = *cover; int groupSize = myGroup.ysize(); for (int k = 1; k < groupSize; ++k) { int h1 = myGroup[k - 1]; int h2 = myGroup[k]; Y_ABORT_UNLESS(hostCoverage[h1].Fin == hostCoverage[h2].Beg, "Invalid host order in CreateGroupMerge()"); } switch (mode) { case AA_POW2: { for (int delta = 1; delta < groupSize; delta *= 2) { int sz = Min(delta, groupSize - delta); for (int offset = 0; offset < groupSize; ++offset) { int srcRank = myGroup[offset]; int dstRank = myGroup[(offset + delta) % groupSize]; int start = offset + 1 - sz; int finish = offset + 1; if (start < 0) { // [start; myGroup.size()) int dataBeg = hostCoverage[myGroup[start + groupSize]].Beg; int dataFin = hostCoverage[myGroup.back()].Fin; plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); // [0; finish) dataBeg = hostCoverage[myGroup[0]].Beg; dataFin = hostCoverage[myGroup[finish - 1]].Fin; plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); } else { // [start;finish) int dataBeg = hostCoverage[myGroup[start]].Beg; int dataFin = hostCoverage[myGroup[finish - 1]].Fin; plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); } } ++iter; } } break; case AA_CIRCLE: { for (int dataDelta = 1; dataDelta < groupSize; ++dataDelta) { for (int offset = 0; offset < groupSize; ++offset) { int srcRank = myGroup[offset]; int dstRank = myGroup[(offset + 1) % groupSize]; int dataRank = myGroup[(offset + 1 - dataDelta + groupSize) % groupSize]; int dataBeg = hostCoverage[dataRank].Beg; int dataFin = hostCoverage[dataRank].Fin; plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); } ++iter; } } break; case AA_STAR: { for (int offset = 0; offset < groupSize; ++offset) { for (int delta = 1; delta < groupSize; ++delta) { int srcRank = myGroup[offset]; int dstRank = myGroup[(offset + delta) % groupSize]; int dataRank = myGroup[offset]; int dataBeg = hostCoverage[dataRank].Beg; int dataFin = hostCoverage[dataRank].Fin; plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); } } ++iter; } break; case AA_POW2_MERGE: { TMergePlan pp; CreatePow2Merge(&pp, groupSize); for (int z = 0; z < pp.Iterations.ysize(); ++z) { const TMergeIteration& mm = pp.Iterations[z]; for (int src = 0; src < mm.Ops.ysize(); ++src) { const TMergeRecord& mr = mm.Ops[src]; int srcRank = myGroup[src]; for (int i = 0; i < mr.OutList.ysize(); ++i) { int dstRank = myGroup[mr.OutList[i].DstRank]; plan->Transfer(iter, srcRank, dstRank, sl, 0, 1); } } ++iter; } } break; default: Y_ASSERT(0); break; } { TCoverInterval cc(hostCoverage[myGroup[0]].Beg, hostCoverage[myGroup.back()].Fin); for (int k = 0; k < groupSize; ++k) { hostCoverage[myGroup[k]] = cc; } } return iter; } // fully populated matrix static void CreateGroupMerge(TMergePlan* plan, EAllToAllMode mode, const TVector>& hostGroup) { int hostCount = hostGroup[0].ysize(); int groupTypeCount = hostGroup.ysize(); plan->Init(hostCount); TVector gcount; gcount.resize(groupTypeCount, 0); for (int hostId = 0; hostId < hostCount; ++hostId) { for (int groupType = 0; groupType < groupTypeCount; ++groupType) { int val = hostGroup[groupType][hostId]; gcount[groupType] = Max(gcount[groupType], val + 1); } } for (int hostId = 1; hostId < hostCount; ++hostId) { bool isIncrement = true; for (int groupType = 0; groupType < groupTypeCount; ++groupType) { int prev = hostGroup[groupType][hostId - 1]; int cur = hostGroup[groupType][hostId]; if (isIncrement) { if (cur == prev + 1) { isIncrement = false; } else { Y_ABORT_UNLESS(cur == 0, "ib_hosts, wrapped to non-zero"); Y_ABORT_UNLESS(prev == gcount[groupType] - 1, "ib_hosts, structure is irregular"); isIncrement = true; } } else { Y_ABORT_UNLESS(prev == cur, "ib_hosts, structure is irregular"); } } } TVector hostCoverage; for (int i = 0; i < hostCount; ++i) { hostCoverage.push_back(TCoverInterval(i, i + 1)); } int baseIter = 0; for (int groupType = hostGroup.ysize() - 1; groupType >= 0; --groupType) { Y_ASSERT(hostGroup[groupType].ysize() == hostCount); TVector> hh; hh.resize(gcount[groupType]); for (int rank = 0; rank < hostGroup[groupType].ysize(); ++rank) { int groupId = hostGroup[groupType][rank]; hh[groupId].push_back(rank); } int newIter = 0; for (int groupId = 0; groupId < hh.ysize(); ++groupId) { int nn = AllToAll(plan, baseIter, 0, mode, hh[groupId], &hostCoverage); // seems to be fastest if (newIter == 0) { newIter = nn; } else { Y_ABORT_UNLESS(newIter == nn, "groups should be of same size"); } } baseIter = newIter; } //printf("%d iterations symmetrical plan\n", baseIter); } ////////////////////////////////////////////////////////////////////////// struct TAllDataSync { static constexpr int WR_COUNT = 64 * 2; int CurrentBuffer; TIntrusivePtr MemBlock[2]; TIntrusivePtr CQ; TIntrusivePtr SRQ; TIntrusivePtr FakeRecvMem; size_t DataSize, BufSize; size_t CurrentOffset, ReadyOffset; bool WasFlushed; int ActiveRDMACount; ui64 FutureRecvMask; TIntrusivePtr ReduceOp; struct TBlockInfo { ui64 Addr; ui32 Key; }; struct TSend { TBlockInfo RemoteBlocks[2]; TIntrusivePtr QP; size_t SrcOffset; size_t DstOffset; size_t Length; ui32 ImmData; int DstRank; union { struct { int RangeBeg, RangeFin; } Gather; struct { int SrcIndex, DstIndex; } Reduce; }; }; struct TRecv { TIntrusivePtr QP; int SrcRank; }; struct TReduce { size_t DstOffset, SrcOffset; int DstIndex, SrcIndex; }; struct TIteration { TVector OutList; TVector InList; TVector ReduceList; ui64 RecvMask; }; TVector Iterations; public: void* GetRawData() { char* myData = (char*)MemBlock[CurrentBuffer]->GetData(); return myData + CurrentOffset; } size_t GetRawDataSize() { return DataSize; } void PostRecv() { SRQ->PostReceive(FakeRecvMem->GetMemRegion(), 0, FakeRecvMem->GetData(), FakeRecvMem->GetSize()); } void Sync() { Y_ASSERT(WasFlushed && "Have to call Flush() before data fill & Sync()"); char* myData = (char*)MemBlock[CurrentBuffer]->GetData(); ui64 recvMask = FutureRecvMask; FutureRecvMask = 0; int recvDebt = 0; for (int z = 0; z < Iterations.ysize(); ++z) { const TIteration& iter = Iterations[z]; for (int k = 0; k < iter.OutList.ysize(); ++k) { const TSend& ss = iter.OutList[k]; const TBlockInfo& remoteBlk = ss.RemoteBlocks[CurrentBuffer]; ss.QP->PostRDMAWriteImm(remoteBlk.Addr + ss.DstOffset, remoteBlk.Key, ss.ImmData, MemBlock[CurrentBuffer]->GetMemRegion(), 0, myData + ss.SrcOffset, ss.Length); ++ActiveRDMACount; //printf("-> %d, imm %d (%" PRId64 " bytes)\n", ss.DstRank, ss.ImmData, ss.Length); //printf("send %d\n", ss.SrcOffset); } ibv_wc wc; while ((recvMask & iter.RecvMask) != iter.RecvMask) { int rv = CQ->Poll(&wc, 1); if (rv > 0) { Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "AllGather::Sync fail, status %d", (int)wc.status); if (wc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { //printf("Got %d\n", wc.imm_data); ++recvDebt; ui64 newBit = ui64(1) << wc.imm_data; if (recvMask & newBit) { Y_ABORT_UNLESS((FutureRecvMask & newBit) == 0, "data from 2 Sync() ahead is impossible"); FutureRecvMask |= newBit; } else { recvMask |= newBit; } } else if (wc.opcode == IBV_WC_RDMA_WRITE) { --ActiveRDMACount; } else { Y_ASSERT(0); } } else { if (recvDebt > 0) { PostRecv(); --recvDebt; } } } for (int k = 0; k < iter.ReduceList.ysize(); ++k) { const TReduce& rr = iter.ReduceList[k]; ReduceOp->Reduce(myData + rr.DstOffset, myData + rr.SrcOffset, DataSize); //printf("Merge %d -> %d (%d bytes)\n", rr.SrcOffset, rr.DstOffset, DataSize); } //printf("Iteration %d done\n", z); } while (recvDebt > 0) { PostRecv(); --recvDebt; } CurrentOffset = ReadyOffset; WasFlushed = false; //printf("new cur offset %g\n", (double)CurrentOffset); //printf("Sync complete\n"); } void Flush() { Y_ASSERT(!WasFlushed); CurrentBuffer = 1 - CurrentBuffer; CurrentOffset = 0; WasFlushed = true; } public: TAllDataSync(size_t bufSize, TPtrArg memPool, TPtrArg reduceOp) : CurrentBuffer(0) , DataSize(0) , BufSize(bufSize) , CurrentOffset(0) , ReadyOffset(0) , WasFlushed(false) , ActiveRDMACount(0) , FutureRecvMask(0) , ReduceOp(reduceOp) { if (memPool) { MemBlock[0] = memPool->Alloc(BufSize); MemBlock[1] = memPool->Alloc(BufSize); CQ = new TComplectionQueue(memPool->GetIBContext(), WR_COUNT); SRQ = new TSharedReceiveQueue(memPool->GetIBContext(), WR_COUNT); FakeRecvMem = memPool->Alloc(4096); } else { MemBlock[0] = new TIBMemBlock(BufSize); MemBlock[1] = new TIBMemBlock(BufSize); CQ = new TComplectionQueue(nullptr, WR_COUNT); SRQ = new TSharedReceiveQueue(nullptr, WR_COUNT); FakeRecvMem = new TIBMemBlock(4096); } for (int i = 0; i < WR_COUNT; ++i) { PostRecv(); } } ~TAllDataSync() { while (ActiveRDMACount > 0) { ibv_wc wc; int rv = CQ->Poll(&wc, 1); if (rv > 0) { if (wc.opcode == IBV_WC_RDMA_WRITE) { --ActiveRDMACount; } else { Y_ASSERT(0); } } } } }; class TAllReduce: public IAllReduce { TAllDataSync DataSync; size_t BufSizeMult; size_t ReadyOffsetMult; public: TAllReduce(size_t bufSize, TPtrArg memPool, TPtrArg reduceOp) : DataSync(bufSize, memPool, reduceOp) , BufSizeMult(0) , ReadyOffsetMult(0) { } TAllDataSync& GetDataSync() { return DataSync; } void* GetRawData() override { return DataSync.GetRawData(); } size_t GetRawDataSize() override { return DataSync.GetRawDataSize(); } void Sync() override { DataSync.Sync(); } void Flush() override { DataSync.Flush(); } bool Resize(size_t dataSize) override { size_t repSize = (dataSize + 63) & (~63ull); size_t bufSize = repSize * BufSizeMult; if (bufSize > DataSync.BufSize) { return false; } for (int z = 0; z < DataSync.Iterations.ysize(); ++z) { TAllDataSync::TIteration& iter = DataSync.Iterations[z]; for (int i = 0; i < iter.OutList.ysize(); ++i) { TAllDataSync::TSend& snd = iter.OutList[i]; snd.Length = dataSize; snd.SrcOffset = snd.Reduce.SrcIndex * repSize; snd.DstOffset = snd.Reduce.DstIndex * repSize; } for (int i = 0; i < iter.ReduceList.ysize(); ++i) { TAllDataSync::TReduce& red = iter.ReduceList[i]; red.SrcOffset = red.SrcIndex * repSize; red.DstOffset = red.DstIndex * repSize; } } DataSync.ReadyOffset = ReadyOffsetMult * repSize; DataSync.DataSize = dataSize; return true; } friend class TIBCollective; }; class TAllGather: public IAllGather { TAllDataSync DataSync; int ColSize; public: TAllGather(int colSize, size_t bufSize, TPtrArg memPool) : DataSync(bufSize, memPool, nullptr) , ColSize(colSize) { } TAllDataSync& GetDataSync() { return DataSync; } void* GetRawData() override { return DataSync.GetRawData(); } size_t GetRawDataSize() override { return DataSync.GetRawDataSize(); } void Sync() override { DataSync.Sync(); } void Flush() override { DataSync.Flush(); } bool Resize(const TVector& szPerRank) override { Y_ABORT_UNLESS(szPerRank.ysize() == ColSize, "Invalid size array"); TVector offsets; offsets.push_back(0); for (int rank = 0; rank < ColSize; ++rank) { offsets.push_back(offsets.back() + szPerRank[rank]); } size_t dataSize = offsets.back(); if (dataSize > DataSync.BufSize) { return false; } for (int z = 0; z < DataSync.Iterations.ysize(); ++z) { TAllDataSync::TIteration& iter = DataSync.Iterations[z]; for (int i = 0; i < iter.OutList.ysize(); ++i) { TAllDataSync::TSend& snd = iter.OutList[i]; int rangeBeg = snd.Gather.RangeBeg; int rangeFin = snd.Gather.RangeFin; snd.Length = offsets[rangeFin] - offsets[rangeBeg]; snd.SrcOffset = offsets[rangeBeg]; snd.DstOffset = snd.SrcOffset; } } DataSync.DataSize = dataSize; return true; } }; struct TIBAddr { int LID, SL; TIBAddr() : LID(0) , SL(0) { } TIBAddr(int lid, int sl) : LID(lid) , SL(sl) { } }; inline bool operator==(const TIBAddr& a, const TIBAddr& b) { return a.LID == b.LID && a.SL == b.SL; } inline bool operator<(const TIBAddr& a, const TIBAddr& b) { if (a.LID == b.LID) { return a.SL < b.SL; } return a.LID < b.LID; } struct TIBAddrHash { int operator()(const TIBAddr& a) const { return a.LID + a.SL * 4254515; } }; class TIBCollective: public IIBCollective { struct TPendingMessage { int QPN; ui64 WorkId; TPendingMessage() { Zero(*this); } TPendingMessage(int qpn, ui64 wid) : QPN(qpn) , WorkId(wid) { } }; struct TBlockInform { TAllDataSync::TBlockInfo RemoteBlocks[2]; int PSN, QPN; }; struct TPeerConnection { TAllDataSync::TBlockInfo RemoteBlocks[2]; TIntrusivePtr QP; }; struct TBWTest { ui64 Addr; ui32 RKey; }; TIntrusivePtr Port; TIntrusivePtr MemPool; int ColSize, ColRank; TVector Hosts; // host LIDs TVector> HostGroup; TVector> Peers; TIntrusivePtr CQ; TIntrusivePtr BP; ui8 SendCountTable[SEND_COUNT_TABLE_SIZE]; ui8 RDMACountTable[SEND_COUNT_TABLE_SIZE]; TDeque Pending; TMergePlan MergePlan, ReducePlan; int QPNTableSizeLog; void WriteCompleted(const ibv_wc& wc) { --SendCountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)]; if (wc.opcode == IBV_WC_RDMA_WRITE) { --RDMACountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)]; } BP->FreeBuf(wc.wr_id); } bool GetMsg(ui64* resWorkId, int* resQPN, TIBMicroPeerTable* tbl) { if (tbl->NeedParsePending()) { for (TDeque::iterator z = Pending.begin(); z != Pending.end(); ++z) { if (!tbl->NeedQPN(z->QPN)) { continue; } *resWorkId = z->WorkId; *resQPN = z->QPN; Pending.erase(z); return true; } //printf("Stop parse pending\n"); tbl->StopParsePending(); } for (;;) { ibv_wc wc; int rv = CQ->Poll(&wc, 1); if (rv > 0) { Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status); if (wc.opcode & IBV_WC_RECV) { BP->RequestPostRecv(); if (tbl->NeedQPN(wc.qp_num)) { *resWorkId = wc.wr_id; *resQPN = wc.qp_num; return true; } else { Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id)); BP->PostRecv(); } } else { WriteCompleted(wc); } } else { return false; } } } bool ProcessSendCompletion(const ibv_wc& wc) { Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status); if (wc.opcode & IBV_WC_RECV) { BP->RequestPostRecv(); Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id)); BP->PostRecv(); } else { WriteCompleted(wc); return true; } return false; } void WaitCompletion(ibv_wc* res) { ibv_wc& wc = *res; for (;;) { int rv = CQ->Poll(&wc, 1); if (rv > 0 && ProcessSendCompletion(wc)) { break; } } } bool TryWaitCompletion() override { ibv_wc wc; for (;;) { int rv = CQ->Poll(&wc, 1); if (rv > 0) { if (ProcessSendCompletion(wc)) { return true; } } else { return false; } } } void WaitCompletion() override { ibv_wc wc; WaitCompletion(&wc); } ui64 WaitForMsg(int qpn) { for (TDeque::iterator z = Pending.begin(); z != Pending.end(); ++z) { if (z->QPN == qpn) { ui64 workId = z->WorkId; Pending.erase(z); return workId; } } ibv_wc wc; for (;;) { int rv = CQ->Poll(&wc, 1); if (rv > 0) { Y_ABORT_UNLESS(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status); if (wc.opcode & IBV_WC_RECV) { BP->RequestPostRecv(); if ((int)wc.qp_num == qpn) { return wc.wr_id; } else { Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id)); BP->PostRecv(); } } else { WriteCompleted(wc); } } } } bool AllocOperationSlot(TPtrArg qp) { int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1); if (SendCountTable[way] >= MAX_REQS_PER_PEER) { return false; } ++SendCountTable[way]; return true; } bool AllocRDMAWriteSlot(TPtrArg qp) { int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1); if (SendCountTable[way] >= MAX_REQS_PER_PEER) { return false; } if (RDMACountTable[way] >= MAX_OUTSTANDING_RDMA) { return false; } ++SendCountTable[way]; ++RDMACountTable[way]; return true; } bool TryPostSend(TPtrArg qp, const void* data, size_t len) { if (AllocOperationSlot(qp)) { BP->PostSend(qp, data, len); return true; } return false; } void PostSend(TPtrArg qp, const void* data, size_t len) { while (!TryPostSend(qp, data, len)) { WaitCompletion(); } } int GetRank() override { return ColRank; } int GetSize() override { return ColSize; } int GetGroupTypeCount() override { return HostGroup.ysize(); } int GetQPN(int rank) override { if (rank == ColRank) { Y_ASSERT(0 && "there is no qpn connected to localhost"); return 0; } return Peers[rank]->GetQPN(); } void Start(const TCollectiveLinkSet& links) override { Hosts = links.Hosts; HostGroup = links.HostGroup; for (int k = 0; k < ColSize; ++k) { if (k == ColRank) { continue; } const TCollectiveLinkSet::TLinkInfo& lnk = links.Links[k]; ibv_ah_attr peerAddr; MakeAH(&peerAddr, Port, Hosts[k], COL_SERVICE_LEVEL); Peers[k]->Init(peerAddr, lnk.QPN, lnk.PSN); } //CreatePow2Merge(&MergePlan, ColSize); //CreatePow2Merge(&ReducePlan, ColSize); CreateGroupMerge(&MergePlan, AA_STAR, HostGroup); CreateGroupMerge(&ReducePlan, AA_POW2_MERGE, HostGroup); } void CreateDataSyncQPs( TPtrArg cq, TPtrArg srq, TPtrArg memBlock0, TPtrArg memBlock1, const TMergePlan& plan, THashMap* res) { THashMap& connections = *res; TIBMemBlock* memBlock[2] = {memBlock0, memBlock1}; // make full peer list TVector peerList; for (int z = 0; z < plan.Iterations.ysize(); ++z) { const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank]; for (int i = 0; i < rr.OutList.ysize(); ++i) { const TMergeRecord::TTransfer& tr = rr.OutList[i]; peerList.push_back(TIBAddr(tr.DstRank, tr.SL)); } for (int i = 0; i < rr.InList.ysize(); ++i) { const TMergeRecord::TInTransfer& tr = rr.InList[i]; peerList.push_back(TIBAddr(tr.SrcRank, tr.SL)); } } Sort(peerList.begin(), peerList.end()); peerList.erase(Unique(peerList.begin(), peerList.end()), peerList.end()); // establish QPs and exchange mem block handlers for (int z = 0; z < peerList.ysize(); ++z) { const TIBAddr& ibAddr = peerList[z]; int dstRank = ibAddr.LID; TPeerConnection& dst = connections[ibAddr]; dst.QP = new TRCQueuePair(Port->GetCtx(), cq, srq, TAllDataSync::WR_COUNT); TBlockInform myBlock; for (int k = 0; k < 2; ++k) { myBlock.RemoteBlocks[k].Addr = memBlock[k]->GetAddr(); myBlock.RemoteBlocks[k].Key = memBlock[k]->GetMemRegion()->GetRKey(); } myBlock.PSN = dst.QP->GetPSN(); myBlock.QPN = dst.QP->GetQPN(); PostSend(Peers[dstRank], &myBlock, sizeof(myBlock)); } for (int z = 0; z < peerList.ysize(); ++z) { const TIBAddr& ibAddr = peerList[z]; int dstRank = ibAddr.LID; int sl = COL_DATA_SERVICE_LEVEL + ClampVal(ibAddr.SL, 0, COL_DATA_SERVICE_LEVEL_COUNT); TPeerConnection& dst = connections[ibAddr]; ui64 wr_id = WaitForMsg(Peers[dstRank]->GetQPN()); TIBRecvPacketProcess pkt(*BP, wr_id); const TBlockInform& info = *(TBlockInform*)pkt.GetData(); ibv_ah_attr peerAddr; MakeAH(&peerAddr, Port, Hosts[dstRank], COL_DATA_SERVICE_LEVEL + sl); dst.QP->Init(peerAddr, info.QPN, info.PSN); dst.RemoteBlocks[0] = info.RemoteBlocks[0]; dst.RemoteBlocks[1] = info.RemoteBlocks[1]; } Fence(); } IAllGather* CreateAllGather(const TVector& szPerRank) override { const TMergePlan& plan = MergePlan; Y_ABORT_UNLESS(szPerRank.ysize() == ColSize, "Invalid size array"); size_t totalSize = 0; for (int i = 0; i < szPerRank.ysize(); ++i) { totalSize += szPerRank[i]; } size_t bufSize = 4096; while (totalSize >= bufSize) { bufSize *= 2; } TAllGather* res = new TAllGather(ColSize, bufSize, MemPool); TAllDataSync& ds = res->GetDataSync(); THashMap connections; CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections); // build plan for (int z = 0; z < plan.Iterations.ysize(); ++z) { const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank]; if (rr.OutList.empty() && rr.InList.empty()) { continue; } TAllDataSync::TIteration& iter = ds.Iterations.emplace_back(); for (int i = 0; i < rr.OutList.ysize(); ++i) { const TMergeRecord::TTransfer& tr = rr.OutList[i]; TAllDataSync::TSend& snd = iter.OutList.emplace_back(); TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)]; snd.ImmData = tr.Id; snd.Gather.RangeBeg = tr.RangeBeg; snd.Gather.RangeFin = tr.RangeFin; snd.QP = pc.QP; snd.RemoteBlocks[0] = pc.RemoteBlocks[0]; snd.RemoteBlocks[1] = pc.RemoteBlocks[1]; snd.DstRank = tr.DstRank; } for (int i = 0; i < rr.InList.ysize(); ++i) { const TMergeRecord::TInTransfer& tr = rr.InList[i]; TAllDataSync::TRecv& rcv = iter.InList.emplace_back(); TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)]; rcv.QP = pc.QP; rcv.SrcRank = tr.SrcRank; } iter.RecvMask = rr.RecvMask; } bool rv = res->Resize(szPerRank); Y_ABORT_UNLESS(rv, "oops"); return res; } IAllGather* CreateAllGather(size_t szPerRank) override { TVector arr; arr.resize(ColSize, szPerRank); return CreateAllGather(arr); } IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg reduceOp) override { const TMergePlan& plan = ReducePlan; size_t bufSizeMult = plan.MaxRankReceiveCount + 1; size_t bufSize = 4096; { size_t sz = (dataSize + 64) * bufSizeMult; while (sz > bufSize) { bufSize *= 2; } } TAllReduce* res = new TAllReduce(bufSize, MemPool, reduceOp); TAllDataSync& ds = res->GetDataSync(); THashMap connections; CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections); // build plan int currentDataOffset = 0; for (int z = 0; z < plan.Iterations.ysize(); ++z) { const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank]; if (rr.OutList.empty() && rr.InList.empty()) { continue; } TAllDataSync::TIteration& iter = ds.Iterations.emplace_back(); for (int i = 0; i < rr.OutList.ysize(); ++i) { const TMergeRecord::TTransfer& tr = rr.OutList[i]; TAllDataSync::TSend& snd = iter.OutList.emplace_back(); TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)]; snd.ImmData = tr.Id; snd.Reduce.SrcIndex = currentDataOffset; snd.Reduce.DstIndex = tr.Id + 1; snd.QP = pc.QP; snd.RemoteBlocks[0] = pc.RemoteBlocks[0]; snd.RemoteBlocks[1] = pc.RemoteBlocks[1]; snd.DstRank = tr.DstRank; } for (int i = 0; i < rr.InList.ysize(); ++i) { const TMergeRecord::TInTransfer& tr = rr.InList[i]; TAllDataSync::TRecv& rcv = iter.InList.emplace_back(); TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)]; rcv.QP = pc.QP; rcv.SrcRank = tr.SrcRank; } iter.RecvMask = rr.RecvMask; TVector inputOffset; inputOffset.push_back(currentDataOffset); int newDataOffset = currentDataOffset; for (int i = 0; i < 64; ++i) { if (rr.RecvMask & (1ull << i)) { int offset = i + 1; inputOffset.push_back(offset); newDataOffset = Max(offset, newDataOffset); } } for (int i = 0; i < inputOffset.ysize(); ++i) { if (inputOffset[i] == newDataOffset) { continue; } TAllDataSync::TReduce& red = iter.ReduceList.emplace_back(); red.SrcIndex = inputOffset[i]; red.DstIndex = newDataOffset; } currentDataOffset = newDataOffset; } res->BufSizeMult = bufSizeMult; res->ReadyOffsetMult = currentDataOffset; bool rv = res->Resize(dataSize); Y_ABORT_UNLESS(rv, "oops"); return res; } void Fence() override { const TMergePlan& plan = ReducePlan; for (int z = 0; z < plan.Iterations.ysize(); ++z) { const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank]; for (int i = 0; i < rr.OutList.ysize(); ++i) { const TMergeRecord::TTransfer& tr = rr.OutList[i]; char c; PostSend(Peers[tr.DstRank], &c, sizeof(c)); } for (int i = 0; i < rr.InList.ysize(); ++i) { const TMergeRecord::TInTransfer& tr = rr.InList[i]; ui64 wr_id = WaitForMsg(Peers[tr.SrcRank]->GetQPN()); TIBRecvPacketProcess pkt(*BP, wr_id); } } } void RunBWTest(int groupType, int delta, int* targetRank, float* res) override { const int BUF_SIZE = 8 * 1024 * 1024; TIntrusivePtr sendMem, recvMem; sendMem = MemPool->Alloc(BUF_SIZE); recvMem = MemPool->Alloc(BUF_SIZE); int myGroup = HostGroup[groupType][ColRank]; int myGroupPos = 0; TVector gg; Y_ASSERT(HostGroup[groupType].ysize() == ColSize); for (int rank = 0; rank < ColSize; ++rank) { if (HostGroup[groupType][rank] == myGroup) { if (rank == ColRank) { myGroupPos = gg.ysize(); } gg.push_back(rank); } } if (delta >= gg.ysize()) { *targetRank = -1; *res = 0; return; } int sendRank = gg[(myGroupPos + delta) % gg.ysize()]; int recvRank = gg[(myGroupPos + gg.ysize() - delta) % gg.ysize()]; *targetRank = sendRank; TIntrusivePtr sendRC = Peers[sendRank]; TIntrusivePtr recvRC = Peers[recvRank]; { TBWTest bw; bw.Addr = recvMem->GetAddr(); bw.RKey = recvMem->GetMemRegion()->GetRKey(); PostSend(recvRC, &bw, sizeof(bw)); } TBWTest dstMem; { ui64 wr_id = WaitForMsg(sendRC->GetQPN()); TIBRecvPacketProcess pkt(*BP, wr_id); dstMem = *(TBWTest*)pkt.GetData(); } // run TVector score; for (int iter = 0; iter < 5; ++iter) { while (!AllocRDMAWriteSlot(sendRC)) { WaitCompletion(); Y_ASSERT(0 && "measurements are imprecise"); } NHPTimer::STime t; NHPTimer::GetTime(&t); sendRC->PostRDMAWrite(dstMem.Addr, dstMem.RKey, sendMem->GetMemRegion(), 0, sendMem->GetData(), BUF_SIZE); for (;;) { ibv_wc wc; WaitCompletion(&wc); if (wc.opcode == IBV_WC_RDMA_WRITE) { if (wc.qp_num != (ui32)sendRC->GetQPN()) { abort(); } break; } } double tPassed = NHPTimer::GetTimePassed(&t); double speed = BUF_SIZE / tPassed / 1000000000.0; // G/sec score.push_back(speed); } Sort(score.begin(), score.end()); // signal completion & wait for signal *res = score[score.size() / 2]; { char bb; PostSend(sendRC, &bb, sizeof(bb)); ui64 wr_id = WaitForMsg(recvRC->GetQPN()); TIBRecvPacketProcess pkt(*BP, wr_id); } } bool TrySendMicro(int dstRank, const void* data, int dataSize) override { return TryPostSend(Peers[dstRank], data, dataSize); } void InitPeerTable(TIBMicroPeerTable* res) override { res->Init(QPNTableSizeLog); } void RdmaWrite(const TVector& reqs) override { TVector> reqPerRank; reqPerRank.resize(ColSize); int reqCount = reqs.ysize(); for (int i = 0; i < reqCount; ++i) { reqPerRank[reqs[i].DstRank].push_back(i); } int inFlight = 0; // IB congestion control sucks :/ so we limit number of simultaneous rdmas int startRank = ColRank; while (reqCount > 0) { if (inFlight < MAX_TOTAL_RDMA) { for (int z = 0; z < ColSize; ++z) { int dstRank = (startRank + 1 + z) % ColSize; if (reqPerRank[dstRank].empty()) { continue; } Y_ASSERT(dstRank != ColRank && "sending self is meaningless"); TRCQueuePair* qp = Peers[dstRank].Get(); if (AllocRDMAWriteSlot(qp)) { const TRdmaRequest& rr = reqs[reqPerRank[dstRank].back()]; qp->PostRDMAWrite(rr.RemoteAddr, rr.RemoteKey, rr.LocalAddr, rr.LocalKey, 0, rr.Size); reqPerRank[dstRank].pop_back(); if (++inFlight >= MAX_TOTAL_RDMA) { startRank = dstRank; break; } } } } { ibv_wc wc; WaitCompletion(&wc); if (wc.opcode == IBV_WC_RDMA_WRITE) { --inFlight; --reqCount; } } } } public: TIBCollective(TPtrArg port, TPtrArg memPool, const TCollectiveInit& params, TCollectiveLinkSet* resLinks) : Port(port) , MemPool(memPool) , QPNTableSizeLog(0) { ColSize = params.Size; ColRank = params.Rank; int maxOutstandingQueries = MAX_REQS_PER_PEER * ColSize + 10; CQ = new TComplectionQueue(Port->GetCtx(), maxOutstandingQueries * 2); BP = new TIBBufferPool(Port->GetCtx(), maxOutstandingQueries); Peers.resize(ColSize); resLinks->Links.resize(ColSize); TVector qpnArr; for (int k = 0; k < ColSize; ++k) { if (k == ColRank) { continue; } TRCQueuePair* rc = new TRCQueuePair(Port->GetCtx(), CQ, BP->GetSRQ(), MAX_REQS_PER_PEER); Peers[k] = rc; TCollectiveLinkSet::TLinkInfo& lnk = resLinks->Links[k]; lnk.PSN = rc->GetPSN(); lnk.QPN = rc->GetQPN(); qpnArr.push_back(lnk.QPN); } resLinks->Hosts.resize(ColSize); resLinks->Hosts[ColRank] = Port->GetLID(); static_assert(MAX_REQS_PER_PEER < 256, "expect MAX_REQS_PER_PEER < 256"); // sent count will fit into SendCountTable[] Zero(SendCountTable); Zero(RDMACountTable); if (!qpnArr.empty()) { for (;;) { TVector qpnTable; int qpnTableSize = 1 << QPNTableSizeLog; qpnTable.resize(qpnTableSize, 0); bool ok = true; for (int i = 0; i < qpnArr.ysize(); ++i) { int idx = qpnArr[i] & (qpnTableSize - 1); if (++qpnTable[idx] == 2) { ok = false; break; } } if (ok) { break; } ++QPNTableSizeLog; } //printf("QPN table, size_log %d\n", QPNTableSizeLog); } } friend class TIBRecvMicro; }; TIBRecvMicro::TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable) : IB(*(TIBCollective*)col) { Y_ASSERT(typeid(IB) == typeid(TIBCollective)); if (IB.GetMsg(&Id, &QPN, peerTable)) { Data = IB.BP->GetBufData(Id); } else { Data = nullptr; } } TIBRecvMicro::~TIBRecvMicro() { if (Data) { IB.BP->FreeBuf(Id); IB.BP->PostRecv(); } } IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks) { return new TIBCollective(GetIBDevice(), GetIBMemPool(), params, resLinks); } }