验证服务证书DNS

This commit is contained in:
陈国伟 2021-07-20 09:42:38 +08:00
parent f5698ef1e1
commit f749200bfd
7 changed files with 134 additions and 77 deletions

View File

@ -40,9 +40,9 @@ namespace FastGithub
/// Sni自定义值表达式 /// Sni自定义值表达式
/// </summary> /// </summary>
/// <param name="value">表示式值</param> /// <param name="value">表示式值</param>
public TlsSniPattern(string value) public TlsSniPattern(string? value)
{ {
this.Value = value; this.Value = value ?? string.Empty;
} }
/// <summary> /// <summary>

View File

@ -32,8 +32,12 @@ namespace FastGithub.ReverseProxy
/// <returns></returns> /// <returns></returns>
public override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) public override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{ {
var isHttps = request.RequestUri?.Scheme == Uri.UriSchemeHttps; request.SetRequestContext(new RequestContext
request.SetTlsSniContext(new TlsSniContext(isHttps, this.tlsSniPattern)); {
Host = request.RequestUri?.Host,
IsHttps = request.RequestUri?.Scheme == Uri.UriSchemeHttps,
TlsSniPattern = this.tlsSniPattern,
});
return base.SendAsync(request, cancellationToken); return base.SendAsync(request, cancellationToken);
} }
} }

View File

@ -1,4 +1,7 @@
using System; using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http; using System.Net.Http;
using System.Net.Security; using System.Net.Security;
using System.Net.Sockets; using System.Net.Sockets;
@ -42,8 +45,8 @@ namespace FastGithub.ReverseProxy
await socket.ConnectAsync(context.DnsEndPoint, cancellationToken); await socket.ConnectAsync(context.DnsEndPoint, cancellationToken);
var stream = new NetworkStream(socket, ownsSocket: true); var stream = new NetworkStream(socket, ownsSocket: true);
var tlsSniContext = context.InitialRequestMessage.GetTlsSniContext(); var requestContext = context.InitialRequestMessage.GetRequestContext();
if (tlsSniContext.IsHttps == false) if (requestContext.IsHttps == false)
{ {
return stream; return stream;
} }
@ -51,20 +54,78 @@ namespace FastGithub.ReverseProxy
var sslStream = new SslStream(stream, leaveInnerStreamOpen: false); var sslStream = new SslStream(stream, leaveInnerStreamOpen: false);
await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions
{ {
TargetHost = tlsSniContext.TlsSniPattern.Value, TargetHost = requestContext.TlsSniPattern.Value,
RemoteCertificateValidationCallback = ValidateServerCertificate RemoteCertificateValidationCallback = ValidateServerCertificate
}, cancellationToken); }, cancellationToken);
return sslStream; return sslStream;
// 这里最好需要验证证书的使用者和所有使用者可选名称
static bool ValidateServerCertificate(object sender, X509Certificate? cert, X509Chain? chain, SslPolicyErrors errors) bool ValidateServerCertificate(object sender, X509Certificate? cert, X509Chain? chain, SslPolicyErrors errors)
{ {
return errors == SslPolicyErrors.None || errors == SslPolicyErrors.RemoteCertificateNameMismatch; if (errors == SslPolicyErrors.RemoteCertificateNameMismatch)
{
var host = requestContext.Host;
var dnsNames = ReadDnsNames(cert);
return dnsNames.Any(dns => IsMatch(dns, host));
}
return errors == SslPolicyErrors.None;
} }
} }
}; };
} }
/// <summary>
/// 读取使用的DNS名称
/// </summary>
/// <param name="cert"></param>
/// <returns></returns>
private static IEnumerable<string> ReadDnsNames(X509Certificate? cert)
{
if (cert == null)
{
yield break;
}
var parser = new Org.BouncyCastle.X509.X509CertificateParser();
var x509Cert = parser.ReadCertificate(cert.GetRawCertData());
var subjects = x509Cert.GetSubjectAlternativeNames();
foreach (var subject in subjects)
{
if (subject is IList list)
{
var type = (int)list[0]!;
if (type == 2) // DNS
{
yield return list[list.Count - 1]!.ToString()!;
}
}
}
}
/// <summary>
/// 比较域名
/// </summary>
/// <param name="dnsName"></param>
/// <param name="host"></param>
/// <returns></returns>
private static bool IsMatch(string dnsName, string? host)
{
if (host == null)
{
return false;
}
if (dnsName == host)
{
return true;
}
if (dnsName[0] == '*')
{
return host.EndsWith(dnsName[1..]);
}
return false;
}
/// <summary> /// <summary>
/// 替换域名为ip /// 替换域名为ip
/// </summary> /// </summary>
@ -85,7 +146,7 @@ namespace FastGithub.ReverseProxy
request.RequestUri = builder.Uri; request.RequestUri = builder.Uri;
request.Headers.Host = uri.Host; request.Headers.Host = uri.Host;
var context = request.GetTlsSniContext(); var context = request.GetRequestContext();
context.TlsSniPattern = context.TlsSniPattern.WithDomain(uri.Host).WithIPAddress(address).WithRandom(); context.TlsSniPattern = context.TlsSniPattern.WithDomain(uri.Host).WithIPAddress(address).WithRandom();
} }
return await base.SendAsync(request, cancellationToken); return await base.SendAsync(request, cancellationToken);

