/*
Technitium DNS Server
Copyright (C) 2020 Shreyas Zare (shreyas@technitium.com)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
*/
using DnsServerCore.Dns.ResourceRecords;
using System;
using System.Collections.Generic;
using System.Net;
using System.Threading.Tasks;
using TechnitiumLibrary.IO;
using TechnitiumLibrary.Net.Dns;
using TechnitiumLibrary.Net.Dns.ResourceRecords;
namespace DnsServerCore.Dns.Zones
{
abstract class AuthZone : Zone, IDisposable
{
#region variables
protected bool _disabled;
#endregion
#region constructor
protected AuthZone(string name)
: base(name)
{ }
#endregion
#region IDisposable
protected virtual void Dispose(bool disposing)
{ }
public void Dispose()
{
Dispose(true);
}
#endregion
#region private
private IReadOnlyList FilterDisabledRecords(DnsResourceRecordType type, IReadOnlyList records)
{
if (_disabled)
return Array.Empty();
if (records.Count == 1)
{
if (records[0].IsDisabled())
return Array.Empty(); //record disabled
return records;
}
List newRecords = new List(records.Count);
foreach (DnsResourceRecord record in records)
{
if (record.IsDisabled())
continue; //record disabled
newRecords.Add(record);
}
if (newRecords.Count > 1)
{
switch (type)
{
case DnsResourceRecordType.A:
case DnsResourceRecordType.AAAA:
case DnsResourceRecordType.NS:
newRecords.Shuffle(); //shuffle records to allow load balancing
break;
}
}
return newRecords;
}
private async Task> GetNameServerAddressesAsync(DnsServer dnsServer, DnsResourceRecord record)
{
string nsDomain;
switch (record.Type)
{
case DnsResourceRecordType.NS:
nsDomain = (record.RDATA as DnsNSRecord).NameServer;
break;
case DnsResourceRecordType.SOA:
nsDomain = (record.RDATA as DnsSOARecord).PrimaryNameServer;
break;
default:
throw new InvalidOperationException();
}
List nameServers = new List(2);
IReadOnlyList glueRecords = record.GetGlueRecords();
if (glueRecords.Count > 0)
{
foreach (DnsResourceRecord glueRecord in glueRecords)
{
switch (glueRecord.Type)
{
case DnsResourceRecordType.A:
nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address));
break;
case DnsResourceRecordType.AAAA:
if (dnsServer.PreferIPv6)
nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address));
break;
}
}
}
else
{
//resolve addresses
try
{
DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN));
if ((response != null) && (response.Answer.Count > 0))
{
IReadOnlyList addresses = DnsClient.ParseResponseA(response);
foreach (IPAddress address in addresses)
nameServers.Add(new NameServerAddress(nsDomain, address));
}
}
catch
{ }
if (dnsServer.PreferIPv6)
{
try
{
DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN));
if ((response != null) && (response.Answer.Count > 0))
{
IReadOnlyList addresses = DnsClient.ParseResponseAAAA(response);
foreach (IPAddress address in addresses)
nameServers.Add(new NameServerAddress(nsDomain, address));
}
}
catch
{ }
}
}
return nameServers;
}
#endregion
#region public
public async Task> GetPrimaryNameServerAddressesAsync(DnsServer dnsServer)
{
List nameServers = new List();
DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0];
DnsSOARecord soa = soaRecord.RDATA as DnsSOARecord;
IReadOnlyList nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
foreach (DnsResourceRecord nsRecord in nsRecords)
{
if (nsRecord.IsDisabled())
continue;
string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase))
{
//found primary NS
nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord));
break;
}
}
foreach (NameServerAddress nameServer in await GetNameServerAddressesAsync(dnsServer, soaRecord))
{
if (!nameServers.Contains(nameServer))
nameServers.Add(nameServer);
}
return nameServers;
}
public async Task> GetSecondaryNameServerAddressesAsync(DnsServer dnsServer)
{
List nameServers = new List();
DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord;
IReadOnlyList nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
foreach (DnsResourceRecord nsRecord in nsRecords)
{
if (nsRecord.IsDisabled())
continue;
string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase))
continue; //skip primary name server
nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord));
}
return nameServers;
}
public void SyncRecords(Dictionary> newEntries, bool dontRemoveRecords)
{
if (!dontRemoveRecords)
{
//remove entires of type that do not exists in new entries
foreach (DnsResourceRecordType type in _entries.Keys)
{
if (!newEntries.ContainsKey(type))
_entries.TryRemove(type, out _);
}
}
//set new entries into zone
if (this is ForwarderZone)
{
//skip NS and SOA records from being added to ForwarderZone
foreach (KeyValuePair> newEntry in newEntries)
{
switch (newEntry.Key)
{
case DnsResourceRecordType.NS:
case DnsResourceRecordType.SOA:
break;
default:
_entries[newEntry.Key] = newEntry.Value;
break;
}
}
}
else
{
foreach (KeyValuePair> newEntry in newEntries)
{
if (newEntry.Key == DnsResourceRecordType.SOA)
{
if (newEntry.Value.Count != 1)
continue; //skip invalid SOA record
if ((this is SecondaryZone) || (this is StubZone))
{
//copy existing SOA record's glue addresses to new SOA record
newEntry.Value[0].SetGlueRecords(_entries[DnsResourceRecordType.SOA][0].GetGlueRecords());
}
}
_entries[newEntry.Key] = newEntry.Value;
}
}
}
public void LoadRecords(DnsResourceRecordType type, IReadOnlyList records)
{
_entries[type] = records;
}
public virtual void SetRecords(DnsResourceRecordType type, IReadOnlyList records)
{
_entries[type] = records;
}
public virtual void AddRecord(DnsResourceRecord record)
{
switch (record.Type)
{
case DnsResourceRecordType.CNAME:
case DnsResourceRecordType.ANAME:
case DnsResourceRecordType.PTR:
case DnsResourceRecordType.SOA:
throw new InvalidOperationException("Cannot add record: use SetRecords() for " + record.Type.ToString() + " record");
}
_entries.AddOrUpdate(record.Type, delegate (DnsResourceRecordType key)
{
return new DnsResourceRecord[] { record };
},
delegate (DnsResourceRecordType key, IReadOnlyList existingRecords)
{
foreach (DnsResourceRecord existingRecord in existingRecords)
{
if (record.Equals(existingRecord.RDATA))
return existingRecords;
}
List updateRecords = new List(existingRecords.Count + 1);
updateRecords.AddRange(existingRecords);
updateRecords.Add(record);
return updateRecords;
});
}
public virtual bool DeleteRecords(DnsResourceRecordType type)
{
return _entries.TryRemove(type, out _);
}
public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData record)
{
if (_entries.TryGetValue(type, out IReadOnlyList existingRecords))
{
if (existingRecords.Count == 1)
{
if (record.Equals(existingRecords[0].RDATA))
return _entries.TryRemove(type, out _);
}
else
{
List updateRecords = new List(existingRecords.Count);
for (int i = 0; i < existingRecords.Count; i++)
{
if (!record.Equals(existingRecords[i].RDATA))
updateRecords.Add(existingRecords[i]);
}
return _entries.TryUpdate(type, updateRecords, existingRecords);
}
}
return false;
}
public virtual IReadOnlyList QueryRecords(DnsResourceRecordType type)
{
//check for CNAME
if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords))
{
IReadOnlyList filteredRecords = FilterDisabledRecords(type, existingCNAMERecords);
if (filteredRecords.Count > 0)
return filteredRecords;
}
if (type == DnsResourceRecordType.ANY)
{
List records = new List(_entries.Count * 2);
foreach (KeyValuePair> entry in _entries)
{
if (entry.Key != DnsResourceRecordType.ANY)
records.AddRange(entry.Value);
}
return FilterDisabledRecords(type, records);
}
if (_entries.TryGetValue(type, out IReadOnlyList existingRecords))
{
IReadOnlyList filteredRecords = FilterDisabledRecords(type, existingRecords);
if (filteredRecords.Count > 0)
return filteredRecords;
}
switch (type)
{
case DnsResourceRecordType.A:
case DnsResourceRecordType.AAAA:
if (_entries.TryGetValue(DnsResourceRecordType.ANAME, out IReadOnlyList anameRecords))
return FilterDisabledRecords(type, anameRecords);
break;
}
return Array.Empty();
}
public IReadOnlyList GetRecords(DnsResourceRecordType type)
{
if (_entries.TryGetValue(type, out IReadOnlyList records))
return records;
return Array.Empty();
}
public override bool ContainsNameServerRecords()
{
if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList records))
return false;
foreach (DnsResourceRecord record in records)
{
if (record.IsDisabled())
continue;
return true;
}
return false;
}
#endregion
#region properties
public virtual bool Disabled
{
get { return _disabled; }
set { _disabled = value; }
}
public virtual bool IsActive
{
get { return !_disabled; }
}
#endregion
}
}