CacheZone.cs 23 KB


  1. /*
  2. Technitium DNS Server
  3. Copyright (C) 2024 Shreyas Zare (shreyas@technitium.com)
  4. This program is free software: you can redistribute it and/or modify
  5. it under the terms of the GNU General Public License as published by
  6. the Free Software Foundation, either version 3 of the License, or
  7. (at your option) any later version.
  8. This program is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. GNU General Public License for more details.
  12. You should have received a copy of the GNU General Public License
  13. along with this program. If not, see <http://www.gnu.org/licenses/>.
  14. */
  15. using DnsServerCore.Dns.ResourceRecords;
  16. using System;
  17. using System.Collections.Concurrent;
  18. using System.Collections.Generic;
  19. using System.IO;
  20. using TechnitiumLibrary;
  21. using TechnitiumLibrary.Net;
  22. using TechnitiumLibrary.Net.Dns;
  23. using TechnitiumLibrary.Net.Dns.ResourceRecords;
  24. namespace DnsServerCore.Dns.Zones
  25. {
  26. class CacheZone : Zone
  27. {
  28. #region variables
  29. ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> _ecsEntries;
  30. #endregion
  31. #region constructor
  32. public CacheZone(string name, int capacity)
  33. : base(name, capacity)
  34. { }
  35. private CacheZone(string name, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries)
  36. : base(name, entries)
  37. { }
  38. #endregion
  39. #region static
  40. public static CacheZone ReadFrom(BinaryReader bR, bool serveStale)
  41. {
  42. byte version = bR.ReadByte();
  43. switch (version)
  44. {
  45. case 1:
  46. string name = bR.ReadString();
  47. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries = ReadEntriesFrom(bR, serveStale);
  48. CacheZone cacheZone = new CacheZone(name, entries);
  49. //write all ECS cache records
  50. {
  51. int ecsCount = bR.ReadInt32();
  52. if (ecsCount > 0)
  53. {
  54. ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntries = new ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>>(1, ecsCount);
  55. for (int i = 0; i < ecsCount; i++)
  56. {
  57. NetworkAddress key = NetworkAddress.ReadFrom(bR);
  58. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> ecsEntry = ReadEntriesFrom(bR, serveStale);
  59. if (!ecsEntry.IsEmpty)
  60. ecsEntries.TryAdd(key, ecsEntry);
  61. }
  62. if (!ecsEntries.IsEmpty)
  63. cacheZone._ecsEntries = ecsEntries;
  64. }
  65. }
  66. return cacheZone;
  67. default:
  68. throw new InvalidDataException("CacheZone format version not supported.");
  69. }
  70. }
  71. #endregion
  72. #region private
  73. private static IReadOnlyList<DnsResourceRecord> ValidateRRSet(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, bool serveStale, bool skipSpecialCacheRecord)
  74. {
  75. foreach (DnsResourceRecord record in records)
  76. {
  77. if (record.IsExpired(serveStale))
  78. return Array.Empty<DnsResourceRecord>(); //RR Set is expired
  79. if (skipSpecialCacheRecord && (record.RDATA is DnsCache.DnsSpecialCacheRecordData))
  80. return Array.Empty<DnsResourceRecord>(); //RR Set is special cache record
  81. }
  82. if (records.Count > 1)
  83. {
  84. switch (type)
  85. {
  86. case DnsResourceRecordType.A:
  87. case DnsResourceRecordType.AAAA:
  88. List<DnsResourceRecord> newRecords = new List<DnsResourceRecord>(records);
  89. newRecords.Shuffle(); //shuffle records to allow load balancing
  90. return newRecords;
  91. }
  92. }
  93. //update last used on
  94. DateTime utcNow = DateTime.UtcNow;
  95. foreach (DnsResourceRecord record in records)
  96. record.GetCacheRecordInfo().LastUsedOn = utcNow;
  97. return records;
  98. }
  99. private static ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> ReadEntriesFrom(BinaryReader bR, bool serveStale)
  100. {
  101. int count = bR.ReadInt32();
  102. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, count);
  103. for (int i = 0; i < count; i++)
  104. {
  105. DnsResourceRecordType key = (DnsResourceRecordType)bR.ReadUInt16();
  106. int rrCount = bR.ReadInt32();
  107. DnsResourceRecord[] records = new DnsResourceRecord[rrCount];
  108. for (int j = 0; j < rrCount; j++)
  109. {
  110. records[j] = DnsResourceRecord.ReadCacheRecordFrom(bR, delegate (DnsResourceRecord record)
  111. {
  112. record.Tag = new CacheRecordInfo(bR);
  113. });
  114. }
  115. if (!DnsResourceRecord.IsRRSetExpired(records, serveStale))
  116. entries.TryAdd(key, records);
  117. }
  118. return entries;
  119. }
  120. private static void WriteEntriesTo(ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries, BinaryWriter bW)
  121. {
  122. bW.Write(entries.Count);
  123. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in entries)
  124. {
  125. bW.Write((ushort)entry.Key);
  126. bW.Write(entry.Value.Count);
  127. foreach (DnsResourceRecord record in entry.Value)
  128. {
  129. record.WriteCacheRecordTo(bW, delegate ()
  130. {
  131. if (record.Tag is not CacheRecordInfo rrInfo)
  132. rrInfo = CacheRecordInfo.Default; //default info
  133. rrInfo.WriteTo(bW);
  134. });
  135. }
  136. }
  137. }
  138. #endregion
  139. #region public
  140. public bool SetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, bool serveStale)
  141. {
  142. if (records.Count == 0)
  143. return false;
  144. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries;
  145. CacheRecordInfo cacheRecordInfo = records[0].GetCacheRecordInfo();
  146. NetworkAddress eDnsClientSubnet = cacheRecordInfo.EDnsClientSubnet;
  147. if (eDnsClientSubnet is null)
  148. {
  149. entries = _entries;
  150. }
  151. else
  152. {
  153. if (_ecsEntries is null)
  154. {
  155. _ecsEntries = new ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>>(1, 5);
  156. entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, 1);
  157. if (!_ecsEntries.TryAdd(eDnsClientSubnet, entries))
  158. return false;
  159. }
  160. else if (!_ecsEntries.TryGetValue(eDnsClientSubnet, out entries))
  161. {
  162. entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, 1);
  163. if (!_ecsEntries.TryAdd(eDnsClientSubnet, entries))
  164. return false;
  165. }
  166. }
  167. bool isFailureRecord = false;
  168. if (records[0].RDATA is DnsCache.DnsSpecialCacheRecordData splRecord)
  169. {
  170. if (splRecord.IsFailureOrBadCache)
  171. {
  172. //call trying to cache failure record
  173. isFailureRecord = true;
  174. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords) && (existingRecords.Count > 0) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale))
  175. {
  176. if ((existingRecords[0].RDATA is not DnsCache.DnsSpecialCacheRecordData existingSplRecord) || !existingSplRecord.IsFailureOrBadCache)
  177. return false; //skip to avoid overwriting a useful record with a failure record
  178. //copy extended errors from existing spl record
  179. splRecord.CopyExtendedDnsErrorsFrom(existingSplRecord);
  180. }
  181. }
  182. }
  183. else if ((type == DnsResourceRecordType.NS) && (records[0].RDATA is DnsNSRecordData ns) && !ns.IsParentSideTtlSet)
  184. {
  185. //for ns revalidation
  186. if (entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> existingNSRecords))
  187. {
  188. if ((existingNSRecords.Count > 0) && (existingNSRecords[0].RDATA is DnsNSRecordData existingNS) && existingNS.IsParentSideTtlSet)
  189. {
  190. uint parentSideTtl = existingNS.ParentSideTtl;
  191. foreach (DnsResourceRecord record in records)
  192. (record.RDATA as DnsNSRecordData).ParentSideTtl = parentSideTtl;
  193. }
  194. }
  195. }
  196. //set last used date time
  197. DateTime utcNow = DateTime.UtcNow;
  198. foreach (DnsResourceRecord record in records)
  199. record.GetCacheRecordInfo().LastUsedOn = utcNow;
  200. //set records
  201. bool added = true;
  202. entries.AddOrUpdate(type, records, delegate (DnsResourceRecordType key, IReadOnlyList<DnsResourceRecord> existingRecords)
  203. {
  204. added = false;
  205. return records;
  206. });
  207. if (serveStale && !isFailureRecord)
  208. {
  209. //remove stale CNAME entry only when serve stale is enabled
  210. //making sure current record is not a failure record causing removal of useful stale CNAME record
  211. switch (type)
  212. {
  213. case DnsResourceRecordType.CNAME:
  214. case DnsResourceRecordType.SOA:
  215. case DnsResourceRecordType.NS:
  216. case DnsResourceRecordType.DS:
  217. //do nothing
  218. break;
  219. default:
  220. //remove stale CNAME entry since current new entry type overlaps any existing CNAME entry in cache
  221. //keeping both entries will create issue with serve stale implementation since stale CNAME entry will be always returned
  222. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  223. {
  224. if ((existingCNAMERecords.Count > 0) && (existingCNAMERecords[0].RDATA is DnsCNAMERecordData) && existingCNAMERecords[0].IsStale)
  225. {
  226. //delete CNAME entry only when it contains stale DnsCNAMERecord RDATA and not special cache records
  227. entries.TryRemove(DnsResourceRecordType.CNAME, out _);
  228. }
  229. }
  230. break;
  231. }
  232. }
  233. return added;
  234. }
  235. public int RemoveExpiredRecords(bool serveStale)
  236. {
  237. int removedEntries = 0;
  238. if (_ecsEntries is not null)
  239. {
  240. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  241. {
  242. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  243. {
  244. if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale))
  245. {
  246. if (ecsEntry.Value.TryRemove(entry.Key, out _)) //RR Set is expired; remove entry
  247. removedEntries++;
  248. }
  249. }
  250. if (ecsEntry.Value.IsEmpty)
  251. _ecsEntries.TryRemove(ecsEntry.Key, out _);
  252. }
  253. }
  254. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
  255. {
  256. if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale))
  257. {
  258. if (_entries.TryRemove(entry.Key, out _)) //RR Set is expired; remove entry
  259. removedEntries++;
  260. }
  261. }
  262. return removedEntries;
  263. }
  264. public int RemoveLeastUsedRecords(DateTime cutoff)
  265. {
  266. int removedEntries = 0;
  267. if (_ecsEntries is not null)
  268. {
  269. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  270. {
  271. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  272. {
  273. if ((entry.Value.Count == 0) || (entry.Value[0].GetCacheRecordInfo().LastUsedOn < cutoff))
  274. {
  275. if (ecsEntry.Value.TryRemove(entry.Key, out _)) //RR Set was last used before cutoff; remove entry
  276. removedEntries++;
  277. }
  278. }
  279. if (ecsEntry.Value.IsEmpty)
  280. _ecsEntries.TryRemove(ecsEntry.Key, out _);
  281. }
  282. }
  283. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
  284. {
  285. if ((entry.Value.Count == 0) || (entry.Value[0].GetCacheRecordInfo().LastUsedOn < cutoff))
  286. {
  287. if (_entries.TryRemove(entry.Key, out _)) //RR Set was last used before cutoff; remove entry
  288. removedEntries++;
  289. }
  290. }
  291. return removedEntries;
  292. }
  293. public int DeleteEDnsClientSubnetData()
  294. {
  295. if (_ecsEntries is null)
  296. return 0;
  297. int count = 0;
  298. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  299. count += ecsEntry.Value.Count;
  300. _ecsEntries = null;
  301. return count;
  302. }
  303. public IReadOnlyList<DnsResourceRecord> QueryRecords(DnsResourceRecordType type, bool serveStale, bool skipSpecialCacheRecord, NetworkAddress eDnsClientSubnet, bool advancedForwardingClientSubnet)
  304. {
  305. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries;
  306. if (eDnsClientSubnet is null)
  307. {
  308. entries = _entries;
  309. }
  310. else
  311. {
  312. if (_ecsEntries is null)
  313. return Array.Empty<DnsResourceRecord>();
  314. NetworkAddress selectedNetwork = null;
  315. entries = null;
  316. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  317. {
  318. NetworkAddress cacheSubnet = ecsEntry.Key;
  319. if (cacheSubnet.PrefixLength > eDnsClientSubnet.PrefixLength)
  320. continue;
  321. if (advancedForwardingClientSubnet)
  322. {
  323. if (cacheSubnet.Equals(eDnsClientSubnet))
  324. {
  325. if ((selectedNetwork is null) || (cacheSubnet.PrefixLength > selectedNetwork.PrefixLength))
  326. {
  327. selectedNetwork = cacheSubnet;
  328. entries = ecsEntry.Value;
  329. }
  330. }
  331. }
  332. else
  333. {
  334. if (cacheSubnet.Equals(eDnsClientSubnet) || cacheSubnet.Contains(eDnsClientSubnet.Address))
  335. {
  336. if ((selectedNetwork is null) || (cacheSubnet.PrefixLength < selectedNetwork.PrefixLength))
  337. {
  338. selectedNetwork = cacheSubnet;
  339. entries = ecsEntry.Value;
  340. }
  341. }
  342. }
  343. }
  344. if (entries is null)
  345. return Array.Empty<DnsResourceRecord>();
  346. }
  347. switch (type)
  348. {
  349. case DnsResourceRecordType.DS:
  350. {
  351. //since some zones have CNAME at apex so no CNAME lookup for DS queries!
  352. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  353. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  354. }
  355. break;
  356. case DnsResourceRecordType.SOA:
  357. case DnsResourceRecordType.DNSKEY:
  358. {
  359. //since some zones have CNAME at apex!
  360. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  361. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  362. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  363. {
  364. IReadOnlyList<DnsResourceRecord> rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord);
  365. if (rrset.Count > 0)
  366. {
  367. if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecordData))
  368. return rrset;
  369. }
  370. }
  371. }
  372. break;
  373. case DnsResourceRecordType.ANY:
  374. List<DnsResourceRecord> anyRecords = new List<DnsResourceRecord>(entries.Count * 2);
  375. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in entries)
  376. {
  377. if (entry.Key == DnsResourceRecordType.DS)
  378. continue;
  379. anyRecords.AddRange(ValidateRRSet(type, entry.Value, serveStale, true));
  380. }
  381. return anyRecords;
  382. default:
  383. {
  384. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  385. {
  386. IReadOnlyList<DnsResourceRecord> rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord);
  387. if (rrset.Count > 0)
  388. {
  389. if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecordData))
  390. return rrset;
  391. }
  392. }
  393. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  394. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  395. }
  396. break;
  397. }
  398. return Array.Empty<DnsResourceRecord>();
  399. }
  400. public override void ListAllRecords(List<DnsResourceRecord> records)
  401. {
  402. if (_ecsEntries is not null)
  403. {
  404. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  405. {
  406. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  407. records.AddRange(entry.Value);
  408. }
  409. }
  410. base.ListAllRecords(records);
  411. }
  412. public override bool ContainsNameServerRecords()
  413. {
  414. if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> records))
  415. return false;
  416. foreach (DnsResourceRecord record in records)
  417. {
  418. if (record.IsStale)
  419. continue;
  420. if (record.RDATA is DnsNSRecordData)
  421. return true;
  422. }
  423. return false;
  424. }
  425. public void WriteTo(BinaryWriter bW)
  426. {
  427. bW.Write((byte)1); //version
  428. //cache zone info
  429. bW.Write(_name);
  430. //write all cache records
  431. WriteEntriesTo(_entries, bW);
  432. //write all ECS cache records
  433. if (_ecsEntries is null)
  434. {
  435. bW.Write(0);
  436. }
  437. else
  438. {
  439. bW.Write(_ecsEntries.Count);
  440. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  441. {
  442. ecsEntry.Key.WriteTo(bW);
  443. WriteEntriesTo(ecsEntry.Value, bW);
  444. }
  445. }
  446. }
  447. #endregion
  448. #region properties
  449. public override bool IsEmpty
  450. {
  451. get
  452. {
  453. if (_ecsEntries is null)
  454. return _entries.IsEmpty;
  455. return _ecsEntries.IsEmpty && _entries.IsEmpty;
  456. }
  457. }
  458. public int TotalEntries
  459. {
  460. get
  461. {
  462. if (_ecsEntries is null)
  463. return _entries.Count;
  464. int count = _entries.Count;
  465. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  466. count += ecsEntry.Value.Count;
  467. return count;
  468. }
  469. }
  470. #endregion
  471. }
  472. }