123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- #pragma once
- #include <library/cpp/binsaver/bin_saver.h>
- namespace NNetliba {
- struct TCollectiveInit {
- int Size, Rank;
- SAVELOAD(Size, Rank);
- };
- struct TCollectiveLinkSet {
- struct TLinkInfo {
- int QPN, PSN;
- };
- TVector<int> Hosts; // host LIDs
- TVector<TVector<int>> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch
- TVector<TLinkInfo> Links;
- SAVELOAD(Hosts, HostGroup, Links);
- };
- struct IAllDataSync: public TThrRefBase {
- virtual void* GetRawData() = 0;
- virtual size_t GetRawDataSize() = 0;
- virtual void Sync() = 0;
- virtual void Flush() = 0;
- template <class T>
- T* GetData() {
- return static_cast<T*>(GetRawData());
- }
- template <class T>
- size_t GetSize() {
- return GetRawDataSize() / sizeof(T);
- }
- };
- struct IAllReduce: public IAllDataSync {
- virtual bool Resize(size_t dataSize) = 0;
- };
- struct IAllGather: public IAllDataSync {
- virtual bool Resize(const TVector<size_t>& szPerRank) = 0;
- };
- struct IReduceOp: public TThrRefBase {
- virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0;
- };
- template <class T, class TElem = typename T::TElem>
- class TAllReduceOp: public IReduceOp {
- T Op;
- public:
- TAllReduceOp() {
- }
- TAllReduceOp(T op)
- : Op(op)
- {
- }
- void Reduce(void* dst, const void* add, size_t dataSize) const override {
- TElem* dstPtr = (TElem*)(dst);
- const TElem* addPtr = (const TElem*)(add);
- TElem* finPtr = (TElem*)(((char*)dst) + dataSize);
- while (dstPtr < finPtr) {
- Op(dstPtr, *addPtr);
- ++dstPtr;
- ++addPtr;
- }
- }
- };
- // table of active peers for micro send/recv
- class TIBMicroPeerTable {
- TVector<ui8> Table; // == 0 means accept mesages from this qpn
- int TableSize;
- bool ParsePending;
- public:
- TIBMicroPeerTable()
- : ParsePending(true)
- {
- Init(0);
- }
- void Init(int tableSizeLog) {
- TableSize = 1 << tableSizeLog;
- ParsePending = true;
- Table.resize(0);
- Table.resize(TableSize, 0);
- }
- bool NeedParsePending() const {
- return ParsePending;
- }
- void StopParsePending() {
- ParsePending = false;
- }
- void StopQPN(int qpn, ui8 mask) {
- Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0);
- Table[qpn & (TableSize - 1)] |= mask;
- }
- void StopQPN(int qpn) {
- Y_ASSERT(Table[qpn & (TableSize - 1)] == 0);
- Table[qpn & (TableSize - 1)] = 0xff;
- }
- bool NeedQPN(int qpn) const {
- return Table[qpn & (TableSize - 1)] != 0xff;
- }
- };
- struct IIBCollective;
- class TIBCollective;
- class TIBRecvMicro: public TNonCopyable {
- TIBCollective& IB;
- ui64 Id;
- int QPN;
- void* Data;
- public:
- TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable);
- ~TIBRecvMicro();
- void* GetRawData() const {
- return Data;
- }
- template <class T>
- T* GetData() {
- return static_cast<T*>(GetRawData());
- }
- int GetQPN() const {
- return QPN;
- }
- };
- struct IIBCollective: public TThrRefBase {
- struct TRdmaRequest {
- int DstRank;
- ui64 RemoteAddr, LocalAddr;
- ui32 RemoteKey, LocalKey;
- ui64 Size;
- };
- virtual int GetRank() = 0;
- virtual int GetSize() = 0;
- virtual int GetGroupTypeCount() = 0;
- virtual int GetQPN(int rank) = 0;
- virtual bool TryWaitCompletion() = 0;
- virtual void WaitCompletion() = 0;
- virtual void Start(const TCollectiveLinkSet& links) = 0;
- virtual IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) = 0;
- virtual IAllGather* CreateAllGather(size_t szPerRank) = 0;
- virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) = 0;
- virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0;
- virtual void Fence() = 0;
- virtual void InitPeerTable(TIBMicroPeerTable* res) = 0;
- virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0;
- virtual void RdmaWrite(const TVector<TRdmaRequest>& reqs) = 0;
- };
- IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks);
- }
|