123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- /*
- Technitium DNS Server
- Copyright (C) 2020 Shreyas Zare (shreyas@technitium.com)
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU General Public License for more details.
- You should have received a copy of the GNU General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
- */
- using DnsServerCore.Dns.ResourceRecords;
- using System;
- using System.Collections.Generic;
- using System.Net;
- using System.Threading.Tasks;
- using TechnitiumLibrary.IO;
- using TechnitiumLibrary.Net.Dns;
- using TechnitiumLibrary.Net.Dns.ResourceRecords;
- namespace DnsServerCore.Dns.Zones
- {
- abstract class AuthZone : Zone, IDisposable
- {
- #region variables
- protected bool _disabled;
- #endregion
- #region constructor
- protected AuthZone(string name)
- : base(name)
- { }
- #endregion
- #region IDisposable
- protected virtual void Dispose(bool disposing)
- { }
- public void Dispose()
- {
- Dispose(true);
- }
- #endregion
- #region private
- private IReadOnlyList<DnsResourceRecord> FilterDisabledRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records)
- {
- if (_disabled)
- return Array.Empty<DnsResourceRecord>();
- if (records.Count == 1)
- {
- if (records[0].IsDisabled())
- return Array.Empty<DnsResourceRecord>(); //record disabled
- return records;
- }
- List<DnsResourceRecord> newRecords = new List<DnsResourceRecord>(records.Count);
- foreach (DnsResourceRecord record in records)
- {
- if (record.IsDisabled())
- continue; //record disabled
- newRecords.Add(record);
- }
- if (newRecords.Count > 1)
- {
- switch (type)
- {
- case DnsResourceRecordType.A:
- case DnsResourceRecordType.AAAA:
- case DnsResourceRecordType.NS:
- newRecords.Shuffle(); //shuffle records to allow load balancing
- break;
- }
- }
- return newRecords;
- }
- private async Task<IReadOnlyList<NameServerAddress>> GetNameServerAddressesAsync(DnsServer dnsServer, DnsResourceRecord record)
- {
- string nsDomain;
- switch (record.Type)
- {
- case DnsResourceRecordType.NS:
- nsDomain = (record.RDATA as DnsNSRecord).NameServer;
- break;
- case DnsResourceRecordType.SOA:
- nsDomain = (record.RDATA as DnsSOARecord).PrimaryNameServer;
- break;
- default:
- throw new InvalidOperationException();
- }
- List<NameServerAddress> nameServers = new List<NameServerAddress>(2);
- IReadOnlyList<DnsResourceRecord> glueRecords = record.GetGlueRecords();
- if (glueRecords.Count > 0)
- {
- foreach (DnsResourceRecord glueRecord in glueRecords)
- {
- switch (glueRecord.Type)
- {
- case DnsResourceRecordType.A:
- nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address));
- break;
- case DnsResourceRecordType.AAAA:
- if (dnsServer.PreferIPv6)
- nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address));
- break;
- }
- }
- }
- else
- {
- //resolve addresses
- try
- {
- DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN));
- if ((response != null) && (response.Answer.Count > 0))
- {
- IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseA(response);
- foreach (IPAddress address in addresses)
- nameServers.Add(new NameServerAddress(nsDomain, address));
- }
- }
- catch
- { }
- if (dnsServer.PreferIPv6)
- {
- try
- {
- DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN));
- if ((response != null) && (response.Answer.Count > 0))
- {
- IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseAAAA(response);
- foreach (IPAddress address in addresses)
- nameServers.Add(new NameServerAddress(nsDomain, address));
- }
- }
- catch
- { }
- }
- }
- return nameServers;
- }
- #endregion
- #region public
- public async Task<IReadOnlyList<NameServerAddress>> GetPrimaryNameServerAddressesAsync(DnsServer dnsServer)
- {
- List<NameServerAddress> nameServers = new List<NameServerAddress>();
- DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0];
- DnsSOARecord soa = soaRecord.RDATA as DnsSOARecord;
- IReadOnlyList<DnsResourceRecord> nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
- foreach (DnsResourceRecord nsRecord in nsRecords)
- {
- if (nsRecord.IsDisabled())
- continue;
- string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
- if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase))
- {
- //found primary NS
- nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord));
- break;
- }
- }
- foreach (NameServerAddress nameServer in await GetNameServerAddressesAsync(dnsServer, soaRecord))
- {
- if (!nameServers.Contains(nameServer))
- nameServers.Add(nameServer);
- }
- return nameServers;
- }
- public async Task<IReadOnlyList<NameServerAddress>> GetSecondaryNameServerAddressesAsync(DnsServer dnsServer)
- {
- List<NameServerAddress> nameServers = new List<NameServerAddress>();
- DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord;
- IReadOnlyList<DnsResourceRecord> nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
- foreach (DnsResourceRecord nsRecord in nsRecords)
- {
- if (nsRecord.IsDisabled())
- continue;
- string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
- if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase))
- continue; //skip primary name server
- nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord));
- }
- return nameServers;
- }
- public void SyncRecords(Dictionary<DnsResourceRecordType, List<DnsResourceRecord>> newEntries, bool dontRemoveRecords)
- {
- if (!dontRemoveRecords)
- {
- //remove entires of type that do not exists in new entries
- foreach (DnsResourceRecordType type in _entries.Keys)
- {
- if (!newEntries.ContainsKey(type))
- _entries.TryRemove(type, out _);
- }
- }
- //set new entries into zone
- if (this is ForwarderZone)
- {
- //skip NS and SOA records from being added to ForwarderZone
- foreach (KeyValuePair<DnsResourceRecordType, List<DnsResourceRecord>> newEntry in newEntries)
- {
- switch (newEntry.Key)
- {
- case DnsResourceRecordType.NS:
- case DnsResourceRecordType.SOA:
- break;
- default:
- _entries[newEntry.Key] = newEntry.Value;
- break;
- }
- }
- }
- else
- {
- foreach (KeyValuePair<DnsResourceRecordType, List<DnsResourceRecord>> newEntry in newEntries)
- {
- if (newEntry.Key == DnsResourceRecordType.SOA)
- {
- if (newEntry.Value.Count != 1)
- continue; //skip invalid SOA record
- if ((this is SecondaryZone) || (this is StubZone))
- {
- //copy existing SOA record's glue addresses to new SOA record
- newEntry.Value[0].SetGlueRecords(_entries[DnsResourceRecordType.SOA][0].GetGlueRecords());
- }
- }
- _entries[newEntry.Key] = newEntry.Value;
- }
- }
- }
- public void LoadRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records)
- {
- _entries[type] = records;
- }
- public virtual void SetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records)
- {
- _entries[type] = records;
- }
- public virtual void AddRecord(DnsResourceRecord record)
- {
- switch (record.Type)
- {
- case DnsResourceRecordType.CNAME:
- case DnsResourceRecordType.ANAME:
- case DnsResourceRecordType.PTR:
- case DnsResourceRecordType.SOA:
- throw new InvalidOperationException("Cannot add record: use SetRecords() for " + record.Type.ToString() + " record");
- }
- _entries.AddOrUpdate(record.Type, delegate (DnsResourceRecordType key)
- {
- return new DnsResourceRecord[] { record };
- },
- delegate (DnsResourceRecordType key, IReadOnlyList<DnsResourceRecord> existingRecords)
- {
- foreach (DnsResourceRecord existingRecord in existingRecords)
- {
- if (record.Equals(existingRecord.RDATA))
- return existingRecords;
- }
- List<DnsResourceRecord> updateRecords = new List<DnsResourceRecord>(existingRecords.Count + 1);
- updateRecords.AddRange(existingRecords);
- updateRecords.Add(record);
- return updateRecords;
- });
- }
- public virtual bool DeleteRecords(DnsResourceRecordType type)
- {
- return _entries.TryRemove(type, out _);
- }
- public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData record)
- {
- if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
- {
- if (existingRecords.Count == 1)
- {
- if (record.Equals(existingRecords[0].RDATA))
- return _entries.TryRemove(type, out _);
- }
- else
- {
- List<DnsResourceRecord> updateRecords = new List<DnsResourceRecord>(existingRecords.Count);
- for (int i = 0; i < existingRecords.Count; i++)
- {
- if (!record.Equals(existingRecords[i].RDATA))
- updateRecords.Add(existingRecords[i]);
- }
- return _entries.TryUpdate(type, updateRecords, existingRecords);
- }
- }
- return false;
- }
- public virtual IReadOnlyList<DnsResourceRecord> QueryRecords(DnsResourceRecordType type)
- {
- //check for CNAME
- if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
- {
- IReadOnlyList<DnsResourceRecord> filteredRecords = FilterDisabledRecords(type, existingCNAMERecords);
- if (filteredRecords.Count > 0)
- return filteredRecords;
- }
- if (type == DnsResourceRecordType.ANY)
- {
- List<DnsResourceRecord> records = new List<DnsResourceRecord>(_entries.Count * 2);
- foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
- {
- if (entry.Key != DnsResourceRecordType.ANY)
- records.AddRange(entry.Value);
- }
- return FilterDisabledRecords(type, records);
- }
- if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
- {
- IReadOnlyList<DnsResourceRecord> filteredRecords = FilterDisabledRecords(type, existingRecords);
- if (filteredRecords.Count > 0)
- return filteredRecords;
- }
- switch (type)
- {
- case DnsResourceRecordType.A:
- case DnsResourceRecordType.AAAA:
- if (_entries.TryGetValue(DnsResourceRecordType.ANAME, out IReadOnlyList<DnsResourceRecord> anameRecords))
- return FilterDisabledRecords(type, anameRecords);
- break;
- }
- return Array.Empty<DnsResourceRecord>();
- }
- public IReadOnlyList<DnsResourceRecord> GetRecords(DnsResourceRecordType type)
- {
- if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> records))
- return records;
- return Array.Empty<DnsResourceRecord>();
- }
- public override bool ContainsNameServerRecords()
- {
- if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> records))
- return false;
- foreach (DnsResourceRecord record in records)
- {
- if (record.IsDisabled())
- continue;
- return true;
- }
- return false;
- }
- #endregion
- #region properties
- public virtual bool Disabled
- {
- get { return _disabled; }
- set { _disabled = value; }
- }
- public virtual bool IsActive
- {
- get { return !_disabled; }
- }
- #endregion
- }
- }
|