diff --git a/FastGithub.Core/TlsSniPattern.cs b/FastGithub.Core/TlsSniPattern.cs index f1af0f9..888fe83 100644 --- a/FastGithub.Core/TlsSniPattern.cs +++ b/FastGithub.Core/TlsSniPattern.cs @@ -40,9 +40,9 @@ namespace FastGithub /// Sni自定义值表达式 /// /// 表示式值 - public TlsSniPattern(string value) + public TlsSniPattern(string? value) { - this.Value = value; + this.Value = value ?? string.Empty; } /// diff --git a/FastGithub.ReverseProxy/HttpClient.cs b/FastGithub.ReverseProxy/HttpClient.cs index a8c4887..8c41199 100644 --- a/FastGithub.ReverseProxy/HttpClient.cs +++ b/FastGithub.ReverseProxy/HttpClient.cs @@ -32,8 +32,12 @@ namespace FastGithub.ReverseProxy /// public override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - var isHttps = request.RequestUri?.Scheme == Uri.UriSchemeHttps; - request.SetTlsSniContext(new TlsSniContext(isHttps, this.tlsSniPattern)); + request.SetRequestContext(new RequestContext + { + Host = request.RequestUri?.Host, + IsHttps = request.RequestUri?.Scheme == Uri.UriSchemeHttps, + TlsSniPattern = this.tlsSniPattern, + }); return base.SendAsync(request, cancellationToken); } } diff --git a/FastGithub.ReverseProxy/HttpClientHanlder.cs b/FastGithub.ReverseProxy/HttpClientHanlder.cs index 455ac3f..6301c26 100644 --- a/FastGithub.ReverseProxy/HttpClientHanlder.cs +++ b/FastGithub.ReverseProxy/HttpClientHanlder.cs @@ -1,4 +1,7 @@ using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Net.Http; using System.Net.Security; using System.Net.Sockets; @@ -42,8 +45,8 @@ namespace FastGithub.ReverseProxy await socket.ConnectAsync(context.DnsEndPoint, cancellationToken); var stream = new NetworkStream(socket, ownsSocket: true); - var tlsSniContext = context.InitialRequestMessage.GetTlsSniContext(); - if (tlsSniContext.IsHttps == false) + var requestContext = context.InitialRequestMessage.GetRequestContext(); + if (requestContext.IsHttps == false) { return stream; } @@ -51,20 +54,78 @@ namespace FastGithub.ReverseProxy var sslStream = new SslStream(stream, leaveInnerStreamOpen: false); await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions { - TargetHost = tlsSniContext.TlsSniPattern.Value, + TargetHost = requestContext.TlsSniPattern.Value, RemoteCertificateValidationCallback = ValidateServerCertificate }, cancellationToken); return sslStream; - // 这里最好需要验证证书的使用者和所有使用者可选名称 - static bool ValidateServerCertificate(object sender, X509Certificate? cert, X509Chain? chain, SslPolicyErrors errors) + + bool ValidateServerCertificate(object sender, X509Certificate? cert, X509Chain? chain, SslPolicyErrors errors) { - return errors == SslPolicyErrors.None || errors == SslPolicyErrors.RemoteCertificateNameMismatch; + if (errors == SslPolicyErrors.RemoteCertificateNameMismatch) + { + var host = requestContext.Host; + var dnsNames = ReadDnsNames(cert); + return dnsNames.Any(dns => IsMatch(dns, host)); + } + + return errors == SslPolicyErrors.None; } } }; } + /// + /// 读取使用的DNS名称 + /// + /// + /// + private static IEnumerable ReadDnsNames(X509Certificate? cert) + { + if (cert == null) + { + yield break; + } + var parser = new Org.BouncyCastle.X509.X509CertificateParser(); + var x509Cert = parser.ReadCertificate(cert.GetRawCertData()); + var subjects = x509Cert.GetSubjectAlternativeNames(); + + foreach (var subject in subjects) + { + if (subject is IList list) + { + var type = (int)list[0]!; + if (type == 2) // DNS + { + yield return list[list.Count - 1]!.ToString()!; + } + } + } + } + + /// + /// 比较域名 + /// + /// + /// + /// + private static bool IsMatch(string dnsName, string? host) + { + if (host == null) + { + return false; + } + if (dnsName == host) + { + return true; + } + if (dnsName[0] == '*') + { + return host.EndsWith(dnsName[1..]); + } + return false; + } + /// /// 替换域名为ip /// @@ -85,7 +146,7 @@ namespace FastGithub.ReverseProxy request.RequestUri = builder.Uri; request.Headers.Host = uri.Host; - var context = request.GetTlsSniContext(); + var context = request.GetRequestContext(); context.TlsSniPattern = context.TlsSniPattern.WithDomain(uri.Host).WithIPAddress(address).WithRandom(); } return await base.SendAsync(request, cancellationToken); diff --git a/FastGithub.ReverseProxy/RequestContext.cs b/FastGithub.ReverseProxy/RequestContext.cs new file mode 100644 index 0000000..23fb808 --- /dev/null +++ b/FastGithub.ReverseProxy/RequestContext.cs @@ -0,0 +1,23 @@ +namespace FastGithub.ReverseProxy +{ + /// + /// 表示请求上下文 + /// + sealed class RequestContext + { + /// + /// 获取或设置是否为https请求 + /// + public bool IsHttps { get; set; } + + /// + /// 请求的主机 + /// + public string? Host { get; set; } + + /// + /// 获取或设置Sni值的表达式 + /// + public TlsSniPattern TlsSniPattern { get; set; } + } +} diff --git a/FastGithub.ReverseProxy/RequestContextExtensions.cs b/FastGithub.ReverseProxy/RequestContextExtensions.cs new file mode 100644 index 0000000..68e6291 --- /dev/null +++ b/FastGithub.ReverseProxy/RequestContextExtensions.cs @@ -0,0 +1,35 @@ +using System; +using System.Net.Http; + +namespace FastGithub.ReverseProxy +{ + /// + /// 请求上下文扩展 + /// + static class RequestContextExtensions + { + private static readonly HttpRequestOptionsKey key = new(nameof(RequestContext)); + + /// + /// 设置RequestContext + /// + /// + /// + public static void SetRequestContext(this HttpRequestMessage httpRequestMessage, RequestContext requestContext) + { + httpRequestMessage.Options.Set(key, requestContext); + } + + /// + /// 获取RequestContext + /// + /// + /// + public static RequestContext GetRequestContext(this HttpRequestMessage httpRequestMessage) + { + return httpRequestMessage.Options.TryGetValue(key, out var requestContext) + ? requestContext + : throw new InvalidOperationException($"请先调用{nameof(SetRequestContext)}"); + } + } +} diff --git a/FastGithub.ReverseProxy/TlsSniContext.cs b/FastGithub.ReverseProxy/TlsSniContext.cs deleted file mode 100644 index 0e308d5..0000000 --- a/FastGithub.ReverseProxy/TlsSniContext.cs +++ /dev/null @@ -1,29 +0,0 @@ -namespace FastGithub.ReverseProxy -{ - /// - /// Sni上下文 - /// - sealed class TlsSniContext - { - /// - /// 获取是否为https请求 - /// - public bool IsHttps { get; } - - /// - /// 获取或设置Sni值的表达式 - /// - public TlsSniPattern TlsSniPattern { get; set; } - - /// - /// Sni上下文 - /// - /// - /// - public TlsSniContext(bool isHttps, TlsSniPattern tlsSniPattern) - { - this.IsHttps = isHttps; - this.TlsSniPattern = tlsSniPattern; - } - } -} diff --git a/FastGithub.ReverseProxy/TlsSniContextExtensions.cs b/FastGithub.ReverseProxy/TlsSniContextExtensions.cs deleted file mode 100644 index 0d44351..0000000 --- a/FastGithub.ReverseProxy/TlsSniContextExtensions.cs +++ /dev/null @@ -1,37 +0,0 @@ -using System; -using System.Net.Http; - -namespace FastGithub.ReverseProxy -{ - /// - /// SniContext扩展 - /// - static class TlsSniContextExtensions - { - private static readonly HttpRequestOptionsKey key = new(nameof(TlsSniContext)); - - /// - /// 设置TlsSniContext - /// - /// - /// - public static void SetTlsSniContext(this HttpRequestMessage httpRequestMessage, TlsSniContext context) - { - httpRequestMessage.Options.Set(key, context); - } - - /// - /// 获取TlsSniContext - /// - /// - /// - public static TlsSniContext GetTlsSniContext(this HttpRequestMessage httpRequestMessage) - { - if (httpRequestMessage.Options.TryGetValue(key, out var value)) - { - return value; - } - throw new InvalidOperationException($"请先调用{nameof(SetTlsSniContext)}"); - } - } -}