基于tcp连接测速

This commit is contained in:
陈国伟 2021-09-30 16:09:36 +08:00
parent 03254aa0d9
commit 255c44af9c
10 changed files with 199 additions and 248 deletions

View File

@ -30,7 +30,6 @@ namespace FastGithub.DomainResolve
private readonly FastGithubConfig fastGithubConfig; private readonly FastGithubConfig fastGithubConfig;
private readonly ILogger<DnsClient> logger; private readonly ILogger<DnsClient> logger;
private readonly ConcurrentDictionary<string, IPAddressCollection> domainIPAddressCollection = new();
private readonly ConcurrentDictionary<string, SemaphoreSlim> semaphoreSlims = new(); private readonly ConcurrentDictionary<string, SemaphoreSlim> semaphoreSlims = new();
private readonly IMemoryCache dnsCache = new MemoryCache(Options.Create(new MemoryCacheOptions())); private readonly IMemoryCache dnsCache = new MemoryCache(Options.Create(new MemoryCacheOptions()));
private readonly TimeSpan defaultEmptyTtl = TimeSpan.FromSeconds(30d); private readonly TimeSpan defaultEmptyTtl = TimeSpan.FromSeconds(30d);
@ -54,15 +53,6 @@ namespace FastGithub.DomainResolve
this.logger = logger; this.logger = logger;
} }
/// <summary>
/// 预加载
/// </summary>
/// <param name="domain">域名</param>
public void Prefetch(string domain)
{
this.domainIPAddressCollection.TryAdd(domain, new IPAddressCollection());
}
/// <summary> /// <summary>
/// 解析域名 /// 解析域名
/// </summary> /// </summary>
@ -70,51 +60,6 @@ namespace FastGithub.DomainResolve
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
public async IAsyncEnumerable<IPAddress> ResolveAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken) public async IAsyncEnumerable<IPAddress> ResolveAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
{
if (this.domainIPAddressCollection.TryGetValue(domain, out var collection) && collection.Count > 0)
{
foreach (var address in collection.ToArray())
{
yield return address;
}
}
else
{
this.domainIPAddressCollection.TryAdd(domain, new IPAddressCollection());
await foreach (var adddress in this.ResolveCoreAsync(domain, cancellationToken))
{
yield return adddress;
}
}
}
/// <summary>
/// 对所有域名所有IP进行ping测试
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async Task PingAllDomainsAsync(CancellationToken cancellationToken)
{
foreach (var keyValue in this.domainIPAddressCollection)
{
var domain = keyValue.Key;
var collection = keyValue.Value;
await foreach (var address in this.ResolveCoreAsync(domain, cancellationToken))
{
collection.Add(address);
}
await collection.PingAllAsync();
}
}
/// <summary>
/// 解析域名
/// </summary>
/// <param name="domain">域名</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
private async IAsyncEnumerable<IPAddress> ResolveCoreAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var hashSet = new HashSet<IPAddress>(); var hashSet = new HashSet<IPAddress>();
foreach (var dns in this.GetDnsServers()) foreach (var dns in this.GetDnsServers())

View File

@ -11,20 +11,20 @@ namespace FastGithub.DomainResolve
sealed class DomainResolveHostedService : BackgroundService sealed class DomainResolveHostedService : BackgroundService
{ {
private readonly DnscryptProxy dnscryptProxy; private readonly DnscryptProxy dnscryptProxy;
private readonly DnsClient dnsClient; private readonly IDomainResolver domainResolver;
private readonly TimeSpan pingPeriodTimeSpan = TimeSpan.FromSeconds(10d); private readonly TimeSpan testPeriodTimeSpan = TimeSpan.FromSeconds (1d);
/// <summary> /// <summary>
/// 域名解析后台服务 /// 域名解析后台服务
/// </summary> /// </summary>
/// <param name="dnscryptProxy"></param> /// <param name="dnscryptProxy"></param>
/// <param name="dnsClient"></param> /// <param name="domainResolver"></param>
public DomainResolveHostedService( public DomainResolveHostedService(
DnscryptProxy dnscryptProxy, DnscryptProxy dnscryptProxy,
DnsClient dnsClient) IDomainResolver domainResolver)
{ {
this.dnscryptProxy = dnscryptProxy; this.dnscryptProxy = dnscryptProxy;
this.dnsClient = dnsClient; this.domainResolver = domainResolver;
} }
/// <summary> /// <summary>
@ -37,8 +37,8 @@ namespace FastGithub.DomainResolve
await this.dnscryptProxy.StartAsync(stoppingToken); await this.dnscryptProxy.StartAsync(stoppingToken);
while (stoppingToken.IsCancellationRequested == false) while (stoppingToken.IsCancellationRequested == false)
{ {
await this.dnsClient.PingAllDomainsAsync(stoppingToken); await this.domainResolver.TestAllEndPointsAsync(stoppingToken);
await Task.Delay(this.pingPeriodTimeSpan, stoppingToken); await Task.Delay(this.testPeriodTimeSpan, stoppingToken);
} }
} }

