/* Technitium DNS Server Copyright (C) 2023 Shreyas Zare (shreyas@technitium.com) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ using DnsServerCore.ApplicationCommon; using System; using System.Collections.Generic; using System.IO; using System.Net; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using TechnitiumLibrary; using TechnitiumLibrary.Net; using TechnitiumLibrary.Net.Dns; using TechnitiumLibrary.Net.Dns.ResourceRecords; using TechnitiumLibrary.Net.Proxy; namespace AdvancedForwarding { public class App : IDnsApplication, IDnsAuthoritativeRequestHandler { #region variables IDnsServer _dnsServer; bool _enableForwarding; Dictionary _configProxyServers; Dictionary _configForwarders; IReadOnlyDictionary _networkGroupMap; IReadOnlyDictionary _groups; #endregion #region IDisposable public void Dispose() { if (_groups is not null) { foreach (KeyValuePair group in _groups) group.Value.Dispose(); } } #endregion #region private private static IReadOnlyList GetUpdatedForwarderRecords(IReadOnlyList forwarderRecords, bool dnssecValidation, ConfigProxyServer configProxyServer) { List newForwarderRecords = new List(forwarderRecords.Count); foreach (DnsForwarderRecordData forwarderRecord in forwarderRecords) newForwarderRecords.Add(GetForwarderRecord(forwarderRecord.Protocol, forwarderRecord.Forwarder, dnssecValidation, configProxyServer)); return newForwarderRecords; } private static DnsForwarderRecordData GetForwarderRecord(NameServerAddress forwarder, bool dnssecValidation, ConfigProxyServer configProxyServer) { return GetForwarderRecord(forwarder.Protocol, forwarder.ToString(), dnssecValidation, configProxyServer); } private static DnsForwarderRecordData GetForwarderRecord(DnsTransportProtocol protocol, string forwarder, bool dnssecValidation, ConfigProxyServer configProxyServer) { DnsForwarderRecordData forwarderRecord; if (configProxyServer is null) forwarderRecord = new DnsForwarderRecordData(protocol, forwarder, dnssecValidation, NetProxyType.None, null, 0, null, null); else forwarderRecord = new DnsForwarderRecordData(protocol, forwarder, dnssecValidation, configProxyServer.Type, configProxyServer.ProxyAddress, configProxyServer.ProxyPort, configProxyServer.ProxyUsername, configProxyServer.ProxyPassword); return forwarderRecord; } private Tuple ReadGroup(JsonElement jsonGroup) { Group group; string name = jsonGroup.GetProperty("name").GetString(); if ((_groups is not null) && _groups.TryGetValue(name, out group)) group.ReloadConfig(_configProxyServers, _configForwarders, jsonGroup); else group = new Group(_dnsServer, _configProxyServers, _configForwarders, jsonGroup); return new Tuple(group.Name, group); } #endregion #region public public Task InitializeAsync(IDnsServer dnsServer, string config) { _dnsServer = dnsServer; using JsonDocument jsonDocument = JsonDocument.Parse(config); JsonElement jsonConfig = jsonDocument.RootElement; _enableForwarding = jsonConfig.GetPropertyValue("enableForwarding", true); if (jsonConfig.TryReadArrayAsMap("proxyServers", delegate (JsonElement jsonProxy) { ConfigProxyServer proxyServer = new ConfigProxyServer(jsonProxy); return new Tuple(proxyServer.Name, proxyServer); }, out Dictionary configProxyServers)) _configProxyServers = configProxyServers; else _configProxyServers = null; if (jsonConfig.TryReadArrayAsMap("forwarders", delegate (JsonElement jsonForwarder) { ConfigForwarder forwarder = new ConfigForwarder(jsonForwarder, _configProxyServers); return new Tuple(forwarder.Name, forwarder); }, out Dictionary configForwarders)) _configForwarders = configForwarders; else _configForwarders = null; _networkGroupMap = jsonConfig.ReadObjectAsMap("networkGroupMap", delegate (string network, JsonElement jsonGroup) { if (!NetworkAddress.TryParse(network, out NetworkAddress networkAddress)) throw new FormatException("Network group map contains an invalid network address: " + network); return new Tuple(networkAddress, jsonGroup.GetString()); }); if (jsonConfig.TryReadArrayAsMap("groups", ReadGroup, out Dictionary groups)) { if (_groups is not null) { foreach (KeyValuePair group in _groups) { if (!groups.ContainsKey(group.Key)) group.Value.Dispose(); } } _groups = groups; } else { throw new FormatException("Groups array was not defined."); } return Task.CompletedTask; } public Task ProcessRequestAsync(DnsDatagram request, IPEndPoint remoteEP, DnsTransportProtocol protocol, bool isRecursionAllowed) { if (!_enableForwarding || !request.RecursionDesired) return Task.FromResult(null); IPAddress remoteIP = remoteEP.Address; NetworkAddress network = null; string groupName = null; foreach (KeyValuePair entry in _networkGroupMap) { if (entry.Key.Contains(remoteIP) && ((network is null) || (entry.Key.PrefixLength > network.PrefixLength))) { network = entry.Key; groupName = entry.Value; } } if ((groupName is null) || !_groups.TryGetValue(groupName, out Group group) || !group.EnableForwarding) return Task.FromResult(null); DnsQuestionRecord question = request.Question[0]; string qname = question.Name; if (!group.TryGetForwarderRecords(qname, out IReadOnlyList forwarderRecords)) return Task.FromResult(null); request.SetShadowEDnsClientSubnetOption(network, true); DnsResourceRecord[] authority = new DnsResourceRecord[forwarderRecords.Count]; for (int i = 0; i < forwarderRecords.Count; i++) authority[i] = new DnsResourceRecord(qname, DnsResourceRecordType.FWD, DnsClass.IN, 0, forwarderRecords[i]); return Task.FromResult(new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, null, authority)); } #endregion #region properties public string Description { get { return "Performs bulk conditional forwarding for configured domain names and AdGuard Upstream config files."; } } #endregion class Group : IDisposable { #region variables readonly IDnsServer _dnsServer; Dictionary _configProxyServers; Dictionary _configForwarders; readonly string _name; bool _enableForwarding; IReadOnlyList _forwardings; IReadOnlyDictionary _adguardUpstreams; #endregion #region constructor public Group(IDnsServer dnsServer, Dictionary configProxyServers, Dictionary configForwarders, JsonElement jsonGroup) { _dnsServer = dnsServer; _name = jsonGroup.GetProperty("name").GetString(); ReloadConfig(configProxyServers, configForwarders, jsonGroup); } #endregion #region IDisposable public void Dispose() { if (_adguardUpstreams is not null) { foreach (KeyValuePair adguardUpstream in _adguardUpstreams) adguardUpstream.Value.Dispose(); _adguardUpstreams = null; } } #endregion #region private private Tuple ReadAdGuardUpstream(JsonElement jsonAdguardUpstream) { AdGuardUpstream adGuardUpstream; string name = jsonAdguardUpstream.GetProperty("configFile").GetString(); if ((_adguardUpstreams is not null) && _adguardUpstreams.TryGetValue(name, out adGuardUpstream)) adGuardUpstream.ReloadConfig(_configProxyServers, jsonAdguardUpstream); else adGuardUpstream = new AdGuardUpstream(_dnsServer, _configProxyServers, jsonAdguardUpstream); return new Tuple(adGuardUpstream.Name, adGuardUpstream); } #endregion #region public public void ReloadConfig(Dictionary configProxyServers, Dictionary configForwarders, JsonElement jsonGroup) { _configProxyServers = configProxyServers; _configForwarders = configForwarders; _enableForwarding = jsonGroup.GetPropertyValue("enableForwarding", true); if (jsonGroup.TryReadArray("forwardings", delegate (JsonElement jsonForwarding) { return new Forwarding(jsonForwarding, _configForwarders); }, out Forwarding[] forwardings)) _forwardings = forwardings; else _forwardings = null; if (jsonGroup.TryReadArrayAsMap("adguardUpstreams", ReadAdGuardUpstream, out Dictionary adguardUpstreams)) { if (_adguardUpstreams is not null) { foreach (KeyValuePair adguardUpstream in _adguardUpstreams) { if (!adguardUpstreams.ContainsKey(adguardUpstream.Key)) adguardUpstream.Value.Dispose(); } } _adguardUpstreams = adguardUpstreams; } else { if (_adguardUpstreams is not null) { foreach (KeyValuePair adguardUpstream in _adguardUpstreams) adguardUpstream.Value.Dispose(); } _adguardUpstreams = null; } } public bool TryGetForwarderRecords(string domain, out IReadOnlyList forwarderRecords) { domain = domain.ToLower(); if ((_forwardings is not null) && (_forwardings.Count > 0) && Forwarding.TryGetForwarderRecords(domain, _forwardings, out forwarderRecords)) return true; if (_adguardUpstreams is not null) { foreach (KeyValuePair adguardUpstream in _adguardUpstreams) { if (adguardUpstream.Value.TryGetForwarderRecords(domain, out forwarderRecords)) return true; } } forwarderRecords = null; return false; } #endregion #region properties public string Name { get { return _name; } } public bool EnableForwarding { get { return _enableForwarding; } } #endregion } class Forwarding { #region variables IReadOnlyList _forwarderRecords; readonly IReadOnlyDictionary _domainMap; #endregion #region constructor public Forwarding(JsonElement jsonForwarding, Dictionary configForwarders) { JsonElement jsonForwarders = jsonForwarding.GetProperty("forwarders"); List forwarderRecords = new List(); foreach (JsonElement jsonForwarder in jsonForwarders.EnumerateArray()) { string forwarderName = jsonForwarder.GetString(); if ((configForwarders is null) || !configForwarders.TryGetValue(forwarderName, out ConfigForwarder configForwarder)) throw new FormatException("Forwarder was not defined: " + forwarderName); forwarderRecords.AddRange(configForwarder.ForwarderRecords); } _forwarderRecords = forwarderRecords; _domainMap = jsonForwarding.ReadArrayAsMap("domains", delegate (JsonElement jsonDomain) { return new Tuple(jsonDomain.GetString().ToLower(), null); }); } public Forwarding(IReadOnlyList domains, NameServerAddress forwarder, bool dnssecValidation, ConfigProxyServer proxy) : this(new DnsForwarderRecordData[] { GetForwarderRecord(forwarder, dnssecValidation, proxy) }, domains) { } public Forwarding(IReadOnlyList forwarderRecords, IReadOnlyList domains) { _forwarderRecords = forwarderRecords; Dictionary domainMap = new Dictionary(domains.Count); foreach (string domain in domains) { if (DnsClient.IsDomainNameValid(domain)) domainMap.TryAdd(domain.ToLower(), null); } _domainMap = domainMap; } #endregion #region static public static bool TryGetForwarderRecords(string domain, IReadOnlyList forwardings, out IReadOnlyList forwarderRecords) { if (forwardings.Count == 1) { if (forwardings[0].TryGetForwarderRecords(domain, out forwarderRecords, out _)) return true; } else { Dictionary> fwdMap = new Dictionary>(forwardings.Count); foreach (Forwarding forwarding in forwardings) { if (forwarding.TryGetForwarderRecords(domain, out IReadOnlyList fwdRecords, out string matchedDomain)) { if (fwdMap.TryGetValue(matchedDomain, out List fwdRecordsList)) { fwdRecordsList.AddRange(fwdRecords); } else { fwdRecordsList = new List(fwdRecords); fwdMap.Add(matchedDomain, fwdRecordsList); } } } if (fwdMap.Count > 0) { forwarderRecords = null; string lastMatchedDomain = null; foreach (KeyValuePair> fwdEntry in fwdMap) { if ((lastMatchedDomain is null) || (fwdEntry.Key.Length > lastMatchedDomain.Length) || ((fwdEntry.Key.Length == lastMatchedDomain.Length) && lastMatchedDomain.StartsWith("*."))) { lastMatchedDomain = fwdEntry.Key; forwarderRecords = fwdEntry.Value; } } return true; } } forwarderRecords = null; return false; } public static bool IsForwarderDomain(string domain, IReadOnlyList forwardings) { foreach (Forwarding forwarding in forwardings) { if (IsForwarderDomain(domain, forwarding._forwarderRecords)) return true; } return false; } public static bool IsForwarderDomain(string domain, IReadOnlyList forwarderRecords) { foreach (DnsForwarderRecordData forwarderRecord in forwarderRecords) { if (domain.Equals(forwarderRecord.NameServer.Host, StringComparison.OrdinalIgnoreCase)) return true; } return false; } #endregion #region private private static string GetParentZone(string domain) { int i = domain.IndexOf('.'); if (i > -1) return domain.Substring(i + 1); //dont return root zone return null; } private bool IsDomainMatching(string domain, out string matchedDomain) { string parent; do { if (_domainMap.TryGetValue(domain, out _)) { matchedDomain = domain; return true; } parent = GetParentZone(domain); if (parent is null) { if (_domainMap.TryGetValue("*", out _)) { matchedDomain = "*"; return true; } break; } domain = "*." + parent; if (_domainMap.TryGetValue(domain, out _)) { matchedDomain = domain; return true; } domain = parent; } while (true); matchedDomain = null; return false; } private bool TryGetForwarderRecords(string domain, out IReadOnlyList forwarderRecords, out string matchedDomain) { if (IsDomainMatching(domain, out matchedDomain)) { forwarderRecords = _forwarderRecords; return true; } forwarderRecords = null; return false; } #endregion #region public public void UpdateForwarderRecords(bool dnssecValidation, ConfigProxyServer proxy) { _forwarderRecords = GetUpdatedForwarderRecords(_forwarderRecords, dnssecValidation, proxy); } #endregion } class AdGuardUpstream : IDisposable { #region variables readonly IDnsServer _dnsServer; readonly string _name; ConfigProxyServer _configProxyServer; bool _dnssecValidation; IReadOnlyList _defaultForwarderRecords; IReadOnlyList _forwardings; readonly string _configFile; DateTime _configFileLastModified; Timer _autoReloadTimer; const int AUTO_RELOAD_TIMER_INTERVAL = 60000; #endregion #region constructor public AdGuardUpstream(IDnsServer dnsServer, Dictionary configProxyServers, JsonElement jsonAdguardUpstream) { _dnsServer = dnsServer; _name = jsonAdguardUpstream.GetProperty("configFile").GetString(); _configFile = _name; if (!Path.IsPathRooted(_configFile)) _configFile = Path.Combine(_dnsServer.ApplicationFolder, _configFile); _autoReloadTimer = new Timer(delegate (object state) { try { DateTime configFileLastModified = File.GetLastWriteTimeUtc(_configFile); if (configFileLastModified > _configFileLastModified) { ReloadUpstreamsFile(); //force GC collection to remove old cache data from memory quickly GC.Collect(); } } catch (Exception ex) { _dnsServer.WriteLog(ex); } finally { _autoReloadTimer?.Change(AUTO_RELOAD_TIMER_INTERVAL, Timeout.Infinite); } }); ReloadConfig(configProxyServers, jsonAdguardUpstream); } #endregion #region IDisposable public void Dispose() { if (_autoReloadTimer is not null) { _autoReloadTimer.Dispose(); _autoReloadTimer = null; } } #endregion #region private private void ReloadUpstreamsFile() { try { _dnsServer.WriteLog("The app is reading AdGuard Upstreams config file: " + _configFile); List defaultForwarderRecords = new List(); List forwardings = new List(); using (FileStream fS = new FileStream(_configFile, FileMode.Open, FileAccess.Read)) { StreamReader sR = new StreamReader(fS, true); string line; while (true) { line = sR.ReadLine(); if (line is null) break; //eof line = line.TrimStart(); if (line.Length == 0) continue; //skip empty line if (line.StartsWith('#')) continue; //skip comment line if (line.StartsWith('[')) { int i = line.LastIndexOf(']'); if (i < 0) throw new FormatException("Invalid AdGuard Upstreams config file format: missing ']' bracket."); string[] domains = line.Substring(1, i - 1).Split('/', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); string forwarder = line.Substring(i + 1); if (forwarder == "#") { if (defaultForwarderRecords.Count == 0) throw new FormatException("Invalid AdGuard Upstreams config file format: missing default upstream servers."); forwardings.Add(new Forwarding(defaultForwarderRecords, domains)); } else { forwardings.Add(new Forwarding(domains, NameServerAddress.Parse(forwarder), _dnssecValidation, _configProxyServer)); } } else { defaultForwarderRecords.Add(GetForwarderRecord(NameServerAddress.Parse(line), _dnssecValidation, _configProxyServer)); } } _configFileLastModified = File.GetLastWriteTimeUtc(fS.SafeFileHandle); } _defaultForwarderRecords = defaultForwarderRecords; _forwardings = forwardings; _dnsServer.WriteLog("The app has successfully loaded AdGuard Upstreams config file: " + _configFile); } catch (Exception ex) { _dnsServer.WriteLog("The app failed to read AdGuard Upstreams config file: " + _configFile + "\r\n" + ex.ToString()); } } #endregion #region public public void ReloadConfig(Dictionary configProxyServers, JsonElement jsonAdguardUpstream) { string proxyName = jsonAdguardUpstream.GetPropertyValue("proxy", null); _dnssecValidation = jsonAdguardUpstream.GetPropertyValue("dnssecValidation", true); ConfigProxyServer configProxyServer = null; if (!string.IsNullOrEmpty(proxyName) && ((configProxyServers is null) || !configProxyServers.TryGetValue(proxyName, out configProxyServer))) throw new FormatException("Proxy server was not defined: " + proxyName); _configProxyServer = configProxyServer; DateTime configFileLastModified = File.GetLastWriteTimeUtc(_configFile); if (configFileLastModified > _configFileLastModified) { //reload complete config file _autoReloadTimer.Change(0, Timeout.Infinite); } else { //update only forwarder records _defaultForwarderRecords = GetUpdatedForwarderRecords(_defaultForwarderRecords, _dnssecValidation, _configProxyServer); foreach (Forwarding forwarding in _forwardings) forwarding.UpdateForwarderRecords(_dnssecValidation, _configProxyServer); } } public bool TryGetForwarderRecords(string domain, out IReadOnlyList forwarderRecords) { if ((_forwardings is not null) && (_forwardings.Count > 0)) { if (Forwarding.IsForwarderDomain(domain, _forwardings)) { forwarderRecords = null; return false; } if (Forwarding.TryGetForwarderRecords(domain, _forwardings, out forwarderRecords)) return true; } if ((_defaultForwarderRecords is not null) && (_defaultForwarderRecords.Count > 0)) { if (Forwarding.IsForwarderDomain(domain, _defaultForwarderRecords)) { forwarderRecords = null; return false; } forwarderRecords = _defaultForwarderRecords; return true; } forwarderRecords = null; return false; } #endregion #region property public string Name { get { return _name; } } #endregion } class ConfigProxyServer { #region variables readonly string _name; readonly NetProxyType _type; readonly string _proxyAddress; readonly ushort _proxyPort; readonly string _proxyUsername; readonly string _proxyPassword; #endregion #region constructor public ConfigProxyServer(JsonElement jsonProxy) { _name = jsonProxy.GetProperty("name").GetString(); _type = jsonProxy.GetPropertyEnumValue("type", NetProxyType.Http); _proxyAddress = jsonProxy.GetProperty("proxyAddress").GetString(); _proxyPort = jsonProxy.GetProperty("proxyPort").GetUInt16(); _proxyUsername = jsonProxy.GetPropertyValue("proxyUsername", null); _proxyPassword = jsonProxy.GetPropertyValue("proxyPassword", null); } #endregion #region properties public string Name { get { return _name; } } public NetProxyType Type { get { return _type; } } public string ProxyAddress { get { return _proxyAddress; } } public ushort ProxyPort { get { return _proxyPort; } } public string ProxyUsername { get { return _proxyUsername; } } public string ProxyPassword { get { return _proxyPassword; } } #endregion } class ConfigForwarder { #region variables readonly string _name; readonly IReadOnlyList _forwarderRecords; #endregion #region constructor public ConfigForwarder(JsonElement jsonForwarder, Dictionary configProxyServers) { _name = jsonForwarder.GetProperty("name").GetString(); string proxyName = jsonForwarder.GetPropertyValue("proxy", null); bool dnssecValidation = jsonForwarder.GetPropertyValue("dnssecValidation", true); DnsTransportProtocol forwarderProtocol = jsonForwarder.GetPropertyEnumValue("forwarderProtocol", DnsTransportProtocol.Udp); ConfigProxyServer configProxyServer = null; if (!string.IsNullOrEmpty(proxyName) && ((configProxyServers is null) || !configProxyServers.TryGetValue(proxyName, out configProxyServer))) throw new FormatException("Proxy server was not defined: " + proxyName); _forwarderRecords = jsonForwarder.ReadArray("forwarderAddresses", delegate (string address) { return GetForwarderRecord(forwarderProtocol, address, dnssecValidation, configProxyServer); }); } #endregion #region properties public string Name { get { return _name; } } public IReadOnlyList ForwarderRecords { get { return _forwarderRecords; } } #endregion } } }