CacheZone.cs 22 KB


  1. /*
  2. Technitium DNS Server
  3. Copyright (C) 2023 Shreyas Zare (shreyas@technitium.com)
  4. This program is free software: you can redistribute it and/or modify
  5. it under the terms of the GNU General Public License as published by
  6. the Free Software Foundation, either version 3 of the License, or
  7. (at your option) any later version.
  8. This program is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. GNU General Public License for more details.
  12. You should have received a copy of the GNU General Public License
  13. along with this program. If not, see <http://www.gnu.org/licenses/>.
  14. */
  15. using DnsServerCore.Dns.ResourceRecords;
  16. using System;
  17. using System.Collections.Concurrent;
  18. using System.Collections.Generic;
  19. using System.IO;
  20. using TechnitiumLibrary;
  21. using TechnitiumLibrary.Net;
  22. using TechnitiumLibrary.Net.Dns;
  23. using TechnitiumLibrary.Net.Dns.ResourceRecords;
  24. namespace DnsServerCore.Dns.Zones
  25. {
  26. class CacheZone : Zone
  27. {
  28. #region variables
  29. ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> _ecsEntries;
  30. #endregion
  31. #region constructor
  32. public CacheZone(string name, int capacity)
  33. : base(name, capacity)
  34. { }
  35. private CacheZone(string name, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries)
  36. : base(name, entries)
  37. { }
  38. #endregion
  39. #region static
  40. public static CacheZone ReadFrom(BinaryReader bR, bool serveStale)
  41. {
  42. byte version = bR.ReadByte();
  43. switch (version)
  44. {
  45. case 1:
  46. string name = bR.ReadString();
  47. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries = ReadEntriesFrom(bR, serveStale);
  48. CacheZone cacheZone = new CacheZone(name, entries);
  49. //write all ECS cache records
  50. {
  51. int ecsCount = bR.ReadInt32();
  52. if (ecsCount > 0)
  53. {
  54. ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntries = new ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>>(1, ecsCount);
  55. for (int i = 0; i < ecsCount; i++)
  56. {
  57. NetworkAddress key = NetworkAddress.ReadFrom(bR);
  58. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> ecsEntry = ReadEntriesFrom(bR, serveStale);
  59. if (!ecsEntry.IsEmpty)
  60. ecsEntries.TryAdd(key, ecsEntry);
  61. }
  62. if (!ecsEntries.IsEmpty)
  63. cacheZone._ecsEntries = ecsEntries;
  64. }
  65. }
  66. return cacheZone;
  67. default:
  68. throw new InvalidDataException("CacheZone format version not supported.");
  69. }
  70. }
  71. #endregion
  72. #region private
  73. private static IReadOnlyList<DnsResourceRecord> ValidateRRSet(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, bool serveStale, bool skipSpecialCacheRecord)
  74. {
  75. foreach (DnsResourceRecord record in records)
  76. {
  77. if (record.IsExpired(serveStale))
  78. return Array.Empty<DnsResourceRecord>(); //RR Set is expired
  79. if (skipSpecialCacheRecord && (record.RDATA is DnsCache.DnsSpecialCacheRecordData))
  80. return Array.Empty<DnsResourceRecord>(); //RR Set is special cache record
  81. }
  82. if (records.Count > 1)
  83. {
  84. switch (type)
  85. {
  86. case DnsResourceRecordType.A:
  87. case DnsResourceRecordType.AAAA:
  88. List<DnsResourceRecord> newRecords = new List<DnsResourceRecord>(records);
  89. newRecords.Shuffle(); //shuffle records to allow load balancing
  90. return newRecords;
  91. }
  92. }
  93. //update last used on
  94. DateTime utcNow = DateTime.UtcNow;
  95. foreach (DnsResourceRecord record in records)
  96. record.GetCacheRecordInfo().LastUsedOn = utcNow;
  97. return records;
  98. }
  99. private static ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> ReadEntriesFrom(BinaryReader bR, bool serveStale)
  100. {
  101. int count = bR.ReadInt32();
  102. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, count);
  103. for (int i = 0; i < count; i++)
  104. {
  105. DnsResourceRecordType key = (DnsResourceRecordType)bR.ReadUInt16();
  106. int rrCount = bR.ReadInt32();
  107. DnsResourceRecord[] records = new DnsResourceRecord[rrCount];
  108. for (int j = 0; j < rrCount; j++)
  109. {
  110. records[j] = DnsResourceRecord.ReadCacheRecordFrom(bR, delegate (DnsResourceRecord record)
  111. {
  112. record.Tag = new CacheRecordInfo(bR);
  113. });
  114. }
  115. if (!DnsResourceRecord.IsRRSetExpired(records, serveStale))
  116. entries.TryAdd(key, records);
  117. }
  118. return entries;
  119. }
  120. private static void WriteEntriesTo(ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries, BinaryWriter bW)
  121. {
  122. bW.Write(entries.Count);
  123. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in entries)
  124. {
  125. bW.Write((ushort)entry.Key);
  126. bW.Write(entry.Value.Count);
  127. foreach (DnsResourceRecord record in entry.Value)
  128. {
  129. record.WriteCacheRecordTo(bW, delegate ()
  130. {
  131. if (record.Tag is not CacheRecordInfo rrInfo)
  132. rrInfo = CacheRecordInfo.Default; //default info
  133. rrInfo.WriteTo(bW);
  134. });
  135. }
  136. }
  137. }
  138. #endregion
  139. #region public
  140. public bool SetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, bool serveStale)
  141. {
  142. if (records.Count == 0)
  143. return false;
  144. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries;
  145. CacheRecordInfo cacheRecordInfo = records[0].GetCacheRecordInfo();
  146. NetworkAddress eDnsClientSubnet = cacheRecordInfo.EDnsClientSubnet;
  147. if (eDnsClientSubnet is null)
  148. {
  149. entries = _entries;
  150. }
  151. else
  152. {
  153. if (_ecsEntries is null)
  154. {
  155. _ecsEntries = new ConcurrentDictionary<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>>(1, 5);
  156. entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, 1);
  157. if (!_ecsEntries.TryAdd(eDnsClientSubnet, entries))
  158. return false;
  159. }
  160. else if (!_ecsEntries.TryGetValue(eDnsClientSubnet, out entries))
  161. {
  162. entries = new ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>(1, 1);
  163. if (!_ecsEntries.TryAdd(eDnsClientSubnet, entries))
  164. return false;
  165. }
  166. }
  167. bool isFailureRecord = false;
  168. if (records[0].RDATA is DnsCache.DnsSpecialCacheRecordData splRecord)
  169. {
  170. if (splRecord.IsFailureOrBadCache)
  171. {
  172. //call trying to cache failure record
  173. isFailureRecord = true;
  174. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords) && (existingRecords.Count > 0) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale))
  175. {
  176. if ((existingRecords[0].RDATA is not DnsCache.DnsSpecialCacheRecordData existingSplRecord) || !existingSplRecord.IsFailureOrBadCache)
  177. return false; //skip to avoid overwriting a useful record with a failure record
  178. //copy extended errors from existing spl record
  179. splRecord.CopyExtendedDnsErrorsFrom(existingSplRecord);
  180. }
  181. }
  182. }
  183. else if ((type == DnsResourceRecordType.NS) && (records[0].RDATA is DnsNSRecordData ns) && !ns.IsParentSideTtlSet)
  184. {
  185. //for ns revalidation
  186. if (entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> existingNSRecords))
  187. {
  188. if ((existingNSRecords.Count > 0) && (existingNSRecords[0].RDATA is DnsNSRecordData existingNS) && existingNS.IsParentSideTtlSet)
  189. {
  190. uint parentSideTtl = existingNS.ParentSideTtl;
  191. foreach (DnsResourceRecord record in records)
  192. (record.RDATA as DnsNSRecordData).ParentSideTtl = parentSideTtl;
  193. }
  194. }
  195. }
  196. //set last used date time
  197. DateTime utcNow = DateTime.UtcNow;
  198. foreach (DnsResourceRecord record in records)
  199. record.GetCacheRecordInfo().LastUsedOn = utcNow;
  200. //set records
  201. bool added = true;
  202. entries.AddOrUpdate(type, records, delegate (DnsResourceRecordType key, IReadOnlyList<DnsResourceRecord> existingRecords)
  203. {
  204. added = false;
  205. return records;
  206. });
  207. if (serveStale && !isFailureRecord)
  208. {
  209. //remove stale CNAME entry only when serve stale is enabled
  210. //making sure current record is not a failure record causing removal of useful stale CNAME record
  211. switch (type)
  212. {
  213. case DnsResourceRecordType.CNAME:
  214. case DnsResourceRecordType.SOA:
  215. case DnsResourceRecordType.NS:
  216. case DnsResourceRecordType.DS:
  217. //do nothing
  218. break;
  219. default:
  220. //remove stale CNAME entry since current new entry type overlaps any existing CNAME entry in cache
  221. //keeping both entries will create issue with serve stale implementation since stale CNAME entry will be always returned
  222. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  223. {
  224. if ((existingCNAMERecords.Count > 0) && (existingCNAMERecords[0].RDATA is DnsCNAMERecordData) && existingCNAMERecords[0].IsStale)
  225. {
  226. //delete CNAME entry only when it contains stale DnsCNAMERecord RDATA and not special cache records
  227. entries.TryRemove(DnsResourceRecordType.CNAME, out _);
  228. }
  229. }
  230. break;
  231. }
  232. }
  233. return added;
  234. }
  235. public int RemoveExpiredRecords(bool serveStale)
  236. {
  237. int removedEntries = 0;
  238. if (_ecsEntries is not null)
  239. {
  240. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  241. {
  242. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  243. {
  244. if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale))
  245. {
  246. if (ecsEntry.Value.TryRemove(entry.Key, out _)) //RR Set is expired; remove entry
  247. removedEntries++;
  248. }
  249. }
  250. if (ecsEntry.Value.IsEmpty)
  251. _ecsEntries.TryRemove(ecsEntry.Key, out _);
  252. }
  253. }
  254. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
  255. {
  256. if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale))
  257. {
  258. if (_entries.TryRemove(entry.Key, out _)) //RR Set is expired; remove entry
  259. removedEntries++;
  260. }
  261. }
  262. return removedEntries;
  263. }
  264. public int RemoveLeastUsedRecords(DateTime cutoff)
  265. {
  266. int removedEntries = 0;
  267. if (_ecsEntries is not null)
  268. {
  269. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  270. {
  271. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  272. {
  273. if ((entry.Value.Count == 0) || (entry.Value[0].GetCacheRecordInfo().LastUsedOn < cutoff))
  274. {
  275. if (ecsEntry.Value.TryRemove(entry.Key, out _)) //RR Set was last used before cutoff; remove entry
  276. removedEntries++;
  277. }
  278. }
  279. if (ecsEntry.Value.IsEmpty)
  280. _ecsEntries.TryRemove(ecsEntry.Key, out _);
  281. }
  282. }
  283. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
  284. {
  285. if ((entry.Value.Count == 0) || (entry.Value[0].GetCacheRecordInfo().LastUsedOn < cutoff))
  286. {
  287. if (_entries.TryRemove(entry.Key, out _)) //RR Set was last used before cutoff; remove entry
  288. removedEntries++;
  289. }
  290. }
  291. return removedEntries;
  292. }
  293. public int DeleteEDnsClientSubnetData()
  294. {
  295. if (_ecsEntries is null)
  296. return 0;
  297. int count = 0;
  298. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  299. count += ecsEntry.Value.Count;
  300. _ecsEntries = null;
  301. return count;
  302. }
  303. public IReadOnlyList<DnsResourceRecord> QueryRecords(DnsResourceRecordType type, bool serveStale, bool skipSpecialCacheRecord, NetworkAddress eDnsClientSubnet, bool conditionalForwardingClientSubnet)
  304. {
  305. ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entries;
  306. if (eDnsClientSubnet is null)
  307. {
  308. entries = _entries;
  309. }
  310. else
  311. {
  312. if (_ecsEntries is null)
  313. return Array.Empty<DnsResourceRecord>();
  314. NetworkAddress selectedNetwork = null;
  315. entries = null;
  316. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  317. {
  318. NetworkAddress cacheSubnet = ecsEntry.Key;
  319. if (cacheSubnet.PrefixLength > eDnsClientSubnet.PrefixLength)
  320. continue;
  321. if (cacheSubnet.Equals(eDnsClientSubnet) || (!conditionalForwardingClientSubnet && cacheSubnet.Contains(eDnsClientSubnet.Address)))
  322. {
  323. if ((selectedNetwork is null) || (cacheSubnet.PrefixLength > selectedNetwork.PrefixLength))
  324. {
  325. selectedNetwork = cacheSubnet;
  326. entries = ecsEntry.Value;
  327. }
  328. }
  329. }
  330. if (entries is null)
  331. return Array.Empty<DnsResourceRecord>();
  332. }
  333. switch (type)
  334. {
  335. case DnsResourceRecordType.DS:
  336. {
  337. //since some zones have CNAME at apex so no CNAME lookup for DS queries!
  338. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  339. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  340. }
  341. break;
  342. case DnsResourceRecordType.SOA:
  343. case DnsResourceRecordType.DNSKEY:
  344. {
  345. //since some zones have CNAME at apex!
  346. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  347. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  348. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  349. {
  350. IReadOnlyList<DnsResourceRecord> rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord);
  351. if (rrset.Count > 0)
  352. {
  353. if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecordData))
  354. return rrset;
  355. }
  356. }
  357. }
  358. break;
  359. case DnsResourceRecordType.ANY:
  360. List<DnsResourceRecord> anyRecords = new List<DnsResourceRecord>(entries.Count * 2);
  361. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in entries)
  362. {
  363. if (entry.Key == DnsResourceRecordType.DS)
  364. continue;
  365. anyRecords.AddRange(ValidateRRSet(type, entry.Value, serveStale, true));
  366. }
  367. return anyRecords;
  368. default:
  369. {
  370. if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  371. {
  372. IReadOnlyList<DnsResourceRecord> rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord);
  373. if (rrset.Count > 0)
  374. {
  375. if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecordData))
  376. return rrset;
  377. }
  378. }
  379. if (entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  380. return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord);
  381. }
  382. break;
  383. }
  384. return Array.Empty<DnsResourceRecord>();
  385. }
  386. public override void ListAllRecords(List<DnsResourceRecord> records)
  387. {
  388. if (_ecsEntries is not null)
  389. {
  390. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  391. {
  392. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in ecsEntry.Value)
  393. records.AddRange(entry.Value);
  394. }
  395. }
  396. base.ListAllRecords(records);
  397. }
  398. public override bool ContainsNameServerRecords()
  399. {
  400. if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> records))
  401. return false;
  402. foreach (DnsResourceRecord record in records)
  403. {
  404. if (record.IsStale)
  405. continue;
  406. if (record.RDATA is DnsNSRecordData)
  407. return true;
  408. }
  409. return false;
  410. }
  411. public void WriteTo(BinaryWriter bW)
  412. {
  413. bW.Write((byte)1); //version
  414. //cache zone info
  415. bW.Write(_name);
  416. //write all cache records
  417. WriteEntriesTo(_entries, bW);
  418. //write all ECS cache records
  419. if (_ecsEntries is null)
  420. {
  421. bW.Write(0);
  422. }
  423. else
  424. {
  425. bW.Write(_ecsEntries.Count);
  426. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  427. {
  428. ecsEntry.Key.WriteTo(bW);
  429. WriteEntriesTo(ecsEntry.Value, bW);
  430. }
  431. }
  432. }
  433. #endregion
  434. #region properties
  435. public override bool IsEmpty
  436. {
  437. get
  438. {
  439. if (_ecsEntries is null)
  440. return _entries.IsEmpty;
  441. return _ecsEntries.IsEmpty && _entries.IsEmpty;
  442. }
  443. }
  444. public int TotalEntries
  445. {
  446. get
  447. {
  448. if (_ecsEntries is null)
  449. return _entries.Count;
  450. int count = _entries.Count;
  451. foreach (KeyValuePair<NetworkAddress, ConcurrentDictionary<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>>> ecsEntry in _ecsEntries)
  452. count += ecsEntry.Value.Count;
  453. return count;
  454. }
  455. }
  456. #endregion
  457. }
  458. }