View File

@ -1,6 +1,12 @@
using FastGithub.Configuration; using FastGithub.Configuration;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net; using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -12,6 +18,7 @@ namespace FastGithub.DomainResolve
sealed class DomainResolver : IDomainResolver sealed class DomainResolver : IDomainResolver
{ {
private readonly DnsClient dnsClient; private readonly DnsClient dnsClient;
private readonly ConcurrentDictionary<DnsEndPoint, IPAddressTestResult> dnsEndPointAddressTestResult = new();
/// <summary> /// <summary>
/// 域名解析器 /// 域名解析器
@ -23,38 +30,117 @@ namespace FastGithub.DomainResolve
} }
/// <summary> /// <summary>
/// 加载 /// 加载
/// </summary> /// </summary>
/// <param name="domain">域名</param> /// <param name="domain">域名</param>
public void Prefetch(string domain) public void Prefetch(string domain)
{ {
this.dnsClient.Prefetch(domain); var endPoint = new DnsEndPoint(domain, 443);
this.dnsEndPointAddressTestResult.TryAdd(endPoint, IPAddressTestResult.Empty);
}
/// <summary>
/// 对所有节点进行测速
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async Task TestAllEndPointsAsync(CancellationToken cancellationToken)
{
foreach (var keyValue in this.dnsEndPointAddressTestResult)
{
if (keyValue.Value.IsEmpty || keyValue.Value.IsExpired)
{
var dnsEndPoint = keyValue.Key;
var addresses = new List<IPAddress>();
await foreach (var adddress in this.dnsClient.ResolveAsync(dnsEndPoint.Host, cancellationToken))
{
addresses.Add(adddress);
}
var addressTestResult = IPAddressTestResult.Empty;
if (addresses.Count == 1)
{
var addressElapseds = new[] { new IPAddressElapsed(addresses[0], TimeSpan.Zero) };
addressTestResult = new IPAddressTestResult(addressElapseds);
}
else if (addresses.Count > 1)
{
var tasks = addresses.Select(item => GetIPAddressElapsedAsync(item, dnsEndPoint.Port, cancellationToken));
var addressElapseds = await Task.WhenAll(tasks);
addressTestResult = new IPAddressTestResult(addressElapseds);
}
this.dnsEndPointAddressTestResult[dnsEndPoint] = addressTestResult;
}
}
}
/// <summary>
/// 获取连接耗时
/// </summary>
/// <param name="address"></param>
/// <param name="port"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
private static async Task<IPAddressElapsed> GetIPAddressElapsedAsync(IPAddress address, int port, CancellationToken cancellationToken)
{
var stopWatch = Stopwatch.StartNew();
try
{
using var timeoutTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10d));
using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
using var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(address, port, linkedTokenSource.Token);
return new IPAddressElapsed(address, stopWatch.Elapsed);
}
catch (Exception)
{
cancellationToken.ThrowIfCancellationRequested();
return new IPAddressElapsed(address, TimeSpan.MaxValue);
}
finally
{
stopWatch.Stop();
}
} }
/// <summary> /// <summary>
/// 解析ip /// 解析ip
/// </summary> /// </summary>
/// <param name="domain">域名</param> /// <param name="endPoint">节点</param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
public async Task<IPAddress> ResolveAnyAsync(string domain, CancellationToken cancellationToken = default) public async Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default)
{ {
await foreach (var address in this.ResolveAllAsync(domain, cancellationToken)) await foreach (var address in this.ResolveAllAsync(endPoint, cancellationToken))
{ {
return address; return address;
} }
throw new FastGithubException($"解析不到{domain}的IP"); throw new FastGithubException($"解析不到{endPoint.Host}的IP");
} }
/// <summary> /// <summary>
/// 解析域名 /// 解析域名
/// </summary> /// </summary>
/// <param name="domain">域名</param> /// <param name="endPoint">节点</param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
public IAsyncEnumerable<IPAddress> ResolveAllAsync(string domain, CancellationToken cancellationToken) public async IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
{ {
return this.dnsClient.ResolveAsync(domain, cancellationToken); if (this.dnsEndPointAddressTestResult.TryGetValue(endPoint, out var speedTestResult) && speedTestResult.IsEmpty == false)
{
foreach (var addressElapsed in speedTestResult.AddressElapseds)
{
yield return addressElapsed.Adddress;
}
}
else
{
this.dnsEndPointAddressTestResult.TryAdd(endPoint, IPAddressTestResult.Empty);
await foreach (var adddress in this.dnsClient.ResolveAsync(endPoint.Host, cancellationToken))
{
yield return adddress;
}
}
} }
} }
} }

