diff --git a/FastGithub.Core/TlsSniPattern.cs b/FastGithub.Core/TlsSniPattern.cs
index f1af0f9..888fe83 100644
--- a/FastGithub.Core/TlsSniPattern.cs
+++ b/FastGithub.Core/TlsSniPattern.cs
@@ -40,9 +40,9 @@ namespace FastGithub
/// Sni自定义值表达式
///
/// 表示式值
- public TlsSniPattern(string value)
+ public TlsSniPattern(string? value)
{
- this.Value = value;
+ this.Value = value ?? string.Empty;
}
///
diff --git a/FastGithub.ReverseProxy/HttpClient.cs b/FastGithub.ReverseProxy/HttpClient.cs
index a8c4887..8c41199 100644
--- a/FastGithub.ReverseProxy/HttpClient.cs
+++ b/FastGithub.ReverseProxy/HttpClient.cs
@@ -32,8 +32,12 @@ namespace FastGithub.ReverseProxy
///
public override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
- var isHttps = request.RequestUri?.Scheme == Uri.UriSchemeHttps;
- request.SetTlsSniContext(new TlsSniContext(isHttps, this.tlsSniPattern));
+ request.SetRequestContext(new RequestContext
+ {
+ Host = request.RequestUri?.Host,
+ IsHttps = request.RequestUri?.Scheme == Uri.UriSchemeHttps,
+ TlsSniPattern = this.tlsSniPattern,
+ });
return base.SendAsync(request, cancellationToken);
}
}
diff --git a/FastGithub.ReverseProxy/HttpClientHanlder.cs b/FastGithub.ReverseProxy/HttpClientHanlder.cs
index 455ac3f..6301c26 100644
--- a/FastGithub.ReverseProxy/HttpClientHanlder.cs
+++ b/FastGithub.ReverseProxy/HttpClientHanlder.cs
@@ -1,4 +1,7 @@
using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using System.Net.Http;
using System.Net.Security;
using System.Net.Sockets;
@@ -42,8 +45,8 @@ namespace FastGithub.ReverseProxy
await socket.ConnectAsync(context.DnsEndPoint, cancellationToken);
var stream = new NetworkStream(socket, ownsSocket: true);
- var tlsSniContext = context.InitialRequestMessage.GetTlsSniContext();
- if (tlsSniContext.IsHttps == false)
+ var requestContext = context.InitialRequestMessage.GetRequestContext();
+ if (requestContext.IsHttps == false)
{
return stream;
}
@@ -51,20 +54,78 @@ namespace FastGithub.ReverseProxy
var sslStream = new SslStream(stream, leaveInnerStreamOpen: false);
await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions
{
- TargetHost = tlsSniContext.TlsSniPattern.Value,
+ TargetHost = requestContext.TlsSniPattern.Value,
RemoteCertificateValidationCallback = ValidateServerCertificate
}, cancellationToken);
return sslStream;
- // 这里最好需要验证证书的使用者和所有使用者可选名称
- static bool ValidateServerCertificate(object sender, X509Certificate? cert, X509Chain? chain, SslPolicyErrors errors)
+
+ bool ValidateServerCertificate(object sender, X509Certificate? cert, X509Chain? chain, SslPolicyErrors errors)
{
- return errors == SslPolicyErrors.None || errors == SslPolicyErrors.RemoteCertificateNameMismatch;
+ if (errors == SslPolicyErrors.RemoteCertificateNameMismatch)
+ {
+ var host = requestContext.Host;
+ var dnsNames = ReadDnsNames(cert);
+ return dnsNames.Any(dns => IsMatch(dns, host));
+ }
+
+ return errors == SslPolicyErrors.None;
}
}
};
}
+ ///
+ /// 读取使用的DNS名称
+ ///
+ ///
+ ///
+ private static IEnumerable ReadDnsNames(X509Certificate? cert)
+ {
+ if (cert == null)
+ {
+ yield break;
+ }
+ var parser = new Org.BouncyCastle.X509.X509CertificateParser();
+ var x509Cert = parser.ReadCertificate(cert.GetRawCertData());
+ var subjects = x509Cert.GetSubjectAlternativeNames();
+
+ foreach (var subject in subjects)
+ {
+ if (subject is IList list)
+ {
+ var type = (int)list[0]!;
+ if (type == 2) // DNS
+ {
+ yield return list[list.Count - 1]!.ToString()!;
+ }
+ }
+ }
+ }
+
+ ///
+ /// 比较域名
+ ///
+ ///
+ ///
+ ///
+ private static bool IsMatch(string dnsName, string? host)
+ {
+ if (host == null)
+ {
+ return false;
+ }
+ if (dnsName == host)
+ {
+ return true;
+ }
+ if (dnsName[0] == '*')
+ {
+ return host.EndsWith(dnsName[1..]);
+ }
+ return false;
+ }
+
///
/// 替换域名为ip
///
@@ -85,7 +146,7 @@ namespace FastGithub.ReverseProxy
request.RequestUri = builder.Uri;
request.Headers.Host = uri.Host;
- var context = request.GetTlsSniContext();
+ var context = request.GetRequestContext();
context.TlsSniPattern = context.TlsSniPattern.WithDomain(uri.Host).WithIPAddress(address).WithRandom();
}
return await base.SendAsync(request, cancellationToken);
diff --git a/FastGithub.ReverseProxy/RequestContext.cs b/FastGithub.ReverseProxy/RequestContext.cs
new file mode 100644
index 0000000..23fb808
--- /dev/null
+++ b/FastGithub.ReverseProxy/RequestContext.cs
@@ -0,0 +1,23 @@
+namespace FastGithub.ReverseProxy
+{
+ ///
+ /// 表示请求上下文
+ ///
+ sealed class RequestContext
+ {
+ ///
+ /// 获取或设置是否为https请求
+ ///
+ public bool IsHttps { get; set; }
+
+ ///
+ /// 请求的主机
+ ///
+ public string? Host { get; set; }
+
+ ///
+ /// 获取或设置Sni值的表达式
+ ///
+ public TlsSniPattern TlsSniPattern { get; set; }
+ }
+}
diff --git a/FastGithub.ReverseProxy/RequestContextExtensions.cs b/FastGithub.ReverseProxy/RequestContextExtensions.cs
new file mode 100644
index 0000000..68e6291
--- /dev/null
+++ b/FastGithub.ReverseProxy/RequestContextExtensions.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Net.Http;
+
+namespace FastGithub.ReverseProxy
+{
+ ///
+ /// 请求上下文扩展
+ ///
+ static class RequestContextExtensions
+ {
+ private static readonly HttpRequestOptionsKey key = new(nameof(RequestContext));
+
+ ///
+ /// 设置RequestContext
+ ///
+ ///
+ ///
+ public static void SetRequestContext(this HttpRequestMessage httpRequestMessage, RequestContext requestContext)
+ {
+ httpRequestMessage.Options.Set(key, requestContext);
+ }
+
+ ///
+ /// 获取RequestContext
+ ///
+ ///
+ ///
+ public static RequestContext GetRequestContext(this HttpRequestMessage httpRequestMessage)
+ {
+ return httpRequestMessage.Options.TryGetValue(key, out var requestContext)
+ ? requestContext
+ : throw new InvalidOperationException($"请先调用{nameof(SetRequestContext)}");
+ }
+ }
+}
diff --git a/FastGithub.ReverseProxy/TlsSniContext.cs b/FastGithub.ReverseProxy/TlsSniContext.cs
deleted file mode 100644
index 0e308d5..0000000
--- a/FastGithub.ReverseProxy/TlsSniContext.cs
+++ /dev/null
@@ -1,29 +0,0 @@
-namespace FastGithub.ReverseProxy
-{
- ///
- /// Sni上下文
- ///
- sealed class TlsSniContext
- {
- ///
- /// 获取是否为https请求
- ///
- public bool IsHttps { get; }
-
- ///
- /// 获取或设置Sni值的表达式
- ///
- public TlsSniPattern TlsSniPattern { get; set; }
-
- ///
- /// Sni上下文
- ///
- ///
- ///
- public TlsSniContext(bool isHttps, TlsSniPattern tlsSniPattern)
- {
- this.IsHttps = isHttps;
- this.TlsSniPattern = tlsSniPattern;
- }
- }
-}
diff --git a/FastGithub.ReverseProxy/TlsSniContextExtensions.cs b/FastGithub.ReverseProxy/TlsSniContextExtensions.cs
deleted file mode 100644
index 0d44351..0000000
--- a/FastGithub.ReverseProxy/TlsSniContextExtensions.cs
+++ /dev/null
@@ -1,37 +0,0 @@
-using System;
-using System.Net.Http;
-
-namespace FastGithub.ReverseProxy
-{
- ///
- /// SniContext扩展
- ///
- static class TlsSniContextExtensions
- {
- private static readonly HttpRequestOptionsKey key = new(nameof(TlsSniContext));
-
- ///
- /// 设置TlsSniContext
- ///
- ///
- ///
- public static void SetTlsSniContext(this HttpRequestMessage httpRequestMessage, TlsSniContext context)
- {
- httpRequestMessage.Options.Set(key, context);
- }
-
- ///
- /// 获取TlsSniContext
- ///
- ///
- ///
- public static TlsSniContext GetTlsSniContext(this HttpRequestMessage httpRequestMessage)
- {
- if (httpRequestMessage.Options.TryGetValue(key, out var value))
- {
- return value;
- }
- throw new InvalidOperationException($"请先调用{nameof(SetTlsSniContext)}");
- }
- }
-}