using FastGithub.Configuration;
using FastGithub.DomainResolve;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace FastGithub.HttpServer.TcpMiddlewares
{
    /// 
    /// 隧道中间件
    /// 
    sealed class TunnelMiddleware
    {
        private readonly FastGithubConfig fastGithubConfig;
        private readonly IDomainResolver domainResolver;
        private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
        /// 
        /// 隧道中间件
        /// 
        /// 
        ///  
        public TunnelMiddleware(
            FastGithubConfig fastGithubConfig,
            IDomainResolver domainResolver)
        {
            this.fastGithubConfig = fastGithubConfig;
            this.domainResolver = domainResolver;
        }
        /// 
        /// 执行中间件
        /// 
        /// 
        /// 
        /// 
        public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext context)
        {
            var proxyFeature = context.Features.Get();
            if (proxyFeature == null || // 非代理
                proxyFeature.ProxyProtocol != ProxyProtocol.TunnelProxy || //非隧道代理
                context.Features.Get() != null) // 经过隧道的https
            {
                await next(context);
            }
            else
            {
                var transport = context.Features.Get()?.Transport;
                if (transport != null)
                {
                    var cancellationToken = context.ConnectionClosed;
                    using var connection = await this.CreateConnectionAsync(proxyFeature.ProxyHost, cancellationToken);
                    var task1 = connection.CopyToAsync(transport.Output, cancellationToken);
                    var task2 = transport.Input.CopyToAsync(connection, cancellationToken);
                    await Task.WhenAny(task1, task2);
                }
            }
        }
        /// 
        /// 创建连接
        /// 
        /// 
        /// 
        /// 
        /// 
        private async Task CreateConnectionAsync(HostString host, CancellationToken cancellationToken)
        {
            var innerExceptions = new List();
            await foreach (var endPoint in this.GetUpstreamEndPointsAsync(host, cancellationToken))
            {
                var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
                try
                {
                    using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
                    using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
                    await socket.ConnectAsync(endPoint, linkedTokenSource.Token);
                    return new NetworkStream(socket, ownsSocket: true);
                }
                catch (Exception ex)
                {
                    socket.Dispose();
                    cancellationToken.ThrowIfCancellationRequested();
                    innerExceptions.Add(ex);
                }
            }
            throw new AggregateException($"无法连接到{host}", innerExceptions);
        }
        /// 
        /// 获取目标终节点
        /// 
        /// 
        /// 
        /// 
        private async IAsyncEnumerable GetUpstreamEndPointsAsync(HostString host, [EnumeratorCancellation] CancellationToken cancellationToken)
        {
            const int HTTPS_PORT = 443;
            var targetHost = host.Host;
            var targetPort = host.Port ?? HTTPS_PORT;
            if (IPAddress.TryParse(targetHost, out var address) == true)
            {
                yield return new IPEndPoint(address, targetPort);
            }
            else if (this.fastGithubConfig.IsMatch(targetHost) == false)
            {
                yield return new DnsEndPoint(targetHost, targetPort);
            }
            else
            {
                var dnsEndPoint = new DnsEndPoint(targetHost, targetPort);
                await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint, cancellationToken))
                {
                    yield return new IPEndPoint(item, targetPort);
                }
            }
        }
    }
}