using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Features;
using System.Buffers;
using System.IO.Pipelines;
using System.Threading.Tasks;
namespace FastGithub.HttpServer.TlsMiddlewares
{
    /// 
    /// https入侵中间件
    /// 
    sealed class TlsInvadeMiddleware
    {  
        /// 
        /// 执行中间件
        /// 
        /// 
        /// 
        public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext context)
        {
            // 连接不是tls
            if (await IsTlsConnectionAsync(context) == false)
            {
                // 没有任何tls中间件执行过
                if (context.Features.Get() == null)
                {
                    // 设置假的ITlsConnectionFeature,迫使https中间件跳过自身的工作
                    context.Features.Set(FakeTlsConnectionFeature.Instance);
                }
            }
            await next(context);
        }
        /// 
        /// 是否为tls协议
        /// 
        /// 
        /// 
        private static async Task IsTlsConnectionAsync(ConnectionContext context)
        {
            try
            {
                var result = await context.Transport.Input.ReadAtLeastAsync(2, context.ConnectionClosed);
                var state = IsTlsProtocol(result);
                context.Transport.Input.AdvanceTo(result.Buffer.Start);
                return state;
            }
            catch
            {
                return false;
            }
            static bool IsTlsProtocol(ReadResult result)
            {
                var reader = new SequenceReader(result.Buffer);
                return reader.TryRead(out var firstByte) &&
                    reader.TryRead(out var nextByte) &&
                    firstByte == 0x16 &&
                    nextByte == 0x3;
            }
        }
    }
}