优化GetLocalAddress

This commit is contained in:
陈国伟 2021-07-15 17:33:01 +08:00
parent 2a298eaf48
commit cfeb5b2404
2 changed files with 15 additions and 45 deletions

View File

@ -77,7 +77,7 @@ namespace FastGithub.Dns
var result = await this.socket.ReceiveFromAsync(this.buffer, SocketFlags.None, remoteEndPoint); var result = await this.socket.ReceiveFromAsync(this.buffer, SocketFlags.None, remoteEndPoint);
var datas = new byte[result.ReceivedBytes]; var datas = new byte[result.ReceivedBytes];
this.buffer.AsSpan(0, datas.Length).CopyTo(datas); this.buffer.AsSpan(0, datas.Length).CopyTo(datas);
this.HandleRequestAsync(datas, (IPEndPoint)result.RemoteEndPoint, stoppingToken); this.HandleRequestAsync(datas, result.RemoteEndPoint, stoppingToken);
} }
} }
@ -87,12 +87,12 @@ namespace FastGithub.Dns
/// <param name="datas"></param> /// <param name="datas"></param>
/// <param name="remoteEndPoint"></param> /// <param name="remoteEndPoint"></param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
private async void HandleRequestAsync(byte[] datas, IPEndPoint remoteEndPoint, CancellationToken cancellationToken) private async void HandleRequestAsync(byte[] datas, EndPoint remoteEndPoint, CancellationToken cancellationToken)
{ {
try try
{ {
var request = Request.FromArray(datas); var request = Request.FromArray(datas);
var remoteRequest = new RemoteRequest(request, remoteEndPoint.Address); var remoteRequest = new RemoteRequest(request, remoteEndPoint);
var response = await this.requestResolver.Resolve(remoteRequest, cancellationToken); var response = await this.requestResolver.Resolve(remoteRequest, cancellationToken);
await this.socket.SendToAsync(response.ToArray(), SocketFlags.None, remoteEndPoint); await this.socket.SendToAsync(response.ToArray(), SocketFlags.None, remoteEndPoint);
} }

View File

@ -1,7 +1,7 @@
using DNS.Protocol; using DNS.Protocol;
using System.Buffers.Binary; using System;
using System.Net; using System.Net;
using System.Net.NetworkInformation; using System.Net.Sockets;
namespace FastGithub.Dns namespace FastGithub.Dns
{ {
@ -13,17 +13,17 @@ namespace FastGithub.Dns
/// <summary> /// <summary>
/// 获取远程地址 /// 获取远程地址
/// </summary> /// </summary>
public IPAddress RemoteAddress { get; } public EndPoint RemoteEndPoint { get; }
/// <summary> /// <summary>
/// 远程请求 /// 远程请求
/// </summary> /// </summary>
/// <param name="request"></param> /// <param name="request"></param>
/// <param name="remoteAddress"></param> /// <param name="remoteEndPoint"></param>
public RemoteRequest(Request request, IPAddress remoteAddress) public RemoteRequest(Request request, EndPoint remoteEndPoint)
: base(request) : base(request)
{ {
this.RemoteAddress = remoteAddress; this.RemoteEndPoint = remoteEndPoint;
} }
/// <summary> /// <summary>
@ -32,45 +32,15 @@ namespace FastGithub.Dns
/// <returns></returns> /// <returns></returns>
public IPAddress? GetLocalAddress() public IPAddress? GetLocalAddress()
{ {
foreach (var @interface in NetworkInterface.GetAllNetworkInterfaces()) try
{ {
var addresses = @interface.GetIPProperties().UnicastAddresses; using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
foreach (var item in addresses) socket.Connect(this.RemoteEndPoint);
{ return socket.LocalEndPoint is IPEndPoint localEndPoint ? localEndPoint.Address : default;
if (IsInSubNet(item.IPv4Mask, item.Address, this.RemoteAddress))
{
return item.Address;
}
}
} }
return default; catch (Exception)
}
/// <summary>
/// 是否在相同的子网里
/// </summary>
/// <param name="mask"></param>
/// <param name="local"></param>
/// <param name="remote"></param>
/// <returns></returns>
private static bool IsInSubNet(IPAddress mask, IPAddress local, IPAddress remote)
{
if (local.AddressFamily != remote.AddressFamily)
{ {
return false; return default;
}
var maskValue = GetValue(mask);
var localValue = GetValue(local);
var remoteValue = GetValue(remote);
return (maskValue & localValue) == (maskValue & remoteValue);
static long GetValue(IPAddress address)
{
var bytes = address.GetAddressBytes();
return bytes.Length == sizeof(int)
? BinaryPrimitives.ReadInt32BigEndian(bytes)
: BinaryPrimitives.ReadInt64BigEndian(bytes);
} }
} }
} }