/*
Technitium DNS Server
Copyright (C) 2023 Shreyas Zare (shreyas@technitium.com)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
*/
using DnsServerCore.Dns.ResourceRecords;
using System;
using System.Collections.Generic;
using TechnitiumLibrary;
using TechnitiumLibrary.Net.Dns.ResourceRecords;
namespace DnsServerCore.Dns.Zones
{
abstract class AuthZone : Zone
{
#region variables
protected bool _disabled;
#endregion
#region constructor
protected AuthZone(AuthZoneInfo zoneInfo)
: base(zoneInfo.Name)
{
_disabled = zoneInfo.Disabled;
}
protected AuthZone(string name)
: base(name)
{ }
#endregion
#region private
private IReadOnlyList FilterDisabledRecords(DnsResourceRecordType type, IReadOnlyList records)
{
if (_disabled)
return Array.Empty();
if (records.Count == 1)
{
AuthRecordInfo authRecordInfo = records[0].GetAuthRecordInfo();
if (authRecordInfo.Disabled)
return Array.Empty(); //record disabled
//update last used on
authRecordInfo.LastUsedOn = DateTime.UtcNow;
return records;
}
List newRecords = new List(records.Count);
DateTime utcNow = DateTime.UtcNow;
foreach (DnsResourceRecord record in records)
{
AuthRecordInfo authRecordInfo = record.GetAuthRecordInfo();
if (authRecordInfo.Disabled)
continue; //record disabled
//update last used on
authRecordInfo.LastUsedOn = utcNow;
newRecords.Add(record);
}
if (newRecords.Count > 1)
{
switch (type)
{
case DnsResourceRecordType.A:
case DnsResourceRecordType.AAAA:
case DnsResourceRecordType.NS:
newRecords.Shuffle(); //shuffle records to allow load balancing
break;
}
}
return newRecords;
}
private IReadOnlyList AppendRRSigTo(IReadOnlyList records)
{
IReadOnlyList rrsigRecords = GetRecords(DnsResourceRecordType.RRSIG);
if (rrsigRecords.Count == 0)
return records;
DnsResourceRecordType type = records[0].Type;
List newRecords = new List(records.Count + 2);
newRecords.AddRange(records);
DateTime utcNow = DateTime.UtcNow;
foreach (DnsResourceRecord rrsigRecord in rrsigRecords)
{
if ((rrsigRecord.RDATA as DnsRRSIGRecordData).TypeCovered == type)
{
rrsigRecord.GetAuthRecordInfo().LastUsedOn = utcNow;
newRecords.Add(rrsigRecord);
}
}
return newRecords;
}
#endregion
#region versioning
internal bool TrySetRecords(DnsResourceRecordType type, IReadOnlyList records, out IReadOnlyList deletedRecords)
{
if (_entries.TryGetValue(type, out IReadOnlyList existingRecords))
{
deletedRecords = existingRecords;
return _entries.TryUpdate(type, records, existingRecords);
}
else
{
deletedRecords = Array.Empty();
return _entries.TryAdd(type, records);
}
}
internal bool TryDeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata, out DnsResourceRecord deletedRecord)
{
if (_entries.TryGetValue(type, out IReadOnlyList existingRecords))
{
if (existingRecords.Count == 1)
{
if (rdata.Equals(existingRecords[0].RDATA))
{
if (_entries.TryRemove(type, out IReadOnlyList removedRecords))
{
deletedRecord = removedRecords[0];
return true;
}
}
}
else
{
deletedRecord = null;
List updatedRecords = new List(existingRecords.Count);
foreach (DnsResourceRecord existingRecord in existingRecords)
{
if ((deletedRecord is null) && rdata.Equals(existingRecord.RDATA))
deletedRecord = existingRecord;
else
updatedRecords.Add(existingRecord);
}
if (deletedRecord is null)
return false; //not found
return _entries.TryUpdate(type, updatedRecords, existingRecords);
}
}
deletedRecord = null;
return false;
}
internal bool TryDeleteRecords(DnsResourceRecordType type, IReadOnlyList records, out IReadOnlyList deletedRecords)
{
if (_entries.TryGetValue(type, out IReadOnlyList existingRecords))
{
if (existingRecords.Count == 1)
{
DnsResourceRecord existingRecord = existingRecords[0];
foreach (DnsResourceRecord record in records)
{
if (record.RDATA.Equals(existingRecord.RDATA))
{
if (_entries.TryRemove(type, out IReadOnlyList removedRecords))
{
deletedRecords = removedRecords;
return true;
}
}
}
}
else
{
List deleted = new List(records.Count);
List updatedRecords = new List(existingRecords.Count);
foreach (DnsResourceRecord existingRecord in existingRecords)
{
bool found = false;
foreach (DnsResourceRecord record in records)
{
if (record.RDATA.Equals(existingRecord.RDATA))
{
found = true;
break;
}
}
if (found)
deleted.Add(existingRecord);
else
updatedRecords.Add(existingRecord);
}
if (deleted.Count > 0)
{
deletedRecords = deleted;
if (updatedRecords.Count > 0)
return _entries.TryUpdate(type, updatedRecords, existingRecords);
return _entries.TryRemove(type, out _);
}
}
}
deletedRecords = null;
return false;
}
internal void AddOrUpdateRRSigRecords(IReadOnlyList newRRSigRecords, out IReadOnlyList deletedRRSigRecords)
{
IReadOnlyList deleted = null;
_entries.AddOrUpdate(DnsResourceRecordType.RRSIG, delegate (DnsResourceRecordType key)
{
deleted = Array.Empty();
return newRRSigRecords;
},
delegate (DnsResourceRecordType key, IReadOnlyList existingRecords)
{
List updatedRecords = new List(existingRecords.Count + newRRSigRecords.Count);
List deletedRecords = new List();
foreach (DnsResourceRecord existingRecord in existingRecords)
{
bool found = false;
DnsRRSIGRecordData existingRRSig = existingRecord.RDATA as DnsRRSIGRecordData;
foreach (DnsResourceRecord newRRSigRecord in newRRSigRecords)
{
DnsRRSIGRecordData newRRSig = newRRSigRecord.RDATA as DnsRRSIGRecordData;
if ((newRRSig.TypeCovered == existingRRSig.TypeCovered) && (newRRSig.KeyTag == existingRRSig.KeyTag))
{
deletedRecords.Add(existingRecord);
found = true;
break;
}
}
if (!found)
updatedRecords.Add(existingRecord);
}
updatedRecords.AddRange(newRRSigRecords);
deleted = deletedRecords;
return updatedRecords;
});
deletedRRSigRecords = deleted;
}
internal void AddRecord(DnsResourceRecord record, out IReadOnlyList addedRecords, out IReadOnlyList deletedRecords)
{
switch (record.Type)
{
case DnsResourceRecordType.CNAME:
case DnsResourceRecordType.DNAME:
case DnsResourceRecordType.SOA:
throw new InvalidOperationException("Cannot add record: use SetRecords() for " + record.Type.ToString() + " record");
}
List added = new List();
List deleted = new List();
addedRecords = added;
deletedRecords = deleted;
_entries.AddOrUpdate(record.Type, delegate (DnsResourceRecordType key)
{
added.Add(record);
return new DnsResourceRecord[] { record };
},
delegate (DnsResourceRecordType key, IReadOnlyList existingRecords)
{
foreach (DnsResourceRecord existingRecord in existingRecords)
{
if (record.RDATA.Equals(existingRecord.RDATA))
return existingRecords;
}
List updatedRecords = new List(existingRecords.Count + 1);
foreach (DnsResourceRecord existingRecord in existingRecords)
{
if (existingRecord.OriginalTtlValue == record.OriginalTtlValue)
{
updatedRecords.Add(existingRecord);
}
else
{
DnsResourceRecord updatedExistingRecord = new DnsResourceRecord(existingRecord.Name, existingRecord.Type, existingRecord.Class, record.OriginalTtlValue, existingRecord.RDATA);
updatedRecords.Add(updatedExistingRecord);
added.Add(updatedExistingRecord);
deleted.Add(existingRecord);
}
}
updatedRecords.Add(record);
added.Add(record);
return updatedRecords;
});
}
#endregion
#region DNSSEC
internal IReadOnlyList SignAllRRSets()
{
List rrsigRecords = new List(_entries.Count);
foreach (KeyValuePair> entry in _entries)
{
if (entry.Key == DnsResourceRecordType.RRSIG)
continue;
rrsigRecords.AddRange(SignRRSet(entry.Value));
}
return rrsigRecords;
}
internal IReadOnlyList RemoveAllDnssecRecords()
{
List allRemovedRecords = new List();
foreach (KeyValuePair> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.DNSKEY:
case DnsResourceRecordType.RRSIG:
case DnsResourceRecordType.NSEC:
case DnsResourceRecordType.NSEC3PARAM:
case DnsResourceRecordType.NSEC3:
if (_entries.TryRemove(entry.Key, out IReadOnlyList removedRecords))
allRemovedRecords.AddRange(removedRecords);
break;
}
}
return allRemovedRecords;
}
internal IReadOnlyList RemoveNSecRecordsWithRRSig()
{
List allRemovedRecords = new List(2);
foreach (KeyValuePair> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.NSEC:
if (_entries.TryRemove(entry.Key, out IReadOnlyList removedRecords))
allRemovedRecords.AddRange(removedRecords);
break;
case DnsResourceRecordType.RRSIG:
List recordsToRemove = new List(1);
foreach (DnsResourceRecord rrsigRecord in entry.Value)
{
DnsRRSIGRecordData rrsig = rrsigRecord.RDATA as DnsRRSIGRecordData;
if (rrsig.TypeCovered == DnsResourceRecordType.NSEC)
recordsToRemove.Add(rrsigRecord);
}
if (recordsToRemove.Count > 0)
{
if (TryDeleteRecords(DnsResourceRecordType.RRSIG, recordsToRemove, out IReadOnlyList deletedRecords))
allRemovedRecords.AddRange(deletedRecords);
}
break;
}
}
return allRemovedRecords;
}
internal IReadOnlyList RemoveNSec3RecordsWithRRSig()
{
List allRemovedRecords = new List(2);
foreach (KeyValuePair> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.NSEC3:
case DnsResourceRecordType.NSEC3PARAM:
if (_entries.TryRemove(entry.Key, out IReadOnlyList removedRecords))
allRemovedRecords.AddRange(removedRecords);
break;
case DnsResourceRecordType.RRSIG:
List recordsToRemove = new List(1);
foreach (DnsResourceRecord rrsigRecord in entry.Value)
{
DnsRRSIGRecordData rrsig = rrsigRecord.RDATA as DnsRRSIGRecordData;
switch (rrsig.TypeCovered)
{
case DnsResourceRecordType.NSEC3:
case DnsResourceRecordType.NSEC3PARAM:
recordsToRemove.Add(rrsigRecord);
break;
}
}
if (recordsToRemove.Count > 0)
{
if (TryDeleteRecords(DnsResourceRecordType.RRSIG, recordsToRemove, out IReadOnlyList deletedRecords))
allRemovedRecords.AddRange(deletedRecords);
}
break;
}
}
return allRemovedRecords;
}
internal bool HasOnlyNSec3Records()
{
if (!_entries.ContainsKey(DnsResourceRecordType.NSEC3))
return false;
foreach (KeyValuePair> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.NSEC3:
case DnsResourceRecordType.RRSIG:
break;
default:
//found non NSEC3 records
return false;
}
}
return true;
}
internal IReadOnlyList RefreshSignatures()
{
if (!_entries.TryGetValue(DnsResourceRecordType.RRSIG, out IReadOnlyList rrsigRecords))
throw new InvalidOperationException();
List typesToRefresh = new List();
DateTime utcNow = DateTime.UtcNow;
foreach (DnsResourceRecord rrsigRecord in rrsigRecords)
{
DnsRRSIGRecordData rrsig = rrsigRecord.RDATA as DnsRRSIGRecordData;
uint signatureValidityPeriod = rrsig.SignatureExpiration - rrsig.SignatureInception;
uint refreshPeriod = signatureValidityPeriod / 3;
if (utcNow > DateTime.UnixEpoch.AddSeconds(rrsig.SignatureExpiration - refreshPeriod))
typesToRefresh.Add(rrsig.TypeCovered);
}
List newRRSigRecords = new List(typesToRefresh.Count);
foreach (DnsResourceRecordType type in typesToRefresh)
{
if (_entries.TryGetValue(type, out IReadOnlyList records))
newRRSigRecords.AddRange(SignRRSet(records));
}
return newRRSigRecords;
}
internal virtual IReadOnlyList SignRRSet(IReadOnlyList records)
{
throw new InvalidOperationException();
}
internal IReadOnlyList GetUpdatedNSecRRSet(string nextDomainName, uint ttl)
{
List types = new List(_entries.Count);
foreach (KeyValuePair> entry in _entries)
types.Add(entry.Key);
if (!_entries.ContainsKey(DnsResourceRecordType.NSEC))
types.Add(DnsResourceRecordType.NSEC);
if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG))
types.Add(DnsResourceRecordType.RRSIG);
types.Sort();
DnsNSECRecordData newNSecRecord = new DnsNSECRecordData(nextDomainName, types);
if (!_entries.TryGetValue(DnsResourceRecordType.NSEC, out IReadOnlyList existingRecords) || (existingRecords[0].TTL != ttl) || !existingRecords[0].RDATA.Equals(newNSecRecord))
return new DnsResourceRecord[] { new DnsResourceRecord(_name, DnsResourceRecordType.NSEC, DnsClass.IN, ttl, newNSecRecord) };
return Array.Empty();
}
internal IReadOnlyList GetUpdatedNSec3RRSet(IReadOnlyList newNSec3Records)
{
if (!_entries.TryGetValue(DnsResourceRecordType.NSEC3, out IReadOnlyList existingRecords) || (existingRecords[0].TTL != newNSec3Records[0].TTL) || !existingRecords[0].RDATA.Equals(newNSec3Records[0].RDATA))
return newNSec3Records;
return Array.Empty();
}
internal IReadOnlyList CreateNSec3RRSet(string hashedOwnerName, byte[] nextHashedOwnerName, uint ttl, ushort iterations, byte[] salt)
{
List types = new List(_entries.Count);
foreach (KeyValuePair> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.NSEC3:
case DnsResourceRecordType.RRSIG:
continue;
default:
types.Add(entry.Key);
break;
}
}
if (types.Count > 0)
{
//zone is not an empty non-terminal (ENT)
if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG))
types.Add(DnsResourceRecordType.RRSIG);
}
types.Sort();
DnsNSEC3RecordData newNSec3 = new DnsNSEC3RecordData(DnssecNSEC3HashAlgorithm.SHA1, DnssecNSEC3Flags.None, iterations, salt, nextHashedOwnerName, types);
return new DnsResourceRecord[] { new DnsResourceRecord(hashedOwnerName, DnsResourceRecordType.NSEC3, DnsClass.IN, ttl, newNSec3) };
}
internal DnsResourceRecord GetPartialNSec3Record(string zoneName, uint ttl, ushort iterations, byte[] salt)
{
List types = new List(_entries.Count);
foreach (KeyValuePair> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.NSEC3:
case DnsResourceRecordType.RRSIG:
continue;
default:
types.Add(entry.Key);
break;
}
}
if (_name.Equals(zoneName, StringComparison.OrdinalIgnoreCase))
{
types.Add(DnsResourceRecordType.NSEC3PARAM); //add NSEC3PARAM type to NSEC3 for unsigned zone apex
if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG))
types.Add(DnsResourceRecordType.RRSIG);
}
else if (types.Count > 0)
{
//zone is not an empty non-terminal (ENT)
if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG))
types.Add(DnsResourceRecordType.RRSIG);
}
types.Sort();
DnsNSEC3RecordData newNSec3Record = new DnsNSEC3RecordData(DnssecNSEC3HashAlgorithm.SHA1, DnssecNSEC3Flags.None, iterations, salt, Array.Empty(), types);
return new DnsResourceRecord(newNSec3Record.ComputeHashedOwnerName(_name) + (zoneName.Length > 0 ? "." + zoneName : ""), DnsResourceRecordType.NSEC3, DnsClass.IN, ttl, newNSec3Record);
}
#endregion
#region public
public void SyncRecords(Dictionary> newEntries)
{
//remove entires of type that do not exists in new entries
foreach (KeyValuePair> entry in _entries)
{
if (!newEntries.ContainsKey(entry.Key))
_entries.TryRemove(entry.Key, out _);
}
//set new entries into zone
if (this is ForwarderZone)
{
//skip NS and SOA records from being added to ForwarderZone
foreach (KeyValuePair> newEntry in newEntries)
{
switch (newEntry.Key)
{
case DnsResourceRecordType.NS:
case DnsResourceRecordType.SOA:
break;
default:
_entries[newEntry.Key] = newEntry.Value;
break;
}
}
}
else
{
foreach (KeyValuePair> newEntry in newEntries)
{
if (newEntry.Key == DnsResourceRecordType.SOA)
{
if (newEntry.Value.Count != 1)
continue; //skip invalid SOA record
if (this is SecondaryZone)
{
//copy existing SOA record's info to new SOA record
DnsResourceRecord existingSoaRecord = _entries[DnsResourceRecordType.SOA][0];
DnsResourceRecord newSoaRecord = newEntry.Value[0];
newSoaRecord.CopyRecordInfoFrom(existingSoaRecord);
}
}
_entries[newEntry.Key] = newEntry.Value;
}
}
}
public void SyncRecords(Dictionary> deletedEntries, Dictionary> addedEntries)
{
if (deletedEntries is not null)
{
foreach (KeyValuePair> deletedEntry in deletedEntries)
{
if (_entries.TryGetValue(deletedEntry.Key, out IReadOnlyList existingRecords))
{
List updatedRecords = new List(Math.Max(0, existingRecords.Count - deletedEntry.Value.Count));
foreach (DnsResourceRecord existingRecord in existingRecords)
{
bool deleted = false;
foreach (DnsResourceRecord deletedRecord in deletedEntry.Value)
{
if (existingRecord.RDATA.Equals(deletedRecord.RDATA))
{
deleted = true;
break;
}
}
if (!deleted)
updatedRecords.Add(existingRecord);
}
if (existingRecords.Count > updatedRecords.Count)
{
if (updatedRecords.Count > 0)
_entries[deletedEntry.Key] = updatedRecords;
else
_entries.TryRemove(deletedEntry.Key, out _);
}
}
}
}
if (addedEntries is not null)
{
foreach (KeyValuePair> addedEntry in addedEntries)
{
_entries.AddOrUpdate(addedEntry.Key, addedEntry.Value, delegate (DnsResourceRecordType key, IReadOnlyList existingRecords)
{
List updatedRecords = new List(existingRecords.Count + addedEntry.Value.Count);
updatedRecords.AddRange(existingRecords);
foreach (DnsResourceRecord addedRecord in addedEntry.Value)
{
bool exists = false;
foreach (DnsResourceRecord existingRecord in existingRecords)
{
if (addedRecord.RDATA.Equals(existingRecord.RDATA))
{
exists = true;
break;
}
}
if (!exists)
updatedRecords.Add(addedRecord);
}
if (updatedRecords.Count > existingRecords.Count)
return updatedRecords;
else
return existingRecords;
});
}
}
}
public void SyncGlueRecords(IReadOnlyCollection deletedGlueRecords, IReadOnlyCollection addedGlueRecords)
{
if (_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList nsRecords))
{
foreach (DnsResourceRecord nsRecord in nsRecords)
nsRecord.SyncGlueRecords(deletedGlueRecords, addedGlueRecords);
}
}
public void LoadRecords(DnsResourceRecordType type, IReadOnlyList records)
{
_entries[type] = records;
}
public virtual void SetRecords(DnsResourceRecordType type, IReadOnlyList records)
{
_entries[type] = records;
}
public virtual void AddRecord(DnsResourceRecord record)
{
AddRecord(record, out _, out _);
}
public virtual bool DeleteRecords(DnsResourceRecordType type)
{
return _entries.TryRemove(type, out _);
}
public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata)
{
return TryDeleteRecord(type, rdata, out _);
}
public virtual void UpdateRecord(DnsResourceRecord oldRecord, DnsResourceRecord newRecord)
{
if (oldRecord.Type == DnsResourceRecordType.SOA)
throw new InvalidOperationException("Cannot update record: use SetRecords() for " + oldRecord.Type.ToString() + " record");
if (oldRecord.Type != newRecord.Type)
throw new InvalidOperationException("Old and new record types do not match.");
if (!DeleteRecord(oldRecord.Type, oldRecord.RDATA))
throw new DnsWebServiceException("Cannot update record: the old record does not exists.");
AddRecord(newRecord);
}
public virtual IReadOnlyList QueryRecords(DnsResourceRecordType type, bool dnssecOk)
{
switch (type)
{
case DnsResourceRecordType.APP:
case DnsResourceRecordType.FWD:
case DnsResourceRecordType.NSEC:
case DnsResourceRecordType.NSEC3:
{
//return only exact type if exists
if (_entries.TryGetValue(type, out IReadOnlyList existingRecords))
{
IReadOnlyList filteredRecords = FilterDisabledRecords(type, existingRecords);
if (filteredRecords.Count > 0)
{
if (dnssecOk)
return AppendRRSigTo(filteredRecords);
return filteredRecords;
}
}
}
break;
case DnsResourceRecordType.ANY:
List records = new List(_entries.Count * 2);
foreach (KeyValuePair> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.FWD:
case DnsResourceRecordType.APP:
//skip records
continue;
default:
records.AddRange(entry.Value);
break;
}
}
return FilterDisabledRecords(type, records);
default:
{
//check for CNAME
if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords))
{
IReadOnlyList filteredRecords = FilterDisabledRecords(type, existingCNAMERecords);
if (filteredRecords.Count > 0)
{
if (dnssecOk)
return AppendRRSigTo(filteredRecords);
return filteredRecords;
}
}
//check for exact type
if (_entries.TryGetValue(type, out IReadOnlyList existingRecords))
{
IReadOnlyList filteredRecords = FilterDisabledRecords(type, existingRecords);
if (filteredRecords.Count > 0)
{
if (dnssecOk)
return AppendRRSigTo(filteredRecords);
return filteredRecords;
}
}
//check special processing
switch (type)
{
case DnsResourceRecordType.A:
case DnsResourceRecordType.AAAA:
//check for ANAME
if (_entries.TryGetValue(DnsResourceRecordType.ANAME, out IReadOnlyList anameRecords))
return FilterDisabledRecords(type, anameRecords);
break;
}
}
break;
}
return Array.Empty();
}
public IReadOnlyList GetRecords(DnsResourceRecordType type)
{
if (_entries.TryGetValue(type, out IReadOnlyList records))
return records;
return Array.Empty();
}
public IReadOnlyDictionary> GetAllRecords()
{
return _entries;
}
public override bool ContainsNameServerRecords()
{
if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList records))
return false;
foreach (DnsResourceRecord record in records)
{
if (record.GetAuthRecordInfo().Disabled)
continue;
return true;
}
return false;
}
#endregion
#region properties
public virtual bool Disabled
{
get { return _disabled; }
set { _disabled = value; }
}
public virtual bool IsActive
{
get { return !_disabled; }
}
#endregion
}
}