#pragma once #include "ib_low.h" namespace NNetliba { // buffer id 0 is special, it is used when data is sent inline and should not be returned const size_t SMALL_PKT_SIZE = 4096; const ui64 BP_AH_USED_FLAG = 0x1000000000000000ul; const ui64 BP_BUF_ID_MASK = 0x00000000fffffffful; // single thread version class TIBBufferPool: public TThrRefBase, TNonCopyable { static constexpr int BLOCK_SIZE_LN = 11; static constexpr int BLOCK_SIZE = 1 << BLOCK_SIZE_LN; static constexpr int BLOCK_COUNT = 1024; struct TSingleBlock { TIntrusivePtr Mem; TVector BlkRefCounts; TVector> AHHolder; void Alloc(TPtrArg ctx) { size_t dataSize = SMALL_PKT_SIZE * BLOCK_SIZE; Mem = new TMemoryRegion(ctx, dataSize); BlkRefCounts.resize(BLOCK_SIZE, 0); AHHolder.resize(BLOCK_SIZE); } char* GetBufData(ui64 idArg) { char* data = Mem->GetData(); return data + (idArg & (BLOCK_SIZE - 1)) * SMALL_PKT_SIZE; } }; TIntrusivePtr IBCtx; TVector FreeList; TVector Blocks; size_t FirstFreeBlock; int PostRecvDeficit; TIntrusivePtr SRQ; void AddBlock() { if (FirstFreeBlock == Blocks.size()) { Y_ABORT_UNLESS(0, "run out of buffers"); } Blocks[FirstFreeBlock].Alloc(IBCtx); size_t start = (FirstFreeBlock == 0) ? 1 : FirstFreeBlock * BLOCK_SIZE; size_t finish = FirstFreeBlock * BLOCK_SIZE + BLOCK_SIZE; for (size_t i = start; i < finish; ++i) { FreeList.push_back(i); } ++FirstFreeBlock; } public: TIBBufferPool(TPtrArg ctx, int maxSRQWorkRequests) : IBCtx(ctx) , FirstFreeBlock(0) , PostRecvDeficit(maxSRQWorkRequests) { Blocks.resize(BLOCK_COUNT); AddBlock(); SRQ = new TSharedReceiveQueue(ctx, maxSRQWorkRequests); PostRecv(); } TSharedReceiveQueue* GetSRQ() const { return SRQ.Get(); } int AllocBuf() { if (FreeList.empty()) { AddBlock(); } int id = FreeList.back(); FreeList.pop_back(); Y_ASSERT(++Blocks[id >> BLOCK_SIZE_LN].BlkRefCounts[id & (BLOCK_SIZE - 1)] == 1); return id; } void FreeBuf(ui64 idArg) { ui64 id = idArg & BP_BUF_ID_MASK; if (id == 0) { return; } Y_ASSERT(id > 0 && id < (ui64)(FirstFreeBlock * BLOCK_SIZE)); FreeList.push_back(id); Y_ASSERT(--Blocks[id >> BLOCK_SIZE_LN].BlkRefCounts[id & (BLOCK_SIZE - 1)] == 0); if (idArg & BP_AH_USED_FLAG) { Blocks[id >> BLOCK_SIZE_LN].AHHolder[id & (BLOCK_SIZE - 1)] = nullptr; } } char* GetBufData(ui64 idArg) { ui64 id = idArg & BP_BUF_ID_MASK; return Blocks[id >> BLOCK_SIZE_LN].GetBufData(id); } int PostSend(TPtrArg qp, const void* data, size_t len) { if (len > SMALL_PKT_SIZE) { Y_ABORT_UNLESS(0, "buffer overrun"); } if (len <= MAX_INLINE_DATA_SIZE) { qp->PostSend(nullptr, 0, data, len); return 0; } else { int id = AllocBuf(); TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN]; char* buf = blk.GetBufData(id); memcpy(buf, data, len); qp->PostSend(blk.Mem, id, buf, len); return id; } } void PostSend(TPtrArg qp, TPtrArg ah, int remoteQPN, int remoteQKey, const void* data, size_t len) { if (len > SMALL_PKT_SIZE - 40) { Y_ABORT_UNLESS(0, "buffer overrun"); } ui64 id = AllocBuf(); TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN]; int ptr = id & (BLOCK_SIZE - 1); blk.AHHolder[ptr] = ah.Get(); id |= BP_AH_USED_FLAG; if (len <= MAX_INLINE_DATA_SIZE) { qp->PostSend(ah, remoteQPN, remoteQKey, nullptr, id, data, len); } else { char* buf = blk.GetBufData(id); memcpy(buf, data, len); qp->PostSend(ah, remoteQPN, remoteQKey, blk.Mem, id, buf, len); } } void RequestPostRecv() { ++PostRecvDeficit; } void PostRecv() { for (int i = 0; i < PostRecvDeficit; ++i) { int id = AllocBuf(); TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN]; char* buf = blk.GetBufData(id); SRQ->PostReceive(blk.Mem, id, buf, SMALL_PKT_SIZE); } PostRecvDeficit = 0; } }; class TIBRecvPacketProcess: public TNonCopyable { TIBBufferPool& BP; ui64 Id; char* Data; public: TIBRecvPacketProcess(TIBBufferPool& bp, const ibv_wc& wc) : BP(bp) , Id(wc.wr_id) { Y_ASSERT(wc.opcode & IBV_WC_RECV); BP.RequestPostRecv(); Data = BP.GetBufData(Id); } // intended for postponed packet processing // with this call RequestPostRecv() should be called outside (and PostRecv() too in order to avoid rnr situation) TIBRecvPacketProcess(TIBBufferPool& bp, const ui64 wr_id) : BP(bp) , Id(wr_id) { Data = BP.GetBufData(Id); } ~TIBRecvPacketProcess() { BP.FreeBuf(Id); BP.PostRecv(); } char* GetData() const { return Data; } char* GetUDData() const { return Data + 40; } ibv_grh* GetGRH() const { return (ibv_grh*)Data; } }; }