应用CancellationToken

This commit is contained in:
陈国伟 2021-11-24 08:57:47 +08:00
parent 1ed3dee689
commit 2f9827a5fd
5 changed files with 35 additions and 23 deletions

View File

@ -5,13 +5,13 @@ namespace FastGithub.HttpServer
/// <summary> /// <summary>
/// github的git代理处理者 /// github的git代理处理者
/// </summary> /// </summary>
sealed class GitReverseProxyHandler : TcpReverseProxyHandler sealed class GithubGitReverseProxyHandler : TcpReverseProxyHandler
{ {
/// <summary> /// <summary>
/// github的git代理处理者 /// github的git代理处理者
/// </summary> /// </summary>
/// <param name="domainResolver"></param> /// <param name="domainResolver"></param>
public GitReverseProxyHandler(IDomainResolver domainResolver) public GithubGitReverseProxyHandler(IDomainResolver domainResolver)
: base(domainResolver, new("github.com", 9418)) : base(domainResolver, new("github.com", 9418))
{ {
} }

View File

@ -5,13 +5,13 @@ namespace FastGithub.HttpServer
/// <summary> /// <summary>
/// github的ssh代理处理者 /// github的ssh代理处理者
/// </summary> /// </summary>
sealed class SshReverseProxyHandler : TcpReverseProxyHandler sealed class GithubSshReverseProxyHandler : TcpReverseProxyHandler
{ {
/// <summary> /// <summary>
/// github的ssh代理处理者 /// github的ssh代理处理者
/// </summary> /// </summary>
/// <param name="domainResolver"></param> /// <param name="domainResolver"></param>
public SshReverseProxyHandler(IDomainResolver domainResolver) public GithubSshReverseProxyHandler(IDomainResolver domainResolver)
: base(domainResolver, new("github.com", 22)) : base(domainResolver, new("github.com", 22))
{ {
} }

View File

@ -12,6 +12,7 @@ using System.Net;
using System.Net.Http; using System.Net.Http;
using System.Net.Sockets; using System.Net.Sockets;
using System.Reflection; using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -88,7 +89,8 @@ namespace FastGithub.HttpServer
} }
else if (context.Request.Method == HttpMethods.Connect) else if (context.Request.Method == HttpMethods.Connect)
{ {
using var connection = await this.CreateConnectionAsync(host); var cancellationToken = context.RequestAborted;
using var connection = await this.CreateConnectionAsync(host, cancellationToken);
var responseFeature = context.Features.Get<IHttpResponseFeature>(); var responseFeature = context.Features.Get<IHttpResponseFeature>();
if (responseFeature != null) if (responseFeature != null)
{ {
@ -100,7 +102,9 @@ namespace FastGithub.HttpServer
var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport; var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport;
if (transport != null) if (transport != null)
{ {
await Task.WhenAny(connection.CopyToAsync(transport.Output), transport.Input.CopyToAsync(connection)); var task1 = connection.CopyToAsync(transport.Output, cancellationToken);
var task2 = transport.Input.CopyToAsync(connection, cancellationToken);
await Task.WhenAny(task1, task2);
} }
} }
else else
@ -156,23 +160,26 @@ namespace FastGithub.HttpServer
/// 创建连接 /// 创建连接
/// </summary> /// </summary>
/// <param name="host"></param> /// <param name="host"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="AggregateException"></exception> /// <exception cref="AggregateException"></exception>
private async Task<Stream> CreateConnectionAsync(HostString host) private async Task<Stream> CreateConnectionAsync(HostString host, CancellationToken cancellationToken)
{ {
var innerExceptions = new List<Exception>(); var innerExceptions = new List<Exception>();
await foreach (var endPoint in this.GetTargetEndPointsAsync(host)) await foreach (var endPoint in this.GetUpstreamEndPointsAsync(host, cancellationToken))
{ {
var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try try
{ {
using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout); using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
await socket.ConnectAsync(endPoint, timeoutTokenSource.Token); using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
await socket.ConnectAsync(endPoint, linkedTokenSource.Token);
return new NetworkStream(socket, ownsSocket: false); return new NetworkStream(socket, ownsSocket: false);
} }
catch (Exception ex) catch (Exception ex)
{ {
socket.Dispose(); socket.Dispose();
cancellationToken.ThrowIfCancellationRequested();
innerExceptions.Add(ex); innerExceptions.Add(ex);
} }
} }
@ -183,8 +190,9 @@ namespace FastGithub.HttpServer
/// 获取目标终节点 /// 获取目标终节点
/// </summary> /// </summary>
/// <param name="host"></param> /// <param name="host"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
private async IAsyncEnumerable<EndPoint> GetTargetEndPointsAsync(HostString host) private async IAsyncEnumerable<EndPoint> GetUpstreamEndPointsAsync(HostString host, [EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var targetHost = host.Host; var targetHost = host.Host;
var targetPort = host.Port ?? HTTPS_PORT; var targetPort = host.Port ?? HTTPS_PORT;
@ -215,7 +223,7 @@ namespace FastGithub.HttpServer
} }
var dnsEndPoint = new DnsEndPoint(targetHost, targetPort); var dnsEndPoint = new DnsEndPoint(targetHost, targetPort);
await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint)) await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint, cancellationToken))
{ {
yield return new IPEndPoint(item, targetPort); yield return new IPEndPoint(item, targetPort);
} }

View File

@ -48,7 +48,7 @@ namespace FastGithub
} }
/// <summary> /// <summary>
/// 尝试监听ssh反向代理 /// 监听ssh反向代理
/// </summary> /// </summary>
/// <param name="kestrel"></param> /// <param name="kestrel"></param>
public static void ListenSshReverseProxy(this KestrelServerOptions kestrel) public static void ListenSshReverseProxy(this KestrelServerOptions kestrel)
@ -57,14 +57,14 @@ namespace FastGithub
kestrel.ListenLocalhost(sshPort, listen => kestrel.ListenLocalhost(sshPort, listen =>
{ {
listen.UseFlowAnalyze(); listen.UseFlowAnalyze();
listen.UseConnectionHandler<SshReverseProxyHandler>(); listen.UseConnectionHandler<GithubSshReverseProxyHandler>();
}); });
kestrel.GetLogger().LogInformation($"已监听ssh://localhost:{sshPort}github的ssh反向代理服务启动完成"); kestrel.GetLogger().LogInformation($"已监听ssh://localhost:{sshPort}github的ssh反向代理服务启动完成");
} }
/// <summary> /// <summary>
/// 尝试监听git反向代理 /// 监听git反向代理
/// </summary> /// </summary>
/// <param name="kestrel"></param> /// <param name="kestrel"></param>
public static void ListenGitReverseProxy(this KestrelServerOptions kestrel) public static void ListenGitReverseProxy(this KestrelServerOptions kestrel)
@ -73,14 +73,14 @@ namespace FastGithub
kestrel.ListenLocalhost(gitPort, listen => kestrel.ListenLocalhost(gitPort, listen =>
{ {
listen.UseFlowAnalyze(); listen.UseFlowAnalyze();
listen.UseConnectionHandler<GitReverseProxyHandler>(); listen.UseConnectionHandler<GithubGitReverseProxyHandler>();
}); });
kestrel.GetLogger().LogInformation($"已监听git://localhost:{gitPort}github的git反向代理服务启动完成"); kestrel.GetLogger().LogInformation($"已监听git://localhost:{gitPort}github的git反向代理服务启动完成");
} }
/// <summary> /// <summary>
/// 尝试监听http反向代理 /// 监听http反向代理
/// </summary> /// </summary>
/// <param name="kestrel"></param> /// <param name="kestrel"></param>
public static void ListenHttpReverseProxy(this KestrelServerOptions kestrel) public static void ListenHttpReverseProxy(this KestrelServerOptions kestrel)

