DnsApplicationAssemblyLoadContext.cs 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /*
  2. Technitium DNS Server
  3. Copyright (C) 2022 Shreyas Zare (shreyas@technitium.com)
  4. This program is free software: you can redistribute it and/or modify
  5. it under the terms of the GNU General Public License as published by
  6. the Free Software Foundation, either version 3 of the License, or
  7. (at your option) any later version.
  8. This program is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. GNU General Public License for more details.
  12. You should have received a copy of the GNU General Public License
  13. along with this program. If not, see <http://www.gnu.org/licenses/>.
  14. */
  15. using System;
  16. using System.Collections.Generic;
  17. using System.IO;
  18. using System.Reflection;
  19. using System.Runtime.InteropServices;
  20. using System.Runtime.Loader;
  21. namespace DnsServerCore.Dns
  22. {
  23. class DnsApplicationAssemblyLoadContext : AssemblyLoadContext
  24. {
  25. #region variables
  26. readonly string _applicationFolder;
  27. readonly Dictionary<string, IntPtr> _loadedUnmanagedDlls = new Dictionary<string, IntPtr>();
  28. readonly List<string> _unmanagedDllTempPaths = new List<string>();
  29. #endregion
  30. #region constructor
  31. public DnsApplicationAssemblyLoadContext(string applicationFolder)
  32. : base(true)
  33. {
  34. _applicationFolder = applicationFolder;
  35. Unloading += delegate (AssemblyLoadContext obj)
  36. {
  37. foreach (string unmanagedDllTempPath in _unmanagedDllTempPaths)
  38. {
  39. try
  40. {
  41. File.Delete(unmanagedDllTempPath);
  42. }
  43. catch
  44. { }
  45. }
  46. };
  47. }
  48. #endregion
  49. #region overrides
  50. protected override Assembly Load(AssemblyName assemblyName)
  51. {
  52. return null;
  53. }
  54. protected override IntPtr LoadUnmanagedDll(string unmanagedDllName)
  55. {
  56. string unmanagedDllPath = null;
  57. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  58. {
  59. string runtime = "win-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower();
  60. string[] prefixes = new string[] { "" };
  61. string[] extensions = new string[] { ".dll" };
  62. unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions);
  63. }
  64. else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
  65. {
  66. bool isAlpine = false;
  67. try
  68. {
  69. string osReleaseFile = "/etc/os-release";
  70. if (File.Exists(osReleaseFile))
  71. isAlpine = File.ReadAllText(osReleaseFile).Contains("alpine", StringComparison.OrdinalIgnoreCase);
  72. }
  73. catch
  74. { }
  75. string runtimeAlpine = "alpine-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower();
  76. string runtimeLinux = "linux-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower();
  77. string[] prefixes = new string[] { "", "lib" };
  78. string[] extensions = new string[] { ".so", ".so.1" };
  79. if (isAlpine)
  80. {
  81. unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeAlpine, prefixes, extensions);
  82. if (unmanagedDllPath is null)
  83. unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions);
  84. }
  85. else
  86. {
  87. unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions);
  88. }
  89. }
  90. else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  91. {
  92. string runtime = "osx-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower();
  93. string[] prefixes = new string[] { "", "lib" };
  94. string[] extensions = new string[] { ".dylib" };
  95. unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions);
  96. }
  97. if (unmanagedDllPath is null)
  98. return IntPtr.Zero;
  99. lock (_loadedUnmanagedDlls)
  100. {
  101. if (!_loadedUnmanagedDlls.TryGetValue(unmanagedDllPath.ToLower(), out IntPtr value))
  102. {
  103. //load the unmanaged DLL
  104. //copy unmanaged dll into temp file for loading to allow uninstalling/updating app at runtime.
  105. string tempPath = Path.GetTempFileName();
  106. using (FileStream srcFile = new FileStream(unmanagedDllPath, FileMode.Open, FileAccess.Read))
  107. {
  108. using (FileStream dstFile = new FileStream(tempPath, FileMode.Create, FileAccess.Write))
  109. {
  110. srcFile.CopyTo(dstFile);
  111. }
  112. }
  113. _unmanagedDllTempPaths.Add(tempPath);
  114. value = LoadUnmanagedDllFromPath(tempPath);
  115. _loadedUnmanagedDlls.Add(unmanagedDllPath.ToLower(), value);
  116. }
  117. return value;
  118. }
  119. }
  120. #endregion
  121. #region private
  122. private string FindUnmanagedDllPath(string unmanagedDllName, string runtime, string[] prefixes, string[] extensions)
  123. {
  124. foreach (string prefix in prefixes)
  125. {
  126. foreach (string extension in extensions)
  127. {
  128. string path = Path.Combine(_applicationFolder, "runtimes", runtime, "native", prefix + unmanagedDllName + extension);
  129. if (File.Exists(path))
  130. return path;
  131. }
  132. }
  133. return null;
  134. }
  135. #endregion
  136. }
  137. }