/*
Technitium DNS Server
Copyright (C) 2024 Shreyas Zare (shreyas@technitium.com)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
*/
using DnsServerCore.ApplicationCommon;
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Text.Json;
using System.Threading.Tasks;
using TechnitiumLibrary;
using TechnitiumLibrary.Net.Dns;
using TechnitiumLibrary.Net.Dns.ResourceRecords;
namespace NxDomainOverride
{
public sealed class App : IDnsApplication, IDnsPostProcessor
{
#region variables
bool _enableOverride;
uint _defaultTtl;
Dictionary _domainSetMap;
Dictionary _sets;
#endregion
#region IDisposable
public void Dispose()
{
//do nothing
}
#endregion
#region private
private static string GetParentZone(string domain)
{
int i = domain.IndexOf('.');
if (i > -1)
return domain.Substring(i + 1);
//dont return root zone
return null;
}
private bool TryGetMappedSets(string domain, out string[] setNames)
{
domain = domain.ToLowerInvariant();
string parent;
do
{
if (_domainSetMap.TryGetValue(domain, out setNames))
return true;
parent = GetParentZone(domain);
if (parent is null)
{
if (_domainSetMap.TryGetValue("*", out setNames))
return true;
break;
}
domain = "*." + parent;
if (_domainSetMap.TryGetValue(domain, out setNames))
return true;
domain = parent;
}
while (true);
return false;
}
#endregion
#region public
public Task InitializeAsync(IDnsServer dnsServer, string config)
{
using JsonDocument jsonDocument = JsonDocument.Parse(config);
JsonElement jsonConfig = jsonDocument.RootElement;
_enableOverride = jsonConfig.GetPropertyValue("enableOverride", true);
_defaultTtl = jsonConfig.GetPropertyValue("defaultTtl", 300u);
_domainSetMap = jsonConfig.ReadObjectAsMap("domainSetMap", delegate (string domain, JsonElement jsonSets)
{
string[] sets = jsonSets.GetArray();
return new Tuple(domain.ToLowerInvariant(), sets);
});
_sets = jsonConfig.ReadArrayAsMap("sets", delegate (JsonElement jsonSet)
{
Set set = new Set(jsonSet);
return new Tuple(set.Name, set);
});
return Task.CompletedTask;
}
public Task PostProcessAsync(DnsDatagram request, IPEndPoint remoteEP, DnsTransportProtocol protocol, DnsDatagram response)
{
if (!_enableOverride)
return Task.FromResult(response);
if (response.DnssecOk)
return Task.FromResult(response);
if (response.OPCODE != DnsOpcode.StandardQuery)
return Task.FromResult(response);
if (response.RCODE != DnsResponseCode.NxDomain)
return Task.FromResult(response);
DnsQuestionRecord question = request.Question[0];
switch (question.Type)
{
case DnsResourceRecordType.A:
case DnsResourceRecordType.AAAA:
break;
default:
//NO DATA response
return Task.FromResult(new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, response.Truncation, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, DnsResponseCode.NoError, response.Question, response.Answer, response.Authority, response.Additional) { Tag = response.Tag });
}
string nxDomain = question.Name;
foreach (DnsResourceRecord record in response.Answer)
{
if (record.Type == DnsResourceRecordType.CNAME)
nxDomain = (record.RDATA as DnsCNAMERecordData).Domain;
}
if (!TryGetMappedSets(nxDomain, out string[] setNames))
return Task.FromResult(response);
List newAnswer = new List();
newAnswer.AddRange(response.Answer);
foreach (string setName in setNames)
{
if (_sets.TryGetValue(setName, out Set set))
{
switch (question.Type)
{
case DnsResourceRecordType.A:
foreach (DnsResourceRecordData rdata in set.RecordDataAddresses)
{
if (rdata is DnsARecordData)
newAnswer.Add(new DnsResourceRecord(nxDomain, DnsResourceRecordType.A, DnsClass.IN, _defaultTtl, rdata));
}
break;
case DnsResourceRecordType.AAAA:
foreach (DnsResourceRecordData rdata in set.RecordDataAddresses)
{
if (rdata is DnsAAAARecordData)
newAnswer.Add(new DnsResourceRecord(nxDomain, DnsResourceRecordType.AAAA, DnsClass.IN, _defaultTtl, rdata));
}
break;
default:
throw new InvalidOperationException();
}
}
}
return Task.FromResult(new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, response.Truncation, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, DnsResponseCode.NoError, response.Question, newAnswer) { Tag = response.Tag });
}
#endregion
#region properties
public string Description
{ get { return "Overrides NX Domain response with custom A/AAAA record response for configured domain names."; } }
#endregion
class Set
{
#region variables
readonly string _name;
readonly DnsResourceRecordData[] _rdataAddresses;
#endregion
#region constructor
public Set(JsonElement jsonSet)
{
_name = jsonSet.GetProperty("name").GetString();
_rdataAddresses = jsonSet.ReadArray("addresses", delegate (string item)
{
IPAddress address = IPAddress.Parse(item);
switch (address.AddressFamily)
{
case AddressFamily.InterNetwork:
return new DnsARecordData(address);
case AddressFamily.InterNetworkV6:
return new DnsAAAARecordData(address);
default:
throw new NotSupportedException("Address family not supported: " + address.AddressFamily.ToString());
}
});
}
#endregion
#region properties
public string Name
{ get { return _name; } }
public DnsResourceRecordData[] RecordDataAddresses
{ get { return _rdataAddresses; } }
#endregion
}
}
}