/*
Technitium DNS Server
Copyright (C) 2023 Shreyas Zare (shreyas@technitium.com)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
*/
using DnsServerCore.ApplicationCommon;
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using TechnitiumLibrary;
using TechnitiumLibrary.Net;
using TechnitiumLibrary.Net.Dns;
using TechnitiumLibrary.Net.Dns.ResourceRecords;
using TechnitiumLibrary.Net.Proxy;
namespace AdvancedForwarding
{
public class App : IDnsApplication, IDnsAuthoritativeRequestHandler
{
#region variables
IDnsServer _dnsServer;
bool _enableForwarding;
Dictionary _configProxyServers;
Dictionary _configForwarders;
IReadOnlyDictionary _networkGroupMap;
IReadOnlyDictionary _groups;
#endregion
#region IDisposable
public void Dispose()
{
if (_groups is not null)
{
foreach (KeyValuePair group in _groups)
group.Value.Dispose();
}
}
#endregion
#region private
private static IReadOnlyList GetUpdatedForwarderRecords(IReadOnlyList forwarderRecords, bool dnssecValidation, ConfigProxyServer configProxyServer)
{
List newForwarderRecords = new List(forwarderRecords.Count);
foreach (DnsForwarderRecordData forwarderRecord in forwarderRecords)
newForwarderRecords.Add(GetForwarderRecord(forwarderRecord.Protocol, forwarderRecord.Forwarder, dnssecValidation, configProxyServer));
return newForwarderRecords;
}
private static DnsForwarderRecordData GetForwarderRecord(NameServerAddress forwarder, bool dnssecValidation, ConfigProxyServer configProxyServer)
{
return GetForwarderRecord(forwarder.Protocol, forwarder.ToString(), dnssecValidation, configProxyServer);
}
private static DnsForwarderRecordData GetForwarderRecord(DnsTransportProtocol protocol, string forwarder, bool dnssecValidation, ConfigProxyServer configProxyServer)
{
DnsForwarderRecordData forwarderRecord;
if (configProxyServer is null)
forwarderRecord = new DnsForwarderRecordData(protocol, forwarder, dnssecValidation, NetProxyType.None, null, 0, null, null);
else
forwarderRecord = new DnsForwarderRecordData(protocol, forwarder, dnssecValidation, configProxyServer.Type, configProxyServer.ProxyAddress, configProxyServer.ProxyPort, configProxyServer.ProxyUsername, configProxyServer.ProxyPassword);
return forwarderRecord;
}
private Tuple ReadGroup(JsonElement jsonGroup)
{
Group group;
string name = jsonGroup.GetProperty("name").GetString();
if ((_groups is not null) && _groups.TryGetValue(name, out group))
group.ReloadConfig(_configProxyServers, _configForwarders, jsonGroup);
else
group = new Group(_dnsServer, _configProxyServers, _configForwarders, jsonGroup);
return new Tuple(group.Name, group);
}
#endregion
#region public
public Task InitializeAsync(IDnsServer dnsServer, string config)
{
_dnsServer = dnsServer;
using JsonDocument jsonDocument = JsonDocument.Parse(config);
JsonElement jsonConfig = jsonDocument.RootElement;
_enableForwarding = jsonConfig.GetPropertyValue("enableForwarding", true);
if (jsonConfig.TryReadArrayAsMap("proxyServers", delegate (JsonElement jsonProxy)
{
ConfigProxyServer proxyServer = new ConfigProxyServer(jsonProxy);
return new Tuple(proxyServer.Name, proxyServer);
}, out Dictionary configProxyServers))
_configProxyServers = configProxyServers;
else
_configProxyServers = null;
if (jsonConfig.TryReadArrayAsMap("forwarders", delegate (JsonElement jsonForwarder)
{
ConfigForwarder forwarder = new ConfigForwarder(jsonForwarder, _configProxyServers);
return new Tuple(forwarder.Name, forwarder);
}, out Dictionary configForwarders))
_configForwarders = configForwarders;
else
_configForwarders = null;
_networkGroupMap = jsonConfig.ReadObjectAsMap("networkGroupMap", delegate (string network, JsonElement jsonGroup)
{
if (!NetworkAddress.TryParse(network, out NetworkAddress networkAddress))
throw new FormatException("Network group map contains an invalid network address: " + network);
return new Tuple(networkAddress, jsonGroup.GetString());
});
if (jsonConfig.TryReadArrayAsMap("groups", ReadGroup, out Dictionary groups))
{
if (_groups is not null)
{
foreach (KeyValuePair group in _groups)
{
if (!groups.ContainsKey(group.Key))
group.Value.Dispose();
}
}
_groups = groups;
}
else
{
throw new FormatException("Groups array was not defined.");
}
return Task.CompletedTask;
}
public Task ProcessRequestAsync(DnsDatagram request, IPEndPoint remoteEP, DnsTransportProtocol protocol, bool isRecursionAllowed)
{
if (!_enableForwarding || !request.RecursionDesired)
return Task.FromResult(null);
IPAddress remoteIP = remoteEP.Address;
NetworkAddress network = null;
string groupName = null;
foreach (KeyValuePair entry in _networkGroupMap)
{
if (entry.Key.Contains(remoteIP) && ((network is null) || (entry.Key.PrefixLength > network.PrefixLength)))
{
network = entry.Key;
groupName = entry.Value;
}
}
if ((groupName is null) || !_groups.TryGetValue(groupName, out Group group) || !group.EnableForwarding)
return Task.FromResult(null);
DnsQuestionRecord question = request.Question[0];
string qname = question.Name;
if (!group.TryGetForwarderRecords(qname, out IReadOnlyList forwarderRecords))
return Task.FromResult(null);
request.SetShadowEDnsClientSubnetOption(network, true);
DnsResourceRecord[] authority = new DnsResourceRecord[forwarderRecords.Count];
for (int i = 0; i < forwarderRecords.Count; i++)
authority[i] = new DnsResourceRecord(qname, DnsResourceRecordType.FWD, DnsClass.IN, 0, forwarderRecords[i]);
return Task.FromResult(new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, null, authority));
}
#endregion
#region properties
public string Description
{ get { return "Performs bulk conditional forwarding for configured domain names and AdGuard Upstream config files."; } }
#endregion
class Group : IDisposable
{
#region variables
readonly IDnsServer _dnsServer;
Dictionary _configProxyServers;
Dictionary _configForwarders;
readonly string _name;
bool _enableForwarding;
IReadOnlyList _forwardings;
IReadOnlyDictionary _adguardUpstreams;
#endregion
#region constructor
public Group(IDnsServer dnsServer, Dictionary configProxyServers, Dictionary configForwarders, JsonElement jsonGroup)
{
_dnsServer = dnsServer;
_name = jsonGroup.GetProperty("name").GetString();
ReloadConfig(configProxyServers, configForwarders, jsonGroup);
}
#endregion
#region IDisposable
public void Dispose()
{
if (_adguardUpstreams is not null)
{
foreach (KeyValuePair adguardUpstream in _adguardUpstreams)
adguardUpstream.Value.Dispose();
_adguardUpstreams = null;
}
}
#endregion
#region private
private Tuple ReadAdGuardUpstream(JsonElement jsonAdguardUpstream)
{
AdGuardUpstream adGuardUpstream;
string name = jsonAdguardUpstream.GetProperty("configFile").GetString();
if ((_adguardUpstreams is not null) && _adguardUpstreams.TryGetValue(name, out adGuardUpstream))
adGuardUpstream.ReloadConfig(_configProxyServers, jsonAdguardUpstream);
else
adGuardUpstream = new AdGuardUpstream(_dnsServer, _configProxyServers, jsonAdguardUpstream);
return new Tuple(adGuardUpstream.Name, adGuardUpstream);
}
#endregion
#region public
public void ReloadConfig(Dictionary configProxyServers, Dictionary configForwarders, JsonElement jsonGroup)
{
_configProxyServers = configProxyServers;
_configForwarders = configForwarders;
_enableForwarding = jsonGroup.GetPropertyValue("enableForwarding", true);
if (jsonGroup.TryReadArray("forwardings", delegate (JsonElement jsonForwarding) { return new Forwarding(jsonForwarding, _configForwarders); }, out Forwarding[] forwardings))
_forwardings = forwardings;
else
_forwardings = null;
if (jsonGroup.TryReadArrayAsMap("adguardUpstreams", ReadAdGuardUpstream, out Dictionary adguardUpstreams))
{
if (_adguardUpstreams is not null)
{
foreach (KeyValuePair adguardUpstream in _adguardUpstreams)
{
if (!adguardUpstreams.ContainsKey(adguardUpstream.Key))
adguardUpstream.Value.Dispose();
}
}
_adguardUpstreams = adguardUpstreams;
}
else
{
if (_adguardUpstreams is not null)
{
foreach (KeyValuePair adguardUpstream in _adguardUpstreams)
adguardUpstream.Value.Dispose();
}
_adguardUpstreams = null;
}
}
public bool TryGetForwarderRecords(string domain, out IReadOnlyList forwarderRecords)
{
domain = domain.ToLower();
if ((_forwardings is not null) && (_forwardings.Count > 0) && Forwarding.TryGetForwarderRecords(domain, _forwardings, out forwarderRecords))
return true;
if (_adguardUpstreams is not null)
{
foreach (KeyValuePair adguardUpstream in _adguardUpstreams)
{
if (adguardUpstream.Value.TryGetForwarderRecords(domain, out forwarderRecords))
return true;
}
}
forwarderRecords = null;
return false;
}
#endregion
#region properties
public string Name
{ get { return _name; } }
public bool EnableForwarding
{ get { return _enableForwarding; } }
#endregion
}
class Forwarding
{
#region variables
IReadOnlyList _forwarderRecords;
readonly IReadOnlyDictionary _domainMap;
#endregion
#region constructor
public Forwarding(JsonElement jsonForwarding, Dictionary configForwarders)
{
JsonElement jsonForwarders = jsonForwarding.GetProperty("forwarders");
List forwarderRecords = new List();
foreach (JsonElement jsonForwarder in jsonForwarders.EnumerateArray())
{
string forwarderName = jsonForwarder.GetString();
if ((configForwarders is null) || !configForwarders.TryGetValue(forwarderName, out ConfigForwarder configForwarder))
throw new FormatException("Forwarder was not defined: " + forwarderName);
forwarderRecords.AddRange(configForwarder.ForwarderRecords);
}
_forwarderRecords = forwarderRecords;
_domainMap = jsonForwarding.ReadArrayAsMap("domains", delegate (JsonElement jsonDomain)
{
return new Tuple(jsonDomain.GetString().ToLower(), null);
});
}
public Forwarding(IReadOnlyList domains, NameServerAddress forwarder, bool dnssecValidation, ConfigProxyServer proxy)
: this(new DnsForwarderRecordData[] { GetForwarderRecord(forwarder, dnssecValidation, proxy) }, domains)
{ }
public Forwarding(IReadOnlyList forwarderRecords, IReadOnlyList domains)
{
_forwarderRecords = forwarderRecords;
Dictionary domainMap = new Dictionary(domains.Count);
foreach (string domain in domains)
{
if (DnsClient.IsDomainNameValid(domain))
domainMap.TryAdd(domain.ToLower(), null);
}
_domainMap = domainMap;
}
#endregion
#region static
public static bool TryGetForwarderRecords(string domain, IReadOnlyList forwardings, out IReadOnlyList forwarderRecords)
{
if (forwardings.Count == 1)
{
if (forwardings[0].TryGetForwarderRecords(domain, out forwarderRecords, out _))
return true;
}
else
{
Dictionary> fwdMap = new Dictionary>(forwardings.Count);
foreach (Forwarding forwarding in forwardings)
{
if (forwarding.TryGetForwarderRecords(domain, out IReadOnlyList fwdRecords, out string matchedDomain))
{
if (fwdMap.TryGetValue(matchedDomain, out List fwdRecordsList))
{
fwdRecordsList.AddRange(fwdRecords);
}
else
{
fwdRecordsList = new List(fwdRecords);
fwdMap.Add(matchedDomain, fwdRecordsList);
}
}
}
if (fwdMap.Count > 0)
{
forwarderRecords = null;
string lastMatchedDomain = null;
foreach (KeyValuePair> fwdEntry in fwdMap)
{
if ((lastMatchedDomain is null) || (fwdEntry.Key.Length > lastMatchedDomain.Length) || ((fwdEntry.Key.Length == lastMatchedDomain.Length) && lastMatchedDomain.StartsWith("*.")))
{
lastMatchedDomain = fwdEntry.Key;
forwarderRecords = fwdEntry.Value;
}
}
return true;
}
}
forwarderRecords = null;
return false;
}
public static bool IsForwarderDomain(string domain, IReadOnlyList forwardings)
{
foreach (Forwarding forwarding in forwardings)
{
if (IsForwarderDomain(domain, forwarding._forwarderRecords))
return true;
}
return false;
}
public static bool IsForwarderDomain(string domain, IReadOnlyList forwarderRecords)
{
foreach (DnsForwarderRecordData forwarderRecord in forwarderRecords)
{
if (domain.Equals(forwarderRecord.NameServer.Host, StringComparison.OrdinalIgnoreCase))
return true;
}
return false;
}
#endregion
#region private
private static string GetParentZone(string domain)
{
int i = domain.IndexOf('.');
if (i > -1)
return domain.Substring(i + 1);
//dont return root zone
return null;
}
private bool IsDomainMatching(string domain, out string matchedDomain)
{
string parent;
do
{
if (_domainMap.TryGetValue(domain, out _))
{
matchedDomain = domain;
return true;
}
parent = GetParentZone(domain);
if (parent is null)
{
if (_domainMap.TryGetValue("*", out _))
{
matchedDomain = "*";
return true;
}
break;
}
domain = "*." + parent;
if (_domainMap.TryGetValue(domain, out _))
{
matchedDomain = domain;
return true;
}
domain = parent;
}
while (true);
matchedDomain = null;
return false;
}
private bool TryGetForwarderRecords(string domain, out IReadOnlyList forwarderRecords, out string matchedDomain)
{
if (IsDomainMatching(domain, out matchedDomain))
{
forwarderRecords = _forwarderRecords;
return true;
}
forwarderRecords = null;
return false;
}
#endregion
#region public
public void UpdateForwarderRecords(bool dnssecValidation, ConfigProxyServer proxy)
{
_forwarderRecords = GetUpdatedForwarderRecords(_forwarderRecords, dnssecValidation, proxy);
}
#endregion
}
class AdGuardUpstream : IDisposable
{
#region variables
readonly IDnsServer _dnsServer;
readonly string _name;
ConfigProxyServer _configProxyServer;
bool _dnssecValidation;
IReadOnlyList _defaultForwarderRecords;
IReadOnlyList _forwardings;
readonly string _configFile;
DateTime _configFileLastModified;
Timer _autoReloadTimer;
const int AUTO_RELOAD_TIMER_INTERVAL = 60000;
#endregion
#region constructor
public AdGuardUpstream(IDnsServer dnsServer, Dictionary configProxyServers, JsonElement jsonAdguardUpstream)
{
_dnsServer = dnsServer;
_name = jsonAdguardUpstream.GetProperty("configFile").GetString();
_configFile = _name;
if (!Path.IsPathRooted(_configFile))
_configFile = Path.Combine(_dnsServer.ApplicationFolder, _configFile);
_autoReloadTimer = new Timer(delegate (object state)
{
try
{
DateTime configFileLastModified = File.GetLastWriteTimeUtc(_configFile);
if (configFileLastModified > _configFileLastModified)
{
ReloadUpstreamsFile();
//force GC collection to remove old cache data from memory quickly
GC.Collect();
}
}
catch (Exception ex)
{
_dnsServer.WriteLog(ex);
}
finally
{
_autoReloadTimer?.Change(AUTO_RELOAD_TIMER_INTERVAL, Timeout.Infinite);
}
});
ReloadConfig(configProxyServers, jsonAdguardUpstream);
}
#endregion
#region IDisposable
public void Dispose()
{
if (_autoReloadTimer is not null)
{
_autoReloadTimer.Dispose();
_autoReloadTimer = null;
}
}
#endregion
#region private
private void ReloadUpstreamsFile()
{
try
{
_dnsServer.WriteLog("The app is reading AdGuard Upstreams config file: " + _configFile);
List defaultForwarderRecords = new List();
List forwardings = new List();
using (FileStream fS = new FileStream(_configFile, FileMode.Open, FileAccess.Read))
{
StreamReader sR = new StreamReader(fS, true);
string line;
while (true)
{
line = sR.ReadLine();
if (line is null)
break; //eof
line = line.TrimStart();
if (line.Length == 0)
continue; //skip empty line
if (line.StartsWith('#'))
continue; //skip comment line
if (line.StartsWith('['))
{
int i = line.LastIndexOf(']');
if (i < 0)
throw new FormatException("Invalid AdGuard Upstreams config file format: missing ']' bracket.");
string[] domains = line.Substring(1, i - 1).Split('/', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
string forwarder = line.Substring(i + 1);
if (forwarder == "#")
{
if (defaultForwarderRecords.Count == 0)
throw new FormatException("Invalid AdGuard Upstreams config file format: missing default upstream servers.");
forwardings.Add(new Forwarding(defaultForwarderRecords, domains));
}
else
{
forwardings.Add(new Forwarding(domains, NameServerAddress.Parse(forwarder), _dnssecValidation, _configProxyServer));
}
}
else
{
defaultForwarderRecords.Add(GetForwarderRecord(NameServerAddress.Parse(line), _dnssecValidation, _configProxyServer));
}
}
_configFileLastModified = File.GetLastWriteTimeUtc(fS.SafeFileHandle);
}
_defaultForwarderRecords = defaultForwarderRecords;
_forwardings = forwardings;
_dnsServer.WriteLog("The app has successfully loaded AdGuard Upstreams config file: " + _configFile);
}
catch (Exception ex)
{
_dnsServer.WriteLog("The app failed to read AdGuard Upstreams config file: " + _configFile + "\r\n" + ex.ToString());
}
}
#endregion
#region public
public void ReloadConfig(Dictionary configProxyServers, JsonElement jsonAdguardUpstream)
{
string proxyName = jsonAdguardUpstream.GetPropertyValue("proxy", null);
_dnssecValidation = jsonAdguardUpstream.GetPropertyValue("dnssecValidation", true);
ConfigProxyServer configProxyServer = null;
if (!string.IsNullOrEmpty(proxyName) && ((configProxyServers is null) || !configProxyServers.TryGetValue(proxyName, out configProxyServer)))
throw new FormatException("Proxy server was not defined: " + proxyName);
_configProxyServer = configProxyServer;
DateTime configFileLastModified = File.GetLastWriteTimeUtc(_configFile);
if (configFileLastModified > _configFileLastModified)
{
//reload complete config file
_autoReloadTimer.Change(0, Timeout.Infinite);
}
else
{
//update only forwarder records
_defaultForwarderRecords = GetUpdatedForwarderRecords(_defaultForwarderRecords, _dnssecValidation, _configProxyServer);
foreach (Forwarding forwarding in _forwardings)
forwarding.UpdateForwarderRecords(_dnssecValidation, _configProxyServer);
}
}
public bool TryGetForwarderRecords(string domain, out IReadOnlyList forwarderRecords)
{
if ((_forwardings is not null) && (_forwardings.Count > 0))
{
if (Forwarding.IsForwarderDomain(domain, _forwardings))
{
forwarderRecords = null;
return false;
}
if (Forwarding.TryGetForwarderRecords(domain, _forwardings, out forwarderRecords))
return true;
}
if ((_defaultForwarderRecords is not null) && (_defaultForwarderRecords.Count > 0))
{
if (Forwarding.IsForwarderDomain(domain, _defaultForwarderRecords))
{
forwarderRecords = null;
return false;
}
forwarderRecords = _defaultForwarderRecords;
return true;
}
forwarderRecords = null;
return false;
}
#endregion
#region property
public string Name
{ get { return _name; } }
#endregion
}
class ConfigProxyServer
{
#region variables
readonly string _name;
readonly NetProxyType _type;
readonly string _proxyAddress;
readonly ushort _proxyPort;
readonly string _proxyUsername;
readonly string _proxyPassword;
#endregion
#region constructor
public ConfigProxyServer(JsonElement jsonProxy)
{
_name = jsonProxy.GetProperty("name").GetString();
_type = jsonProxy.GetPropertyEnumValue("type", NetProxyType.Http);
_proxyAddress = jsonProxy.GetProperty("proxyAddress").GetString();
_proxyPort = jsonProxy.GetProperty("proxyPort").GetUInt16();
_proxyUsername = jsonProxy.GetPropertyValue("proxyUsername", null);
_proxyPassword = jsonProxy.GetPropertyValue("proxyPassword", null);
}
#endregion
#region properties
public string Name
{ get { return _name; } }
public NetProxyType Type
{ get { return _type; } }
public string ProxyAddress
{ get { return _proxyAddress; } }
public ushort ProxyPort
{ get { return _proxyPort; } }
public string ProxyUsername
{ get { return _proxyUsername; } }
public string ProxyPassword
{ get { return _proxyPassword; } }
#endregion
}
class ConfigForwarder
{
#region variables
readonly string _name;
readonly IReadOnlyList _forwarderRecords;
#endregion
#region constructor
public ConfigForwarder(JsonElement jsonForwarder, Dictionary configProxyServers)
{
_name = jsonForwarder.GetProperty("name").GetString();
string proxyName = jsonForwarder.GetPropertyValue("proxy", null);
bool dnssecValidation = jsonForwarder.GetPropertyValue("dnssecValidation", true);
DnsTransportProtocol forwarderProtocol = jsonForwarder.GetPropertyEnumValue("forwarderProtocol", DnsTransportProtocol.Udp);
ConfigProxyServer configProxyServer = null;
if (!string.IsNullOrEmpty(proxyName) && ((configProxyServers is null) || !configProxyServers.TryGetValue(proxyName, out configProxyServer)))
throw new FormatException("Proxy server was not defined: " + proxyName);
_forwarderRecords = jsonForwarder.ReadArray("forwarderAddresses", delegate (string address)
{
return GetForwarderRecord(forwarderProtocol, address, dnssecValidation, configProxyServer);
});
}
#endregion
#region properties
public string Name
{ get { return _name; } }
public IReadOnlyList ForwarderRecords
{ get { return _forwarderRecords; } }
#endregion
}
}
}