using DNS.Protocol;
using Microsoft.AspNetCore.Http;
using System;
using System.IO;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
namespace FastGithub.Dns
{
    /// 
    /// DoH中间件
    /// 
    sealed class DnsOverHttpsMiddleware
    {
        private static readonly PathString dnsQueryPath = "/dns-query";
        private const string MEDIA_TYPE = "application/dns-message";
        private readonly RequestResolver requestResolver;
        /// 
        /// DoH中间件
        /// 
        /// 
        public DnsOverHttpsMiddleware(RequestResolver requestResolver)
        {
            this.requestResolver = requestResolver;
        }
        /// 
        /// 执行请求
        /// 
        /// 
        /// 
        /// 
        public async Task InvokeAsync(HttpContext context, RequestDelegate next)
        {
            try
            {
                var request = await ParseDnsRequestAsync(context.Request);
                if (request == null)
                {
                    await next(context);
                }
                else
                {
                    var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
                    var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
                    var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
                    var response = await this.requestResolver.Resolve(remoteEndPointRequest);
                    context.Response.ContentType = MEDIA_TYPE;
                    await context.Response.BodyWriter.WriteAsync(response.ToArray());
                }
            }
            catch (Exception)
            {
                await next(context);
            }
        }
        /// 
        /// 解析dns请求
        /// 
        /// 
        /// 
        private static async Task ParseDnsRequestAsync(HttpRequest request)
        {
            if (request.Path != dnsQueryPath ||
                request.Headers.TryGetValue("accept", out var accept) == false ||
                accept.Contains(MEDIA_TYPE) == false)
            {
                return default;
            }
            if (request.Method == HttpMethods.Get)
            {
                if (request.Query.TryGetValue("dns", out var dns) == false)
                {
                    return default;
                }
                var dnsRequest = dns.ToString().Replace('-', '+').Replace('_', '/');
                int mod = dnsRequest.Length % 4;
                if (mod > 0)
                {
                    dnsRequest = dnsRequest.PadRight(dnsRequest.Length - mod + 4, '=');
                }
                var message = Convert.FromBase64String(dnsRequest);
                return Request.FromArray(message);
            }
            if (request.Method == HttpMethods.Post && request.ContentType == MEDIA_TYPE)
            {
                using var message = new MemoryStream();
                await request.Body.CopyToAsync(message);
                return Request.FromArray(message.ToArray());
            }
            return default;
        }
    }
}