简化DomainResolver

This commit is contained in:
陈国伟 2021-09-26 17:39:36 +08:00
parent a0cb04cec3
commit 930a8b624a
5 changed files with 28 additions and 75 deletions

View File

@ -1,11 +1,6 @@
using FastGithub.Configuration;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@ -21,11 +16,6 @@ namespace FastGithub.DomainResolve
private readonly FastGithubConfig fastGithubConfig;
private readonly DnsClient dnsClient;
private readonly ConcurrentDictionary<IPEndPoint, SemaphoreSlim> semaphoreSlims = new();
private readonly IMemoryCache ipEndPointAvailableCache = new MemoryCache(Options.Create(new MemoryCacheOptions()));
private readonly TimeSpan ipEndPointExpiration = TimeSpan.FromMinutes(2d);
private readonly TimeSpan ipEndPointConnectTimeout = TimeSpan.FromSeconds(5d);
/// <summary>
/// 域名解析器
/// </summary>
@ -43,72 +33,27 @@ namespace FastGithub.DomainResolve
}
/// <summary>
/// 解析可用的ip
/// 解析ip
/// </summary>
/// <param name="endPoint">远程节点</param>
/// <param name="domain">域名</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async Task<IPAddress> ResolveAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default)
public async Task<IPAddress?> ResolveAsync(string domain, CancellationToken cancellationToken = default)
{
await foreach (var address in this.ResolveAsync(endPoint.Host, cancellationToken))
await foreach (var address in this.ResolveAllAsync(domain, cancellationToken))
{
if (await this.IsAvailableAsync(new IPEndPoint(address, endPoint.Port), cancellationToken))
{
return address;
}
return address;
}
throw new FastGithubException($"解析不到{endPoint.Host}可用的IP");
return default;
}
/// <summary>
/// 验证远程节点是否可连接
/// </summary>
/// <param name="ipEndPoint"></param>
/// <param name="cancellationToken"></param>
/// <exception cref="OperationCanceledException"></exception>
/// <returns></returns>
private async Task<bool> IsAvailableAsync(IPEndPoint ipEndPoint, CancellationToken cancellationToken)
{
var semaphore = this.semaphoreSlims.GetOrAdd(ipEndPoint, _ => new SemaphoreSlim(1, 1));
try
{
await semaphore.WaitAsync(CancellationToken.None);
if (this.ipEndPointAvailableCache.TryGetValue<bool>(ipEndPoint, out var available))
{
return available;
}
try
{
using var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
using var timeoutTokenSource = new CancellationTokenSource(this.ipEndPointConnectTimeout);
using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
await socket.ConnectAsync(ipEndPoint, linkedTokenSource.Token);
available = true;
}
catch (Exception)
{
cancellationToken.ThrowIfCancellationRequested();
available = false;
}
this.ipEndPointAvailableCache.Set(ipEndPoint, available, ipEndPointExpiration);
return available;
}
finally
{
semaphore.Release();
}
}
/// <summary>
/// 解析域名
/// </summary>
/// <param name="domain">域名</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async IAsyncEnumerable<IPAddress> ResolveAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
public async IAsyncEnumerable<IPAddress> ResolveAllAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
{
var hashSet = new HashSet<IPAddress>();
foreach (var dns in this.GetDnsServers())

View File

@ -11,12 +11,12 @@ namespace FastGithub.DomainResolve
public interface IDomainResolver
{
/// <summary>
/// 解析可用的ip
/// 解析ip
/// </summary>
/// <param name="endPoint">远程节点</param>
/// <param name="domain">域名</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
Task<IPAddress> ResolveAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
Task<IPAddress?> ResolveAsync(string domain, CancellationToken cancellationToken = default);
/// <summary>
/// 解析所有ip
@ -24,6 +24,6 @@ namespace FastGithub.DomainResolve
/// <param name="domain">域名</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
IAsyncEnumerable<IPAddress> ResolveAsync(string domain, CancellationToken cancellationToken = default);
IAsyncEnumerable<IPAddress> ResolveAllAsync(string domain, CancellationToken cancellationToken = default);
}
}

View File

@ -175,7 +175,7 @@ namespace FastGithub.Http
}
else
{
await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint.Host, cancellationToken))
await foreach (var item in this.domainResolver.ResolveAllAsync(dnsEndPoint.Host, cancellationToken))
{
yield return new IPEndPoint(item, dnsEndPoint.Port);
}

View File

@ -152,8 +152,10 @@ namespace FastGithub.HttpServer
}
// dns优选
address = await this.domainResolver.ResolveAsync(new DnsEndPoint(targetHost, targetPort));
return new IPEndPoint(address, targetPort);
address = await this.domainResolver.ResolveAsync(targetHost);
return address == null
? throw new FastGithubException($"解析不到{targetHost}的IP")
: new IPEndPoint(address, targetPort);
}
/// <summary>

View File

@ -1,7 +1,7 @@
using FastGithub.DomainResolve;
using FastGithub.Configuration;
using FastGithub.DomainResolve;
using Microsoft.AspNetCore.Connections;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Threading.Tasks;
@ -13,7 +13,8 @@ namespace FastGithub.HttpServer
sealed class SshReverseProxyHandler : ConnectionHandler
{
private readonly IDomainResolver domainResolver;
private readonly DnsEndPoint githubSshEndPoint = new("ssh.github.com", 443);
private const string SSH_GITHUB_COM = "ssh.github.com";
private const int SSH_OVER_HTTPS_PORT = 443;
/// <summary>
/// github的ssh代理处理者
@ -31,9 +32,14 @@ namespace FastGithub.HttpServer
/// <returns></returns>
public override async Task OnConnectedAsync(ConnectionContext context)
{
var address = await this.domainResolver.ResolveAsync(this.githubSshEndPoint);
using var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(address, this.githubSshEndPoint.Port);
var address = await this.domainResolver.ResolveAsync(SSH_GITHUB_COM);
if (address == null)
{
throw new FastGithubException($"解析不到{SSH_GITHUB_COM}的IP");
}
using var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(address, SSH_OVER_HTTPS_PORT);
var targetStream = new NetworkStream(socket, ownsSocket: false);
var task1 = targetStream.CopyToAsync(context.Transport.Output);