优化DoH中间件

This commit is contained in:
陈国伟 2021-08-27 09:01:56 +08:00
parent d4f9172574
commit 97727944ff
2 changed files with 43 additions and 17 deletions

View File

@ -1,5 +1,6 @@
using DNS.Protocol; using DNS.Protocol;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using System; using System;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
@ -16,14 +17,19 @@ namespace FastGithub.Dns
private static readonly PathString dnsQueryPath = "/dns-query"; private static readonly PathString dnsQueryPath = "/dns-query";
private const string MEDIA_TYPE = "application/dns-message"; private const string MEDIA_TYPE = "application/dns-message";
private readonly RequestResolver requestResolver; private readonly RequestResolver requestResolver;
private readonly ILogger<DnsOverHttpsMiddleware> logger;
/// <summary> /// <summary>
/// DoH中间件 /// DoH中间件
/// </summary> /// </summary>
/// <param name="requestResolver"></param> /// <param name="requestResolver"></param>
public DnsOverHttpsMiddleware(RequestResolver requestResolver) /// <param name="logger"></param>
public DnsOverHttpsMiddleware(
RequestResolver requestResolver,
ILogger<DnsOverHttpsMiddleware> logger)
{ {
this.requestResolver = requestResolver; this.requestResolver = requestResolver;
this.logger = logger;
} }
/// <summary> /// <summary>
@ -34,27 +40,47 @@ namespace FastGithub.Dns
/// <returns></returns> /// <returns></returns>
public async Task InvokeAsync(HttpContext context, RequestDelegate next) public async Task InvokeAsync(HttpContext context, RequestDelegate next)
{ {
Request? request;
try try
{ {
var request = await ParseDnsRequestAsync(context.Request); request = await ParseDnsRequestAsync(context.Request);
if (request == null)
{
await next(context);
}
else
{
var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
var response = await this.requestResolver.Resolve(remoteEndPointRequest);
context.Response.ContentType = MEDIA_TYPE;
await context.Response.BodyWriter.WriteAsync(response.ToArray());
}
} }
catch (Exception) catch (Exception)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
if (request == null)
{ {
await next(context); await next(context);
return;
}
var response = await this.ResolveAsync(context, request);
context.Response.ContentType = MEDIA_TYPE;
await context.Response.BodyWriter.WriteAsync(response.ToArray());
}
/// <summary>
/// 解析dns域名
/// </summary>
/// <param name="context"></param>
/// <param name="request"></param>
/// <returns></returns>
private async Task<IResponse> ResolveAsync(HttpContext context, Request request)
{
try
{
var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
return await this.requestResolver.Resolve(remoteEndPointRequest);
}
catch (Exception ex)
{
this.logger.LogWarning($"处理DNS异常{ex.Message}");
return Response.FromRequest(request);
} }
} }

View File

@ -61,7 +61,7 @@ namespace FastGithub.DomainResolve
var semaphore = this.semaphoreSlims.GetOrAdd(domain, _ => new SemaphoreSlim(1, 1)); var semaphore = this.semaphoreSlims.GetOrAdd(domain, _ => new SemaphoreSlim(1, 1));
try try
{ {
await semaphore.WaitAsync(cancellationToken); await semaphore.WaitAsync();
return await this.LookupAsync(domain, cancellationToken); return await this.LookupAsync(domain, cancellationToken);
} }
finally finally