AuthZone.cs 15 KB


  1. /*
  2. Technitium DNS Server
  3. Copyright (C) 2020 Shreyas Zare (shreyas@technitium.com)
  4. This program is free software: you can redistribute it and/or modify
  5. it under the terms of the GNU General Public License as published by
  6. the Free Software Foundation, either version 3 of the License, or
  7. (at your option) any later version.
  8. This program is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. GNU General Public License for more details.
  12. You should have received a copy of the GNU General Public License
  13. along with this program. If not, see <http://www.gnu.org/licenses/>.
  14. */
  15. using DnsServerCore.Dns.ResourceRecords;
  16. using System;
  17. using System.Collections.Generic;
  18. using System.Net;
  19. using System.Threading.Tasks;
  20. using TechnitiumLibrary.IO;
  21. using TechnitiumLibrary.Net.Dns;
  22. using TechnitiumLibrary.Net.Dns.ResourceRecords;
  23. namespace DnsServerCore.Dns.Zones
  24. {
  25. abstract class AuthZone : Zone, IDisposable
  26. {
  27. #region variables
  28. protected bool _disabled;
  29. #endregion
  30. #region constructor
  31. protected AuthZone(string name)
  32. : base(name)
  33. { }
  34. #endregion
  35. #region IDisposable
  36. protected virtual void Dispose(bool disposing)
  37. { }
  38. public void Dispose()
  39. {
  40. Dispose(true);
  41. }
  42. #endregion
  43. #region private
  44. private IReadOnlyList<DnsResourceRecord> FilterDisabledRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records)
  45. {
  46. if (_disabled)
  47. return Array.Empty<DnsResourceRecord>();
  48. if (records.Count == 1)
  49. {
  50. if (records[0].IsDisabled())
  51. return Array.Empty<DnsResourceRecord>(); //record disabled
  52. return records;
  53. }
  54. List<DnsResourceRecord> newRecords = new List<DnsResourceRecord>(records.Count);
  55. foreach (DnsResourceRecord record in records)
  56. {
  57. if (record.IsDisabled())
  58. continue; //record disabled
  59. newRecords.Add(record);
  60. }
  61. if (newRecords.Count > 1)
  62. {
  63. switch (type)
  64. {
  65. case DnsResourceRecordType.A:
  66. case DnsResourceRecordType.AAAA:
  67. case DnsResourceRecordType.NS:
  68. newRecords.Shuffle(); //shuffle records to allow load balancing
  69. break;
  70. }
  71. }
  72. return newRecords;
  73. }
  74. private async Task<IReadOnlyList<NameServerAddress>> GetNameServerAddressesAsync(DnsServer dnsServer, DnsResourceRecord record)
  75. {
  76. string nsDomain;
  77. switch (record.Type)
  78. {
  79. case DnsResourceRecordType.NS:
  80. nsDomain = (record.RDATA as DnsNSRecord).NameServer;
  81. break;
  82. case DnsResourceRecordType.SOA:
  83. nsDomain = (record.RDATA as DnsSOARecord).PrimaryNameServer;
  84. break;
  85. default:
  86. throw new InvalidOperationException();
  87. }
  88. List<NameServerAddress> nameServers = new List<NameServerAddress>(2);
  89. IReadOnlyList<DnsResourceRecord> glueRecords = record.GetGlueRecords();
  90. if (glueRecords.Count > 0)
  91. {
  92. foreach (DnsResourceRecord glueRecord in glueRecords)
  93. {
  94. switch (glueRecord.Type)
  95. {
  96. case DnsResourceRecordType.A:
  97. nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address));
  98. break;
  99. case DnsResourceRecordType.AAAA:
  100. if (dnsServer.PreferIPv6)
  101. nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address));
  102. break;
  103. }
  104. }
  105. }
  106. else
  107. {
  108. //resolve addresses
  109. try
  110. {
  111. DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN));
  112. if ((response != null) && (response.Answer.Count > 0))
  113. {
  114. IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseA(response);
  115. foreach (IPAddress address in addresses)
  116. nameServers.Add(new NameServerAddress(nsDomain, address));
  117. }
  118. }
  119. catch
  120. { }
  121. if (dnsServer.PreferIPv6)
  122. {
  123. try
  124. {
  125. DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN));
  126. if ((response != null) && (response.Answer.Count > 0))
  127. {
  128. IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseAAAA(response);
  129. foreach (IPAddress address in addresses)
  130. nameServers.Add(new NameServerAddress(nsDomain, address));
  131. }
  132. }
  133. catch
  134. { }
  135. }
  136. }
  137. return nameServers;
  138. }
  139. #endregion
  140. #region public
  141. public async Task<IReadOnlyList<NameServerAddress>> GetPrimaryNameServerAddressesAsync(DnsServer dnsServer)
  142. {
  143. List<NameServerAddress> nameServers = new List<NameServerAddress>();
  144. DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0];
  145. DnsSOARecord soa = soaRecord.RDATA as DnsSOARecord;
  146. IReadOnlyList<DnsResourceRecord> nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
  147. foreach (DnsResourceRecord nsRecord in nsRecords)
  148. {
  149. if (nsRecord.IsDisabled())
  150. continue;
  151. string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
  152. if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase))
  153. {
  154. //found primary NS
  155. nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord));
  156. break;
  157. }
  158. }
  159. foreach (NameServerAddress nameServer in await GetNameServerAddressesAsync(dnsServer, soaRecord))
  160. {
  161. if (!nameServers.Contains(nameServer))
  162. nameServers.Add(nameServer);
  163. }
  164. return nameServers;
  165. }
  166. public async Task<IReadOnlyList<NameServerAddress>> GetSecondaryNameServerAddressesAsync(DnsServer dnsServer)
  167. {
  168. List<NameServerAddress> nameServers = new List<NameServerAddress>();
  169. DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord;
  170. IReadOnlyList<DnsResourceRecord> nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
  171. foreach (DnsResourceRecord nsRecord in nsRecords)
  172. {
  173. if (nsRecord.IsDisabled())
  174. continue;
  175. string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
  176. if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase))
  177. continue; //skip primary name server
  178. nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord));
  179. }
  180. return nameServers;
  181. }
  182. public void SyncRecords(Dictionary<DnsResourceRecordType, List<DnsResourceRecord>> newEntries, bool dontRemoveRecords)
  183. {
  184. if (!dontRemoveRecords)
  185. {
  186. //remove entires of type that do not exists in new entries
  187. foreach (DnsResourceRecordType type in _entries.Keys)
  188. {
  189. if (!newEntries.ContainsKey(type))
  190. _entries.TryRemove(type, out _);
  191. }
  192. }
  193. //set new entries into zone
  194. if (this is ForwarderZone)
  195. {
  196. //skip NS and SOA records from being added to ForwarderZone
  197. foreach (KeyValuePair<DnsResourceRecordType, List<DnsResourceRecord>> newEntry in newEntries)
  198. {
  199. switch (newEntry.Key)
  200. {
  201. case DnsResourceRecordType.NS:
  202. case DnsResourceRecordType.SOA:
  203. break;
  204. default:
  205. _entries[newEntry.Key] = newEntry.Value;
  206. break;
  207. }
  208. }
  209. }
  210. else
  211. {
  212. foreach (KeyValuePair<DnsResourceRecordType, List<DnsResourceRecord>> newEntry in newEntries)
  213. {
  214. if (newEntry.Key == DnsResourceRecordType.SOA)
  215. {
  216. if (newEntry.Value.Count != 1)
  217. continue; //skip invalid SOA record
  218. if ((this is SecondaryZone) || (this is StubZone))
  219. {
  220. //copy existing SOA record's glue addresses to new SOA record
  221. newEntry.Value[0].SetGlueRecords(_entries[DnsResourceRecordType.SOA][0].GetGlueRecords());
  222. }
  223. }
  224. _entries[newEntry.Key] = newEntry.Value;
  225. }
  226. }
  227. }
  228. public void LoadRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records)
  229. {
  230. _entries[type] = records;
  231. }
  232. public virtual void SetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records)
  233. {
  234. _entries[type] = records;
  235. }
  236. public virtual void AddRecord(DnsResourceRecord record)
  237. {
  238. switch (record.Type)
  239. {
  240. case DnsResourceRecordType.CNAME:
  241. case DnsResourceRecordType.ANAME:
  242. case DnsResourceRecordType.PTR:
  243. case DnsResourceRecordType.SOA:
  244. throw new InvalidOperationException("Cannot add record: use SetRecords() for " + record.Type.ToString() + " record");
  245. }
  246. _entries.AddOrUpdate(record.Type, delegate (DnsResourceRecordType key)
  247. {
  248. return new DnsResourceRecord[] { record };
  249. },
  250. delegate (DnsResourceRecordType key, IReadOnlyList<DnsResourceRecord> existingRecords)
  251. {
  252. foreach (DnsResourceRecord existingRecord in existingRecords)
  253. {
  254. if (record.Equals(existingRecord.RDATA))
  255. return existingRecords;
  256. }
  257. List<DnsResourceRecord> updateRecords = new List<DnsResourceRecord>(existingRecords.Count + 1);
  258. updateRecords.AddRange(existingRecords);
  259. updateRecords.Add(record);
  260. return updateRecords;
  261. });
  262. }
  263. public virtual bool DeleteRecords(DnsResourceRecordType type)
  264. {
  265. return _entries.TryRemove(type, out _);
  266. }
  267. public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData record)
  268. {
  269. if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  270. {
  271. if (existingRecords.Count == 1)
  272. {
  273. if (record.Equals(existingRecords[0].RDATA))
  274. return _entries.TryRemove(type, out _);
  275. }
  276. else
  277. {
  278. List<DnsResourceRecord> updateRecords = new List<DnsResourceRecord>(existingRecords.Count);
  279. for (int i = 0; i < existingRecords.Count; i++)
  280. {
  281. if (!record.Equals(existingRecords[i].RDATA))
  282. updateRecords.Add(existingRecords[i]);
  283. }
  284. return _entries.TryUpdate(type, updateRecords, existingRecords);
  285. }
  286. }
  287. return false;
  288. }
  289. public virtual IReadOnlyList<DnsResourceRecord> QueryRecords(DnsResourceRecordType type)
  290. {
  291. //check for CNAME
  292. if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
  293. {
  294. IReadOnlyList<DnsResourceRecord> filteredRecords = FilterDisabledRecords(type, existingCNAMERecords);
  295. if (filteredRecords.Count > 0)
  296. return filteredRecords;
  297. }
  298. if (type == DnsResourceRecordType.ANY)
  299. {
  300. List<DnsResourceRecord> records = new List<DnsResourceRecord>(_entries.Count * 2);
  301. foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
  302. {
  303. if (entry.Key != DnsResourceRecordType.ANY)
  304. records.AddRange(entry.Value);
  305. }
  306. return FilterDisabledRecords(type, records);
  307. }
  308. if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
  309. {
  310. IReadOnlyList<DnsResourceRecord> filteredRecords = FilterDisabledRecords(type, existingRecords);
  311. if (filteredRecords.Count > 0)
  312. return filteredRecords;
  313. }
  314. switch (type)
  315. {
  316. case DnsResourceRecordType.A:
  317. case DnsResourceRecordType.AAAA:
  318. if (_entries.TryGetValue(DnsResourceRecordType.ANAME, out IReadOnlyList<DnsResourceRecord> anameRecords))
  319. return FilterDisabledRecords(type, anameRecords);
  320. break;
  321. }
  322. return Array.Empty<DnsResourceRecord>();
  323. }
  324. public IReadOnlyList<DnsResourceRecord> GetRecords(DnsResourceRecordType type)
  325. {
  326. if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> records))
  327. return records;
  328. return Array.Empty<DnsResourceRecord>();
  329. }
  330. public override bool ContainsNameServerRecords()
  331. {
  332. if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> records))
  333. return false;
  334. foreach (DnsResourceRecord record in records)
  335. {
  336. if (record.IsDisabled())
  337. continue;
  338. return true;
  339. }
  340. return false;
  341. }
  342. #endregion
  343. #region properties
  344. public virtual bool Disabled
  345. {
  346. get { return _disabled; }
  347. set { _disabled = value; }
  348. }
  349. public virtual bool IsActive
  350. {
  351. get { return !_disabled; }
  352. }
  353. #endregion
  354. }
  355. }