View File

@ -0,0 +1,23 @@
namespace FastGithub.ReverseProxy
{
/// <summary>
/// 表示请求上下文
/// </summary>
sealed class RequestContext
{
/// <summary>
/// 获取或设置是否为https请求
/// </summary>
public bool IsHttps { get; set; }
/// <summary>
/// 请求的主机
/// </summary>
public string? Host { get; set; }
/// <summary>
/// 获取或设置Sni值的表达式
/// </summary>
public TlsSniPattern TlsSniPattern { get; set; }
}
}

View File

@ -0,0 +1,35 @@
using System;
using System.Net.Http;
namespace FastGithub.ReverseProxy
{
/// <summary>
/// 请求上下文扩展
/// </summary>
static class RequestContextExtensions
{
private static readonly HttpRequestOptionsKey<RequestContext> key = new(nameof(RequestContext));
/// <summary>
/// 设置RequestContext
/// </summary>
/// <param name="httpRequestMessage"></param>
/// <param name="requestContext"></param>
public static void SetRequestContext(this HttpRequestMessage httpRequestMessage, RequestContext requestContext)
{
httpRequestMessage.Options.Set(key, requestContext);
}
/// <summary>
/// 获取RequestContext
/// </summary>
/// <param name="httpRequestMessage"></param>
/// <returns></returns>
public static RequestContext GetRequestContext(this HttpRequestMessage httpRequestMessage)
{
return httpRequestMessage.Options.TryGetValue(key, out var requestContext)
? requestContext
: throw new InvalidOperationException($"请先调用{nameof(SetRequestContext)}");
}
}
}

View File

@ -1,29 +0,0 @@
namespace FastGithub.ReverseProxy
{
/// <summary>
/// Sni上下文
/// </summary>
sealed class TlsSniContext
{
/// <summary>
/// 获取是否为https请求
/// </summary>
public bool IsHttps { get; }
/// <summary>
/// 获取或设置Sni值的表达式
/// </summary>
public TlsSniPattern TlsSniPattern { get; set; }
/// <summary>
/// Sni上下文
/// </summary>
/// <param name="isHttps"></param>
/// <param name="tlsSniPattern"></param>
public TlsSniContext(bool isHttps, TlsSniPattern tlsSniPattern)
{
this.IsHttps = isHttps;
this.TlsSniPattern = tlsSniPattern;
}
}
}

View File

@ -1,37 +0,0 @@
using System;
using System.Net.Http;
namespace FastGithub.ReverseProxy
{
/// <summary>
/// SniContext扩展
/// </summary>
static class TlsSniContextExtensions
{
private static readonly HttpRequestOptionsKey<TlsSniContext> key = new(nameof(TlsSniContext));
/// <summary>
/// 设置TlsSniContext
/// </summary>
/// <param name="httpRequestMessage"></param>
/// <param name="context"></param>
public static void SetTlsSniContext(this HttpRequestMessage httpRequestMessage, TlsSniContext context)
{
httpRequestMessage.Options.Set(key, context);
}
/// <summary>
/// 获取TlsSniContext
/// </summary>
/// <param name="httpRequestMessage"></param>
/// <returns></returns>
public static TlsSniContext GetTlsSniContext(this HttpRequestMessage httpRequestMessage)
{
if (httpRequestMessage.Options.TryGetValue(key, out var value))
{
return value;
}
throw new InvalidOperationException($"请先调用{nameof(SetTlsSniContext)}");
}
}
}