net_acks.cpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #include "stdafx.h"
  2. #include "net_acks.h"
  3. #include <util/datetime/cputimer.h>
  4. #include <atomic>
  5. namespace NNetliba {
  6. const float RTT_AVERAGE_OVER = 15;
  7. float TCongestionControl::StartWindowSize = 3;
  8. float TCongestionControl::MaxPacketRate = 0; // unlimited
  9. bool UseTOSforAcks = false; //true;//
  10. void EnableUseTOSforAcks(bool enable) {
  11. UseTOSforAcks = enable;
  12. }
  13. float CONG_CTRL_CHANNEL_INFLATE = 1;
  14. void SetCongCtrlChannelInflate(float inflate) {
  15. CONG_CTRL_CHANNEL_INFLATE = inflate;
  16. }
  17. //////////////////////////////////////////////////////////////////////////
  18. TPingTracker::TPingTracker()
  19. : AvrgRTT(CONG_CTRL_INITIAL_RTT)
  20. , AvrgRTT2(CONG_CTRL_INITIAL_RTT * CONG_CTRL_INITIAL_RTT)
  21. , RTTCount(0)
  22. {
  23. }
  24. void TPingTracker::RegisterRTT(float rtt) {
  25. Y_ASSERT(rtt > 0);
  26. float keep = RTTCount / (RTTCount + 1);
  27. AvrgRTT *= keep;
  28. AvrgRTT += (1 - keep) * rtt;
  29. AvrgRTT2 *= keep;
  30. AvrgRTT2 += (1 - keep) * Sqr(rtt);
  31. RTTCount = Min(RTTCount + 1, RTT_AVERAGE_OVER);
  32. //static int n;
  33. //if ((++n % 1024) == 0)
  34. // printf("Average RTT = %g (sko = %g)\n", GetRTT() * 1000, GetRTTSKO() * 1000);
  35. }
  36. void TPingTracker::IncreaseRTT() {
  37. const float F_RTT_DECAY_RATE = 1.1f;
  38. AvrgRTT *= F_RTT_DECAY_RATE;
  39. AvrgRTT2 *= Sqr(F_RTT_DECAY_RATE);
  40. }
  41. //////////////////////////////////////////////////////////////////////////
  42. void TAckTracker::Resend() {
  43. CurrentPacket = 0;
  44. for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i)
  45. Congestion->Failure(); // not actually correct but simplifies logic a lot
  46. PacketsInFly.clear();
  47. DroppedPackets.clear();
  48. ResendQueue.clear();
  49. for (size_t i = 0; i < AckReceived.size(); ++i)
  50. AckReceived[i] = false;
  51. }
  52. int TAckTracker::SelectPacket() {
  53. if (!ResendQueue.empty()) {
  54. int res = ResendQueue.back();
  55. ResendQueue.pop_back();
  56. //printf("resending packet %d\n", res);
  57. return res;
  58. }
  59. if (CurrentPacket == PacketCount) {
  60. return -1;
  61. }
  62. return CurrentPacket++;
  63. }
  64. TAckTracker::~TAckTracker() {
  65. for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i)
  66. Congestion->Failure();
  67. // object will be incorrect state after this (failed packets are not added to resend queue), but who cares
  68. }
  69. int TAckTracker::GetPacketToSend(float deltaT) {
  70. int res = SelectPacket();
  71. if (res == -1) {
  72. // needed to count time even if we don't have anything to send
  73. Congestion->HasTriedToSend();
  74. return res;
  75. }
  76. Congestion->LaunchPacket();
  77. PacketsInFly[res] = -deltaT; // deltaT is time since last Step(), so for the timing to be correct we should subtract it
  78. return res;
  79. }
  80. // called on SendTo() failure
  81. void TAckTracker::AddToResend(int pkt) {
  82. //printf("AddToResend(%d)\n", pkt);
  83. TPacketHash::iterator i = PacketsInFly.find(pkt);
  84. if (i != PacketsInFly.end()) {
  85. PacketsInFly.erase(i);
  86. Congestion->FailureOnSend();
  87. ResendQueue.push_back(pkt);
  88. } else
  89. Y_ASSERT(0);
  90. }
  91. void TAckTracker::Ack(int pkt, float deltaT, bool updateRTT) {
  92. Y_ASSERT(pkt >= 0 && pkt < PacketCount);
  93. if (AckReceived[pkt])
  94. return;
  95. AckReceived[pkt] = true;
  96. //printf("Ack received for %d\n", pkt);
  97. TPacketHash::iterator i = PacketsInFly.find(pkt);
  98. if (i == PacketsInFly.end()) {
  99. for (size_t k = 0; k < ResendQueue.size(); ++k) {
  100. if (ResendQueue[k] == pkt) {
  101. ResendQueue[k] = ResendQueue.back();
  102. ResendQueue.pop_back();
  103. break;
  104. }
  105. }
  106. TPacketHash::iterator z = DroppedPackets.find(pkt);
  107. if (z != DroppedPackets.end()) {
  108. // late packet arrived
  109. if (updateRTT) {
  110. float ping = z->second + deltaT;
  111. Congestion->RegisterRTT(ping);
  112. }
  113. DroppedPackets.erase(z);
  114. } else {
  115. // Y_ASSERT(0); // ack on nonsent packet, possible in resend scenario
  116. }
  117. return;
  118. }
  119. if (updateRTT) {
  120. float ping = i->second + deltaT;
  121. //printf("Register RTT %g\n", ping * 1000);
  122. Congestion->RegisterRTT(ping);
  123. }
  124. PacketsInFly.erase(i);
  125. Congestion->Success();
  126. }
  127. void TAckTracker::AckAll() {
  128. for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i) {
  129. int pkt = i->first;
  130. AckReceived[pkt] = true;
  131. Congestion->Success();
  132. }
  133. PacketsInFly.clear();
  134. }
  135. void TAckTracker::Step(float deltaT) {
  136. float timeoutVal = Congestion->GetTimeout();
  137. //static int n;
  138. //if ((++n % 1024) == 0)
  139. // printf("timeout = %g, window = %g, fail_rate %g, pkt_rate = %g\n", timeoutVal * 1000, Congestion->GetWindow(), Congestion->GetFailRate(), (1 - Congestion->GetFailRate()) * Congestion->GetWindow() / Congestion->GetRTT());
  140. TimeToNextPacketTimeout = 1000;
  141. // для окон меньше единицы мы кидаем рандом один раз за RTT на то, можно ли пускать пакет
  142. // поэтому можно ждать максимум RTT, после этого надо кинуть новый random
  143. if (Congestion->GetWindow() < 1)
  144. TimeToNextPacketTimeout = Congestion->GetRTT();
  145. for (auto& droppedPacket : DroppedPackets) {
  146. float& t = droppedPacket.second;
  147. t += deltaT;
  148. }
  149. for (TPacketHash::iterator i = PacketsInFly.begin(); i != PacketsInFly.end();) {
  150. float& t = i->second;
  151. t += deltaT;
  152. if (t > timeoutVal) {
  153. //printf("packet %d timed out (timeout = %g)\n", i->first, timeoutVal);
  154. ResendQueue.push_back(i->first);
  155. DroppedPackets[i->first] = i->second;
  156. TPacketHash::iterator k = i++;
  157. PacketsInFly.erase(k);
  158. Congestion->Failure();
  159. } else {
  160. TimeToNextPacketTimeout = Min(TimeToNextPacketTimeout, timeoutVal - t);
  161. ++i;
  162. }
  163. }
  164. }
  165. static std::atomic<ui32> netAckRndVal = (ui32)GetCycleCount();
  166. ui32 NetAckRnd() {
  167. const auto nextNetAckRndVal = static_cast<ui32>(((ui64)netAckRndVal.load(std::memory_order_acquire) * 279470273) % 4294967291);
  168. netAckRndVal.store(nextNetAckRndVal, std::memory_order_release);
  169. return nextNetAckRndVal;
  170. }
  171. }