using FastGithub.DomainResolve;
using Microsoft.AspNetCore.Connections;
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
namespace FastGithub.HttpServer
{
    /// 
    /// tcp反射代理处理者
    /// 
    abstract class TcpReverseProxyHandler : ConnectionHandler
    {
        private readonly IDomainResolver domainResolver;
        private readonly DnsEndPoint endPoint;
        private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
        /// 
        /// tcp反射代理处理者
        /// 
        /// 
        /// 
        public TcpReverseProxyHandler(IDomainResolver domainResolver, DnsEndPoint endPoint)
        {
            this.domainResolver = domainResolver;
            this.endPoint = endPoint;
        }
        /// 
        /// tcp连接后
        /// 
        /// 
        /// 
        public override async Task OnConnectedAsync(ConnectionContext context)
        {
            var cancellationToken = context.ConnectionClosed;
            using var connection = await this.CreateConnectionAsync(cancellationToken);
            var task1 = connection.CopyToAsync(context.Transport.Output, cancellationToken);
            var task2 = context.Transport.Input.CopyToAsync(connection, cancellationToken);
            await Task.WhenAny(task1, task2);
        }
        /// 
        /// 创建连接
        /// 
        /// 
        /// 
        /// 
        private async Task CreateConnectionAsync(CancellationToken cancellationToken)
        {
            var innerExceptions = new List();
            await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint, cancellationToken))
            {
                var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
                try
                {
                    using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
                    using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
                    await socket.ConnectAsync(address, this.endPoint.Port, linkedTokenSource.Token);
                    return new NetworkStream(socket, ownsSocket: false);
                }
                catch (Exception ex)
                {
                    socket.Dispose();
                    cancellationToken.ThrowIfCancellationRequested();
                    innerExceptions.Add(ex);
                }
            }
            throw new AggregateException($"无法连接到{this.endPoint.Host}:{this.endPoint.Port}", innerExceptions);
        }
    }
}