优化DoH中间件

This commit is contained in:
陈国伟 2021-08-27 09:01:56 +08:00
parent d4f9172574
commit 97727944ff
2 changed files with 43 additions and 17 deletions

View File

@ -1,5 +1,6 @@
using DNS.Protocol;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using System;
using System.IO;
using System.Linq;
@ -16,14 +17,19 @@ namespace FastGithub.Dns
private static readonly PathString dnsQueryPath = "/dns-query";
private const string MEDIA_TYPE = "application/dns-message";
private readonly RequestResolver requestResolver;
private readonly ILogger<DnsOverHttpsMiddleware> logger;
/// <summary>
/// DoH中间件
/// </summary>
/// <param name="requestResolver"></param>
public DnsOverHttpsMiddleware(RequestResolver requestResolver)
/// <param name="logger"></param>
public DnsOverHttpsMiddleware(
RequestResolver requestResolver,
ILogger<DnsOverHttpsMiddleware> logger)
{
this.requestResolver = requestResolver;
this.logger = logger;
}
/// <summary>
@ -34,27 +40,47 @@ namespace FastGithub.Dns
/// <returns></returns>
public async Task InvokeAsync(HttpContext context, RequestDelegate next)
{
Request? request;
try
{
var request = await ParseDnsRequestAsync(context.Request);
if (request == null)
{
await next(context);
}
else
{
var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
var response = await this.requestResolver.Resolve(remoteEndPointRequest);
context.Response.ContentType = MEDIA_TYPE;
await context.Response.BodyWriter.WriteAsync(response.ToArray());
}
request = await ParseDnsRequestAsync(context.Request);
}
catch (Exception)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
if (request == null)
{
await next(context);
return;
}
var response = await this.ResolveAsync(context, request);
context.Response.ContentType = MEDIA_TYPE;
await context.Response.BodyWriter.WriteAsync(response.ToArray());
}
/// <summary>
/// 解析dns域名
/// </summary>
/// <param name="context"></param>
/// <param name="request"></param>
/// <returns></returns>
private async Task<IResponse> ResolveAsync(HttpContext context, Request request)
{
try
{
var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
return await this.requestResolver.Resolve(remoteEndPointRequest);
}
catch (Exception ex)
{
this.logger.LogWarning($"处理DNS异常{ex.Message}");
return Response.FromRequest(request);
}
}

View File

@ -61,7 +61,7 @@ namespace FastGithub.DomainResolve
var semaphore = this.semaphoreSlims.GetOrAdd(domain, _ => new SemaphoreSlim(1, 1));
try
{
await semaphore.WaitAsync(cancellationToken);
await semaphore.WaitAsync();
return await this.LookupAsync(domain, cancellationToken);
}
finally