/* Technitium DNS Server Copyright (C) 2020 Shreyas Zare (shreyas@technitium.com) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ using System; using System.Collections.Generic; using TechnitiumLibrary.IO; using TechnitiumLibrary.Net.Dns; namespace DnsServerCore.Dns.Zones { class CacheZone : Zone { #region constructor public CacheZone(string name) : base(name) { } #endregion #region private private static IReadOnlyList FilterExpiredRecords(DnsResourceRecordType type, IReadOnlyList records, bool serveStale) { if (records.Count == 1) { if (!serveStale && records[0].IsStale) return Array.Empty(); //record is stale if (records[0].TtlValue < 1u) return Array.Empty(); //ttl expired return records; } List newRecords = new List(records.Count); foreach (DnsResourceRecord record in records) { if (!serveStale && record.IsStale) continue; //record is stale if (record.TtlValue < 1u) continue; //ttl expired newRecords.Add(record); } if (newRecords.Count > 1) { switch (type) { case DnsResourceRecordType.A: case DnsResourceRecordType.AAAA: newRecords.Shuffle(); //shuffle records to allow load balancing break; } } return newRecords; } #endregion #region public public void SetRecords(DnsResourceRecordType type, IReadOnlyList records) { if ((records.Count > 0) && (records[0].RDATA is DnsCache.DnsFailureRecord)) { //call trying to cache failure record if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) { if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsFailureRecord)) return; //skip to avoid overwriting a useful stale record with a failure record to allow serve-stale to work as intended } } //set records _entries[type] = records; switch (type) { case DnsResourceRecordType.CNAME: case DnsResourceRecordType.SOA: case DnsResourceRecordType.NS: //do nothing break; default: //remove old CNAME entry since current new entry type overlaps any existing CNAME entry in cache //keeping both entries will create issue with serve stale implementation since stale CNAME entry will be always returned _entries.TryRemove(DnsResourceRecordType.CNAME, out _); break; } } public void RemoveExpiredRecords() { foreach (DnsResourceRecordType type in _entries.Keys) { IReadOnlyList records = _entries[type]; foreach (DnsResourceRecord record in records) { if (record.TtlValue < 1u) { //record is expired; update entry List newRecords = new List(records.Count); foreach (DnsResourceRecord existingRecord in records) { if (existingRecord.TtlValue < 1u) continue; newRecords.Add(existingRecord); } if (newRecords.Count > 0) { //try update entry with non-expired records _entries.TryUpdate(type, newRecords, records); } else { //all records expired; remove entry _entries.TryRemove(type, out _); } break; } } } } public IReadOnlyList QueryRecords(DnsResourceRecordType type, bool serveStale) { //check for CNAME if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) { IReadOnlyList filteredRecords = FilterExpiredRecords(type, existingCNAMERecords, serveStale); if (filteredRecords.Count > 0) return filteredRecords; } if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) return FilterExpiredRecords(type, existingRecords, serveStale); return Array.Empty(); } public override bool ContainsNameServerRecords() { if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList records)) return false; foreach (DnsResourceRecord record in records) { if (record.IsStale) continue; if (record.TtlValue < 1u) continue; return true; } return false; } #endregion } }