130 lines
4.2 KiB
C#
130 lines
4.2 KiB
C#
using DNS.Protocol;
|
||
using Microsoft.AspNetCore.Http;
|
||
using Microsoft.Extensions.Logging;
|
||
using System;
|
||
using System.IO;
|
||
using System.Linq;
|
||
using System.Net;
|
||
using System.Threading.Tasks;
|
||
|
||
namespace FastGithub.Dns
|
||
{
|
||
/// <summary>
|
||
/// DoH中间件
|
||
/// </summary>
|
||
sealed class DnsOverHttpsMiddleware
|
||
{
|
||
private static readonly PathString dnsQueryPath = "/dns-query";
|
||
private const string MEDIA_TYPE = "application/dns-message";
|
||
private readonly RequestResolver requestResolver;
|
||
private readonly ILogger<DnsOverHttpsMiddleware> logger;
|
||
|
||
/// <summary>
|
||
/// DoH中间件
|
||
/// </summary>
|
||
/// <param name="requestResolver"></param>
|
||
/// <param name="logger"></param>
|
||
public DnsOverHttpsMiddleware(
|
||
RequestResolver requestResolver,
|
||
ILogger<DnsOverHttpsMiddleware> logger)
|
||
{
|
||
this.requestResolver = requestResolver;
|
||
this.logger = logger;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 执行请求
|
||
/// </summary>
|
||
/// <param name="context"></param>
|
||
/// <param name="next"></param>
|
||
/// <returns></returns>
|
||
public async Task InvokeAsync(HttpContext context, RequestDelegate next)
|
||
{
|
||
Request? request;
|
||
try
|
||
{
|
||
request = await ParseDnsRequestAsync(context.Request);
|
||
}
|
||
catch (Exception)
|
||
{
|
||
context.Response.StatusCode = StatusCodes.Status400BadRequest;
|
||
return;
|
||
}
|
||
|
||
if (request == null)
|
||
{
|
||
await next(context);
|
||
return;
|
||
}
|
||
|
||
var response = await this.ResolveAsync(context, request);
|
||
context.Response.ContentType = MEDIA_TYPE;
|
||
await context.Response.BodyWriter.WriteAsync(response.ToArray());
|
||
}
|
||
|
||
/// <summary>
|
||
/// 解析dns域名
|
||
/// </summary>
|
||
/// <param name="context"></param>
|
||
/// <param name="request"></param>
|
||
/// <returns></returns>
|
||
private async Task<IResponse> ResolveAsync(HttpContext context, Request request)
|
||
{
|
||
try
|
||
{
|
||
var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
|
||
var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
|
||
var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
|
||
return await this.requestResolver.Resolve(remoteEndPointRequest);
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
this.logger.LogWarning($"处理DNS异常:{ex.Message}");
|
||
return Response.FromRequest(request);
|
||
}
|
||
}
|
||
|
||
/// <summary>
|
||
/// 解析dns请求
|
||
/// </summary>
|
||
/// <param name="request"></param>
|
||
/// <returns></returns>
|
||
private static async Task<Request?> ParseDnsRequestAsync(HttpRequest request)
|
||
{
|
||
if (request.Path != dnsQueryPath ||
|
||
request.Headers.TryGetValue("accept", out var accept) == false ||
|
||
accept.Contains(MEDIA_TYPE) == false)
|
||
{
|
||
return default;
|
||
}
|
||
|
||
if (request.Method == HttpMethods.Get)
|
||
{
|
||
if (request.Query.TryGetValue("dns", out var dns) == false)
|
||
{
|
||
return default;
|
||
}
|
||
|
||
var dnsRequest = dns.ToString().Replace('-', '+').Replace('_', '/');
|
||
int mod = dnsRequest.Length % 4;
|
||
if (mod > 0)
|
||
{
|
||
dnsRequest = dnsRequest.PadRight(dnsRequest.Length - mod + 4, '=');
|
||
}
|
||
|
||
var message = Convert.FromBase64String(dnsRequest);
|
||
return Request.FromArray(message);
|
||
}
|
||
|
||
if (request.Method == HttpMethods.Post && request.ContentType == MEDIA_TYPE)
|
||
{
|
||
using var message = new MemoryStream();
|
||
await request.Body.CopyToAsync(message);
|
||
return Request.FromArray(message.ToArray());
|
||
}
|
||
|
||
return default;
|
||
}
|
||
}
|
||
}
|