/* Technitium DNS Server Copyright (C) 2024 Shreyas Zare (shreyas@technitium.com) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ using DnsServerCore.Dns.ResourceRecords; using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using TechnitiumLibrary; using TechnitiumLibrary.Net; using TechnitiumLibrary.Net.Dns; using TechnitiumLibrary.Net.Dns.ResourceRecords; namespace DnsServerCore.Dns.Zones { class CacheZone : Zone { #region variables ConcurrentDictionary>> _ecsEntries; #endregion #region constructor public CacheZone(string name, int capacity) : base(name, capacity) { } private CacheZone(string name, ConcurrentDictionary> entries) : base(name, entries) { } #endregion #region static public static CacheZone ReadFrom(BinaryReader bR, bool serveStale) { byte version = bR.ReadByte(); switch (version) { case 1: string name = bR.ReadString(); ConcurrentDictionary> entries = ReadEntriesFrom(bR, serveStale); CacheZone cacheZone = new CacheZone(name, entries); //write all ECS cache records { int ecsCount = bR.ReadInt32(); if (ecsCount > 0) { ConcurrentDictionary>> ecsEntries = new ConcurrentDictionary>>(1, ecsCount); for (int i = 0; i < ecsCount; i++) { NetworkAddress key = NetworkAddress.ReadFrom(bR); ConcurrentDictionary> ecsEntry = ReadEntriesFrom(bR, serveStale); if (!ecsEntry.IsEmpty) ecsEntries.TryAdd(key, ecsEntry); } if (!ecsEntries.IsEmpty) cacheZone._ecsEntries = ecsEntries; } } return cacheZone; default: throw new InvalidDataException("CacheZone format version not supported."); } } #endregion #region private private static IReadOnlyList ValidateRRSet(DnsResourceRecordType type, IReadOnlyList records, bool serveStale, bool skipSpecialCacheRecord) { foreach (DnsResourceRecord record in records) { if (record.IsExpired(serveStale)) return Array.Empty(); //RR Set is expired if (skipSpecialCacheRecord && (record.RDATA is DnsCache.DnsSpecialCacheRecordData)) return Array.Empty(); //RR Set is special cache record } if (records.Count > 1) { switch (type) { case DnsResourceRecordType.A: case DnsResourceRecordType.AAAA: List newRecords = new List(records); newRecords.Shuffle(); //shuffle records to allow load balancing return newRecords; } } //update last used on DateTime utcNow = DateTime.UtcNow; foreach (DnsResourceRecord record in records) record.GetCacheRecordInfo().LastUsedOn = utcNow; return records; } private static ConcurrentDictionary> ReadEntriesFrom(BinaryReader bR, bool serveStale) { int count = bR.ReadInt32(); ConcurrentDictionary> entries = new ConcurrentDictionary>(1, count); for (int i = 0; i < count; i++) { DnsResourceRecordType key = (DnsResourceRecordType)bR.ReadUInt16(); int rrCount = bR.ReadInt32(); DnsResourceRecord[] records = new DnsResourceRecord[rrCount]; for (int j = 0; j < rrCount; j++) { records[j] = DnsResourceRecord.ReadCacheRecordFrom(bR, delegate (DnsResourceRecord record) { record.Tag = new CacheRecordInfo(bR); }); } if (!DnsResourceRecord.IsRRSetExpired(records, serveStale)) entries.TryAdd(key, records); } return entries; } private static void WriteEntriesTo(ConcurrentDictionary> entries, BinaryWriter bW) { bW.Write(entries.Count); foreach (KeyValuePair> entry in entries) { bW.Write((ushort)entry.Key); bW.Write(entry.Value.Count); foreach (DnsResourceRecord record in entry.Value) { record.WriteCacheRecordTo(bW, delegate () { if (record.Tag is not CacheRecordInfo rrInfo) rrInfo = CacheRecordInfo.Default; //default info rrInfo.WriteTo(bW); }); } } } #endregion #region public public bool SetRecords(DnsResourceRecordType type, IReadOnlyList records, bool serveStale) { if (records.Count == 0) return false; ConcurrentDictionary> entries; CacheRecordInfo cacheRecordInfo = records[0].GetCacheRecordInfo(); NetworkAddress eDnsClientSubnet = cacheRecordInfo.EDnsClientSubnet; if (eDnsClientSubnet is null) { entries = _entries; } else { if (_ecsEntries is null) { _ecsEntries = new ConcurrentDictionary>>(1, 5); entries = new ConcurrentDictionary>(1, 1); if (!_ecsEntries.TryAdd(eDnsClientSubnet, entries)) return false; } else if (!_ecsEntries.TryGetValue(eDnsClientSubnet, out entries)) { entries = new ConcurrentDictionary>(1, 1); if (!_ecsEntries.TryAdd(eDnsClientSubnet, entries)) return false; } } bool isFailureRecord = false; if (records[0].RDATA is DnsCache.DnsSpecialCacheRecordData splRecord) { if (splRecord.IsFailureOrBadCache) { //call trying to cache failure record isFailureRecord = true; if (entries.TryGetValue(type, out IReadOnlyList existingRecords) && (existingRecords.Count > 0) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale)) { if ((existingRecords[0].RDATA is not DnsCache.DnsSpecialCacheRecordData existingSplRecord) || !existingSplRecord.IsFailureOrBadCache) return false; //skip to avoid overwriting a useful record with a failure record //copy extended errors from existing spl record splRecord.CopyExtendedDnsErrorsFrom(existingSplRecord); } } } else if ((type == DnsResourceRecordType.NS) && (records[0].RDATA is DnsNSRecordData ns) && !ns.IsParentSideTtlSet) { //for ns revalidation if (entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList existingNSRecords)) { if ((existingNSRecords.Count > 0) && (existingNSRecords[0].RDATA is DnsNSRecordData existingNS) && existingNS.IsParentSideTtlSet) { uint parentSideTtl = existingNS.ParentSideTtl; foreach (DnsResourceRecord record in records) (record.RDATA as DnsNSRecordData).ParentSideTtl = parentSideTtl; } } } //set last used date time DateTime utcNow = DateTime.UtcNow; foreach (DnsResourceRecord record in records) record.GetCacheRecordInfo().LastUsedOn = utcNow; //set records bool added = true; entries.AddOrUpdate(type, records, delegate (DnsResourceRecordType key, IReadOnlyList existingRecords) { added = false; return records; }); if (serveStale && !isFailureRecord) { //remove stale CNAME entry only when serve stale is enabled //making sure current record is not a failure record causing removal of useful stale CNAME record switch (type) { case DnsResourceRecordType.CNAME: case DnsResourceRecordType.SOA: case DnsResourceRecordType.NS: case DnsResourceRecordType.DS: //do nothing break; default: //remove stale CNAME entry since current new entry type overlaps any existing CNAME entry in cache //keeping both entries will create issue with serve stale implementation since stale CNAME entry will be always returned if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) { if ((existingCNAMERecords.Count > 0) && (existingCNAMERecords[0].RDATA is DnsCNAMERecordData) && existingCNAMERecords[0].IsStale) { //delete CNAME entry only when it contains stale DnsCNAMERecord RDATA and not special cache records entries.TryRemove(DnsResourceRecordType.CNAME, out _); } } break; } } return added; } public int RemoveExpiredRecords(bool serveStale) { int removedEntries = 0; if (_ecsEntries is not null) { foreach (KeyValuePair>> ecsEntry in _ecsEntries) { foreach (KeyValuePair> entry in ecsEntry.Value) { if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale)) { if (ecsEntry.Value.TryRemove(entry.Key, out _)) //RR Set is expired; remove entry removedEntries++; } } if (ecsEntry.Value.IsEmpty) _ecsEntries.TryRemove(ecsEntry.Key, out _); } } foreach (KeyValuePair> entry in _entries) { if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale)) { if (_entries.TryRemove(entry.Key, out _)) //RR Set is expired; remove entry removedEntries++; } } return removedEntries; } public int RemoveLeastUsedRecords(DateTime cutoff) { int removedEntries = 0; if (_ecsEntries is not null) { foreach (KeyValuePair>> ecsEntry in _ecsEntries) { foreach (KeyValuePair> entry in ecsEntry.Value) { if ((entry.Value.Count == 0) || (entry.Value[0].GetCacheRecordInfo().LastUsedOn < cutoff)) { if (ecsEntry.Value.TryRemove(entry.Key, out _)) //RR Set was last used before cutoff; remove entry removedEntries++; } } if (ecsEntry.Value.IsEmpty) _ecsEntries.TryRemove(ecsEntry.Key, out _); } } foreach (KeyValuePair> entry in _entries) { if ((entry.Value.Count == 0) || (entry.Value[0].GetCacheRecordInfo().LastUsedOn < cutoff)) { if (_entries.TryRemove(entry.Key, out _)) //RR Set was last used before cutoff; remove entry removedEntries++; } } return removedEntries; } public int DeleteEDnsClientSubnetData() { if (_ecsEntries is null) return 0; int count = 0; foreach (KeyValuePair>> ecsEntry in _ecsEntries) count += ecsEntry.Value.Count; _ecsEntries = null; return count; } public IReadOnlyList QueryRecords(DnsResourceRecordType type, bool serveStale, bool skipSpecialCacheRecord, NetworkAddress eDnsClientSubnet, bool advancedForwardingClientSubnet) { ConcurrentDictionary> entries; if (eDnsClientSubnet is null) { entries = _entries; } else { if (_ecsEntries is null) return Array.Empty(); NetworkAddress selectedNetwork = null; entries = null; foreach (KeyValuePair>> ecsEntry in _ecsEntries) { NetworkAddress cacheSubnet = ecsEntry.Key; if (cacheSubnet.PrefixLength > eDnsClientSubnet.PrefixLength) continue; if (advancedForwardingClientSubnet) { if (cacheSubnet.Equals(eDnsClientSubnet)) { if ((selectedNetwork is null) || (cacheSubnet.PrefixLength > selectedNetwork.PrefixLength)) { selectedNetwork = cacheSubnet; entries = ecsEntry.Value; } } } else { if (cacheSubnet.Equals(eDnsClientSubnet) || cacheSubnet.Contains(eDnsClientSubnet.Address)) { if ((selectedNetwork is null) || (cacheSubnet.PrefixLength < selectedNetwork.PrefixLength)) { selectedNetwork = cacheSubnet; entries = ecsEntry.Value; } } } } if (entries is null) return Array.Empty(); } switch (type) { case DnsResourceRecordType.DS: { //since some zones have CNAME at apex so no CNAME lookup for DS queries! if (entries.TryGetValue(type, out IReadOnlyList existingRecords)) return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord); } break; case DnsResourceRecordType.SOA: case DnsResourceRecordType.DNSKEY: { //since some zones have CNAME at apex! if (entries.TryGetValue(type, out IReadOnlyList existingRecords)) return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord); if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) { IReadOnlyList rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord); if (rrset.Count > 0) { if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecordData)) return rrset; } } } break; case DnsResourceRecordType.ANY: List anyRecords = new List(entries.Count * 2); foreach (KeyValuePair> entry in entries) { if (entry.Key == DnsResourceRecordType.DS) continue; anyRecords.AddRange(ValidateRRSet(type, entry.Value, serveStale, true)); } return anyRecords; default: { if (entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) { IReadOnlyList rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord); if (rrset.Count > 0) { if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecordData)) return rrset; } } if (entries.TryGetValue(type, out IReadOnlyList existingRecords)) return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord); } break; } return Array.Empty(); } public override void ListAllRecords(List records) { if (_ecsEntries is not null) { foreach (KeyValuePair>> ecsEntry in _ecsEntries) { foreach (KeyValuePair> entry in ecsEntry.Value) records.AddRange(entry.Value); } } base.ListAllRecords(records); } public override bool ContainsNameServerRecords() { if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList records)) return false; foreach (DnsResourceRecord record in records) { if (record.IsStale) continue; if (record.RDATA is DnsNSRecordData) return true; } return false; } public void WriteTo(BinaryWriter bW) { bW.Write((byte)1); //version //cache zone info bW.Write(_name); //write all cache records WriteEntriesTo(_entries, bW); //write all ECS cache records if (_ecsEntries is null) { bW.Write(0); } else { bW.Write(_ecsEntries.Count); foreach (KeyValuePair>> ecsEntry in _ecsEntries) { ecsEntry.Key.WriteTo(bW); WriteEntriesTo(ecsEntry.Value, bW); } } } #endregion #region properties public override bool IsEmpty { get { if (_ecsEntries is null) return _entries.IsEmpty; return _ecsEntries.IsEmpty && _entries.IsEmpty; } } public int TotalEntries { get { if (_ecsEntries is null) return _entries.Count; int count = _entries.Count; foreach (KeyValuePair>> ecsEntry in _ecsEntries) count += ecsEntry.Value.Count; return count; } } #endregion } }