#pragma once #include namespace NNetliba { struct TCollectiveInit { int Size, Rank; SAVELOAD(Size, Rank); }; struct TCollectiveLinkSet { struct TLinkInfo { int QPN, PSN; }; TVector Hosts; // host LIDs TVector> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch TVector Links; SAVELOAD(Hosts, HostGroup, Links); }; struct IAllDataSync: public TThrRefBase { virtual void* GetRawData() = 0; virtual size_t GetRawDataSize() = 0; virtual void Sync() = 0; virtual void Flush() = 0; template T* GetData() { return static_cast(GetRawData()); } template size_t GetSize() { return GetRawDataSize() / sizeof(T); } }; struct IAllReduce: public IAllDataSync { virtual bool Resize(size_t dataSize) = 0; }; struct IAllGather: public IAllDataSync { virtual bool Resize(const TVector& szPerRank) = 0; }; struct IReduceOp: public TThrRefBase { virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0; }; template class TAllReduceOp: public IReduceOp { T Op; public: TAllReduceOp() { } TAllReduceOp(T op) : Op(op) { } void Reduce(void* dst, const void* add, size_t dataSize) const override { TElem* dstPtr = (TElem*)(dst); const TElem* addPtr = (const TElem*)(add); TElem* finPtr = (TElem*)(((char*)dst) + dataSize); while (dstPtr < finPtr) { Op(dstPtr, *addPtr); ++dstPtr; ++addPtr; } } }; // table of active peers for micro send/recv class TIBMicroPeerTable { TVector Table; // == 0 means accept mesages from this qpn int TableSize; bool ParsePending; public: TIBMicroPeerTable() : ParsePending(true) { Init(0); } void Init(int tableSizeLog) { TableSize = 1 << tableSizeLog; ParsePending = true; Table.resize(0); Table.resize(TableSize, 0); } bool NeedParsePending() const { return ParsePending; } void StopParsePending() { ParsePending = false; } void StopQPN(int qpn, ui8 mask) { Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0); Table[qpn & (TableSize - 1)] |= mask; } void StopQPN(int qpn) { Y_ASSERT(Table[qpn & (TableSize - 1)] == 0); Table[qpn & (TableSize - 1)] = 0xff; } bool NeedQPN(int qpn) const { return Table[qpn & (TableSize - 1)] != 0xff; } }; struct IIBCollective; class TIBCollective; class TIBRecvMicro: public TNonCopyable { TIBCollective& IB; ui64 Id; int QPN; void* Data; public: TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable); ~TIBRecvMicro(); void* GetRawData() const { return Data; } template T* GetData() { return static_cast(GetRawData()); } int GetQPN() const { return QPN; } }; struct IIBCollective: public TThrRefBase { struct TRdmaRequest { int DstRank; ui64 RemoteAddr, LocalAddr; ui32 RemoteKey, LocalKey; ui64 Size; }; virtual int GetRank() = 0; virtual int GetSize() = 0; virtual int GetGroupTypeCount() = 0; virtual int GetQPN(int rank) = 0; virtual bool TryWaitCompletion() = 0; virtual void WaitCompletion() = 0; virtual void Start(const TCollectiveLinkSet& links) = 0; virtual IAllGather* CreateAllGather(const TVector& szPerRank) = 0; virtual IAllGather* CreateAllGather(size_t szPerRank) = 0; virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg reduceOp) = 0; virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0; virtual void Fence() = 0; virtual void InitPeerTable(TIBMicroPeerTable* res) = 0; virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0; virtual void RdmaWrite(const TVector& reqs) = 0; }; IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks); }