IP测速使用缓存机制

This commit is contained in:
老九 2021-11-19 22:57:28 +08:00
parent 1f0e7dc456
commit ea3206d20e
5 changed files with 146 additions and 142 deletions

View File

@ -16,30 +16,33 @@ namespace FastGithub.DomainResolve
/// </summary>
sealed class DomainResolver : IDomainResolver
{
const int MAX_ADDRESS_COUNT = 4;
private readonly DnsClient dnsClient;
private readonly DomainPersistence persistence;
private readonly IPAddressStatusService statusService;
private readonly ILogger<DomainResolver> logger;
private readonly ConcurrentDictionary<DnsEndPoint, IPAddressElapsed[]> dnsEndPointAddressElapseds = new();
private readonly ConcurrentDictionary<DnsEndPoint, IPAddress[]> dnsEndPointAddress = new();
/// <summary>
/// 域名解析器
/// </summary>
/// <param name="dnsClient"></param>
/// <param name="persistence"></param>
/// <param name="statusService"></param>
/// <param name="logger"></param>
public DomainResolver(
DnsClient dnsClient,
DomainPersistence persistence,
IPAddressStatusService statusService,
ILogger<DomainResolver> logger)
{
this.dnsClient = dnsClient;
this.persistence = persistence;
this.statusService = statusService;
this.logger = logger;
foreach (var endPoint in persistence.ReadDnsEndPoints())
{
this.dnsEndPointAddressElapseds.TryAdd(endPoint, Array.Empty<IPAddressElapsed>());
this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty<IPAddress>());
}
}
@ -67,18 +70,18 @@ namespace FastGithub.DomainResolve
/// <returns></returns>
public async IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
{
if (this.dnsEndPointAddressElapseds.TryGetValue(endPoint, out var addressElapseds) && addressElapseds.Length > 0)
if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0)
{
foreach (var addressElapsed in addressElapseds)
foreach (var address in addresses)
{
yield return addressElapsed.Adddress;
yield return address;
}
}
else
{
if (this.dnsEndPointAddressElapseds.TryAdd(endPoint, Array.Empty<IPAddressElapsed>()))
if (this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty<IPAddress>()))
{
await this.persistence.WriteDnsEndPointsAsync(this.dnsEndPointAddressElapseds.Keys, cancellationToken);
await this.persistence.WriteDnsEndPointsAsync(this.dnsEndPointAddress.Keys, cancellationToken);
}
await foreach (var adddress in this.dnsClient.ResolveAsync(endPoint, fastSort: true, cancellationToken))
@ -95,44 +98,29 @@ namespace FastGithub.DomainResolve
/// <returns></returns>
public async Task TestAllEndPointsAsync(CancellationToken cancellationToken)
{
foreach (var keyValue in this.dnsEndPointAddressElapseds)
foreach (var keyValue in this.dnsEndPointAddress)
{
var oldValues = keyValue.Value;
if (oldValues.Length >= MAX_ADDRESS_COUNT)
{
if (oldValues.Any(item => item.NeedUpdateElapsed()) == false)
{
continue;
}
}
var dnsEndPoint = keyValue.Key;
var hashSet = new HashSet<IPAddressElapsed>(oldValues);
await foreach (var adddress in this.dnsClient.ResolveAsync(dnsEndPoint, fastSort: false, cancellationToken))
var oldAddresses = keyValue.Value;
var hashSet = new HashSet<IPAddress>(oldAddresses);
await foreach (var address in this.dnsClient.ResolveAsync(dnsEndPoint, fastSort: false, cancellationToken))
{
hashSet.Add(new IPAddressElapsed(adddress, dnsEndPoint.Port));
hashSet.Add(address);
}
// 两个以上才进行测速排序
if (hashSet.Count > 1)
{
var updateTasks = hashSet
.Where(item => item.NeedUpdateElapsed())
.Select(item => item.UpdateElapsedAsync(cancellationToken));
await Task.WhenAll(updateTasks);
}
var newValues = hashSet
var statusArray = await this.statusService.GetParallelAsync(hashSet, dnsEndPoint.Port, cancellationToken);
var newAddresses = statusArray
.Where(item => item.Elapsed < TimeSpan.MaxValue)
.OrderBy(item => item.Elapsed)
.Take(count: MAX_ADDRESS_COUNT)
.Select(item => item.Address)
.ToArray();
if (oldValues.SequenceEqual(newValues) == false)
if (oldAddresses.SequenceEqual(newAddresses) == false)
{
this.dnsEndPointAddressElapseds[dnsEndPoint] = newValues;
this.dnsEndPointAddress[dnsEndPoint] = newAddresses;
var addressArray = string.Join(", ", newValues.Select(item => item.ToString()));
var addressArray = string.Join(", ", newAddresses.Select(item => item.ToString()));
this.logger.LogInformation($"{dnsEndPoint.Host}->[{addressArray}]");
}
}

View File

@ -1,107 +0,0 @@
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
namespace FastGithub.DomainResolve
{
/// <summary>
/// IP延时记录
/// 5分钟有效期
/// 5秒连接超时
/// </summary>
[DebuggerDisplay("Adddress={Adddress} Elapsed={Elapsed}")]
sealed class IPAddressElapsed : IEquatable<IPAddressElapsed>
{
private static readonly long maxLifeTime = 5 * 60 * 1000;
private static readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(5d);
private long lastTestTickCount = 0L;
/// <summary>
/// 获取IP地址
/// </summary>
public IPAddress Adddress { get; }
/// <summary>
/// 获取端口
/// </summary>
public int Port { get; }
/// <summary>
/// 获取延时
/// </summary>
public TimeSpan Elapsed { get; private set; }
/// <summary>
/// IP延时
/// </summary>
/// <param name="adddress"></param>
/// <param name="port"></param>
public IPAddressElapsed(IPAddress adddress, int port)
{
this.Adddress = adddress;
this.Port = port;
}
/// <summary>
/// 是否需求更新延时
/// </summary>
/// <returns></returns>
public bool NeedUpdateElapsed()
{
return Environment.TickCount64 - this.lastTestTickCount > maxLifeTime;
}
/// <summary>
/// 更新连接耗时
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async Task UpdateElapsedAsync(CancellationToken cancellationToken)
{
var stopWatch = Stopwatch.StartNew();
try
{
using var timeoutTokenSource = new CancellationTokenSource(connectTimeout);
using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
using var socket = new Socket(this.Adddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(this.Adddress, this.Port, linkedTokenSource.Token);
this.Elapsed = stopWatch.Elapsed;
}
catch (Exception)
{
cancellationToken.ThrowIfCancellationRequested();
this.Elapsed = TimeSpan.MaxValue;
}
finally
{
this.lastTestTickCount = Environment.TickCount64;
stopWatch.Stop();
}
}
public bool Equals(IPAddressElapsed? other)
{
return other != null && other.Adddress.Equals(this.Adddress);
}
public override bool Equals([NotNullWhen(true)] object? obj)
{
return obj is IPAddressElapsed other && this.Equals(other);
}
public override int GetHashCode()
{
return this.Adddress.GetHashCode();
}
public override string ToString()
{
return this.Adddress.ToString();
}
}
}

View File

@ -0,0 +1,34 @@
using System;
using System.Net;
namespace FastGithub.DomainResolve
{
/// <summary>
/// 表示IP的状态
/// </summary>
struct IPAddressStatus
{
/// <summary>
/// 获取IP地址
/// </summary>
public IPAddress Address { get; }
/// <summary>
/// 获取延时
/// 当连接失败时值为MaxValue
/// </summary>
public TimeSpan Elapsed { get; }
/// <summary>
/// IP的状态
/// </summary>
/// <param name="address"></param>
/// <param name="elapsed"></param>
public IPAddressStatus(IPAddress address, TimeSpan elapsed)
{
this.Address = address;
this.Elapsed = elapsed;
}
}
}

View File

@ -0,0 +1,88 @@
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
namespace FastGithub.DomainResolve
{
/// <summary>
/// IP状态服务
/// 连接成功的IP缓存5分钟
/// 连接失败的IP缓存2分钟
/// </summary>
sealed class IPAddressStatusService
{
private readonly TimeSpan activeTTL = TimeSpan.FromMinutes(5d);
private readonly TimeSpan negativeTTL = TimeSpan.FromMinutes(2d);
private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(5d);
private readonly IMemoryCache statusCache = new MemoryCache(Options.Create(new MemoryCacheOptions()));
/// <summary>
/// 并行获取多个IP的状态
/// </summary>
/// <param name="addresses"></param>
/// <param name="port"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public Task<IPAddressStatus[]> GetParallelAsync(IEnumerable<IPAddress> addresses, int port, CancellationToken cancellationToken)
{
var statusTasks = addresses.Select(item => this.GetAsync(item, port, cancellationToken));
return Task.WhenAll(statusTasks);
}
/// <summary>
/// 获取IP状态
/// </summary>
/// <param name="address"></param>
/// <param name="port"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async Task<IPAddressStatus> GetAsync(IPAddress address, int port, CancellationToken cancellationToken)
{
var endPoint = new IPEndPoint(address, port);
if (this.statusCache.TryGetValue<IPAddressStatus>(endPoint, out var status))
{
return status;
}
status = await this.GetAddressStatusAsync(endPoint, cancellationToken);
var ttl = status.Elapsed < TimeSpan.MaxValue ? this.activeTTL : this.negativeTTL;
return this.statusCache.Set(endPoint, status, ttl);
}
/// <summary>
/// 获取IP状态
/// </summary>
/// <param name="endPoint"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
private async Task<IPAddressStatus> GetAddressStatusAsync(IPEndPoint endPoint, CancellationToken cancellationToken)
{
var stopWatch = Stopwatch.StartNew();
try
{
using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
using var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(endPoint, linkedTokenSource.Token);
return new IPAddressStatus(endPoint.Address, stopWatch.Elapsed);
}
catch (Exception)
{
cancellationToken.ThrowIfCancellationRequested();
return new IPAddressStatus(endPoint.Address, TimeSpan.MaxValue);
}
finally
{
stopWatch.Stop();
}
}
}
}

View File

@ -19,6 +19,7 @@ namespace FastGithub
services.TryAddSingleton<DnsClient>();
services.TryAddSingleton<DnscryptProxy>();
services.TryAddSingleton<DomainPersistence>();
services.TryAddSingleton<IPAddressStatusService>();
services.TryAddSingleton<IDomainResolver, DomainResolver>();
services.AddHostedService<DomainResolveHostedService>();
return services;