/*
Technitium DNS Server
Copyright (C) 2023 Shreyas Zare (shreyas@technitium.com)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
*/
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using TechnitiumLibrary.Net;
using TechnitiumLibrary.Net.Dns;
using TechnitiumLibrary.Net.Dns.EDnsOptions;
using TechnitiumLibrary.Net.Dns.ResourceRecords;
using TechnitiumLibrary.Net.Http.Client;
namespace DnsServerCore.Dns.ZoneManagers
{
public sealed class BlockListZoneManager
{
#region variables
readonly static char[] _popWordSeperator = new char[] { ' ', '\t' };
readonly DnsServer _dnsServer;
readonly string _localCacheFolder;
readonly List _allowListUrls = new List();
readonly List _blockListUrls = new List();
IReadOnlyDictionary _allowListZone = new Dictionary();
IReadOnlyDictionary> _blockListZone = new Dictionary>();
DnsSOARecordData _soaRecord;
DnsNSRecordData _nsRecord;
readonly IReadOnlyCollection _aRecords = new DnsARecordData[] { new DnsARecordData(IPAddress.Any) };
readonly IReadOnlyCollection _aaaaRecords = new DnsAAAARecordData[] { new DnsAAAARecordData(IPAddress.IPv6Any) };
#endregion
#region constructor
public BlockListZoneManager(DnsServer dnsServer)
{
_dnsServer = dnsServer;
_localCacheFolder = Path.Combine(_dnsServer.ConfigFolder, "blocklists");
if (!Directory.Exists(_localCacheFolder))
Directory.CreateDirectory(_localCacheFolder);
UpdateServerDomain(_dnsServer.ServerDomain);
}
#endregion
#region private
private void UpdateServerDomain(string serverDomain)
{
_soaRecord = new DnsSOARecordData(serverDomain, "hostadmin@" + serverDomain, 1, 14400, 3600, 604800, 60);
_nsRecord = new DnsNSRecordData(serverDomain);
}
private string GetBlockListFilePath(Uri blockListUrl)
{
using (HashAlgorithm hash = SHA256.Create())
{
return Path.Combine(_localCacheFolder, Convert.ToHexString(hash.ComputeHash(Encoding.UTF8.GetBytes(blockListUrl.AbsoluteUri))).ToLower());
}
}
private static string PopWord(ref string line)
{
if (line.Length == 0)
return line;
line = line.TrimStart(_popWordSeperator);
int i = line.IndexOfAny(_popWordSeperator);
string word;
if (i < 0)
{
word = line;
line = "";
}
else
{
word = line.Substring(0, i);
line = line.Substring(i + 1);
}
return word;
}
private Queue ReadListFile(Uri listUrl, bool isAllowList, out Queue exceptionDomains)
{
Queue domains = new Queue();
exceptionDomains = new Queue();
try
{
_dnsServer.LogManager?.Write("DNS Server is reading " + (isAllowList ? "allow" : "block") + " list from: " + listUrl.AbsoluteUri);
using (FileStream fS = new FileStream(GetBlockListFilePath(listUrl), FileMode.Open, FileAccess.Read))
{
//parse hosts file and populate block zone
StreamReader sR = new StreamReader(fS, true);
char[] trimSeperator = new char[] { ' ', '\t', '*', '.' };
string line;
string firstWord;
string secondWord;
string hostname;
string domain;
string options;
int i;
while (true)
{
line = sR.ReadLine();
if (line is null)
break; //eof
line = line.TrimStart(trimSeperator);
if (line.Length == 0)
continue; //skip empty line
if (line.StartsWith('#') || line.StartsWith("!"))
continue; //skip comment line
if (line.StartsWith("||"))
{
//adblock format
i = line.IndexOf('^');
if (i > -1)
{
domain = line.Substring(2, i - 2);
options = line.Substring(i + 1);
if (((options.Length == 0) || (options.StartsWith('$') && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain))
domains.Enqueue(domain.ToLower());
}
else
{
domain = line.Substring(2);
if (DnsClient.IsDomainNameValid(domain))
domains.Enqueue(domain.ToLower());
}
}
else if (line.StartsWith("@@||"))
{
//adblock format - exception syntax
i = line.IndexOf('^');
if (i > -1)
{
domain = line.Substring(4, i - 4);
options = line.Substring(i + 1);
if (((options.Length == 0) || (options.StartsWith('$') && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain))
exceptionDomains.Enqueue(domain.ToLower());
}
else
{
domain = line.Substring(4);
if (DnsClient.IsDomainNameValid(domain))
exceptionDomains.Enqueue(domain.ToLower());
}
}
else
{
//hosts file format
firstWord = PopWord(ref line);
if (line.Length == 0)
{
hostname = firstWord;
}
else
{
secondWord = PopWord(ref line);
if ((secondWord.Length == 0) || secondWord.StartsWith('#'))
hostname = firstWord;
else
hostname = secondWord;
}
hostname = hostname.Trim('.').ToLower();
switch (hostname)
{
case "":
case "localhost":
case "localhost.localdomain":
case "local":
case "broadcasthost":
case "ip6-localhost":
case "ip6-loopback":
case "ip6-localnet":
case "ip6-mcastprefix":
case "ip6-allnodes":
case "ip6-allrouters":
case "ip6-allhosts":
continue; //skip these hostnames
}
if (!DnsClient.IsDomainNameValid(hostname))
continue;
if (IPAddress.TryParse(hostname, out _))
continue; //skip line when hostname is IP address
domains.Enqueue(hostname);
}
}
}
_dnsServer.LogManager?.Write("DNS Server read " + (isAllowList ? "allow" : "block") + " list file (" + domains.Count + " domains) from: " + listUrl.AbsoluteUri);
}
catch (Exception ex)
{
_dnsServer.LogManager?.Write("DNS Server failed to read " + (isAllowList ? "allow" : "block") + " list from: " + listUrl.AbsoluteUri + "\r\n" + ex.ToString());
}
return domains;
}
private List IsZoneBlocked(string domain, out string blockedDomain)
{
domain = domain.ToLower();
do
{
if (_blockListZone.TryGetValue(domain, out List blockLists))
{
//found zone blocked
blockedDomain = domain;
return blockLists;
}
domain = AuthZoneManager.GetParentZone(domain);
}
while (domain is not null);
blockedDomain = null;
return null;
}
private bool IsZoneAllowed(string domain)
{
domain = domain.ToLower();
do
{
if (_allowListZone.TryGetValue(domain, out _))
return true;
domain = AuthZoneManager.GetParentZone(domain);
}
while (domain is not null);
return false;
}
#endregion
#region public
public void LoadBlockLists()
{
Dictionary> allowListQueues = new Dictionary>(_allowListUrls.Count);
Dictionary> blockListQueues = new Dictionary>(_blockListUrls.Count);
int totalAllowedDomains = 0;
int totalBlockedDomains = 0;
//read all allow lists in a queue
foreach (Uri allowListUrl in _allowListUrls)
{
if (!allowListQueues.ContainsKey(allowListUrl))
{
Queue allowListQueue = ReadListFile(allowListUrl, true, out Queue blockListQueue);
totalAllowedDomains += allowListQueue.Count;
allowListQueues.Add(allowListUrl, allowListQueue);
totalBlockedDomains += blockListQueue.Count;
blockListQueues.Add(allowListUrl, blockListQueue);
}
}
//read all block lists in a queue
foreach (Uri blockListUrl in _blockListUrls)
{
if (!blockListQueues.ContainsKey(blockListUrl))
{
Queue blockListQueue = ReadListFile(blockListUrl, false, out Queue allowListQueue);
totalBlockedDomains += blockListQueue.Count;
blockListQueues.Add(blockListUrl, blockListQueue);
totalAllowedDomains += allowListQueue.Count;
allowListQueues.Add(blockListUrl, allowListQueue);
}
}
//load block list zone
Dictionary allowListZone = new Dictionary(totalAllowedDomains);
foreach (KeyValuePair> allowListQueue in allowListQueues)
{
Queue queue = allowListQueue.Value;
while (queue.Count > 0)
{
string domain = queue.Dequeue();
allowListZone.TryAdd(domain, null);
}
}
Dictionary> blockListZone = new Dictionary>(totalBlockedDomains);
foreach (KeyValuePair> blockListQueue in blockListQueues)
{
Queue queue = blockListQueue.Value;
while (queue.Count > 0)
{
string domain = queue.Dequeue();
if (!blockListZone.TryGetValue(domain, out List blockLists))
{
blockLists = new List(2);
blockListZone.Add(domain, blockLists);
}
blockLists.Add(blockListQueue.Key);
}
}
//set new allowed and blocked zones
_allowListZone = allowListZone;
_blockListZone = blockListZone;
_dnsServer.LogManager?.Write("DNS Server block list zone was loaded successfully.");
}
public void Flush()
{
_allowListZone = new Dictionary();
_blockListZone = new Dictionary>();
}
public async Task UpdateBlockListsAsync()
{
bool downloaded = false;
bool notModified = false;
async Task DownloadListUrlAsync(Uri listUrl, bool isAllowList)
{
string listFilePath = GetBlockListFilePath(listUrl);
string listDownloadFilePath = listFilePath + ".downloading";
try
{
if (File.Exists(listDownloadFilePath))
File.Delete(listDownloadFilePath);
SocketsHttpHandler handler = new SocketsHttpHandler();
handler.Proxy = _dnsServer.Proxy;
handler.UseProxy = _dnsServer.Proxy is not null;
handler.AutomaticDecompression = DecompressionMethods.All;
using (HttpClient http = new HttpClient(new HttpClientRetryHandler(handler)))
{
if (File.Exists(listFilePath))
http.DefaultRequestHeaders.IfModifiedSince = File.GetLastWriteTimeUtc(listFilePath);
HttpResponseMessage httpResponse = await http.GetAsync(listUrl);
switch (httpResponse.StatusCode)
{
case HttpStatusCode.OK:
{
using (FileStream fS = new FileStream(listDownloadFilePath, FileMode.Create, FileAccess.Write))
{
using (Stream httpStream = await httpResponse.Content.ReadAsStreamAsync())
{
await httpStream.CopyToAsync(fS);
}
}
if (File.Exists(listFilePath))
File.Delete(listFilePath);
File.Move(listDownloadFilePath, listFilePath);
if (httpResponse.Content.Headers.LastModified != null)
File.SetLastWriteTimeUtc(listFilePath, httpResponse.Content.Headers.LastModified.Value.UtcDateTime);
downloaded = true;
LogManager log = _dnsServer.LogManager;
if (log != null)
log.Write("DNS Server successfully downloaded " + (isAllowList ? "allow" : "block") + " list (" + WebUtilities.GetFormattedSize(new FileInfo(listFilePath).Length) + "): " + listUrl.AbsoluteUri);
}
break;
case HttpStatusCode.NotModified:
{
notModified = true;
LogManager log = _dnsServer.LogManager;
if (log != null)
log.Write("DNS Server successfully checked for a new update of the " + (isAllowList ? "allow" : "block") + " list: " + listUrl.AbsoluteUri);
}
break;
default:
throw new HttpRequestException((int)httpResponse.StatusCode + " " + httpResponse.ReasonPhrase);
}
}
}
catch (Exception ex)
{
LogManager log = _dnsServer.LogManager;
if (log != null)
log.Write("DNS Server failed to download " + (isAllowList ? "allow" : "block") + " list and will use previously downloaded file (if available): " + listUrl.AbsoluteUri + "\r\n" + ex.ToString());
}
}
List tasks = new List();
foreach (Uri allowListUrl in _allowListUrls)
tasks.Add(DownloadListUrlAsync(allowListUrl, true));
foreach (Uri blockListUrl in _blockListUrls)
tasks.Add(DownloadListUrlAsync(blockListUrl, false));
await Task.WhenAll(tasks);
if (downloaded)
{
LoadBlockLists();
//force GC collection to remove old zone data from memory quickly
GC.Collect();
}
return downloaded || notModified;
}
public bool IsAllowed(DnsDatagram request)
{
if (_allowListZone.Count < 1)
return false;
return IsZoneAllowed(request.Question[0].Name);
}
public DnsDatagram Query(DnsDatagram request)
{
if (_blockListZone.Count < 1)
return null;
DnsQuestionRecord question = request.Question[0];
List blockLists = IsZoneBlocked(question.Name, out string blockedDomain);
if (blockLists is null)
return null; //zone not blocked
//zone is blocked
if (_dnsServer.AllowTxtBlockingReport && (question.Type == DnsResourceRecordType.TXT))
{
//return meta data
DnsResourceRecord[] answer = new DnsResourceRecord[blockLists.Count];
for (int i = 0; i < answer.Length; i++)
answer[i] = new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, 60, new DnsTXTRecordData("source=block-list-zone; blockListUrl=" + blockLists[i].AbsoluteUri + "; domain=" + blockedDomain));
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer);
}
else
{
EDnsOption[] options = null;
if (_dnsServer.AllowTxtBlockingReport && (request.EDNS is not null))
{
options = new EDnsOption[blockLists.Count];
for (int i = 0; i < options.Length; i++)
options[i] = new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Blocked, "source=block-list-zone; blockListUrl=" + blockLists[i].AbsoluteUri + "; domain=" + blockedDomain));
}
IReadOnlyCollection aRecords;
IReadOnlyCollection aaaaRecords;
switch (_dnsServer.BlockingType)
{
case DnsServerBlockingType.AnyAddress:
aRecords = _aRecords;
aaaaRecords = _aaaaRecords;
break;
case DnsServerBlockingType.CustomAddress:
aRecords = _dnsServer.CustomBlockingARecords;
aaaaRecords = _dnsServer.CustomBlockingAAAARecords;
break;
case DnsServerBlockingType.NxDomain:
string parentDomain = AuthZoneManager.GetParentZone(blockedDomain);
if (parentDomain is null)
parentDomain = string.Empty;
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NxDomain, request.Question, null, new DnsResourceRecord[] { new DnsResourceRecord(parentDomain, DnsResourceRecordType.SOA, question.Class, 60, _soaRecord) }, null, request.EDNS is null ? ushort.MinValue : _dnsServer.UdpPayloadSize, EDnsHeaderFlags.None, options);
default:
throw new InvalidOperationException();
}
IReadOnlyList answer = null;
IReadOnlyList authority = null;
switch (question.Type)
{
case DnsResourceRecordType.A:
{
List rrList = new List(aRecords.Count);
foreach (DnsARecordData record in aRecords)
rrList.Add(new DnsResourceRecord(question.Name, DnsResourceRecordType.A, question.Class, 60, record));
answer = rrList;
}
break;
case DnsResourceRecordType.AAAA:
{
List rrList = new List(aaaaRecords.Count);
foreach (DnsAAAARecordData record in aaaaRecords)
rrList.Add(new DnsResourceRecord(question.Name, DnsResourceRecordType.AAAA, question.Class, 60, record));
answer = rrList;
}
break;
case DnsResourceRecordType.NS:
if (question.Name.Equals(blockedDomain, StringComparison.OrdinalIgnoreCase))
answer = new DnsResourceRecord[] { new DnsResourceRecord(blockedDomain, DnsResourceRecordType.NS, question.Class, 60, _nsRecord) };
else
authority = new DnsResourceRecord[] { new DnsResourceRecord(blockedDomain, DnsResourceRecordType.SOA, question.Class, 60, _soaRecord) };
break;
case DnsResourceRecordType.SOA:
answer = new DnsResourceRecord[] { new DnsResourceRecord(blockedDomain, DnsResourceRecordType.SOA, question.Class, 60, _soaRecord) };
break;
default:
authority = new DnsResourceRecord[] { new DnsResourceRecord(blockedDomain, DnsResourceRecordType.SOA, question.Class, 60, _soaRecord) };
break;
}
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer, authority, null, request.EDNS is null ? ushort.MinValue : _dnsServer.UdpPayloadSize, EDnsHeaderFlags.None, options);
}
}
#endregion
#region properties
public string ServerDomain
{
get { return _soaRecord.PrimaryNameServer; }
set { UpdateServerDomain(value); }
}
public List AllowListUrls
{ get { return _allowListUrls; } }
public List BlockListUrls
{ get { return _blockListUrls; } }
public int TotalZonesAllowed
{ get { return _allowListZone.Count; } }
public int TotalZonesBlocked
{ get { return _blockListZone.Count; } }
#endregion
}
}