using FastGithub.Configuration; using Microsoft.Extensions.Logging; using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace FastGithub.DomainResolve { /// /// 域名解析器 /// sealed class DomainResolver : IDomainResolver { const int MAX_ADDRESS_COUNT = 4; private readonly DnsClient dnsClient; private readonly DomainPersistence persistence; private readonly ILogger logger; private readonly ConcurrentDictionary dnsEndPointAddressElapseds = new(); /// /// 域名解析器 /// /// /// /// public DomainResolver( DnsClient dnsClient, DomainPersistence persistence, ILogger logger) { this.dnsClient = dnsClient; this.persistence = persistence; this.logger = logger; foreach (var endPoint in persistence.ReadDnsEndPoints()) { this.dnsEndPointAddressElapseds.TryAdd(endPoint, Array.Empty()); } } /// /// 解析ip /// /// 节点 /// /// public async Task ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default) { await foreach (var address in this.ResolveAllAsync(endPoint, cancellationToken)) { return address; } throw new FastGithubException($"解析不到{endPoint.Host}的IP"); } /// /// 解析域名 /// /// 节点 /// /// public async IAsyncEnumerable ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken) { if (this.dnsEndPointAddressElapseds.TryGetValue(endPoint, out var addressElapseds) && addressElapseds.Length > 0) { foreach (var addressElapsed in addressElapseds) { yield return addressElapsed.Adddress; } } else { if (this.dnsEndPointAddressElapseds.TryAdd(endPoint, Array.Empty())) { await this.persistence.WriteDnsEndPointsAsync(this.dnsEndPointAddressElapseds.Keys, cancellationToken); } await foreach (var adddress in this.dnsClient.ResolveAsync(endPoint, fastSort: true, cancellationToken)) { yield return adddress; } } } /// /// 对所有节点进行测速 /// /// /// public async Task TestAllEndPointsAsync(CancellationToken cancellationToken) { foreach (var keyValue in this.dnsEndPointAddressElapseds) { var dnsEndPoint = keyValue.Key; var hashSet = new HashSet(); foreach (var item in keyValue.Value) { hashSet.Add(item); } await foreach (var adddress in this.dnsClient.ResolveAsync(dnsEndPoint, fastSort: false, cancellationToken)) { hashSet.Add(new IPAddressElapsed(adddress, dnsEndPoint.Port)); } var updateTasks = hashSet .Where(item => item.CanUpdateElapsed()) .Select(item => item.UpdateElapsedAsync(cancellationToken)); await Task.WhenAll(updateTasks); var addressElapseds = hashSet .Where(item => item.Elapsed < TimeSpan.MaxValue) .OrderBy(item => item.Elapsed) .Take(count: MAX_ADDRESS_COUNT) .ToArray(); if (keyValue.Value.SequenceEqual(addressElapseds) == false) { var addressArray = string.Join(", ", addressElapseds.Select(item => item.ToString())); this.logger.LogInformation($"{dnsEndPoint.Host}->[{addressArray}]"); } this.dnsEndPointAddressElapseds[dnsEndPoint] = addressElapseds; } } } }