/* Technitium DNS Server Copyright (C) 2024 Shreyas Zare (shreyas@technitium.com) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ using DnsServerCore.ApplicationCommon; using System; using System.Collections.Generic; using System.IO; using System.Net; using System.Text; using System.Text.Json; using System.Threading.Tasks; using TechnitiumLibrary; using TechnitiumLibrary.Net.Dns; using TechnitiumLibrary.Net.Dns.ResourceRecords; namespace DefaultRecords { public sealed class App : IDnsApplication, IDnsPostProcessor { #region variables IDnsServer _dnsServer; bool _enableDefaultRecords; uint _defaultTtl; Dictionary _zoneSetMap; Dictionary _sets; #endregion #region IDisposable public void Dispose() { //do nothing } #endregion #region private private static string GetParentZone(string domain) { int i = domain.IndexOf('.'); if (i > -1) return domain.Substring(i + 1); //dont return root zone return null; } private bool TryGetMappedSets(string domain, out string zone, out string[] setNames) { domain = domain.ToLowerInvariant(); string parent; do { if (_zoneSetMap.TryGetValue(domain, out setNames)) { zone = domain; return true; } parent = GetParentZone(domain); if (parent is null) { if (_zoneSetMap.TryGetValue("*", out setNames)) { zone = "*"; return true; } break; } domain = "*." + parent; if (_zoneSetMap.TryGetValue(domain, out setNames)) { zone = domain; return true; } domain = parent; } while (true); zone = null; return false; } #endregion #region public public Task InitializeAsync(IDnsServer dnsServer, string config) { _dnsServer = dnsServer; using JsonDocument jsonDocument = JsonDocument.Parse(config); JsonElement jsonConfig = jsonDocument.RootElement; _enableDefaultRecords = jsonConfig.GetProperty("enableDefaultRecords").GetBoolean(); _defaultTtl = jsonConfig.GetPropertyValue("defaultTtl", 3600u); _zoneSetMap = jsonConfig.ReadObjectAsMap("zoneSetMap", delegate (string zone, JsonElement jsonSets) { string[] sets = jsonSets.GetArray(); return new Tuple(zone.ToLowerInvariant(), sets); }); _sets = jsonConfig.ReadArrayAsMap("sets", delegate (JsonElement jsonSet) { Set set = new Set(jsonSet); return new Tuple(set.Name, set); }); return Task.CompletedTask; } public async Task PostProcessAsync(DnsDatagram request, IPEndPoint remoteEP, DnsTransportProtocol protocol, DnsDatagram response) { if (!_enableDefaultRecords) return response; if (!response.AuthoritativeAnswer || (response.OPCODE != DnsOpcode.StandardQuery)) return response; switch (response.RCODE) { case DnsResponseCode.NoError: case DnsResponseCode.NxDomain: break; default: return response; } DnsQuestionRecord question = request.Question[0]; if (!TryGetMappedSets(question.Name, out string zone, out string[] setNames)) return response; if (zone.StartsWith('*')) { DnsDatagram soaResponse = await _dnsServer.DirectQueryAsync(new DnsQuestionRecord(question.Name, DnsResourceRecordType.SOA, DnsClass.IN)); if (soaResponse is null) return response; if ((soaResponse.Answer.Count > 0) && (soaResponse.Answer[soaResponse.Answer.Count - 1].Type == DnsResourceRecordType.SOA)) zone = soaResponse.Answer[soaResponse.Answer.Count - 1].Name; else if ((soaResponse.Authority.Count > 0) && (soaResponse.Authority[0].Type == DnsResourceRecordType.SOA)) zone = soaResponse.Authority[0].Name; else return response; } StringBuilder sb = new StringBuilder(); foreach (string setName in setNames) { if (_sets.TryGetValue(setName, out Set set) && set.Enable) { foreach (string record in set.Records) sb.AppendLine(record); } } if (sb.Length == 0) return response; StringReader sR = new StringReader(sb.ToString()); List records = ZoneFile.ReadZoneFileFromAsync(sR, zone, _defaultTtl).Sync(); List newAnswer = new List(response.Answer.Count + records.Count); string qname = question.Name; if (response.Answer.Count > 0) { newAnswer.AddRange(response.Answer); DnsResourceRecord lastRR = response.Answer[response.Answer.Count - 1]; if (lastRR.Type == DnsResourceRecordType.CNAME) qname = (lastRR.RDATA as DnsCNAMERecordData).Domain; } foreach (DnsResourceRecord record in records) { if (record.Class != question.Class) continue; if ((record.Type != question.Type) && (record.Type != DnsResourceRecordType.CNAME)) continue; if (!record.Name.Equals(qname, StringComparison.OrdinalIgnoreCase)) continue; newAnswer.Add(record); if (record.Type == DnsResourceRecordType.CNAME) qname = (record.RDATA as DnsCNAMERecordData).Domain; } if (newAnswer.Count == response.Answer.Count) return response; return new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, response.Truncation, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, DnsResponseCode.NoError, response.Question, newAnswer) { Tag = response.Tag }; } #endregion #region properties public string Description { get { return "Enables default records for configured local zones."; } } #endregion class Set { #region variables readonly string _name; readonly bool _enable; readonly string[] _records; #endregion #region constructor public Set(JsonElement jsonSet) { _name = jsonSet.GetProperty("name").GetString(); _enable = jsonSet.GetProperty("enable").GetBoolean(); _records = jsonSet.ReadArray("records"); } #endregion #region properties public string Name { get { return _name; } } public bool Enable { get { return _enable; } } public string[] Records { get { return _records; } } #endregion } } }