diff --git a/FastGithub.Dns/DnsOverHttpsMiddleware.cs b/FastGithub.Dns/DnsOverHttpsMiddleware.cs index e7d9b35..de63218 100644 --- a/FastGithub.Dns/DnsOverHttpsMiddleware.cs +++ b/FastGithub.Dns/DnsOverHttpsMiddleware.cs @@ -1,5 +1,6 @@ using DNS.Protocol; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; using System; using System.IO; using System.Linq; @@ -16,14 +17,19 @@ namespace FastGithub.Dns private static readonly PathString dnsQueryPath = "/dns-query"; private const string MEDIA_TYPE = "application/dns-message"; private readonly RequestResolver requestResolver; + private readonly ILogger logger; /// /// DoH中间件 /// /// - public DnsOverHttpsMiddleware(RequestResolver requestResolver) + /// + public DnsOverHttpsMiddleware( + RequestResolver requestResolver, + ILogger logger) { this.requestResolver = requestResolver; + this.logger = logger; } /// @@ -34,27 +40,47 @@ namespace FastGithub.Dns /// public async Task InvokeAsync(HttpContext context, RequestDelegate next) { + Request? request; try { - var request = await ParseDnsRequestAsync(context.Request); - if (request == null) - { - await next(context); - } - else - { - var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback; - var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort); - var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint); - var response = await this.requestResolver.Resolve(remoteEndPointRequest); - - context.Response.ContentType = MEDIA_TYPE; - await context.Response.BodyWriter.WriteAsync(response.ToArray()); - } + request = await ParseDnsRequestAsync(context.Request); } catch (Exception) + { + context.Response.StatusCode = StatusCodes.Status400BadRequest; + return; + } + + if (request == null) { await next(context); + return; + } + + var response = await this.ResolveAsync(context, request); + context.Response.ContentType = MEDIA_TYPE; + await context.Response.BodyWriter.WriteAsync(response.ToArray()); + } + + /// + /// 解析dns域名 + /// + /// + /// + /// + private async Task ResolveAsync(HttpContext context, Request request) + { + try + { + var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback; + var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort); + var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint); + return await this.requestResolver.Resolve(remoteEndPointRequest); + } + catch (Exception ex) + { + this.logger.LogWarning($"处理DNS异常:{ex.Message}"); + return Response.FromRequest(request); } } diff --git a/FastGithub.DomainResolve/DomainResolver.cs b/FastGithub.DomainResolve/DomainResolver.cs index 2601817..1e80aea 100644 --- a/FastGithub.DomainResolve/DomainResolver.cs +++ b/FastGithub.DomainResolve/DomainResolver.cs @@ -61,7 +61,7 @@ namespace FastGithub.DomainResolve var semaphore = this.semaphoreSlims.GetOrAdd(domain, _ => new SemaphoreSlim(1, 1)); try { - await semaphore.WaitAsync(cancellationToken); + await semaphore.WaitAsync(); return await this.LookupAsync(domain, cancellationToken); } finally