diff --git a/FastGithub.HttpServer/TcpMiddlewares/HttpProxyMiddleware.cs b/FastGithub.HttpServer/TcpMiddlewares/HttpProxyMiddleware.cs
index 57eb764..98df5c3 100644
--- a/FastGithub.HttpServer/TcpMiddlewares/HttpProxyMiddleware.cs
+++ b/FastGithub.HttpServer/TcpMiddlewares/HttpProxyMiddleware.cs
@@ -26,56 +26,77 @@ namespace FastGithub.HttpServer.TcpMiddlewares
///
public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext context)
{
- var result = await context.Transport.Input.ReadAsync();
- var httpRequest = this.GetHttpRequestHandler(result, out var consumed);
+ var input = context.Transport.Input;
+ var output = context.Transport.Output;
+ var request = new HttpRequestHandler();
- // 协议错误
- if (consumed == 0L)
+ while (context.ConnectionClosed.IsCancellationRequested == false)
{
- await context.Transport.Output.WriteAsync(this.http400, context.ConnectionClosed);
- }
- else
- {
- // 隧道代理连接请求
- if (httpRequest.ProxyProtocol == ProxyProtocol.TunnelProxy)
+ var result = await input.ReadAsync();
+ if (result.IsCanceled)
{
- var position = result.Buffer.GetPosition(consumed);
- context.Transport.Input.AdvanceTo(position);
- await context.Transport.Output.WriteAsync(this.http200, context.ConnectionClosed);
- }
- else
- {
- var position = result.Buffer.Start;
- context.Transport.Input.AdvanceTo(position);
+ break;
}
- context.Features.Set(httpRequest);
- await next(context);
+ try
+ {
+ if (this.ParseRequest(result, request, out var consumed))
+ {
+ if (request.ProxyProtocol == ProxyProtocol.TunnelProxy)
+ {
+ input.AdvanceTo(consumed);
+ await output.WriteAsync(this.http200, context.ConnectionClosed);
+ }
+ else
+ {
+ input.AdvanceTo(result.Buffer.Start);
+ }
+
+ context.Features.Set(request);
+ await next(context);
+
+ break;
+ }
+ else
+ {
+ input.AdvanceTo(result.Buffer.Start, result.Buffer.End);
+ }
+
+ if (result.IsCompleted)
+ {
+ break;
+ }
+ }
+ catch (Exception)
+ {
+ await output.WriteAsync(this.http400, context.ConnectionClosed);
+ break;
+ }
}
}
///
- /// 获取http请求处理者
+ /// 解析http请求
///
///
+ ///
///
///
- private HttpRequestHandler GetHttpRequestHandler(ReadResult result, out long consumed)
+ private bool ParseRequest(ReadResult result, HttpRequestHandler request, out SequencePosition consumed)
{
- var handler = new HttpRequestHandler();
var reader = new SequenceReader(result.Buffer);
-
- if (this.httpParser.ParseRequestLine(handler, ref reader) &&
- this.httpParser.ParseHeaders(handler, ref reader))
+ if (this.httpParser.ParseRequestLine(request, ref reader) &&
+ this.httpParser.ParseHeaders(request, ref reader))
{
- consumed = reader.Consumed;
+ consumed = reader.Position;
+ return true;
}
else
{
- consumed = 0L;
+ consumed = default;
+ return false;
}
- return handler;
}