net_acks.h 18 KB


  1. #pragma once
  2. #include "net_test.h"
  3. #include "net_queue_stat.h"
  4. #include <util/system/spinlock.h>
  5. namespace NNetliba {
  6. const float MIN_PACKET_RTT_SKO = 0.001f; // avoid drops due to small hiccups
  7. const float CONG_CTRL_INITIAL_RTT = 0.24f; //0.01f; // taking into account Las Vegas 10ms estimate is too optimistic
  8. const float CONG_CTRL_WINDOW_GROW = 0.005f;
  9. const float CONG_CTRL_WINDOW_SHRINK = 0.9f;
  10. const float CONG_CTRL_WINDOW_SHRINK_RTT = 0.95f;
  11. const float CONG_CTRL_RTT_MIX_RATE = 0.9f;
  12. const int CONG_CTRL_RTT_SEQ_COUNT = 8;
  13. const float CONG_CTRL_MIN_WINDOW = 0.01f;
  14. const float CONG_CTRL_LARGE_TIME_WINDOW = 10000.0f;
  15. const float CONG_CTRL_TIME_WINDOW_LIMIT_PERIOD = 0.4f; // in seconds
  16. const float CONG_CTRL_MINIMAL_SEND_INTERVAL = 1;
  17. const float CONG_CTRL_MIN_FAIL_INTERVAL = 0.001f;
  18. const float CONG_CTRL_ALLOWED_BURST_SIZE = 3;
  19. const float CONG_CTRL_MIN_RTT_FOR_BURST_REDUCTION = 0.002f;
  20. const float LAME_MTU_TIMEOUT = 0.3f;
  21. const float LAME_MTU_INTERVAL = 0.05f;
  22. const float START_CHECK_PORT_DELAY = 0.5;
  23. const float FINISH_CHECK_PORT_DELAY = 10;
  24. const int N_PORT_TEST_COUNT_LIMIT = 256; // or 512
  25. // if enabled all acks are sent with different TOS, so they end up in different queue
  26. // this allows us to limit window based on minimal RTT observed and 1G link assumption
  27. extern bool UseTOSforAcks;
  28. class TPingTracker {
  29. float AvrgRTT, AvrgRTT2; // RTT statistics
  30. float RTTCount;
  31. public:
  32. TPingTracker();
  33. float GetRTT() const {
  34. return AvrgRTT;
  35. }
  36. float GetRTTSKO() const {
  37. float sko = sqrt(fabs(Sqr(AvrgRTT) - AvrgRTT2));
  38. float minSKO = Max(MIN_PACKET_RTT_SKO, AvrgRTT * 0.05f);
  39. return Max(minSKO, sko);
  40. }
  41. float GetTimeout() const {
  42. return GetRTT() + GetRTTSKO() * 3;
  43. }
  44. void RegisterRTT(float rtt);
  45. void IncreaseRTT();
  46. };
  47. ui32 NetAckRnd();
  48. class TLameMTUDiscovery: public TThrRefBase {
  49. enum EState {
  50. NEED_PING,
  51. WAIT,
  52. };
  53. float TimePassed, TimeSinceLastPing;
  54. EState State;
  55. public:
  56. TLameMTUDiscovery()
  57. : TimePassed(0)
  58. , TimeSinceLastPing(0)
  59. , State(NEED_PING)
  60. {
  61. }
  62. bool CanSend() {
  63. return State == NEED_PING;
  64. }
  65. void PingSent() {
  66. State = WAIT;
  67. TimeSinceLastPing = 0;
  68. }
  69. bool IsTimedOut() const {
  70. return TimePassed > LAME_MTU_TIMEOUT;
  71. }
  72. void Step(float deltaT) {
  73. TimePassed += deltaT;
  74. TimeSinceLastPing += deltaT;
  75. if (TimeSinceLastPing > LAME_MTU_INTERVAL)
  76. State = NEED_PING;
  77. }
  78. };
  79. struct TPeerQueueStats: public IPeerQueueStats {
  80. int Count;
  81. TPeerQueueStats()
  82. : Count(0)
  83. {
  84. }
  85. int GetPacketCount() override {
  86. return Count;
  87. }
  88. };
  89. // pretend we have multiple channels in parallel
  90. // not exact approximation since N channels should have N distinct windows
  91. extern float CONG_CTRL_CHANNEL_INFLATE;
  92. class TCongestionControl: public TThrRefBase {
  93. float Window, PacketsInFly, FailRate;
  94. float MinRTT, MaxWindow;
  95. bool FullSpeed, DoCountTime;
  96. TPingTracker PingTracker;
  97. double TimeSinceLastRecv;
  98. TAdaptiveLock PortTesterLock;
  99. TIntrusivePtr<TPortUnreachableTester> PortTester;
  100. int ActiveTransferCount;
  101. float AvrgRTT;
  102. int HighRTTCounter;
  103. float WindowFraction, FractionRecalc;
  104. float TimeWindow;
  105. double TimeSinceLastFail;
  106. float VirtualPackets;
  107. int MTU;
  108. TIntrusivePtr<TLameMTUDiscovery> MTUDiscovery;
  109. TIntrusivePtr<TPeerQueueStats> QueueStats;
  110. void CalcMaxWindow() {
  111. if (MTU == 0)
  112. return;
  113. MaxWindow = 125000000 / MTU * Max(0.001f, MinRTT);
  114. }
  115. public:
  116. static float StartWindowSize, MaxPacketRate;
  117. public:
  118. TCongestionControl()
  119. : Window(StartWindowSize * CONG_CTRL_CHANNEL_INFLATE)
  120. , PacketsInFly(0)
  121. , FailRate(0)
  122. , MinRTT(10)
  123. , MaxWindow(10000)
  124. , FullSpeed(false)
  125. , DoCountTime(false)
  126. , TimeSinceLastRecv(0)
  127. , ActiveTransferCount(0)
  128. , AvrgRTT(0)
  129. , HighRTTCounter(0)
  130. , WindowFraction(0)
  131. , FractionRecalc(0)
  132. , TimeWindow(CONG_CTRL_LARGE_TIME_WINDOW)
  133. , TimeSinceLastFail(0)
  134. , MTU(0)
  135. {
  136. VirtualPackets = Max(Window - CONG_CTRL_ALLOWED_BURST_SIZE, 0.f);
  137. }
  138. bool CanSend() {
  139. bool res = VirtualPackets + PacketsInFly + WindowFraction <= Window;
  140. FullSpeed |= !res;
  141. res &= TimeWindow > 0;
  142. return res;
  143. }
  144. void LaunchPacket() {
  145. PacketsInFly += 1.0f;
  146. TimeWindow -= 1.0f;
  147. }
  148. void RegisterRTT(float RTT) {
  149. if (RTT < 0)
  150. return;
  151. RTT = ClampVal(RTT, 0.0001f, 1.0f);
  152. if (RTT < MinRTT && MTU != 0) {
  153. MinRTT = RTT;
  154. CalcMaxWindow();
  155. }
  156. MinRTT = Min(MinRTT, RTT);
  157. PingTracker.RegisterRTT(RTT);
  158. if (AvrgRTT == 0)
  159. AvrgRTT = RTT;
  160. if (RTT > AvrgRTT) {
  161. ++HighRTTCounter;
  162. if (HighRTTCounter >= CONG_CTRL_RTT_SEQ_COUNT) {
  163. //printf("Too many high RTT in a row\n");
  164. if (FullSpeed) {
  165. float windowSubtract = Window * ((1 - CONG_CTRL_WINDOW_SHRINK_RTT) / CONG_CTRL_CHANNEL_INFLATE);
  166. Window = Max(CONG_CTRL_MIN_WINDOW, Window - windowSubtract);
  167. VirtualPackets = Max(0.f, VirtualPackets - windowSubtract);
  168. //printf("reducing window by RTT , new window %g\n", Window);
  169. }
  170. // reduce no more then twice per RTT
  171. HighRTTCounter = Min(0, CONG_CTRL_RTT_SEQ_COUNT - (int)(Window * 0.5));
  172. }
  173. } else {
  174. HighRTTCounter = Min(0, HighRTTCounter);
  175. }
  176. float rttMixRate = CONG_CTRL_RTT_MIX_RATE;
  177. AvrgRTT = AvrgRTT * rttMixRate + RTT * (1 - rttMixRate);
  178. }
  179. void Success() {
  180. PacketsInFly -= 1;
  181. Y_ASSERT(PacketsInFly >= 0);
  182. // FullSpeed should be correct at this point
  183. // we assume that after UpdateAlive() we send all packets first then we listen for acks and call Success()
  184. // FullSpeed is set in CanSend() during send if we are using full window
  185. // do not increaese window while send rate is limited by virtual packets (ie start of transfer)
  186. if (FullSpeed && VirtualPackets == 0) {
  187. // there are 2 requirements for window growth
  188. // 1) growth should be proportional to window size to ensure constant FailRate
  189. // 2) growth should be constant to ensure fairness among different flows
  190. // so lets make it square root :)
  191. Window += sqrt(Window / CONG_CTRL_CHANNEL_INFLATE) * CONG_CTRL_WINDOW_GROW;
  192. if (UseTOSforAcks) {
  193. Window = Min(Window, MaxWindow);
  194. }
  195. }
  196. FailRate *= 0.99f;
  197. }
  198. void FailureOnSend() {
  199. //printf("Failure on send\n");
  200. PacketsInFly -= 1;
  201. Y_ASSERT(PacketsInFly >= 0);
  202. // not a congestion event, do not modify Window
  203. // do not set FullSpeed since we are not using full Window
  204. }
  205. void Failure() {
  206. //printf("Congestion failure\n");
  207. PacketsInFly -= 1;
  208. Y_ASSERT(PacketsInFly >= 0);
  209. // account limited number of fails per segment
  210. if (TimeSinceLastFail > CONG_CTRL_MIN_FAIL_INTERVAL) {
  211. TimeSinceLastFail = 0;
  212. if (Window <= CONG_CTRL_MIN_WINDOW) {
  213. // ping dead hosts less frequently
  214. if (PingTracker.GetRTT() / CONG_CTRL_MIN_WINDOW < CONG_CTRL_MINIMAL_SEND_INTERVAL)
  215. PingTracker.IncreaseRTT();
  216. Window = CONG_CTRL_MIN_WINDOW;
  217. VirtualPackets = 0;
  218. } else {
  219. float windowSubtract = Window * ((1 - CONG_CTRL_WINDOW_SHRINK) / CONG_CTRL_CHANNEL_INFLATE);
  220. Window = Max(CONG_CTRL_MIN_WINDOW, Window - windowSubtract);
  221. VirtualPackets = Max(0.f, VirtualPackets - windowSubtract);
  222. }
  223. }
  224. FailRate = FailRate * 0.99f + 0.01f;
  225. }
  226. bool HasPacketsInFly() const {
  227. return PacketsInFly > 0;
  228. }
  229. float GetTimeout() const {
  230. return PingTracker.GetTimeout();
  231. }
  232. float GetWindow() const {
  233. return Window;
  234. }
  235. float GetRTT() const {
  236. return PingTracker.GetRTT();
  237. }
  238. float GetFailRate() const {
  239. return FailRate;
  240. }
  241. float GetTimeSinceLastRecv() const {
  242. return TimeSinceLastRecv;
  243. }
  244. int GetTransferCount() const {
  245. return ActiveTransferCount;
  246. }
  247. float GetMaxWindow() const {
  248. return UseTOSforAcks ? MaxWindow : -1;
  249. }
  250. void MarkAlive() {
  251. TimeSinceLastRecv = 0;
  252. with_lock (PortTesterLock) {
  253. PortTester = nullptr;
  254. }
  255. }
  256. void HasTriedToSend() {
  257. DoCountTime = true;
  258. }
  259. bool IsAlive() const {
  260. return TimeSinceLastRecv < 1e6f;
  261. }
  262. void Kill() {
  263. TimeSinceLastRecv = 1e6f;
  264. }
  265. bool UpdateAlive(const TUdpAddress& toAddress, float deltaT, float timeout, float* resMaxWaitTime) {
  266. if (!FullSpeed) {
  267. // create virtual packets during idle to avoid burst on transmit start
  268. if (AvrgRTT > CONG_CTRL_MIN_RTT_FOR_BURST_REDUCTION) {
  269. VirtualPackets = Max(VirtualPackets, Window - PacketsInFly - CONG_CTRL_ALLOWED_BURST_SIZE);
  270. }
  271. } else {
  272. if (VirtualPackets > 0) {
  273. if (Window <= CONG_CTRL_ALLOWED_BURST_SIZE) {
  274. VirtualPackets = 0;
  275. }
  276. float xRTT = AvrgRTT == 0 ? CONG_CTRL_INITIAL_RTT : AvrgRTT;
  277. float virtualPktsPerSecond = Window / xRTT;
  278. VirtualPackets = Max(0.f, VirtualPackets - deltaT * virtualPktsPerSecond);
  279. *resMaxWaitTime = Min(*resMaxWaitTime, 0.001f); // need to update virtual packets counter regularly
  280. }
  281. }
  282. float currentRTT = GetRTT();
  283. FractionRecalc += deltaT;
  284. if (FractionRecalc > currentRTT) {
  285. int cycleCount = (int)(FractionRecalc / currentRTT);
  286. FractionRecalc -= currentRTT * cycleCount;
  287. WindowFraction = (NetAckRnd() & 1023) * (1 / 1023.0f) / cycleCount;
  288. }
  289. if (MaxPacketRate > 0 && AvrgRTT > 0) {
  290. float maxTimeWindow = CONG_CTRL_TIME_WINDOW_LIMIT_PERIOD * MaxPacketRate;
  291. TimeWindow = Min(maxTimeWindow, TimeWindow + MaxPacketRate * deltaT);
  292. } else
  293. TimeWindow = CONG_CTRL_LARGE_TIME_WINDOW;
  294. // guarantee minimal send rate
  295. if (currentRTT > CONG_CTRL_MINIMAL_SEND_INTERVAL * Window) {
  296. Window = Max(CONG_CTRL_MIN_WINDOW, currentRTT / CONG_CTRL_MINIMAL_SEND_INTERVAL);
  297. VirtualPackets = 0;
  298. }
  299. TimeSinceLastFail += deltaT;
  300. //static int n;
  301. //if ((++n & 127) == 0)
  302. // printf("window = %g, fly = %g, VirtualPkts = %g, deltaT = %g, FailRate = %g FullSpeed = %d AvrgRTT = %g\n",
  303. // Window, PacketsInFly, VirtualPackets, deltaT * 1000, FailRate, (int)FullSpeed, AvrgRTT * 1000);
  304. if (PacketsInFly > 0 || FullSpeed || DoCountTime) {
  305. // считаем время только когда есть пакеты в полете
  306. TimeSinceLastRecv += deltaT;
  307. if (TimeSinceLastRecv > START_CHECK_PORT_DELAY) {
  308. if (TimeSinceLastRecv < FINISH_CHECK_PORT_DELAY) {
  309. TIntrusivePtr<TPortUnreachableTester> portTester;
  310. with_lock (PortTesterLock) {
  311. portTester = PortTester;
  312. }
  313. if (!portTester && AtomicGet(ActivePortTestersCount) < N_PORT_TEST_COUNT_LIMIT) {
  314. portTester = new TPortUnreachableTester();
  315. with_lock (PortTesterLock) {
  316. PortTester = portTester;
  317. }
  318. if (portTester->IsValid()) {
  319. portTester->Connect(toAddress);
  320. } else {
  321. with_lock (PortTesterLock) {
  322. PortTester = nullptr;
  323. }
  324. }
  325. }
  326. if (portTester && !portTester->Test(deltaT)) {
  327. Kill();
  328. return false;
  329. }
  330. } else {
  331. with_lock (PortTesterLock) {
  332. PortTester = nullptr;
  333. }
  334. }
  335. }
  336. if (TimeSinceLastRecv > timeout) {
  337. Kill();
  338. return false;
  339. }
  340. }
  341. FullSpeed = false;
  342. DoCountTime = false;
  343. if (MTUDiscovery.Get())
  344. MTUDiscovery->Step(deltaT);
  345. return true;
  346. }
  347. bool IsKnownMTU() const {
  348. return MTU != 0;
  349. }
  350. int GetMTU() const {
  351. return MTU;
  352. }
  353. TLameMTUDiscovery* GetMTUDiscovery() {
  354. if (MTUDiscovery.Get() == nullptr)
  355. MTUDiscovery = new TLameMTUDiscovery;
  356. return MTUDiscovery.Get();
  357. }
  358. void SetMTU(int sz) {
  359. MTU = sz;
  360. MTUDiscovery = nullptr;
  361. CalcMaxWindow();
  362. }
  363. void AttachQueueStats(TIntrusivePtr<TPeerQueueStats> s) {
  364. if (s.Get()) {
  365. s->Count = ActiveTransferCount;
  366. }
  367. Y_ASSERT(QueueStats.Get() == nullptr);
  368. QueueStats = s;
  369. }
  370. friend class TCongestionControlPtr;
  371. };
  372. class TCongestionControlPtr {
  373. TIntrusivePtr<TCongestionControl> Ptr;
  374. void Inc() {
  375. if (Ptr.Get()) {
  376. ++Ptr->ActiveTransferCount;
  377. if (Ptr->QueueStats.Get()) {
  378. Ptr->QueueStats->Count = Ptr->ActiveTransferCount;
  379. }
  380. }
  381. }
  382. void Dec() {
  383. if (Ptr.Get()) {
  384. --Ptr->ActiveTransferCount;
  385. if (Ptr->QueueStats.Get()) {
  386. Ptr->QueueStats->Count = Ptr->ActiveTransferCount;
  387. }
  388. }
  389. }
  390. public:
  391. TCongestionControlPtr() {
  392. }
  393. ~TCongestionControlPtr() {
  394. Dec();
  395. }
  396. TCongestionControlPtr(TCongestionControl* p)
  397. : Ptr(p)
  398. {
  399. Inc();
  400. }
  401. TCongestionControlPtr& operator=(const TCongestionControlPtr& a) {
  402. Dec();
  403. Ptr = a.Ptr;
  404. Inc();
  405. return *this;
  406. }
  407. TCongestionControlPtr& operator=(TCongestionControl* a) {
  408. Dec();
  409. Ptr = a;
  410. Inc();
  411. return *this;
  412. }
  413. operator TCongestionControl*() const {
  414. return Ptr.Get();
  415. }
  416. TCongestionControl* operator->() const {
  417. return Ptr.Get();
  418. }
  419. TIntrusivePtr<TCongestionControl> Get() const {
  420. return Ptr;
  421. }
  422. };
  423. class TAckTracker {
  424. struct TFlyingPacket {
  425. float T;
  426. int PktId;
  427. TFlyingPacket()
  428. : T(0)
  429. , PktId(-1)
  430. {
  431. }
  432. TFlyingPacket(float t, int pktId)
  433. : T(t)
  434. , PktId(pktId)
  435. {
  436. }
  437. };
  438. int PacketCount, CurrentPacket;
  439. typedef THashMap<int, float> TPacketHash;
  440. TPacketHash PacketsInFly, DroppedPackets;
  441. TVector<int> ResendQueue;
  442. TCongestionControlPtr Congestion;
  443. TVector<bool> AckReceived;
  444. float TimeToNextPacketTimeout;
  445. int SelectPacket();
  446. public:
  447. TAckTracker()
  448. : PacketCount(0)
  449. , CurrentPacket(0)
  450. , TimeToNextPacketTimeout(1000)
  451. {
  452. }
  453. ~TAckTracker();
  454. void AttachCongestionControl(TCongestionControl* p) {
  455. Congestion = p;
  456. }
  457. TIntrusivePtr<TCongestionControl> GetCongestionControl() const {
  458. return Congestion.Get();
  459. }
  460. void SetPacketCount(int n) {
  461. Y_ASSERT(PacketCount == 0);
  462. PacketCount = n;
  463. AckReceived.resize(n, false);
  464. }
  465. void Resend();
  466. bool IsInitialized() {
  467. return PacketCount != 0;
  468. }
  469. int GetPacketToSend(float deltaT);
  470. void AddToResend(int pkt); // called when failed to send packet
  471. void Ack(int pkt, float deltaT, bool updateRTT);
  472. void AckAll();
  473. void MarkAlive() {
  474. Congestion->MarkAlive();
  475. }
  476. bool IsAlive() const {
  477. return Congestion->IsAlive();
  478. }
  479. void Step(float deltaT);
  480. bool CanSend() const {
  481. return Congestion->CanSend();
  482. }
  483. float GetTimeToNextPacketTimeout() const {
  484. return TimeToNextPacketTimeout;
  485. }
  486. };
  487. }