using DNS.Protocol; using Microsoft.AspNetCore.Http; using System; using System.IO; using System.Linq; using System.Net; using System.Threading.Tasks; namespace FastGithub.Dns { /// /// DoH中间件 /// sealed class DnsOverHttpsMiddleware { private static readonly PathString dnsQueryPath = "/dns-query"; private const string MEDIA_TYPE = "application/dns-message"; private readonly RequestResolver requestResolver; /// /// DoH中间件 /// /// public DnsOverHttpsMiddleware(RequestResolver requestResolver) { this.requestResolver = requestResolver; } /// /// 执行请求 /// /// /// /// public async Task InvokeAsync(HttpContext context, RequestDelegate next) { try { var request = await ParseDnsRequestAsync(context.Request); if (request == null) { await next(context); } else { var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback; var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort); var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint); var response = await this.requestResolver.Resolve(remoteEndPointRequest); context.Response.ContentType = MEDIA_TYPE; await context.Response.BodyWriter.WriteAsync(response.ToArray()); } } catch (Exception) { await next(context); } } /// /// 解析dns请求 /// /// /// private static async Task ParseDnsRequestAsync(HttpRequest request) { if (request.Path != dnsQueryPath || request.Headers.TryGetValue("accept", out var accept) == false || accept.Contains(MEDIA_TYPE) == false) { return default; } if (request.Method == HttpMethods.Get) { if (request.Query.TryGetValue("dns", out var dns) == false) { return default; } var dnsRequest = dns.ToString().Replace('-', '+').Replace('_', '/'); int mod = dnsRequest.Length % 4; if (mod > 0) { dnsRequest = dnsRequest.PadRight(dnsRequest.Length - mod + 4, '='); } var message = Convert.FromBase64String(dnsRequest); return Request.FromArray(message); } if (request.Method == HttpMethods.Post && request.ContentType == MEDIA_TYPE) { using var message = new MemoryStream(); await request.Body.CopyToAsync(message); return Request.FromArray(message.ToArray()); } return default; } } }