更换windivert

This commit is contained in:
陈国伟 2022-09-28 14:50:33 +08:00
parent 893e9729ca
commit e493765433
3 changed files with 73 additions and 84 deletions

View File

@ -1,7 +1,6 @@
using DNS.Protocol; using DNS.Protocol;
using DNS.Protocol.ResourceRecords; using DNS.Protocol.ResourceRecords;
using FastGithub.Configuration; using FastGithub.Configuration;
using FastGithub.WinDiverts;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using System; using System;
@ -13,6 +12,7 @@ using System.Runtime.InteropServices;
using System.Runtime.Versioning; using System.Runtime.Versioning;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using WindivertDotnet;
namespace FastGithub.PacketIntercept.Dns namespace FastGithub.PacketIntercept.Dns
{ {
@ -22,7 +22,7 @@ namespace FastGithub.PacketIntercept.Dns
[SupportedOSPlatform("windows")] [SupportedOSPlatform("windows")]
sealed class DnsInterceptor : IDnsInterceptor sealed class DnsInterceptor : IDnsInterceptor
{ {
private const string DNS_FILTER = "udp.DstPort == 53"; private static readonly Filter filter = Filter.True.And(f => f.Udp.DstPort == 53);
private readonly FastGithubConfig fastGithubConfig; private readonly FastGithubConfig fastGithubConfig;
private readonly ILogger<DnsInterceptor> logger; private readonly ILogger<DnsInterceptor> logger;
@ -40,8 +40,11 @@ namespace FastGithub.PacketIntercept.Dns
/// </summary> /// </summary>
static DnsInterceptor() static DnsInterceptor()
{ {
var handle = WinDivert.WinDivertOpen("false", WinDivertLayer.Network, 0, WinDivertOpenFlags.None); try
WinDivert.WinDivertClose(handle); {
using (new WinDivert(Filter.False, WinDivertLayer.Network)) { }
}
catch (Exception) { }
} }
/// <summary> /// <summary>
@ -71,33 +74,23 @@ namespace FastGithub.PacketIntercept.Dns
{ {
await Task.Yield(); await Task.Yield();
var handle = WinDivert.WinDivertOpen(DNS_FILTER, WinDivertLayer.Network, 0, WinDivertOpenFlags.None); using var divert = new WinDivert(filter, WinDivertLayer.Network);
if (handle == new IntPtr(unchecked((long)ulong.MaxValue))) cancellationToken.Register(d =>
{ {
throw new Win32Exception(); ((WinDivert)d!).Dispose();
}
cancellationToken.Register(hwnd =>
{
WinDivert.WinDivertClose((IntPtr)hwnd!);
DnsFlushResolverCache(); DnsFlushResolverCache();
}, handle); }, divert);
var packetLength = 0U; var addr = new WinDivertAddress();
using var winDivertBuffer = new WinDivertBuffer(); using var packet = new WinDivertPacket();
var winDivertAddress = new WinDivertAddress();
DnsFlushResolverCache(); DnsFlushResolverCache();
while (cancellationToken.IsCancellationRequested == false) while (cancellationToken.IsCancellationRequested == false)
{ {
if (WinDivert.WinDivertRecv(handle, winDivertBuffer, ref winDivertAddress, ref packetLength) == false) divert.Recv(packet, ref addr);
{
throw new Win32Exception();
}
try try
{ {
this.ModifyDnsPacket(winDivertBuffer, ref winDivertAddress, ref packetLength); this.ModifyDnsPacket(packet, ref addr);
} }
catch (Exception ex) catch (Exception ex)
{ {
@ -105,7 +98,7 @@ namespace FastGithub.PacketIntercept.Dns
} }
finally finally
{ {
WinDivert.WinDivertSend(handle, winDivertBuffer, packetLength, ref winDivertAddress); divert.Send(packet, ref addr);
} }
} }
} }
@ -113,13 +106,12 @@ namespace FastGithub.PacketIntercept.Dns
/// <summary> /// <summary>
/// 修改DNS数据包 /// 修改DNS数据包
/// </summary> /// </summary>
/// <param name="winDivertBuffer"></param> /// <param name="packet"></param>
/// <param name="winDivertAddress"></param> /// <param name="addr"></param>
/// <param name="packetLength"></param> unsafe private void ModifyDnsPacket(WinDivertPacket packet, ref WinDivertAddress addr)
unsafe private void ModifyDnsPacket(WinDivertBuffer winDivertBuffer, ref WinDivertAddress winDivertAddress, ref uint packetLength)
{ {
var packet = WinDivert.WinDivertHelperParsePacket(winDivertBuffer, packetLength); var result = packet.GetParseResult();
var requestPayload = new Span<byte>(packet.PacketPayload, (int)packet.PacketPayloadLength).ToArray(); var requestPayload = result.DataSpan.ToArray();
if (TryParseRequest(requestPayload, out var request) == false || if (TryParseRequest(requestPayload, out var request) == false ||
request.OperationCode != OperationCode.Query || request.OperationCode != OperationCode.Query ||
@ -148,38 +140,43 @@ namespace FastGithub.PacketIntercept.Dns
var responsePayload = response.ToArray(); var responsePayload = response.ToArray();
// 修改payload和包长 // 修改payload和包长
responsePayload.CopyTo(new Span<byte>(packet.PacketPayload, responsePayload.Length)); responsePayload.CopyTo(new Span<byte>(result.Data, responsePayload.Length));
packetLength = (uint)((int)packetLength + responsePayload.Length - requestPayload.Length); packet.Length = packet.Length + responsePayload.Length - requestPayload.Length;
// 修改ip包 // 修改ip包
IPAddress destAddress; IPAddress destAddress;
if (packet.IPv4Header != null) if (result.IPV4Header != null)
{ {
destAddress = packet.IPv4Header->DstAddr; destAddress = result.IPV4Header->DstAddr;
packet.IPv4Header->DstAddr = packet.IPv4Header->SrcAddr; result.IPV4Header->DstAddr = result.IPV4Header->SrcAddr;
packet.IPv4Header->SrcAddr = destAddress; result.IPV4Header->SrcAddr = destAddress;
packet.IPv4Header->Length = (ushort)packetLength; result.IPV4Header->Length = (ushort)packet.Length;
} }
else else
{ {
destAddress = packet.IPv6Header->DstAddr; destAddress = result.IPV6Header->DstAddr;
packet.IPv6Header->DstAddr = packet.IPv6Header->SrcAddr; result.IPV6Header->DstAddr = result.IPV6Header->SrcAddr;
packet.IPv6Header->SrcAddr = destAddress; result.IPV6Header->SrcAddr = destAddress;
packet.IPv6Header->Length = (ushort)(packetLength - sizeof(IPv6Header)); result.IPV6Header->Length = (ushort)(packet.Length - sizeof(IPV6Header));
} }
// 修改udp包 // 修改udp包
var destPort = packet.UdpHeader->DstPort; var destPort = result.UdpHeader->DstPort;
packet.UdpHeader->DstPort = packet.UdpHeader->SrcPort; result.UdpHeader->DstPort = result.UdpHeader->SrcPort;
packet.UdpHeader->SrcPort = destPort; result.UdpHeader->SrcPort = destPort;
packet.UdpHeader->Length = (ushort)(sizeof(UdpHeader) + responsePayload.Length); result.UdpHeader->Length = (ushort)(sizeof(UdpHeader) + responsePayload.Length);
winDivertAddress.Impostor = true; addr.Flags |= WinDivertAddressFlag.Impostor;
winDivertAddress.Direction = winDivertAddress.Loopback if (addr.Flags.HasFlag(WinDivertAddressFlag.Loopback))
? WinDivertDirection.Outbound {
: WinDivertDirection.Inbound; addr.Flags |= WinDivertAddressFlag.Outbound;
}
else
{
addr.Flags ^= WinDivertAddressFlag.Outbound;
}
WinDivert.WinDivertHelperCalcChecksums(winDivertBuffer, packetLength, ref winDivertAddress, WinDivertChecksumHelperParam.All); packet.CalcChecksums(ref addr);
this.logger.LogInformation($"{domain}->{loopback}"); this.logger.LogInformation($"{domain}->{loopback}");
} }

View File

@ -7,7 +7,7 @@
<ItemGroup> <ItemGroup>
<FrameworkReference Include="Microsoft.AspNetCore.App" /> <FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="DNS" Version="7.0.0" /> <PackageReference Include="DNS" Version="7.0.0" />
<PackageReference Include="FastGithub.WinDiverts" Version="1.4.1" /> <PackageReference Include="WindivertDotnet" Version="1.0.0-beta1" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@ -1,5 +1,4 @@
using FastGithub.WinDiverts; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging;
using System; using System;
using System.ComponentModel; using System.ComponentModel;
using System.Net; using System.Net;
@ -7,6 +6,7 @@ using System.Net.Sockets;
using System.Runtime.Versioning; using System.Runtime.Versioning;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using WindivertDotnet;
namespace FastGithub.PacketIntercept.Tcp namespace FastGithub.PacketIntercept.Tcp
{ {
@ -16,7 +16,7 @@ namespace FastGithub.PacketIntercept.Tcp
[SupportedOSPlatform("windows")] [SupportedOSPlatform("windows")]
abstract class TcpInterceptor : ITcpInterceptor abstract class TcpInterceptor : ITcpInterceptor
{ {
private readonly string filter; private readonly Filter filter;
private readonly ushort oldServerPort; private readonly ushort oldServerPort;
private readonly ushort newServerPort; private readonly ushort newServerPort;
private readonly ILogger logger; private readonly ILogger logger;
@ -29,7 +29,10 @@ namespace FastGithub.PacketIntercept.Tcp
/// <param name="logger"></param> /// <param name="logger"></param>
public TcpInterceptor(int oldServerPort, int newServerPort, ILogger logger) public TcpInterceptor(int oldServerPort, int newServerPort, ILogger logger)
{ {
this.filter = $"loopback and (tcp.DstPort == {oldServerPort} or tcp.SrcPort == {newServerPort})"; this.filter = Filter.True
.And(f => f.Network.Loopback)
.And(f => f.Tcp.DstPort == oldServerPort || f.Tcp.SrcPort == newServerPort);
this.oldServerPort = (ushort)oldServerPort; this.oldServerPort = (ushort)oldServerPort;
this.newServerPort = (ushort)newServerPort; this.newServerPort = (ushort)newServerPort;
this.logger = logger; this.logger = logger;
@ -49,12 +52,7 @@ namespace FastGithub.PacketIntercept.Tcp
await Task.Yield(); await Task.Yield();
var handle = WinDivert.WinDivertOpen(this.filter, WinDivertLayer.Network, 0, WinDivertOpenFlags.None); using var divert = new WinDivert(this.filter, WinDivertLayer.Network, 0, WinDivertFlag.None);
if (handle == new IntPtr(unchecked((long)ulong.MaxValue)))
{
throw new Win32Exception();
}
if (Socket.OSSupportsIPv4) if (Socket.OSSupportsIPv4)
{ {
this.logger.LogInformation($"{IPAddress.Loopback}:{this.oldServerPort} <=> {IPAddress.Loopback}:{this.newServerPort}"); this.logger.LogInformation($"{IPAddress.Loopback}:{this.oldServerPort} <=> {IPAddress.Loopback}:{this.newServerPort}");
@ -63,23 +61,18 @@ namespace FastGithub.PacketIntercept.Tcp
{ {
this.logger.LogInformation($"{IPAddress.IPv6Loopback}:{this.oldServerPort} <=> {IPAddress.IPv6Loopback}:{this.newServerPort}"); this.logger.LogInformation($"{IPAddress.IPv6Loopback}:{this.oldServerPort} <=> {IPAddress.IPv6Loopback}:{this.newServerPort}");
} }
cancellationToken.Register(hwnd => WinDivert.WinDivertClose((IntPtr)hwnd!), handle); cancellationToken.Register(d => ((WinDivert)d!).Dispose(), divert);
var packetLength = 0U;
using var winDivertBuffer = new WinDivertBuffer();
var winDivertAddress = new WinDivertAddress();
var addr = new WinDivertAddress();
using var packet = new WinDivertPacket();
while (cancellationToken.IsCancellationRequested == false) while (cancellationToken.IsCancellationRequested == false)
{ {
winDivertAddress.Reset(); addr.Clear();
if (WinDivert.WinDivertRecv(handle, winDivertBuffer, ref winDivertAddress, ref packetLength) == false) divert.Recv(packet, ref addr);
{
throw new Win32Exception();
}
try try
{ {
this.ModifyTcpPacket(winDivertBuffer, ref winDivertAddress, ref packetLength); this.ModifyTcpPacket(packet, ref addr);
} }
catch (Exception ex) catch (Exception ex)
{ {
@ -87,7 +80,7 @@ namespace FastGithub.PacketIntercept.Tcp
} }
finally finally
{ {
WinDivert.WinDivertSend(handle, winDivertBuffer, packetLength, ref winDivertAddress); divert.Send(packet, ref addr);
} }
} }
} }
@ -95,31 +88,30 @@ namespace FastGithub.PacketIntercept.Tcp
/// <summary> /// <summary>
/// 修改tcp数据端口的端口 /// 修改tcp数据端口的端口
/// </summary> /// </summary>
/// <param name="winDivertBuffer"></param> /// <param name="packet"></param>
/// <param name="winDivertAddress"></param> /// <param name="addr"></param>
/// <param name="packetLength"></param> unsafe private void ModifyTcpPacket(WinDivertPacket packet, ref WinDivertAddress addr)
unsafe private void ModifyTcpPacket(WinDivertBuffer winDivertBuffer, ref WinDivertAddress winDivertAddress, ref uint packetLength)
{ {
var packet = WinDivert.WinDivertHelperParsePacket(winDivertBuffer, packetLength); var result = packet.GetParseResult();
if (packet.IPv4Header != null && packet.IPv4Header->SrcAddr.Equals(IPAddress.Loopback) == false) if (result.IPV4Header != null && result.IPV4Header->SrcAddr.Equals(IPAddress.Loopback) == false)
{ {
return; return;
} }
if (packet.IPv6Header != null && packet.IPv6Header->SrcAddr.Equals(IPAddress.IPv6Loopback) == false) if (result.IPV6Header != null && result.IPV6Header->SrcAddr.Equals(IPAddress.IPv6Loopback) == false)
{ {
return; return;
} }
if (packet.TcpHeader->DstPort == oldServerPort) if (result.TcpHeader->DstPort == oldServerPort)
{ {
packet.TcpHeader->DstPort = this.newServerPort; result.TcpHeader->DstPort = this.newServerPort;
} }
else else
{ {
packet.TcpHeader->SrcPort = oldServerPort; result.TcpHeader->SrcPort = oldServerPort;
} }
winDivertAddress.Impostor = true; addr.Flags |= WinDivertAddressFlag.Impostor;
WinDivert.WinDivertHelperCalcChecksums(winDivertBuffer, packetLength, ref winDivertAddress, WinDivertChecksumHelperParam.All); packet.CalcChecksums(ref addr);
} }
} }
} }