View File

@ -11,25 +11,32 @@ namespace FastGithub.DomainResolve
public interface IDomainResolver public interface IDomainResolver
{ {
/// <summary> /// <summary>
/// 加载 /// 加载
/// </summary> /// </summary>
/// <param name="domain">域名</param> /// <param name="domain">域名</param>
void Prefetch(string domain); void Prefetch(string domain);
/// <summary> /// <summary>
/// 解析ip /// 对所有节点进行测速
/// </summary> /// </summary>
/// <param name="domain">域名</param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
Task<IPAddress> ResolveAnyAsync(string domain, CancellationToken cancellationToken = default); Task TestAllEndPointsAsync(CancellationToken cancellationToken);
/// <summary>
/// 解析ip
/// </summary>
/// <param name="endPoint">节点</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
/// <summary> /// <summary>
/// 解析所有ip /// 解析所有ip
/// </summary> /// </summary>
/// <param name="domain">域名</param> /// <param name="endPoint">节点</param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
IAsyncEnumerable<IPAddress> ResolveAllAsync(string domain, CancellationToken cancellationToken = default); IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
} }
} }

View File

@ -1,165 +0,0 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.NetworkInformation;
using System.Threading.Tasks;
namespace FastGithub.DomainResolve
{
/// <summary>
/// IPAddress集合
/// </summary>
[DebuggerDisplay("Count = {Count}")]
sealed class IPAddressCollection
{
private readonly object syncRoot = new();
private readonly HashSet<IPAddressItem> hashSet = new();
/// <summary>
/// 获取元素数量
/// </summary>
public int Count => this.hashSet.Count;
/// <summary>
/// 添加元素
/// </summary>
/// <param name="address"></param>
/// <returns></returns>
public bool Add(IPAddress address)
{
lock (this.syncRoot)
{
return this.hashSet.Add(new IPAddressItem(address));
}
}
/// <summary>
/// 转后为数组
/// </summary>
/// <returns></returns>
public IPAddress[] ToArray()
{
lock (this.syncRoot)
{
return this.hashSet.OrderBy(item => item.PingElapsed).Select(item => item.Address).ToArray();
}
}
/// <summary>
/// Ping所有IP
/// </summary>
/// <returns></returns>
public async Task PingAllAsync()
{
foreach (var item in this.ToItemArray())
{
await item.PingAsync();
}
}
/// <summary>
/// 转换为数组
/// </summary>
/// <returns></returns>
private IPAddressItem[] ToItemArray()
{
lock (this.syncRoot)
{
return this.hashSet.ToArray();
}
}
/// <summary>
/// IP地址项
/// </summary>
[DebuggerDisplay("Address = {Address}, PingElapsed = {PingElapsed}")]
private class IPAddressItem : IEquatable<IPAddressItem>
{
/// <summary>
/// Ping的时间点
/// </summary>
private int? pingTicks;
/// <summary>
/// 地址
/// </summary>
public IPAddress Address { get; }
/// <summary>
/// Ping耗时
/// </summary>
public TimeSpan PingElapsed { get; private set; } = TimeSpan.MaxValue;
/// <summary>
/// IP地址项
/// </summary>
/// <param name="address"></param>
public IPAddressItem(IPAddress address)
{
this.Address = address;
}
/// <summary>
/// 发起ping请求
/// </summary>
/// <returns></returns>
public async Task PingAsync()
{
if (this.NeedToPing() == false)
{
return;
}
try
{
using var ping = new Ping();
var reply = await ping.SendPingAsync(this.Address);
this.PingElapsed = reply.Status == IPStatus.Success
? TimeSpan.FromMilliseconds(reply.RoundtripTime)
: TimeSpan.MaxValue;
}
catch (Exception)
{
this.PingElapsed = TimeSpan.MaxValue;
}
finally
{
this.pingTicks = Environment.TickCount;
}
}
/// <summary>
/// 是否需要ping
/// 5分钟内只ping一次
/// </summary>
/// <returns></returns>
private bool NeedToPing()
{
var ticks = this.pingTicks;
if (ticks == null)
{
return true;
}
var pingTimeSpan = TimeSpan.FromMilliseconds(Environment.TickCount - ticks.Value);
return pingTimeSpan > TimeSpan.FromMinutes(5d);
}
public bool Equals(IPAddressItem? other)
{
return other != null && other.Address.Equals(this.Address);
}
public override bool Equals(object? obj)
{
return obj is IPAddressItem other && this.Equals(other);
}
public override int GetHashCode()
{
return this.Address.GetHashCode();
}
}
}
}

