FastGithub/FastGithub.Dns/DnsOverHttpsMiddleware.cs
2021-08-27 09:01:56 +08:00

130 lines
4.2 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using DNS.Protocol;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using System;
using System.IO;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
namespace FastGithub.Dns
{
/// <summary>
/// DoH中间件
/// </summary>
sealed class DnsOverHttpsMiddleware
{
private static readonly PathString dnsQueryPath = "/dns-query";
private const string MEDIA_TYPE = "application/dns-message";
private readonly RequestResolver requestResolver;
private readonly ILogger<DnsOverHttpsMiddleware> logger;
/// <summary>
/// DoH中间件
/// </summary>
/// <param name="requestResolver"></param>
/// <param name="logger"></param>
public DnsOverHttpsMiddleware(
RequestResolver requestResolver,
ILogger<DnsOverHttpsMiddleware> logger)
{
this.requestResolver = requestResolver;
this.logger = logger;
}
/// <summary>
/// 执行请求
/// </summary>
/// <param name="context"></param>
/// <param name="next"></param>
/// <returns></returns>
public async Task InvokeAsync(HttpContext context, RequestDelegate next)
{
Request? request;
try
{
request = await ParseDnsRequestAsync(context.Request);
}
catch (Exception)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
if (request == null)
{
await next(context);
return;
}
var response = await this.ResolveAsync(context, request);
context.Response.ContentType = MEDIA_TYPE;
await context.Response.BodyWriter.WriteAsync(response.ToArray());
}
/// <summary>
/// 解析dns域名
/// </summary>
/// <param name="context"></param>
/// <param name="request"></param>
/// <returns></returns>
private async Task<IResponse> ResolveAsync(HttpContext context, Request request)
{
try
{
var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
return await this.requestResolver.Resolve(remoteEndPointRequest);
}
catch (Exception ex)
{
this.logger.LogWarning($"处理DNS异常{ex.Message}");
return Response.FromRequest(request);
}
}
/// <summary>
/// 解析dns请求
/// </summary>
/// <param name="request"></param>
/// <returns></returns>
private static async Task<Request?> ParseDnsRequestAsync(HttpRequest request)
{
if (request.Path != dnsQueryPath ||
request.Headers.TryGetValue("accept", out var accept) == false ||
accept.Contains(MEDIA_TYPE) == false)
{
return default;
}
if (request.Method == HttpMethods.Get)
{
if (request.Query.TryGetValue("dns", out var dns) == false)
{
return default;
}
var dnsRequest = dns.ToString().Replace('-', '+').Replace('_', '/');
int mod = dnsRequest.Length % 4;
if (mod > 0)
{
dnsRequest = dnsRequest.PadRight(dnsRequest.Length - mod + 4, '=');
}
var message = Convert.FromBase64String(dnsRequest);
return Request.FromArray(message);
}
if (request.Method == HttpMethods.Post && request.ContentType == MEDIA_TYPE)
{
using var message = new MemoryStream();
await request.Body.CopyToAsync(message);
return Request.FromArray(message.ToArray());
}
return default;
}
}
}