diff --git a/FastGithub.Dns/DnsServerHostedService.cs b/FastGithub.Dns/DnsServerHostedService.cs index 034605e..4218d0b 100644 --- a/FastGithub.Dns/DnsServerHostedService.cs +++ b/FastGithub.Dns/DnsServerHostedService.cs @@ -20,6 +20,7 @@ namespace FastGithub.Dns { private readonly RequestResolver requestResolver; private readonly FastGithubConfig fastGithubConfig; + private readonly HostsValidator hostsValidator; private readonly ILogger logger; private readonly Socket socket = new(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); @@ -40,11 +41,13 @@ namespace FastGithub.Dns public DnsServerHostedService( RequestResolver requestResolver, FastGithubConfig fastGithubConfig, + HostsValidator hostsValidator, IOptionsMonitor options, ILogger logger) { this.requestResolver = requestResolver; this.fastGithubConfig = fastGithubConfig; + this.hostsValidator = hostsValidator; this.logger = logger; options.OnChange(opt => FlushResolverCache()); } @@ -74,18 +77,21 @@ namespace FastGithub.Dns } await BindAsync(this.socket, new IPEndPoint(IPAddress.Any, DNS_PORT), cancellationToken); - if (OperatingSystem.IsWindows()) { const int SIO_UDP_CONNRESET = unchecked((int)0x9800000C); this.socket.IOControl(SIO_UDP_CONNRESET, new byte[4], new byte[4]); } - this.logger.LogInformation("dns服务启动成功"); + // 验证host文件 + await this.hostsValidator.ValidateAsync(); + + // 设置网关的dns var secondary = this.fastGithubConfig.FastDns.Address; this.dnsAddresses = this.SetNameServers(IPAddress.Loopback, secondary); FlushResolverCache(); + this.logger.LogInformation("dns服务启动成功"); await base.StartAsync(cancellationToken); } diff --git a/FastGithub.Dns/DnsServerServiceCollectionExtensions.cs b/FastGithub.Dns/DnsServerServiceCollectionExtensions.cs index e5d9a72..7ab8f2e 100644 --- a/FastGithub.Dns/DnsServerServiceCollectionExtensions.cs +++ b/FastGithub.Dns/DnsServerServiceCollectionExtensions.cs @@ -17,6 +17,7 @@ namespace FastGithub { return services .AddSingleton() + .AddSingleton() .AddHostedService(); } } diff --git a/FastGithub.Dns/FastGithub.Dns.csproj b/FastGithub.Dns/FastGithub.Dns.csproj index 6e90918..65ae28c 100644 --- a/FastGithub.Dns/FastGithub.Dns.csproj +++ b/FastGithub.Dns/FastGithub.Dns.csproj @@ -4,7 +4,7 @@ net5.0 true - + diff --git a/FastGithub.Dns/HostsValidator.cs b/FastGithub.Dns/HostsValidator.cs new file mode 100644 index 0000000..cbcbba0 --- /dev/null +++ b/FastGithub.Dns/HostsValidator.cs @@ -0,0 +1,98 @@ +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.NetworkInformation; +using System.Threading.Tasks; + +namespace FastGithub.Dns +{ + /// + /// host文件配置验证器 + /// + sealed class HostsValidator + { + private readonly FastGithubConfig fastGithubConfig; + private readonly ILogger logger; + + /// + /// host文件配置验证器 + /// + /// + /// + public HostsValidator( + FastGithubConfig fastGithubConfig, + ILogger logger) + { + this.fastGithubConfig = fastGithubConfig; + this.logger = logger; + } + + /// + /// 验证host文件的域名解析配置 + /// + /// + public async Task ValidateAsync() + { + var hostsPath = @"/etc/hosts"; + if (OperatingSystem.IsWindows()) + { + hostsPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.System), $"drivers/{hostsPath}"); + } + + if (File.Exists(hostsPath) == false) + { + return; + } + + var lines = await File.ReadAllLinesAsync(hostsPath); + var records = lines.Where(item => item.TrimStart().StartsWith("#") == false); + var localAddresses = GetLocalMachineIPAddress().ToArray(); + + foreach (var record in records) + { + var items = record.Split(' ', StringSplitOptions.RemoveEmptyEntries); + if (items.Length < 2) + { + continue; + } + + if (IPAddress.TryParse(items[0], out var address) == false) + { + continue; + } + + if (localAddresses.Contains(address)) + { + continue; + } + + var domain = items[1]; + if (this.fastGithubConfig.IsMatch(domain)) + { + this.logger.LogWarning($"hosts文件设置了[{domain}->{address}],{nameof(FastGithub)}对此域名反向代理失效"); + } + } + } + + /// + /// 获取本机所有ip + /// + /// + private static IEnumerable GetLocalMachineIPAddress() + { + yield return IPAddress.Loopback; + yield return IPAddress.IPv6Loopback; + + foreach (var @interface in NetworkInterface.GetAllNetworkInterfaces()) + { + foreach (var addressInfo in @interface.GetIPProperties().UnicastAddresses) + { + yield return addressInfo.Address; + } + } + } + } +}