使用多ip尝试连接

This commit is contained in:
陈国伟 2021-11-23 17:08:16 +08:00
parent 4bbc48c9f8
commit 82efd98448
5 changed files with 62 additions and 50 deletions

View File

@ -1,5 +1,4 @@
using FastGithub.Configuration; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging;
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
@ -44,23 +43,7 @@ namespace FastGithub.DomainResolve
{ {
this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty<IPAddress>()); this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty<IPAddress>());
} }
} }
/// <summary>
/// 解析ip
/// </summary>
/// <param name="endPoint">节点</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default)
{
await foreach (var address in this.ResolveAllAsync(endPoint, cancellationToken))
{
return address;
}
throw new FastGithubException($"解析不到{endPoint.Host}的IP");
}
/// <summary> /// <summary>
/// 解析域名 /// 解析域名
@ -68,7 +51,7 @@ namespace FastGithub.DomainResolve
/// <param name="endPoint">节点</param> /// <param name="endPoint">节点</param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
public async IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken) public async IAsyncEnumerable<IPAddress> ResolveAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
{ {
if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0) if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0)
{ {

View File

@ -9,22 +9,14 @@ namespace FastGithub.DomainResolve
/// 域名解析器 /// 域名解析器
/// </summary> /// </summary>
public interface IDomainResolver public interface IDomainResolver
{ {
/// <summary>
/// 解析ip
/// </summary>
/// <param name="endPoint">节点</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
/// <summary> /// <summary>
/// 解析所有ip /// 解析所有ip
/// </summary> /// </summary>
/// <param name="endPoint">节点</param> /// <param name="endPoint">节点</param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default); IAsyncEnumerable<IPAddress> ResolveAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
/// <summary> /// <summary>
/// 对所有节点进行测速 /// 对所有节点进行测速

View File

@ -186,7 +186,7 @@ namespace FastGithub.Http
yield return new IPEndPoint(this.domainConfig.IPAddress, dnsEndPoint.Port); yield return new IPEndPoint(this.domainConfig.IPAddress, dnsEndPoint.Port);
} }
await foreach (var item in this.domainResolver.ResolveAllAsync(dnsEndPoint, cancellationToken)) await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint, cancellationToken))
{ {
yield return new IPEndPoint(item, dnsEndPoint.Port); yield return new IPEndPoint(item, dnsEndPoint.Port);
} }

View File

