CacheZone.cs 23 KB


  1. /*
  2. Technitium DNS Server
  3. Copyright (C) 2023 Shreyas Zare (shreyas@technitium.com)
  4. This program is free software: you can redistribute it and/or modify
  5. it under the terms of the GNU General Public License as published by
  6. the Free Software Foundation, either version 3 of the License, or
  7. (at your option) any later version.
  8. This program is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. GNU General Public License for more details.
  12. You should have received a copy of the GNU General Public License
  13. along with this program. If not, see <http://www.gnu.org/licenses/>.
  14. */
  15. using DnsServerCore.Dns.ResourceRecords;
  16. using System;
  17. using System.Collections.Concurrent;
  18. using System.Collections.Generic;
  19. using System.IO;
  20. using TechnitiumLibrary;
  21. using TechnitiumLibrary.Net;
  22. using TechnitiumLibrary.Net.Dns;
  23. using TechnitiumLibrary.Net.Dns.ResourceRecords;
  24. namespace DnsServerCore.Dns.Zones
  25. {
  26. class CacheZone : Zone
  27. {
  28. #region variables
  29. ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> _ecsEntries;
  30. #endregion
  31. #region constructor
  32. public CacheZone(string name, int capacity)
  33. : base(name, capacity)
  34. { }
  35. private CacheZone(string name, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries)
  36. : base(name, entries)
  37. { }
  38. #endregion
  39. #region static
  40. public static CacheZone ReadFrom(BinaryReader bR, bool serveStale)
  41. {
  42. byte version = bR.ReadByte();
  43. switch (version)
  44. {
  45. case 1:
  46. string name = bR.ReadString();
  47. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries = ReadEntriesFrom(bR, serveStale);
  48. CacheZone cacheZone = new CacheZone(name, entries);
  49. //write all ECS cache records
  50. {
  51. int ecsCount = bR.ReadInt32();
  52. if (ecsCount > 0)
  53. {
  54. ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntries = new ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>>(1, ecsCount);
  55. for (int i = 0; i < ecsCount; i++)
  56. {
  57. NetworkAddress key = NetworkAddress.ReadFrom(bR);
  58. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> ecsEntry = ReadEntriesFrom(bR, serveStale);
  59. if (!ecsEntry.IsEmpty)
  60. ecsEntries.TryAdd(key, ecsEntry);
  61. }
  62. if (!ecsEntries.IsEmpty)
  63. cacheZone._ecsEntries = ecsEntries;
  64. }
  65. }
  66. return cacheZone;
  67. default:
  68. throw new InvalidDataException("CacheZone format version not supported.");
  69. }
  70. }
  71. public static bool IsTypeSupportedForEDnsClientSubnet(DnsResourceRecordType type)
  72. {
  73. switch (type)
  74. {
  75. case DnsResourceRecordType.A:
  76. case DnsResourceRecordType.AAAA:
  77. case DnsResourceRecordType.CNAME:
  78. return true;
  79. default:
  80. return false;
  81. }
  82. }
  83. #endregion
  84. #region private
  85. private static IReadOnlyList<DnsResourceRecord> ValidateRRSet(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, bool serveStale, bool skipSpecialCacheRecord)
  86. {
  87. foreach (DnsResourceRecord record in records)
  88. {
  89. if (record.IsExpired(serveStale))
  90. return Array.Empty<DnsResourceRecord>(); //RR Set is expired
  91. if (skipSpecialCacheRecord && (record.RDATA is DnsCache.DnsSpecialCacheRecordData))
  92. return Array.Empty<DnsResourceRecord>(); //RR Set is special cache record
  93. }
  94. if (records.Count > 1)
  95. {
  96. switch (type)
  97. {
  98. case DnsResourceRecordType.A:
  99. case DnsResourceRecordType.AAAA:
  100. List<DnsResourceRecord> newRecords = new List<DnsResourceRecord>(records);
  101. newRecords.Shuffle(); //shuffle records to allow load balancing
  102. return newRecords;
  103. }
  104. }
  105. //update last used on
  106. DateTime utcNow = DateTime.UtcNow;
  107. foreach (DnsResourceRecord record in records)
  108. record.GetCacheRecordInfo().LastUsedOn = utcNow;
  109. return records;
  110. }
  111. private static ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> ReadEntriesFrom(BinaryReader bR, bool serveStale)
  112. {
  113. int count = bR.ReadInt32();
  114. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, count);
  115. for (int i = 0; i < count; i++)
  116. {
  117. DnsResourceRecordType key = (DnsResourceRecordType)bR.ReadUInt16();
  118. int rrCount = bR.ReadInt32();
  119. DnsResourceRecord[] records = new DnsResourceRecord[rrCount];
  120. for (int j = 0; j < rrCount; j++)
  121. {
  122. records[j] = DnsResourceRecord.ReadCacheRecordFrom(bR, delegate (DnsResourceRecord record)
  123. {
  124. record.Tag = new CacheRecordInfo(bR);
  125. });
  126. }
  127. if (!DnsResourceRecord.IsRRSetExpired(records, serveStale))
  128. entries.TryAdd(key, records);
  129. }
  130. return entries;
  131. }
  132. private static void WriteEntriesTo(ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries, BinaryWriter bW)
  133. {
  134. bW.Write(entries.Count);
  135. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in entries)
  136. {
  137. bW.Write((ushort)entry.Key);
  138. bW.Write(entry.Value.Count);
  139. foreach (DnsResourceRecord record in entry.Value)
  140. {
  141. record.WriteCacheRecordTo(bW, delegate ()
  142. {
  143. if (record.Tag is not CacheRecordInfo rrInfo)
  144. rrInfo = CacheRecordInfo.Default; //default info
  145. rrInfo.WriteTo(bW);
  146. });
  147. }
  148. }
  149. }
  150. #endregion
  151. #region public
  152. public bool SetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, bool serveStale)
  153. {
  154. if (records.Count == 0)
  155. return false;
  156. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries;
  157. NetworkAddress eDnsClientSubnet = records[0].GetCacheRecordInfo().EDnsClientSubnet;
  158. if ((eDnsClientSubnet is null) || !IsTypeSupportedForEDnsClientSubnet(type))
  159. {
  160. entries = _entries;
  161. }
  162. else
  163. {
  164. if (_ecsEntries is null)
  165. {
  166. _ecsEntries = new ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>>(1, 5);
  167. entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, 1);
  168. if (!_ecsEntries.TryAdd(eDnsClientSubnet, entries))
  169. return false;
  170. }
  171. else if (!_ecsEntries.TryGetValue(eDnsClientSubnet, out entries))
  172. {
  173. entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, 1);
  174. if (!_ecsEntries.TryAdd(eDnsClientSubnet, entries))
  175. return false;
  176. }
  177. }
  178. bool isFailureRecord = false;
  179. if (records[0].RDATA is DnsCache.DnsSpecialCacheRecordData splRecord)
  180. {
  181. if (splRecord.IsFailureOrBadCache)
  182. {
  183. //call trying to cache failure record
  184. isFailureRecord = true;
  185. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords) && (existingRecords.Count > 0) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale))
  186. {
  187. if ((existingRecords[0].RDATA is not DnsCache.DnsSpecialCacheRecordData existingSplRecord) || !existingSplRecord.IsFailureOrBadCache)
  188. return false; //skip to avoid overwriting a useful record with a failure record
  189. //copy extended errors from existing spl record
  190. splRecord.CopyExtendedDnsErrorsFrom(existingSplRecord);
  191. }
  192. }
  193. }
  194. else if ((type == DnsResourceRecordType.NS) && (records[0].RDATA is DnsNSRecordData ns) && !ns.IsParentSideTtlSet)
  195. {
  196. //for ns revalidation
  197. if (entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> existingNSRecords))
  198. {
  199. if ((existingNSRecords.Count > 0) && (existingNSRecords[0].RDATA is DnsNSRecordData existingNS) && existingNS.IsParentSideTtlSet)
  200. {
  201. uint parentSideTtl = existingNS.ParentSideTtl;
  202. foreach (DnsResourceRecord record in records)
  203. (record.RDATA as DnsNSRecordData).ParentSideTtl = parentSideTtl;
  204. }
  205. }
  206. }
  207. //set last used date time
  208. DateTime utcNow = DateTime.UtcNow;
  209. foreach (DnsResourceRecord record in records)
  210. record.GetCacheRecordInfo().LastUsedOn = utcNow;
  211. //set records
  212. bool added = true;
  213. entries.AddOrUpdate(type, records, delegate (DnsResourceRecordType key, IReadOnlyList<DnsResourceRecord> existingRecords)
  214. {
  215. added = false;
  216. return records;
  217. });
  218. if (serveStale && !isFailureRecord)
  219. {
  220. //remove stale CNAME entry only when serve stale is enabled
  221. //making sure current record is not a failure record causing removal of useful stale CNAME record
  222. switch (type)
  223. {
  224. case DnsResourceRecordType.CNAME:
  225. case DnsResourceRecordType.SOA:
  226. case DnsResourceRecordType.NS:
  227. case DnsResourceRecordType.DS:
  228. //do nothing
  229. break;
  230. default:
  231. //remove stale CNAME entry since current new entry type overlaps any existing CNAME entry in cache
  232. //keeping both entries will create issue with serve stale implementation since stale CNAME entry will be always returned
  233. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  234. {
  235. if ((existingCNAMERecords.Count > 0) && (existingCNAMERecords[0].RDATA is DnsCNAMERecordData) && existingCNAMERecords[0].IsStale)
  236. {
  237. //delete CNAME entry only when it contains stale DnsCNAMERecord RDATA and not special cache records
  238. entries.TryRemove(DnsResourceRecordType.CNAME, out _);
  239. }
  240. }
  241. break;
  242. }
  243. }
  244. return added;
  245. }
  246. public int RemoveExpiredRecords(bool serveStale)
  247. {
  248. int removedEntries = 0;
  249. if (_ecsEntries is not null)
  250. {
  251. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  252. {
  253. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  254. {
  255. if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale))
  256. {
  257. if (ecsEntry.Value.TryRemove(entry.Key, out _)) //RR Set is expired; remove entry
  258. removedEntries++;
  259. }
  260. }
  261. if (ecsEntry.Value.IsEmpty)
  262. _ecsEntries.TryRemove(ecsEntry.Key, out _);
  263. }
  264. }
  265. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
  266. {
  267. if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale))
  268. {
  269. if (_entries.TryRemove(entry.Key, out _)) //RR Set is expired; remove entry
  270. removedEntries++;
  271. }
  272. }
  273. return removedEntries;
  274. }
  275. public int RemoveLeastUsedRecords(DateTime cutoff)
  276. {
  277. int removedEntries = 0;
  278. if (_ecsEntries is not null)
  279. {
  280. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  281. {
  282. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  283. {
  284. if ((entry.Value.Count == 0) || (entry.Value[0].GetCacheRecordInfo().LastUsedOn < cutoff))
  285. {
  286. if (ecsEntry.Value.TryRemove(entry.Key, out _)) //RR Set was last used before cutoff; remove entry
  287. removedEntries++;
  288. }
  289. }
  290. if (ecsEntry.Value.IsEmpty)
  291. _ecsEntries.TryRemove(ecsEntry.Key, out _);
  292. }
  293. }
  294. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
  295. {
  296. if ((entry.Value.Count == 0) || (entry.Value[0].GetCacheRecordInfo().LastUsedOn < cutoff))
  297. {
  298. if (_entries.TryRemove(entry.Key, out _)) //RR Set was last used before cutoff; remove entry
  299. removedEntries++;
  300. }
  301. }
  302. return removedEntries;
  303. }
  304. public int DeleteEDnsClientSubnetData()
  305. {
  306. if (_ecsEntries is null)
  307. return 0;
  308. int count = 0;
  309. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  310. count += ecsEntry.Value.Count;
  311. _ecsEntries = null;
  312. return count;
  313. }
  314. public IReadOnlyList<DnsResourceRecord> QueryRecords(DnsResourceRecordType type, bool serveStale, bool skipSpecialCacheRecord, NetworkAddress eDnsClientSubnet)
  315. {
  316. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries;
  317. if ((eDnsClientSubnet is null) || !IsTypeSupportedForEDnsClientSubnet(type))
  318. {
  319. entries = _entries;
  320. }
  321. else
  322. {
  323. if (_ecsEntries is null)
  324. return Array.Empty<DnsResourceRecord>();
  325. NetworkAddress selectedNetwork = null;
  326. entries = null;
  327. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  328. {
  329. NetworkAddress cacheSubnet = ecsEntry.Key;
  330. if (cacheSubnet.PrefixLength > eDnsClientSubnet.PrefixLength)
  331. continue;
  332. if (cacheSubnet.Equals(eDnsClientSubnet) || cacheSubnet.Contains(eDnsClientSubnet.Address))
  333. {
  334. if ((selectedNetwork is null) || (cacheSubnet.PrefixLength > selectedNetwork.PrefixLength))
  335. {
  336. selectedNetwork = cacheSubnet;
  337. entries = ecsEntry.Value;
  338. }
  339. }
  340. }
  341. if (entries is null)
  342. return Array.Empty<DnsResourceRecord>();
  343. }
  344. switch (type)
  345. {
  346. case DnsResourceRecordType.DS:
  347. {
  348. //since some zones have CNAME at apex so no CNAME lookup for DS queries!
  349. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  350. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  351. }
  352. break;
  353. case DnsResourceRecordType.SOA:
  354. case DnsResourceRecordType.DNSKEY:
  355. {
  356. //since some zones have CNAME at apex!
  357. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  358. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  359. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  360. {
  361. IReadOnlyList<DnsResourceRecord> rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord);
  362. if (rrset.Count > 0)
  363. {
  364. if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecordData))
  365. return rrset;
  366. }
  367. }
  368. }
  369. break;
  370. case DnsResourceRecordType.ANY:
  371. List<DnsResourceRecord> anyRecords = new List<DnsResourceRecord>(entries.Count * 2);
  372. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in entries)
  373. {
  374. if (entry.Key == DnsResourceRecordType.DS)
  375. continue;
  376. anyRecords.AddRange(ValidateRRSet(type, entry.Value, serveStale, true));
  377. }
  378. return anyRecords;
  379. default:
  380. {
  381. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  382. {
  383. IReadOnlyList<DnsResourceRecord> rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord);
  384. if (rrset.Count > 0)
  385. {
  386. if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecordData))
  387. return rrset;
  388. }
  389. }
  390. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  391. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  392. }
  393. break;
  394. }
  395. return Array.Empty<DnsResourceRecord>();
  396. }
  397. public override void ListAllRecords(List<DnsResourceRecord> records)
  398. {
  399. if (_ecsEntries is not null)
  400. {
  401. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  402. {
  403. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  404. records.AddRange(entry.Value);
  405. }
  406. }
  407. base.ListAllRecords(records);
  408. }
  409. public override bool ContainsNameServerRecords()
  410. {
  411. if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> records))
  412. return false;
  413. foreach (DnsResourceRecord record in records)
  414. {
  415. if (record.IsStale)
  416. continue;
  417. if (record.RDATA is DnsNSRecordData)
  418. return true;
  419. }
  420. return false;
  421. }
  422. public void WriteTo(BinaryWriter bW)
  423. {
  424. bW.Write((byte)1); //version
  425. //cache zone info
  426. bW.Write(_name);
  427. //write all cache records
  428. WriteEntriesTo(_entries, bW);
  429. //write all ECS cache records
  430. if (_ecsEntries is null)
  431. {
  432. bW.Write(0);
  433. }
  434. else
  435. {
  436. bW.Write(_ecsEntries.Count);
  437. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  438. {
  439. ecsEntry.Key.WriteTo(bW);
  440. WriteEntriesTo(ecsEntry.Value, bW);
  441. }
  442. }
  443. }
  444. #endregion
  445. #region properties
  446. public override bool IsEmpty
  447. {
  448. get
  449. {
  450. if (_ecsEntries is null)
  451. return _entries.IsEmpty;
  452. return _ecsEntries.IsEmpty && _entries.IsEmpty;
  453. }
  454. }
  455. public int TotalEntries
  456. {
  457. get
  458. {
  459. if (_ecsEntries is null)
  460. return _entries.Count;
  461. int count = _entries.Count;
  462. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  463. count += ecsEntry.Value.Count;
  464. return count;
  465. }
  466. }
  467. #endregion
  468. }
  469. }