应用CancellationToken

This commit is contained in:
陈国伟 2021-11-24 08:57:47 +08:00
parent 1ed3dee689
commit 2f9827a5fd
5 changed files with 35 additions and 23 deletions

View File

@ -5,13 +5,13 @@ namespace FastGithub.HttpServer
/// <summary>
/// github的git代理处理者
/// </summary>
sealed class GitReverseProxyHandler : TcpReverseProxyHandler
sealed class GithubGitReverseProxyHandler : TcpReverseProxyHandler
{
/// <summary>
/// github的git代理处理者
/// </summary>
/// <param name="domainResolver"></param>
public GitReverseProxyHandler(IDomainResolver domainResolver)
public GithubGitReverseProxyHandler(IDomainResolver domainResolver)
: base(domainResolver, new("github.com", 9418))
{
}

View File

@ -5,13 +5,13 @@ namespace FastGithub.HttpServer
/// <summary>
/// github的ssh代理处理者
/// </summary>
sealed class SshReverseProxyHandler : TcpReverseProxyHandler
sealed class GithubSshReverseProxyHandler : TcpReverseProxyHandler
{
/// <summary>
/// github的ssh代理处理者
/// </summary>
/// <param name="domainResolver"></param>
public SshReverseProxyHandler(IDomainResolver domainResolver)
public GithubSshReverseProxyHandler(IDomainResolver domainResolver)
: base(domainResolver, new("github.com", 22))
{
}

View File

@ -12,6 +12,7 @@ using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@ -88,7 +89,8 @@ namespace FastGithub.HttpServer
}
else if (context.Request.Method == HttpMethods.Connect)
{
using var connection = await this.CreateConnectionAsync(host);
var cancellationToken = context.RequestAborted;
using var connection = await this.CreateConnectionAsync(host, cancellationToken);
var responseFeature = context.Features.Get<IHttpResponseFeature>();
if (responseFeature != null)
{
@ -100,7 +102,9 @@ namespace FastGithub.HttpServer
var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport;
if (transport != null)
{
await Task.WhenAny(connection.CopyToAsync(transport.Output), transport.Input.CopyToAsync(connection));
var task1 = connection.CopyToAsync(transport.Output, cancellationToken);
var task2 = transport.Input.CopyToAsync(connection, cancellationToken);
await Task.WhenAny(task1, task2);
}
}
else
@ -156,23 +160,26 @@ namespace FastGithub.HttpServer
/// 创建连接
/// </summary>
/// <param name="host"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="AggregateException"></exception>
private async Task<Stream> CreateConnectionAsync(HostString host)
private async Task<Stream> CreateConnectionAsync(HostString host, CancellationToken cancellationToken)
{
var innerExceptions = new List<Exception>();
await foreach (var endPoint in this.GetTargetEndPointsAsync(host))
await foreach (var endPoint in this.GetUpstreamEndPointsAsync(host, cancellationToken))
{
var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try
{
using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
await socket.ConnectAsync(endPoint, timeoutTokenSource.Token);
using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
await socket.ConnectAsync(endPoint, linkedTokenSource.Token);
return new NetworkStream(socket, ownsSocket: false);
}
catch (Exception ex)
{
socket.Dispose();
cancellationToken.ThrowIfCancellationRequested();
innerExceptions.Add(ex);
}
}
@ -183,8 +190,9 @@ namespace FastGithub.HttpServer
/// 获取目标终节点
/// </summary>
/// <param name="host"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
private async IAsyncEnumerable<EndPoint> GetTargetEndPointsAsync(HostString host)
private async IAsyncEnumerable<EndPoint> GetUpstreamEndPointsAsync(HostString host, [EnumeratorCancellation] CancellationToken cancellationToken)
{
var targetHost = host.Host;
var targetPort = host.Port ?? HTTPS_PORT;
@ -215,7 +223,7 @@ namespace FastGithub.HttpServer
}
var dnsEndPoint = new DnsEndPoint(targetHost, targetPort);
await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint))
await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint, cancellationToken))
{
yield return new IPEndPoint(item, targetPort);
}

View File

@ -48,7 +48,7 @@ namespace FastGithub
}
/// <summary>
/// 尝试监听ssh反向代理
/// 监听ssh反向代理
/// </summary>
/// <param name="kestrel"></param>
public static void ListenSshReverseProxy(this KestrelServerOptions kestrel)
@ -57,14 +57,14 @@ namespace FastGithub
kestrel.ListenLocalhost(sshPort, listen =>
{
listen.UseFlowAnalyze();
listen.UseConnectionHandler<SshReverseProxyHandler>();
listen.UseConnectionHandler<GithubSshReverseProxyHandler>();
});
kestrel.GetLogger().LogInformation($"已监听ssh://localhost:{sshPort}github的ssh反向代理服务启动完成");
}
/// <summary>
/// 尝试监听git反向代理
/// 监听git反向代理
/// </summary>
/// <param name="kestrel"></param>
public static void ListenGitReverseProxy(this KestrelServerOptions kestrel)
@ -73,14 +73,14 @@ namespace FastGithub
kestrel.ListenLocalhost(gitPort, listen =>
{
listen.UseFlowAnalyze();
listen.UseConnectionHandler<GitReverseProxyHandler>();
listen.UseConnectionHandler<GithubGitReverseProxyHandler>();
});
kestrel.GetLogger().LogInformation($"已监听git://localhost:{gitPort}github的git反向代理服务启动完成");
}
/// <summary>
/// 尝试监听http反向代理
/// 监听http反向代理
/// </summary>
/// <param name="kestrel"></param>
public static void ListenHttpReverseProxy(this KestrelServerOptions kestrel)

View File

@ -32,38 +32,42 @@ namespace FastGithub.HttpServer
}
/// <summary>
/// ssh连接后
/// tcp连接后
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public override async Task OnConnectedAsync(ConnectionContext context)
{
using var connection = await this.CreateConnectionAsync();
var task1 = connection.CopyToAsync(context.Transport.Output);
var task2 = context.Transport.Input.CopyToAsync(connection);
var cancellationToken = context.ConnectionClosed;
using var connection = await this.CreateConnectionAsync(cancellationToken);
var task1 = connection.CopyToAsync(context.Transport.Output, cancellationToken);
var task2 = context.Transport.Input.CopyToAsync(connection, cancellationToken);
await Task.WhenAny(task1, task2);
}
/// <summary>
/// 创建连接
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="AggregateException"></exception>
private async Task<Stream> CreateConnectionAsync()
private async Task<Stream> CreateConnectionAsync(CancellationToken cancellationToken)
{
var innerExceptions = new List<Exception>();
await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint))
await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint, cancellationToken))
{
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
await socket.ConnectAsync(address, this.endPoint.Port, timeoutTokenSource.Token);
using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
await socket.ConnectAsync(address, this.endPoint.Port, linkedTokenSource.Token);
return new NetworkStream(socket, ownsSocket: false);
}
catch (Exception ex)
{
socket.Dispose();
cancellationToken.ThrowIfCancellationRequested();
innerExceptions.Add(ex);
}
}