@ -4,12 +4,16 @@ using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines; using System.IO.Pipelines;
using System.Net; using System.Net;
using System.Net.Http; using System.Net.Http;
using System.Net.Sockets; using System.Net.Sockets;
using System.Reflection; using System.Reflection;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Yarp.ReverseProxy.Forwarder; using Yarp.ReverseProxy.Forwarder;
@ -30,6 +34,7 @@ namespace FastGithub.HttpServer
private readonly HttpReverseProxyMiddleware httpReverseProxy; private readonly HttpReverseProxyMiddleware httpReverseProxy;
private readonly HttpMessageInvoker defaultHttpClient; private readonly HttpMessageInvoker defaultHttpClient;
private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
static HttpProxyMiddleware() static HttpProxyMiddleware()
{ {
@ -83,10 +88,7 @@ namespace FastGithub.HttpServer
} }
else if (context.Request.Method == HttpMethods.Connect) else if (context.Request.Method == HttpMethods.Connect)
{ {
var endpoint = await this.GetTargetEndPointAsync(host); using var connection = await this.CreateConnectionAsync(host);
using var targetSocket = new Socket(SocketType.Stream, ProtocolType.Tcp);
await targetSocket.ConnectAsync(endpoint);
var responseFeature = context.Features.Get<IHttpResponseFeature>(); var responseFeature = context.Features.Get<IHttpResponseFeature>();
if (responseFeature != null) if (responseFeature != null)
{ {
@ -98,8 +100,7 @@ namespace FastGithub.HttpServer
var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport; var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport;
if (transport != null) if (transport != null)
{ {
var targetStream = new NetworkStream(targetSocket, ownsSocket: false); await Task.WhenAny(connection.CopyToAsync(transport.Output), transport.Input.CopyToAsync(connection));
await Task.WhenAny(targetStream.CopyToAsync(transport.Output), transport.Input.CopyToAsync(targetStream));
} }
} }
else else
@ -151,40 +152,73 @@ namespace FastGithub.HttpServer
return buidler.ToString(); return buidler.ToString();
} }
/// <summary>
/// 创建连接
/// </summary>
/// <param name="host"></param>
/// <returns></returns>
/// <exception cref="AggregateException"></exception>
private async Task<Stream> CreateConnectionAsync(HostString host)
{
var innerExceptions = new List<Exception>();
await foreach (var endPoint in this.GetTargetEndPointsAsync(host))
{
var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try
{
using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
await socket.ConnectAsync(endPoint, timeoutTokenSource.Token);
return new NetworkStream(socket, ownsSocket: false);
}
catch (Exception ex)
{
socket.Dispose();
innerExceptions.Add(ex);
}
}
throw new AggregateException($"无法连接到{host}", innerExceptions);
}
/// <summary> /// <summary>
/// 获取目标终节点 /// 获取目标终节点
/// </summary> /// </summary>
/// <param name="host"></param> /// <param name="host"></param>
/// <returns></returns> /// <returns></returns>
private async Task<EndPoint> GetTargetEndPointAsync(HostString host) private async IAsyncEnumerable<EndPoint> GetTargetEndPointsAsync(HostString host)
{ {
var targetHost = host.Host; var targetHost = host.Host;
var targetPort = host.Port ?? HTTPS_PORT; var targetPort = host.Port ?? HTTPS_PORT;
if (IPAddress.TryParse(targetHost, out var address) == true) if (IPAddress.TryParse(targetHost, out var address) == true)
{ {
return new IPEndPoint(address, targetPort); yield return new IPEndPoint(address, targetPort);
yield break;
} }
// 不关心的域名直接使用系统dns // 不关心的域名直接使用系统dns
if (this.fastGithubConfig.IsMatch(targetHost) == false) if (this.fastGithubConfig.IsMatch(targetHost) == false)
{ {
return new DnsEndPoint(targetHost, targetPort); yield return new DnsEndPoint(targetHost, targetPort);
yield break;
} }
if (targetPort == HTTP_PORT) if (targetPort == HTTP_PORT)
{ {
return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Http); yield return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Http);
yield break;
} }
if (targetPort == HTTPS_PORT) if (targetPort == HTTPS_PORT)
{ {
return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Https); yield return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Https);
yield break;
} }
// 不使用系统dns var dnsEndPoint = new DnsEndPoint(targetHost, targetPort);
address = await this.domainResolver.ResolveAnyAsync(new DnsEndPoint(targetHost, targetPort)); await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint))
return new IPEndPoint(address, targetPort); {
yield return new IPEndPoint(item, targetPort);
}
} }
/// <summary> /// <summary>

View File

@ -6,6 +6,7 @@ using System.IO;
using System.IO.Pipelines; using System.IO.Pipelines;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace FastGithub.HttpServer namespace FastGithub.HttpServer
@ -17,6 +18,7 @@ namespace FastGithub.HttpServer
{ {
private readonly IDomainResolver domainResolver; private readonly IDomainResolver domainResolver;
private readonly DnsEndPoint endPoint; private readonly DnsEndPoint endPoint;
private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
/// <summary> /// <summary>
/// tcp反射代理处理者 /// tcp反射代理处理者
@ -36,9 +38,9 @@ namespace FastGithub.HttpServer
/// <returns></returns> /// <returns></returns>
public override async Task OnConnectedAsync(ConnectionContext context) public override async Task OnConnectedAsync(ConnectionContext context)
{ {
using var targetStream = await this.CreateConnectionAsync(); using var connection = await this.CreateConnectionAsync();
var task1 = targetStream.CopyToAsync(context.Transport.Output); var task1 = connection.CopyToAsync(context.Transport.Output);
var task2 = context.Transport.Input.CopyToAsync(targetStream); var task2 = context.Transport.Input.CopyToAsync(connection);
await Task.WhenAny(task1, task2); await Task.WhenAny(task1, task2);
} }
@ -50,12 +52,13 @@ namespace FastGithub.HttpServer
private async Task<Stream> CreateConnectionAsync() private async Task<Stream> CreateConnectionAsync()
{ {
var innerExceptions = new List<Exception>(); var innerExceptions = new List<Exception>();
await foreach (var address in this.domainResolver.ResolveAllAsync(this.endPoint)) await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint))
{ {
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try try
{ {
await socket.ConnectAsync(address, this.endPoint.Port); using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
await socket.ConnectAsync(address, this.endPoint.Port, timeoutTokenSource.Token);
return new NetworkStream(socket, ownsSocket: false); return new NetworkStream(socket, ownsSocket: false);
} }
catch (Exception ex) catch (Exception ex)