diff --git a/FastGithub.DomainResolve/DomainResolver.cs b/FastGithub.DomainResolve/DomainResolver.cs index a6d05ab..9d245e2 100644 --- a/FastGithub.DomainResolve/DomainResolver.cs +++ b/FastGithub.DomainResolve/DomainResolver.cs @@ -16,30 +16,33 @@ namespace FastGithub.DomainResolve /// sealed class DomainResolver : IDomainResolver { - const int MAX_ADDRESS_COUNT = 4; private readonly DnsClient dnsClient; private readonly DomainPersistence persistence; + private readonly IPAddressStatusService statusService; private readonly ILogger logger; - private readonly ConcurrentDictionary dnsEndPointAddressElapseds = new(); + private readonly ConcurrentDictionary dnsEndPointAddress = new(); /// /// 域名解析器 /// /// /// + /// /// public DomainResolver( DnsClient dnsClient, DomainPersistence persistence, + IPAddressStatusService statusService, ILogger logger) { this.dnsClient = dnsClient; this.persistence = persistence; + this.statusService = statusService; this.logger = logger; foreach (var endPoint in persistence.ReadDnsEndPoints()) { - this.dnsEndPointAddressElapseds.TryAdd(endPoint, Array.Empty()); + this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty()); } } @@ -67,18 +70,18 @@ namespace FastGithub.DomainResolve /// public async IAsyncEnumerable ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken) { - if (this.dnsEndPointAddressElapseds.TryGetValue(endPoint, out var addressElapseds) && addressElapseds.Length > 0) + if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0) { - foreach (var addressElapsed in addressElapseds) + foreach (var address in addresses) { - yield return addressElapsed.Adddress; + yield return address; } } else { - if (this.dnsEndPointAddressElapseds.TryAdd(endPoint, Array.Empty())) + if (this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty())) { - await this.persistence.WriteDnsEndPointsAsync(this.dnsEndPointAddressElapseds.Keys, cancellationToken); + await this.persistence.WriteDnsEndPointsAsync(this.dnsEndPointAddress.Keys, cancellationToken); } await foreach (var adddress in this.dnsClient.ResolveAsync(endPoint, fastSort: true, cancellationToken)) @@ -95,44 +98,29 @@ namespace FastGithub.DomainResolve /// public async Task TestAllEndPointsAsync(CancellationToken cancellationToken) { - foreach (var keyValue in this.dnsEndPointAddressElapseds) + foreach (var keyValue in this.dnsEndPointAddress) { - var oldValues = keyValue.Value; - if (oldValues.Length >= MAX_ADDRESS_COUNT) - { - if (oldValues.Any(item => item.NeedUpdateElapsed()) == false) - { - continue; - } - } - var dnsEndPoint = keyValue.Key; - var hashSet = new HashSet(oldValues); - await foreach (var adddress in this.dnsClient.ResolveAsync(dnsEndPoint, fastSort: false, cancellationToken)) + var oldAddresses = keyValue.Value; + + var hashSet = new HashSet(oldAddresses); + await foreach (var address in this.dnsClient.ResolveAsync(dnsEndPoint, fastSort: false, cancellationToken)) { - hashSet.Add(new IPAddressElapsed(adddress, dnsEndPoint.Port)); + hashSet.Add(address); } - // 两个以上才进行测速排序 - if (hashSet.Count > 1) - { - var updateTasks = hashSet - .Where(item => item.NeedUpdateElapsed()) - .Select(item => item.UpdateElapsedAsync(cancellationToken)); - await Task.WhenAll(updateTasks); - } - - var newValues = hashSet + var statusArray = await this.statusService.GetParallelAsync(hashSet, dnsEndPoint.Port, cancellationToken); + var newAddresses = statusArray .Where(item => item.Elapsed < TimeSpan.MaxValue) .OrderBy(item => item.Elapsed) - .Take(count: MAX_ADDRESS_COUNT) + .Select(item => item.Address) .ToArray(); - if (oldValues.SequenceEqual(newValues) == false) + if (oldAddresses.SequenceEqual(newAddresses) == false) { - this.dnsEndPointAddressElapseds[dnsEndPoint] = newValues; + this.dnsEndPointAddress[dnsEndPoint] = newAddresses; - var addressArray = string.Join(", ", newValues.Select(item => item.ToString())); + var addressArray = string.Join(", ", newAddresses.Select(item => item.ToString())); this.logger.LogInformation($"{dnsEndPoint.Host}->[{addressArray}]"); } } diff --git a/FastGithub.DomainResolve/IPAddressElapsed.cs b/FastGithub.DomainResolve/IPAddressElapsed.cs deleted file mode 100644 index c77895f..0000000 --- a/FastGithub.DomainResolve/IPAddressElapsed.cs +++ /dev/null @@ -1,107 +0,0 @@ -using System; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Net; -using System.Net.Sockets; -using System.Threading; -using System.Threading.Tasks; - -namespace FastGithub.DomainResolve -{ - /// - /// IP延时记录 - /// 5分钟有效期 - /// 5秒连接超时 - /// - [DebuggerDisplay("Adddress={Adddress} Elapsed={Elapsed}")] - sealed class IPAddressElapsed : IEquatable - { - private static readonly long maxLifeTime = 5 * 60 * 1000; - private static readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(5d); - - private long lastTestTickCount = 0L; - - /// - /// 获取IP地址 - /// - public IPAddress Adddress { get; } - - /// - /// 获取端口 - /// - public int Port { get; } - - /// - /// 获取延时 - /// - public TimeSpan Elapsed { get; private set; } - - /// - /// IP延时 - /// - /// - /// - public IPAddressElapsed(IPAddress adddress, int port) - { - this.Adddress = adddress; - this.Port = port; - } - - /// - /// 是否需求更新延时 - /// - /// - public bool NeedUpdateElapsed() - { - return Environment.TickCount64 - this.lastTestTickCount > maxLifeTime; - } - - /// - /// 更新连接耗时 - /// - /// - /// - public async Task UpdateElapsedAsync(CancellationToken cancellationToken) - { - var stopWatch = Stopwatch.StartNew(); - try - { - using var timeoutTokenSource = new CancellationTokenSource(connectTimeout); - using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token); - using var socket = new Socket(this.Adddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - await socket.ConnectAsync(this.Adddress, this.Port, linkedTokenSource.Token); - this.Elapsed = stopWatch.Elapsed; - } - catch (Exception) - { - cancellationToken.ThrowIfCancellationRequested(); - this.Elapsed = TimeSpan.MaxValue; - } - finally - { - this.lastTestTickCount = Environment.TickCount64; - stopWatch.Stop(); - } - } - - public bool Equals(IPAddressElapsed? other) - { - return other != null && other.Adddress.Equals(this.Adddress); - } - - public override bool Equals([NotNullWhen(true)] object? obj) - { - return obj is IPAddressElapsed other && this.Equals(other); - } - - public override int GetHashCode() - { - return this.Adddress.GetHashCode(); - } - - public override string ToString() - { - return this.Adddress.ToString(); - } - } -} diff --git a/FastGithub.DomainResolve/IPAddressStatus.cs b/FastGithub.DomainResolve/IPAddressStatus.cs new file mode 100644 index 0000000..8019f8b --- /dev/null +++ b/FastGithub.DomainResolve/IPAddressStatus.cs @@ -0,0 +1,34 @@ +using System; +using System.Net; + +namespace FastGithub.DomainResolve +{ + /// + /// 表示IP的状态 + /// + struct IPAddressStatus + { + /// + /// 获取IP地址 + /// + public IPAddress Address { get; } + + /// + /// 获取延时 + /// 当连接失败时值为MaxValue + /// + public TimeSpan Elapsed { get; } + + + /// + /// IP的状态 + /// + /// + /// + public IPAddressStatus(IPAddress address, TimeSpan elapsed) + { + this.Address = address; + this.Elapsed = elapsed; + } + } +} diff --git a/FastGithub.DomainResolve/IPAddressStatusService.cs b/FastGithub.DomainResolve/IPAddressStatusService.cs new file mode 100644 index 0000000..73c62bb --- /dev/null +++ b/FastGithub.DomainResolve/IPAddressStatusService.cs @@ -0,0 +1,88 @@ +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace FastGithub.DomainResolve +{ + /// + /// IP状态服务 + /// 连接成功的IP缓存5分钟 + /// 连接失败的IP缓存2分钟 + /// + sealed class IPAddressStatusService + { + private readonly TimeSpan activeTTL = TimeSpan.FromMinutes(5d); + private readonly TimeSpan negativeTTL = TimeSpan.FromMinutes(2d); + private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(5d); + private readonly IMemoryCache statusCache = new MemoryCache(Options.Create(new MemoryCacheOptions())); + + + /// + /// 并行获取多个IP的状态 + /// + /// + /// + /// + /// + public Task GetParallelAsync(IEnumerable addresses, int port, CancellationToken cancellationToken) + { + var statusTasks = addresses.Select(item => this.GetAsync(item, port, cancellationToken)); + return Task.WhenAll(statusTasks); + } + + /// + /// 获取IP状态 + /// + /// + /// + /// + /// + public async Task GetAsync(IPAddress address, int port, CancellationToken cancellationToken) + { + var endPoint = new IPEndPoint(address, port); + if (this.statusCache.TryGetValue(endPoint, out var status)) + { + return status; + } + + status = await this.GetAddressStatusAsync(endPoint, cancellationToken); + var ttl = status.Elapsed < TimeSpan.MaxValue ? this.activeTTL : this.negativeTTL; + return this.statusCache.Set(endPoint, status, ttl); + } + + /// + /// 获取IP状态 + /// + /// + /// + /// + private async Task GetAddressStatusAsync(IPEndPoint endPoint, CancellationToken cancellationToken) + { + var stopWatch = Stopwatch.StartNew(); + try + { + using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout); + using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token); + using var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(endPoint, linkedTokenSource.Token); + return new IPAddressStatus(endPoint.Address, stopWatch.Elapsed); + } + catch (Exception) + { + cancellationToken.ThrowIfCancellationRequested(); + return new IPAddressStatus(endPoint.Address, TimeSpan.MaxValue); + } + finally + { + stopWatch.Stop(); + } + } + } +} diff --git a/FastGithub.DomainResolve/ServiceCollectionExtensions.cs b/FastGithub.DomainResolve/ServiceCollectionExtensions.cs index 49341c6..69c5075 100644 --- a/FastGithub.DomainResolve/ServiceCollectionExtensions.cs +++ b/FastGithub.DomainResolve/ServiceCollectionExtensions.cs @@ -19,6 +19,7 @@ namespace FastGithub services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); + services.TryAddSingleton(); services.TryAddSingleton(); services.AddHostedService(); return services;