View File

@ -0,0 +1,34 @@
using System;
using System.Diagnostics;
using System.Net;
namespace FastGithub.DomainResolve
{
/// <summary>
/// IP连接耗时
/// </summary>
[DebuggerDisplay("Adddress={Adddress} Elapsed={Elapsed}")]
struct IPAddressElapsed
{
/// <summary>
/// 获取IP地址
/// </summary>
public IPAddress Adddress { get; }
/// <summary>
/// 获取连接耗时
/// </summary>
public TimeSpan Elapsed { get; }
/// <summary>
/// IP连接耗时
/// </summary>
/// <param name="adddress"></param>
/// <param name="elapsed"></param>
public IPAddressElapsed(IPAddress adddress, TimeSpan elapsed)
{
this.Adddress = adddress;
this.Elapsed = elapsed;
}
}
}

View File

@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.Linq;
namespace FastGithub.DomainResolve
{
/// <summary>
/// IP测速结果
/// </summary>
sealed class IPAddressTestResult
{
private static readonly TimeSpan lifeTime = TimeSpan.FromMinutes(2d);
private readonly int creationTickCount = Environment.TickCount;
/// <summary>
/// 获取空的
/// </summary>
public static IPAddressTestResult Empty = new(Array.Empty<IPAddressElapsed>());
/// <summary>
/// 获取是否为空
/// </summary>
public bool IsEmpty => this.AddressElapseds.Length == 0;
/// <summary>
/// 获取是否已过期
/// </summary>
public bool IsExpired => lifeTime < TimeSpan.FromMilliseconds(Environment.TickCount - this.creationTickCount);
/// <summary>
/// 获取测速结果
/// </summary>
public IPAddressElapsed[] AddressElapseds { get; }
/// <summary>
/// 测速结果
/// </summary>
/// <param name="result"></param>
public IPAddressTestResult(IEnumerable<IPAddressElapsed> addressElapseds)
{
this.AddressElapseds = addressElapseds.OrderBy(item => item.Elapsed).ToArray();
}
}
}

View File

@ -186,7 +186,7 @@ namespace FastGithub.Http
} }
else else
{ {
await foreach (var item in this.domainResolver.ResolveAllAsync(dnsEndPoint.Host, cancellationToken)) await foreach (var item in this.domainResolver.ResolveAllAsync(dnsEndPoint, cancellationToken))
{ {
yield return new IPEndPoint(item, dnsEndPoint.Port); yield return new IPEndPoint(item, dnsEndPoint.Port);
} }

View File

@ -158,7 +158,7 @@ namespace FastGithub.HttpServer
} }
// 不使用系统dns // 不使用系统dns
address = await this.domainResolver.ResolveAnyAsync(targetHost); address = await this.domainResolver.ResolveAnyAsync(new DnsEndPoint(targetHost, targetPort));
return new IPEndPoint(address, targetPort); return new IPEndPoint(address, targetPort);
} }

View File

@ -1,6 +1,7 @@
using FastGithub.DomainResolve; using FastGithub.DomainResolve;
using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections;
using System.IO.Pipelines; using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -12,8 +13,7 @@ namespace FastGithub.HttpServer
sealed class SshReverseProxyHandler : ConnectionHandler sealed class SshReverseProxyHandler : ConnectionHandler
{ {
private readonly IDomainResolver domainResolver; private readonly IDomainResolver domainResolver;
private const string SSH_GITHUB_COM = "ssh.github.com"; private readonly DnsEndPoint sshOverHttpsEndPoint = new("ssh.github.com", 443);
private const int SSH_OVER_HTTPS_PORT = 443;
/// <summary> /// <summary>
/// github的ssh代理处理者 /// github的ssh代理处理者
@ -31,9 +31,9 @@ namespace FastGithub.HttpServer
/// <returns></returns> /// <returns></returns>
public override async Task OnConnectedAsync(ConnectionContext context) public override async Task OnConnectedAsync(ConnectionContext context)
{ {
var address = await this.domainResolver.ResolveAnyAsync(SSH_GITHUB_COM); var address = await this.domainResolver.ResolveAnyAsync(this.sshOverHttpsEndPoint);
using var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); using var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(address, SSH_OVER_HTTPS_PORT); await socket.ConnectAsync(address, this.sshOverHttpsEndPoint.Port);
var targetStream = new NetworkStream(socket, ownsSocket: false); var targetStream = new NetworkStream(socket, ownsSocket: false);
var task1 = targetStream.CopyToAsync(context.Transport.Output); var task1 = targetStream.CopyToAsync(context.Transport.Output);