View File

@ -32,38 +32,42 @@ namespace FastGithub.HttpServer
} }
/// <summary> /// <summary>
/// ssh连接后 /// tcp连接后
/// </summary> /// </summary>
/// <param name="context"></param> /// <param name="context"></param>
/// <returns></returns> /// <returns></returns>
public override async Task OnConnectedAsync(ConnectionContext context) public override async Task OnConnectedAsync(ConnectionContext context)
{ {
using var connection = await this.CreateConnectionAsync(); var cancellationToken = context.ConnectionClosed;
var task1 = connection.CopyToAsync(context.Transport.Output); using var connection = await this.CreateConnectionAsync(cancellationToken);
var task2 = context.Transport.Input.CopyToAsync(connection); var task1 = connection.CopyToAsync(context.Transport.Output, cancellationToken);
var task2 = context.Transport.Input.CopyToAsync(connection, cancellationToken);
await Task.WhenAny(task1, task2); await Task.WhenAny(task1, task2);
} }
/// <summary> /// <summary>
/// 创建连接 /// 创建连接
/// </summary> /// </summary>
/// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="AggregateException"></exception> /// <exception cref="AggregateException"></exception>
private async Task<Stream> CreateConnectionAsync() private async Task<Stream> CreateConnectionAsync(CancellationToken cancellationToken)
{ {
var innerExceptions = new List<Exception>(); var innerExceptions = new List<Exception>();
await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint)) await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint, cancellationToken))
{ {
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try try
{ {
using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout); using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
await socket.ConnectAsync(address, this.endPoint.Port, timeoutTokenSource.Token); using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
await socket.ConnectAsync(address, this.endPoint.Port, linkedTokenSource.Token);
return new NetworkStream(socket, ownsSocket: false); return new NetworkStream(socket, ownsSocket: false);
} }
catch (Exception ex) catch (Exception ex)
{ {
socket.Dispose(); socket.Dispose();
cancellationToken.ThrowIfCancellationRequested();
innerExceptions.Add(ex); innerExceptions.Add(ex);
} }
} }