using FastGithub.Configuration;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace FastGithub.DomainResolve
{
    /// 
    /// 域名解析器
    ///  
    sealed class DomainResolver : IDomainResolver
    {
        private readonly DnsClient dnsClient;
        private readonly DomainPersistence persistence;
        private readonly IPAddressService addressService;
        private readonly ILogger logger;
        private readonly ConcurrentDictionary dnsEndPointAddress = new();
        /// 
        /// 域名解析器
        /// 
        /// 
        /// 
        /// 
        /// 
        public DomainResolver(
            DnsClient dnsClient,
            DomainPersistence persistence,
            IPAddressService addressService,
            ILogger logger)
        {
            this.dnsClient = dnsClient;
            this.persistence = persistence;
            this.addressService = addressService;
            this.logger = logger;
            foreach (var endPoint in persistence.ReadDnsEndPoints())
            {
                this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty());
            }
        }
        /// 
        /// 解析ip
        /// 
        /// 节点
        /// 
        /// 
        public async Task ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default)
        {
            await foreach (var address in this.ResolveAllAsync(endPoint, cancellationToken))
            {
                return address;
            }
            throw new FastGithubException($"解析不到{endPoint.Host}的IP");
        }
        /// 
        /// 解析域名
        /// 
        /// 节点
        /// 
        /// 
        public async IAsyncEnumerable ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
        {
            if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0)
            {
                foreach (var address in addresses)
                {
                    yield return address;
                }
            }
            else
            {
                if (this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty()))
                {
                    await this.persistence.WriteDnsEndPointsAsync(this.dnsEndPointAddress.Keys, cancellationToken);
                }
                await foreach (var adddress in this.dnsClient.ResolveAsync(endPoint, fastSort: true, cancellationToken))
                {
                    yield return adddress;
                }
            }
        }
        /// 
        /// 对所有节点进行测速
        /// 
        /// 
        /// 
        public async Task TestAllEndPointsAsync(CancellationToken cancellationToken)
        {
            foreach (var keyValue in this.dnsEndPointAddress)
            {
                var dnsEndPoint = keyValue.Key;
                var oldAddresses = keyValue.Value;
                var newAddresses = await this.addressService.GetAddressesAsync(dnsEndPoint, oldAddresses, cancellationToken);
                if (oldAddresses.SequenceEqual(newAddresses) == false)
                {
                    this.dnsEndPointAddress[dnsEndPoint] = newAddresses;
                    var addressArray = string.Join(", ", newAddresses.Select(item => item.ToString()));
                    this.logger.LogInformation($"{dnsEndPoint.Host}->[{addressArray}]");
                }
            }
        }
    }
}