diff --git a/FastGithub.HttpServer/TcpMiddlewares/HttpProxyMiddleware.cs b/FastGithub.HttpServer/TcpMiddlewares/HttpProxyMiddleware.cs index 57eb764..98df5c3 100644 --- a/FastGithub.HttpServer/TcpMiddlewares/HttpProxyMiddleware.cs +++ b/FastGithub.HttpServer/TcpMiddlewares/HttpProxyMiddleware.cs @@ -26,56 +26,77 @@ namespace FastGithub.HttpServer.TcpMiddlewares /// public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext context) { - var result = await context.Transport.Input.ReadAsync(); - var httpRequest = this.GetHttpRequestHandler(result, out var consumed); + var input = context.Transport.Input; + var output = context.Transport.Output; + var request = new HttpRequestHandler(); - // 协议错误 - if (consumed == 0L) + while (context.ConnectionClosed.IsCancellationRequested == false) { - await context.Transport.Output.WriteAsync(this.http400, context.ConnectionClosed); - } - else - { - // 隧道代理连接请求 - if (httpRequest.ProxyProtocol == ProxyProtocol.TunnelProxy) + var result = await input.ReadAsync(); + if (result.IsCanceled) { - var position = result.Buffer.GetPosition(consumed); - context.Transport.Input.AdvanceTo(position); - await context.Transport.Output.WriteAsync(this.http200, context.ConnectionClosed); - } - else - { - var position = result.Buffer.Start; - context.Transport.Input.AdvanceTo(position); + break; } - context.Features.Set(httpRequest); - await next(context); + try + { + if (this.ParseRequest(result, request, out var consumed)) + { + if (request.ProxyProtocol == ProxyProtocol.TunnelProxy) + { + input.AdvanceTo(consumed); + await output.WriteAsync(this.http200, context.ConnectionClosed); + } + else + { + input.AdvanceTo(result.Buffer.Start); + } + + context.Features.Set(request); + await next(context); + + break; + } + else + { + input.AdvanceTo(result.Buffer.Start, result.Buffer.End); + } + + if (result.IsCompleted) + { + break; + } + } + catch (Exception) + { + await output.WriteAsync(this.http400, context.ConnectionClosed); + break; + } } } /// - /// 获取http请求处理者 + /// 解析http请求 /// /// + /// /// /// - private HttpRequestHandler GetHttpRequestHandler(ReadResult result, out long consumed) + private bool ParseRequest(ReadResult result, HttpRequestHandler request, out SequencePosition consumed) { - var handler = new HttpRequestHandler(); var reader = new SequenceReader(result.Buffer); - - if (this.httpParser.ParseRequestLine(handler, ref reader) && - this.httpParser.ParseHeaders(handler, ref reader)) + if (this.httpParser.ParseRequestLine(request, ref reader) && + this.httpParser.ParseHeaders(request, ref reader)) { - consumed = reader.Consumed; + consumed = reader.Position; + return true; } else { - consumed = 0L; + consumed = default; + return false; } - return handler; }