diff --git a/FastGithub.DomainResolve/DomainResolver.cs b/FastGithub.DomainResolve/DomainResolver.cs index 7121a7d..a588ad5 100644 --- a/FastGithub.DomainResolve/DomainResolver.cs +++ b/FastGithub.DomainResolve/DomainResolver.cs @@ -1,5 +1,4 @@ -using FastGithub.Configuration; -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -44,23 +43,7 @@ namespace FastGithub.DomainResolve { this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty()); } - } - - - /// - /// 解析ip - /// - /// 节点 - /// - /// - public async Task ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default) - { - await foreach (var address in this.ResolveAllAsync(endPoint, cancellationToken)) - { - return address; - } - throw new FastGithubException($"解析不到{endPoint.Host}的IP"); - } + } /// /// 解析域名 @@ -68,7 +51,7 @@ namespace FastGithub.DomainResolve /// 节点 /// /// - public async IAsyncEnumerable ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken) + public async IAsyncEnumerable ResolveAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken) { if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0) { diff --git a/FastGithub.DomainResolve/IDomainResolver.cs b/FastGithub.DomainResolve/IDomainResolver.cs index f965210..5f8af9c 100644 --- a/FastGithub.DomainResolve/IDomainResolver.cs +++ b/FastGithub.DomainResolve/IDomainResolver.cs @@ -9,22 +9,14 @@ namespace FastGithub.DomainResolve /// 域名解析器 /// public interface IDomainResolver - { - /// - /// 解析ip - /// - /// 节点 - /// - /// - Task ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default); - + { /// /// 解析所有ip /// /// 节点 /// /// - IAsyncEnumerable ResolveAllAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default); + IAsyncEnumerable ResolveAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default); /// /// 对所有节点进行测速 diff --git a/FastGithub.Http/HttpClientHandler.cs b/FastGithub.Http/HttpClientHandler.cs index a1001b5..f7f5ec8 100644 --- a/FastGithub.Http/HttpClientHandler.cs +++ b/FastGithub.Http/HttpClientHandler.cs @@ -186,7 +186,7 @@ namespace FastGithub.Http yield return new IPEndPoint(this.domainConfig.IPAddress, dnsEndPoint.Port); } - await foreach (var item in this.domainResolver.ResolveAllAsync(dnsEndPoint, cancellationToken)) + await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint, cancellationToken)) { yield return new IPEndPoint(item, dnsEndPoint.Port); } diff --git a/FastGithub.HttpServer/HttpProxyMiddleware.cs b/FastGithub.HttpServer/HttpProxyMiddleware.cs index 1efebcc..8283f2c 100644 --- a/FastGithub.HttpServer/HttpProxyMiddleware.cs +++ b/FastGithub.HttpServer/HttpProxyMiddleware.cs @@ -4,12 +4,16 @@ using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using System; +using System.Collections.Generic; +using System.IO; using System.IO.Pipelines; using System.Net; using System.Net.Http; using System.Net.Sockets; using System.Reflection; using System.Text; +using System.Threading; using System.Threading.Tasks; using Yarp.ReverseProxy.Forwarder; @@ -30,6 +34,7 @@ namespace FastGithub.HttpServer private readonly HttpReverseProxyMiddleware httpReverseProxy; private readonly HttpMessageInvoker defaultHttpClient; + private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d); static HttpProxyMiddleware() { @@ -83,10 +88,7 @@ namespace FastGithub.HttpServer } else if (context.Request.Method == HttpMethods.Connect) { - var endpoint = await this.GetTargetEndPointAsync(host); - using var targetSocket = new Socket(SocketType.Stream, ProtocolType.Tcp); - await targetSocket.ConnectAsync(endpoint); - + using var connection = await this.CreateConnectionAsync(host); var responseFeature = context.Features.Get(); if (responseFeature != null) { @@ -98,8 +100,7 @@ namespace FastGithub.HttpServer var transport = context.Features.Get()?.Transport; if (transport != null) { - var targetStream = new NetworkStream(targetSocket, ownsSocket: false); - await Task.WhenAny(targetStream.CopyToAsync(transport.Output), transport.Input.CopyToAsync(targetStream)); + await Task.WhenAny(connection.CopyToAsync(transport.Output), transport.Input.CopyToAsync(connection)); } } else @@ -151,40 +152,73 @@ namespace FastGithub.HttpServer return buidler.ToString(); } + /// + /// 创建连接 + /// + /// + /// + /// + private async Task CreateConnectionAsync(HostString host) + { + var innerExceptions = new List(); + await foreach (var endPoint in this.GetTargetEndPointsAsync(host)) + { + var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + try + { + using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout); + await socket.ConnectAsync(endPoint, timeoutTokenSource.Token); + return new NetworkStream(socket, ownsSocket: false); + } + catch (Exception ex) + { + socket.Dispose(); + innerExceptions.Add(ex); + } + } + throw new AggregateException($"无法连接到{host}", innerExceptions); + } + /// /// 获取目标终节点 /// /// /// - private async Task GetTargetEndPointAsync(HostString host) + private async IAsyncEnumerable GetTargetEndPointsAsync(HostString host) { var targetHost = host.Host; var targetPort = host.Port ?? HTTPS_PORT; if (IPAddress.TryParse(targetHost, out var address) == true) { - return new IPEndPoint(address, targetPort); + yield return new IPEndPoint(address, targetPort); + yield break; } // 不关心的域名,直接使用系统dns if (this.fastGithubConfig.IsMatch(targetHost) == false) { - return new DnsEndPoint(targetHost, targetPort); + yield return new DnsEndPoint(targetHost, targetPort); + yield break; } if (targetPort == HTTP_PORT) { - return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Http); + yield return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Http); + yield break; } if (targetPort == HTTPS_PORT) { - return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Https); + yield return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Https); + yield break; } - // 不使用系统dns - address = await this.domainResolver.ResolveAnyAsync(new DnsEndPoint(targetHost, targetPort)); - return new IPEndPoint(address, targetPort); + var dnsEndPoint = new DnsEndPoint(targetHost, targetPort); + await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint)) + { + yield return new IPEndPoint(item, targetPort); + } } /// diff --git a/FastGithub.HttpServer/TcpReverseProxyHandler.cs b/FastGithub.HttpServer/TcpReverseProxyHandler.cs index 91229f0..e5eb39b 100644 --- a/FastGithub.HttpServer/TcpReverseProxyHandler.cs +++ b/FastGithub.HttpServer/TcpReverseProxyHandler.cs @@ -6,6 +6,7 @@ using System.IO; using System.IO.Pipelines; using System.Net; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; namespace FastGithub.HttpServer @@ -17,6 +18,7 @@ namespace FastGithub.HttpServer { private readonly IDomainResolver domainResolver; private readonly DnsEndPoint endPoint; + private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d); /// /// tcp反射代理处理者 @@ -36,9 +38,9 @@ namespace FastGithub.HttpServer /// public override async Task OnConnectedAsync(ConnectionContext context) { - using var targetStream = await this.CreateConnectionAsync(); - var task1 = targetStream.CopyToAsync(context.Transport.Output); - var task2 = context.Transport.Input.CopyToAsync(targetStream); + using var connection = await this.CreateConnectionAsync(); + var task1 = connection.CopyToAsync(context.Transport.Output); + var task2 = context.Transport.Input.CopyToAsync(connection); await Task.WhenAny(task1, task2); } @@ -50,12 +52,13 @@ namespace FastGithub.HttpServer private async Task CreateConnectionAsync() { var innerExceptions = new List(); - await foreach (var address in this.domainResolver.ResolveAllAsync(this.endPoint)) + await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint)) { var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); try { - await socket.ConnectAsync(address, this.endPoint.Port); + using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout); + await socket.ConnectAsync(address, this.endPoint.Port, timeoutTokenSource.Token); return new NetworkStream(socket, ownsSocket: false); } catch (Exception ex)