未验证 提交 c5b6881e 编写于 作者: J Jan Kotas 提交者: GitHub

Simplify DynamicWinsockMethods (#43190)

Co-authored-by: NStephen Toub <stoub@microsoft.com>
上级 bae8b42f
......@@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using System.Threading;
......@@ -38,86 +39,24 @@ public static DynamicWinsockMethods GetMethods(AddressFamily addressFamily, Sock
private readonly AddressFamily _addressFamily;
private readonly SocketType _socketType;
private readonly ProtocolType _protocolType;
private readonly object _lockObject;
private AcceptExDelegate? _acceptEx;
private GetAcceptExSockaddrsDelegate? _getAcceptExSockaddrs;
private ConnectExDelegate? _connectEx;
private TransmitPacketsDelegate? _transmitPackets;
private DisconnectExDelegate? _disconnectEx;
private DisconnectExDelegateBlocking? _disconnectExBlocking;
private WSARecvMsgDelegate? _recvMsg;
private WSARecvMsgDelegateBlocking? _recvMsgBlocking;
private DynamicWinsockMethods(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
_addressFamily = addressFamily;
_socketType = socketType;
_protocolType = protocolType;
_lockObject = new object();
}
public T GetDelegate<T>(SafeSocketHandle socketHandle)
where T : class
{
if (typeof(T) == typeof(AcceptExDelegate))
{
EnsureAcceptEx(socketHandle);
Debug.Assert(_acceptEx != null);
return (T)(object)_acceptEx;
}
else if (typeof(T) == typeof(GetAcceptExSockaddrsDelegate))
{
EnsureGetAcceptExSockaddrs(socketHandle);
Debug.Assert(_getAcceptExSockaddrs != null);
return (T)(object)_getAcceptExSockaddrs;
}
else if (typeof(T) == typeof(ConnectExDelegate))
{
EnsureConnectEx(socketHandle);
Debug.Assert(_connectEx != null);
return (T)(object)_connectEx;
}
else if (typeof(T) == typeof(DisconnectExDelegate))
{
EnsureDisconnectEx(socketHandle);
Debug.Assert(_disconnectEx != null);
return (T)(object)_disconnectEx;
}
else if (typeof(T) == typeof(DisconnectExDelegateBlocking))
{
EnsureDisconnectEx(socketHandle);
Debug.Assert(_disconnectExBlocking != null);
return (T)(object)_disconnectExBlocking;
}
else if (typeof(T) == typeof(WSARecvMsgDelegate))
{
EnsureWSARecvMsg(socketHandle);
Debug.Assert(_recvMsg != null);
return (T)(object)_recvMsg;
}
else if (typeof(T) == typeof(WSARecvMsgDelegateBlocking))
{
EnsureWSARecvMsgBlocking(socketHandle);
Debug.Assert(_recvMsgBlocking != null);
return (T)(object)_recvMsgBlocking;
}
else if (typeof(T) == typeof(TransmitPacketsDelegate))
{
EnsureTransmitPackets(socketHandle);
Debug.Assert(_transmitPackets != null);
return (T)(object)_transmitPackets;
}
Debug.Fail("Invalid type passed to DynamicWinsockMethods.GetDelegate");
return null;
}
// Private methods that actually load the function pointers.
private IntPtr LoadDynamicFunctionPointer(SafeSocketHandle socketHandle, ref Guid guid)
private static T CreateDelegate<T>([NotNull] ref T? cache, SafeSocketHandle socketHandle, string guidString) where T: Delegate
{
Guid guid = new Guid(guidString);
IntPtr ptr = IntPtr.Zero;
int length;
SocketError errorCode;
......@@ -141,125 +80,27 @@ private IntPtr LoadDynamicFunctionPointer(SafeSocketHandle socketHandle, ref Gui
throw new SocketException();
}
return ptr;
Interlocked.CompareExchange(ref cache, Marshal.GetDelegateForFunctionPointer<T>(ptr), null);
return cache;
}
// NOTE: the volatile writes in the functions below are necessary to ensure that all writes
// to the fields of the delegate instances are visible before the write to the field
// that holds the reference to the delegate instance.
internal AcceptExDelegate GetAcceptExDelegate(SafeSocketHandle socketHandle)
=> _acceptEx ?? CreateDelegate(ref _acceptEx, socketHandle, "b5367df1cbac11cf95ca00805f48a192");
private void EnsureAcceptEx(SafeSocketHandle socketHandle)
{
if (_acceptEx == null)
{
lock (_lockObject)
{
if (_acceptEx == null)
{
Guid guid = new Guid("{0xb5367df1,0xcbac,0x11cf,{0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}}");
IntPtr ptrAcceptEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _acceptEx, Marshal.GetDelegateForFunctionPointer<AcceptExDelegate>(ptrAcceptEx));
}
}
}
}
private void EnsureGetAcceptExSockaddrs(SafeSocketHandle socketHandle)
{
if (_getAcceptExSockaddrs == null)
{
lock (_lockObject)
{
if (_getAcceptExSockaddrs == null)
{
Guid guid = new Guid("{0xb5367df2,0xcbac,0x11cf,{0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}}");
IntPtr ptrGetAcceptExSockaddrs = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _getAcceptExSockaddrs, Marshal.GetDelegateForFunctionPointer<GetAcceptExSockaddrsDelegate>(ptrGetAcceptExSockaddrs));
}
}
}
}
internal GetAcceptExSockaddrsDelegate GetGetAcceptExSockaddrsDelegate(SafeSocketHandle socketHandle)
=> _getAcceptExSockaddrs ?? CreateDelegate(ref _getAcceptExSockaddrs, socketHandle, "b5367df2cbac11cf95ca00805f48a192");
private void EnsureConnectEx(SafeSocketHandle socketHandle)
{
if (_connectEx == null)
{
lock (_lockObject)
{
if (_connectEx == null)
{
Guid guid = new Guid("{0x25a207b9,0x0ddf3,0x4660,{0x8e,0xe9,0x76,0xe5,0x8c,0x74,0x06,0x3e}}");
IntPtr ptrConnectEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _connectEx, Marshal.GetDelegateForFunctionPointer<ConnectExDelegate>(ptrConnectEx));
}
}
}
}
internal ConnectExDelegate GetConnectExDelegate(SafeSocketHandle socketHandle)
=> _connectEx ?? CreateDelegate(ref _connectEx, socketHandle, "25a207b9ddf346608ee976e58c74063e");
private void EnsureDisconnectEx(SafeSocketHandle socketHandle)
{
if (_disconnectEx == null)
{
lock (_lockObject)
{
if (_disconnectEx == null)
{
Guid guid = new Guid("{0x7fda2e11,0x8630,0x436f,{0xa0, 0x31, 0xf5, 0x36, 0xa6, 0xee, 0xc1, 0x57}}");
IntPtr ptrDisconnectEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
_disconnectExBlocking = Marshal.GetDelegateForFunctionPointer<DisconnectExDelegateBlocking>(ptrDisconnectEx);
Volatile.Write(ref _disconnectEx, Marshal.GetDelegateForFunctionPointer<DisconnectExDelegate>(ptrDisconnectEx));
}
}
}
}
private void EnsureWSARecvMsg(SafeSocketHandle socketHandle)
{
if (_recvMsg == null)
{
lock (_lockObject)
{
if (_recvMsg == null)
{
Guid guid = new Guid("{0xf689d7c8,0x6f1f,0x436b,{0x8a,0x53,0xe5,0x4f,0xe3,0x51,0xc3,0x22}}");
IntPtr ptrWSARecvMsg = LoadDynamicFunctionPointer(socketHandle, ref guid);
_recvMsgBlocking = Marshal.GetDelegateForFunctionPointer<WSARecvMsgDelegateBlocking>(ptrWSARecvMsg);
Volatile.Write(ref _recvMsg, Marshal.GetDelegateForFunctionPointer<WSARecvMsgDelegate>(ptrWSARecvMsg));
}
}
}
}
internal DisconnectExDelegate GetDisconnectExDelegate(SafeSocketHandle socketHandle)
=> _disconnectEx ?? CreateDelegate(ref _disconnectEx, socketHandle, "7fda2e118630436fa031f536a6eec157");
private void EnsureWSARecvMsgBlocking(SafeSocketHandle socketHandle)
{
if (_recvMsgBlocking == null)
{
lock (_lockObject)
{
if (_recvMsgBlocking == null)
{
Guid guid = new Guid("{0xf689d7c8,0x6f1f,0x436b,{0x8a,0x53,0xe5,0x4f,0xe3,0x51,0xc3,0x22}}");
IntPtr ptrWSARecvMsg = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _recvMsgBlocking, Marshal.GetDelegateForFunctionPointer<WSARecvMsgDelegateBlocking>(ptrWSARecvMsg));
}
}
}
}
internal WSARecvMsgDelegate GetWSARecvMsgDelegate(SafeSocketHandle socketHandle)
=> _recvMsg ?? CreateDelegate(ref _recvMsg, socketHandle, "f689d7c86f1f436b8a53e54fe351c322");
private void EnsureTransmitPackets(SafeSocketHandle socketHandle)
{
if (_transmitPackets == null)
{
lock (_lockObject)
{
if (_transmitPackets == null)
{
Guid guid = new Guid("{0xd9689da0,0x1f90,0x11d3,{0x99,0x71,0x00,0xc0,0x4f,0x68,0xc8,0x76}}");
IntPtr ptrTransmitPackets = LoadDynamicFunctionPointer(socketHandle, ref guid);
Volatile.Write(ref _transmitPackets, Marshal.GetDelegateForFunctionPointer<TransmitPacketsDelegate>(ptrTransmitPackets));
}
}
}
}
internal TransmitPacketsDelegate GetTransmitPacketsDelegate(SafeSocketHandle socketHandle)
=> _transmitPackets ?? CreateDelegate(ref _transmitPackets, socketHandle, "d9689da01f9011d3997100c04f68c876");
}
[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
......@@ -302,13 +143,6 @@ private void EnsureTransmitPackets(SafeSocketHandle socketHandle)
int flags,
int reserved);
[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
internal delegate bool DisconnectExDelegateBlocking(
SafeSocketHandle socketHandle,
IntPtr overlapped,
int flags,
int reserved);
[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
internal unsafe delegate SocketError WSARecvMsgDelegate(
SafeSocketHandle socketHandle,
......@@ -317,14 +151,6 @@ private void EnsureTransmitPackets(SafeSocketHandle socketHandle)
NativeOverlapped* overlapped,
IntPtr completionRoutine);
[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
internal delegate SocketError WSARecvMsgDelegateBlocking(
SafeSocketHandle socketHandle,
IntPtr msg,
out int bytesTransferred,
IntPtr overlapped,
IntPtr completionRoutine);
[UnmanagedFunctionPointer(CallingConvention.StdCall, SetLastError = true)]
internal unsafe delegate bool TransmitPacketsDelegate(
SafeSocketHandle socketHandle,
......
......@@ -151,12 +151,9 @@ public Socket EndAccept(out byte[] buffer, out int bytesTransferred, IAsyncResul
return EndAcceptCommon(out buffer!, out bytesTransferred, asyncResult);
}
private void EnsureDynamicWinsockMethods()
private DynamicWinsockMethods GetDynamicWinsockMethods()
{
if (_dynamicWinsockMethods == null)
{
_dynamicWinsockMethods = DynamicWinsockMethods.GetMethods(_addressFamily, _socketType, _protocolType);
}
return _dynamicWinsockMethods ??= DynamicWinsockMethods.GetMethods(_addressFamily, _socketType, _protocolType);
}
internal unsafe bool AcceptEx(SafeSocketHandle listenSocketHandle,
......@@ -168,8 +165,7 @@ private void EnsureDynamicWinsockMethods()
out int bytesReceived,
NativeOverlapped* overlapped)
{
EnsureDynamicWinsockMethods();
AcceptExDelegate acceptEx = _dynamicWinsockMethods!.GetDelegate<AcceptExDelegate>(listenSocketHandle);
AcceptExDelegate acceptEx = GetDynamicWinsockMethods().GetAcceptExDelegate(listenSocketHandle);
return acceptEx(listenSocketHandle,
acceptSocketHandle,
......@@ -190,8 +186,7 @@ private void EnsureDynamicWinsockMethods()
out IntPtr remoteSocketAddress,
out int remoteSocketAddressLength)
{
EnsureDynamicWinsockMethods();
GetAcceptExSockaddrsDelegate getAcceptExSockaddrs = _dynamicWinsockMethods!.GetDelegate<GetAcceptExSockaddrsDelegate>(_handle);
GetAcceptExSockaddrsDelegate getAcceptExSockaddrs = GetDynamicWinsockMethods().GetGetAcceptExSockaddrsDelegate(_handle);
getAcceptExSockaddrs(buffer,
receiveDataLength,
......@@ -205,18 +200,16 @@ private void EnsureDynamicWinsockMethods()
internal unsafe bool DisconnectEx(SafeSocketHandle socketHandle, NativeOverlapped* overlapped, int flags, int reserved)
{
EnsureDynamicWinsockMethods();
DisconnectExDelegate disconnectEx = _dynamicWinsockMethods!.GetDelegate<DisconnectExDelegate>(socketHandle);
DisconnectExDelegate disconnectEx = GetDynamicWinsockMethods().GetDisconnectExDelegate(socketHandle);
return disconnectEx(socketHandle, overlapped, flags, reserved);
}
internal bool DisconnectExBlocking(SafeSocketHandle socketHandle, IntPtr overlapped, int flags, int reserved)
internal unsafe bool DisconnectExBlocking(SafeSocketHandle socketHandle, int flags, int reserved)
{
EnsureDynamicWinsockMethods();
DisconnectExDelegateBlocking disconnectEx_Blocking = _dynamicWinsockMethods!.GetDelegate<DisconnectExDelegateBlocking>(socketHandle);
DisconnectExDelegate disconnectEx = GetDynamicWinsockMethods().GetDisconnectExDelegate(socketHandle);
return disconnectEx_Blocking(socketHandle, overlapped, flags, reserved);
return disconnectEx(socketHandle, null, flags, reserved);
}
partial void WildcardBindForConnectIfNecessary(AddressFamily addressFamily)
......@@ -257,32 +250,28 @@ internal bool DisconnectExBlocking(SafeSocketHandle socketHandle, IntPtr overlap
out int bytesSent,
NativeOverlapped* overlapped)
{
EnsureDynamicWinsockMethods();
ConnectExDelegate connectEx = _dynamicWinsockMethods!.GetDelegate<ConnectExDelegate>(socketHandle);
ConnectExDelegate connectEx = GetDynamicWinsockMethods().GetConnectExDelegate(socketHandle);
return connectEx(socketHandle, socketAddress, socketAddressSize, buffer, dataLength, out bytesSent, overlapped);
}
internal unsafe SocketError WSARecvMsg(SafeSocketHandle socketHandle, IntPtr msg, out int bytesTransferred, NativeOverlapped* overlapped, IntPtr completionRoutine)
{
EnsureDynamicWinsockMethods();
WSARecvMsgDelegate recvMsg = _dynamicWinsockMethods!.GetDelegate<WSARecvMsgDelegate>(socketHandle);
WSARecvMsgDelegate recvMsg = GetDynamicWinsockMethods().GetWSARecvMsgDelegate(socketHandle);
return recvMsg(socketHandle, msg, out bytesTransferred, overlapped, completionRoutine);
}
internal SocketError WSARecvMsgBlocking(SafeSocketHandle socketHandle, IntPtr msg, out int bytesTransferred, IntPtr overlapped, IntPtr completionRoutine)
internal unsafe SocketError WSARecvMsgBlocking(SafeSocketHandle socketHandle, IntPtr msg, out int bytesTransferred)
{
EnsureDynamicWinsockMethods();
WSARecvMsgDelegateBlocking recvMsg_Blocking = _dynamicWinsockMethods!.GetDelegate<WSARecvMsgDelegateBlocking>(_handle);
WSARecvMsgDelegate recvMsg = GetDynamicWinsockMethods().GetWSARecvMsgDelegate(_handle);
return recvMsg_Blocking(socketHandle, msg, out bytesTransferred, overlapped, completionRoutine);
return recvMsg(socketHandle, msg, out bytesTransferred, null, IntPtr.Zero);
}
internal unsafe bool TransmitPackets(SafeSocketHandle socketHandle, IntPtr packetArray, int elementCount, int sendSize, NativeOverlapped* overlapped, TransmitFileOptions flags)
{
EnsureDynamicWinsockMethods();
TransmitPacketsDelegate transmitPackets = _dynamicWinsockMethods!.GetDelegate<TransmitPacketsDelegate>(socketHandle);
TransmitPacketsDelegate transmitPackets = GetDynamicWinsockMethods().GetTransmitPacketsDelegate(socketHandle);
return transmitPackets(socketHandle, packetArray, elementCount, sendSize, overlapped, flags);
}
......
......@@ -480,9 +480,7 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan
if (socket.WSARecvMsgBlocking(
handle,
(IntPtr)(&wsaMsg),
out bytesTransferred,
IntPtr.Zero,
IntPtr.Zero) == SocketError.SocketError)
out bytesTransferred) == SocketError.SocketError)
{
return GetLastSocketError();
}
......@@ -498,9 +496,7 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan
if (socket.WSARecvMsgBlocking(
handle,
(IntPtr)(&wsaMsg),
out bytesTransferred,
IntPtr.Zero,
IntPtr.Zero) == SocketError.SocketError)
out bytesTransferred) == SocketError.SocketError)
{
return GetLastSocketError();
}
......@@ -515,9 +511,7 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan
if (socket.WSARecvMsgBlocking(
handle,
(IntPtr)(&wsaMsg),
out bytesTransferred,
IntPtr.Zero,
IntPtr.Zero) == SocketError.SocketError)
out bytesTransferred) == SocketError.SocketError)
{
return GetLastSocketError();
}
......@@ -1372,7 +1366,7 @@ internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, b
SocketError errorCode = SocketError.Success;
// This can throw ObjectDisposedException (handle, and retrieving the delegate).
if (!socket.DisconnectExBlocking(handle, IntPtr.Zero, (int)(reuseSocket ? TransmitFileOptions.ReuseSocket : 0), 0))
if (!socket.DisconnectExBlocking(handle, (int)(reuseSocket ? TransmitFileOptions.ReuseSocket : 0), 0))
{
errorCode = GetLastSocketError();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册