未验证 提交 548b70db 编写于 作者: A Anton Firszov 提交者: GitHub

Fix DNS cancellation deadlock (#63904)

Avoid taking a lock, and address the use-after-free race condition by guarding GetAddrInfoExContext with a SafeHandle.
上级 135e566b
......@@ -8,6 +8,7 @@
using System.Threading;
using System.Threading.Tasks;
using System.Diagnostics;
using Microsoft.Win32.SafeHandles;
namespace System.Net
{
......@@ -138,17 +139,14 @@ public static unsafe string GetHostName()
{
Interop.Winsock.EnsureInitialized();
GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext();
GetAddrInfoExState state;
GetAddrInfoExState? state = null;
try
{
state = new GetAddrInfoExState(context, hostName, justAddresses);
context->QueryStateHandle = state.CreateHandle();
state = new GetAddrInfoExState(hostName, justAddresses);
}
catch
{
GetAddrInfoExContext.FreeContext(context);
state?.Dispose();
throw;
}
......@@ -158,6 +156,8 @@ public static unsafe string GetHostName()
hints.ai_flags = AddressInfoHints.AI_CANONNAME;
}
GetAddrInfoExContext* context = state.Context;
SocketError errorCode = (SocketError)Interop.Winsock.GetAddrInfoExW(
hostName, null, Interop.Winsock.NS_ALL, IntPtr.Zero, &hints, &context->Result, IntPtr.Zero, &context->Overlapped, &GetAddressInfoExCallback, &context->CancelHandle);
......@@ -172,7 +172,7 @@ public static unsafe string GetHostName()
// and final result would be posted via overlapped IO.
// synchronous failure here may signal issue when GetAddrInfoExW does not work from
// impersonated context. Windows 8 and Server 2012 fail for same reason with different errorCode.
GetAddrInfoExContext.FreeContext(context);
state.Dispose();
return null;
}
else
......@@ -194,10 +194,10 @@ private static unsafe void GetAddressInfoExCallback(int error, int bytes, Native
private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExContext* context)
{
GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle);
try
{
GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle);
CancellationToken cancellationToken = state.UnregisterAndGetCancellationToken();
if (errorCode == SocketError.Success)
......@@ -222,7 +222,7 @@ private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExCon
}
finally
{
GetAddrInfoExContext.FreeContext(context);
state.Dispose();
}
}
......@@ -360,18 +360,21 @@ private static unsafe IPAddress CreateIPv6Address(ReadOnlySpan<byte> socketAddre
return new IPAddress(address, scope);
}
private sealed unsafe class GetAddrInfoExState : IThreadPoolWorkItem
// GetAddrInfoExState is a SafeHandle that manages the lifetime of GetAddrInfoExContext*
// to make sure GetAddrInfoExCancel always takes a valid memory address regardless of the race
// between cancellation and completion callbacks.
private sealed unsafe class GetAddrInfoExState : SafeHandleZeroOrMinusOneIsInvalid, IThreadPoolWorkItem
{
private GetAddrInfoExContext* _cancellationContext;
private CancellationTokenRegistration _cancellationRegistration;
private AsyncTaskMethodBuilder<IPHostEntry> IPHostEntryBuilder;
private AsyncTaskMethodBuilder<IPAddress[]> IPAddressArrayBuilder;
private object? _result;
private volatile bool _completed;
public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool justAddresses)
public GetAddrInfoExState(string hostName, bool justAddresses)
: base(true)
{
_cancellationContext = context;
HostName = hostName;
JustAddresses = justAddresses;
if (justAddresses)
......@@ -384,6 +387,10 @@ public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool j
IPHostEntryBuilder = AsyncTaskMethodBuilder<IPHostEntry>.Create();
_ = IPHostEntryBuilder.Task; // force initialization
}
GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext();
context->QueryStateHandle = CreateHandle();
SetHandle((IntPtr)context);
}
public string HostName { get; }
......@@ -392,52 +399,62 @@ public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool j
public Task Task => JustAddresses ? (Task)IPAddressArrayBuilder.Task : IPHostEntryBuilder.Task;
internal GetAddrInfoExContext* Context => (GetAddrInfoExContext*)handle;
public void RegisterForCancellation(CancellationToken cancellationToken)
{
if (!cancellationToken.CanBeCanceled) return;
lock (this)
if (_completed)
{
if (_cancellationContext == null)
// The operation completed before registration could be done.
return;
}
_cancellationRegistration = cancellationToken.UnsafeRegister(static o =>
{
var @this = (GetAddrInfoExState)o!;
if (@this._completed)
{
// The operation completed before registration could be done.
// Escape early and avoid ObjectDisposedException in DangerousAddRef
return;
}
_cancellationRegistration = cancellationToken.UnsafeRegister(o =>
bool needRelease = false;
try
{
var @this = (GetAddrInfoExState)o!;
int cancelResult = 0;
@this.DangerousAddRef(ref needRelease);
lock (@this)
{
GetAddrInfoExContext* context = @this._cancellationContext;
if (context != null)
{
// An outstanding operation will be completed with WSA_E_CANCELLED, and GetAddrInfoExCancel will return NO_ERROR.
// If this thread has lost the race between cancellation and completion, this will be a NOP
// with GetAddrInfoExCancel returning WSA_INVALID_HANDLE.
cancelResult = Interop.Winsock.GetAddrInfoExCancel(&context->CancelHandle);
}
}
// If DangerousAddRef didn't throw ODE, the handle should contain a valid pointer.
GetAddrInfoExContext* context = @this.Context;
if (cancelResult != 0 && cancelResult != Interop.Winsock.WSA_INVALID_HANDLE && NetEventSource.Log.IsEnabled())
// An outstanding operation will be completed with WSA_E_CANCELLED, and GetAddrInfoExCancel will return NO_ERROR.
// If this thread has lost the race between cancellation and completion, this will be a NOP
// with GetAddrInfoExCancel returning WSA_INVALID_HANDLE.
int cancelResult = Interop.Winsock.GetAddrInfoExCancel(&context->CancelHandle);
if (cancelResult != Interop.Winsock.WSA_INVALID_HANDLE && NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(@this, $"GetAddrInfoExCancel returned error {cancelResult}");
}
}, this);
}
}
finally
{
if (needRelease)
{
@this.DangerousRelease();
}
}
}, this);
}
public CancellationToken UnregisterAndGetCancellationToken()
{
lock (this)
{
_cancellationContext = null;
_cancellationRegistration.Unregister();
}
_completed = true;
// We should not wait for pending cancellation callbacks with CTR.Dispose(),
// since we are in a completion routine and GetAddrInfoExCancel may get blocked until it's finished.
_cancellationRegistration.Unregister();
return _cancellationRegistration.Token;
}
......@@ -479,8 +496,6 @@ void IThreadPoolWorkItem.Execute()
}
}
public IntPtr CreateHandle() => GCHandle.ToIntPtr(GCHandle.Alloc(this, GCHandleType.Normal));
public static GetAddrInfoExState FromHandleAndFree(IntPtr handle)
{
GCHandle gcHandle = GCHandle.FromIntPtr(handle);
......@@ -488,6 +503,15 @@ public static GetAddrInfoExState FromHandleAndFree(IntPtr handle)
gcHandle.Free();
return state;
}
protected override bool ReleaseHandle()
{
GetAddrInfoExContext.FreeContext(Context);
return true;
}
private IntPtr CreateHandle() => GCHandle.ToIntPtr(GCHandle.Alloc(this, GCHandleType.Normal));
}
[StructLayout(LayoutKind.Sequential)]
......@@ -498,12 +522,7 @@ private unsafe struct GetAddrInfoExContext
public IntPtr CancelHandle;
public IntPtr QueryStateHandle;
public static GetAddrInfoExContext* AllocateContext()
{
var context = (GetAddrInfoExContext*)Marshal.AllocHGlobal(sizeof(GetAddrInfoExContext));
*context = default;
return context;
}
public static GetAddrInfoExContext* AllocateContext() => (GetAddrInfoExContext*)NativeMemory.AllocZeroed((nuint)sizeof(GetAddrInfoExContext));
public static void FreeContext(GetAddrInfoExContext* context)
{
......@@ -511,8 +530,7 @@ public static void FreeContext(GetAddrInfoExContext* context)
{
Interop.Winsock.FreeAddrInfoExW(context->Result);
}
Marshal.FreeHGlobal((IntPtr)context);
NativeMemory.Free(context);
}
}
}
......
......@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Linq;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
......@@ -170,10 +171,13 @@ public async Task DnsGetHostAddresses_PreCancelledToken_Throws()
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(() => Dns.GetHostAddressesAsync(TestSettings.LocalHost, cts.Token));
Assert.Equal(cts.Token, oce.CancellationToken);
}
}
[OuterLoop]
// Cancellation tests are sequential to reduce the chance of timing issues.
[Collection(nameof(DisableParallelization))]
public class GetHostAddressesTest_Cancellation
{
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/43816")] // Race condition outlined below.
[ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix.
public async Task DnsGetHostAddresses_PostCancelledToken_Throws()
{
......@@ -188,5 +192,35 @@ public async Task DnsGetHostAddresses_PostCancelledToken_Throws()
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(() => task);
Assert.Equal(cts.Token, oce.CancellationToken);
}
// This is a regression test for https://github.com/dotnet/runtime/issues/63552
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix.
public async Task DnsGetHostAddresses_ResolveParallelCancelOnFailure_AllCallsReturn()
{
string invalidAddress = TestSettings.UncachedHost;
await ResolveManyAsync(invalidAddress);
await ResolveManyAsync(invalidAddress, TestSettings.LocalHost)
.WaitAsync(TestSettings.PassingTestTimeout);
static async Task ResolveManyAsync(params string[] addresses)
{
using CancellationTokenSource cts = new();
Task[] resolveTasks = addresses.Select(a => ResolveOneAsync(a, cts)).ToArray();
await Task.WhenAll(resolveTasks);
}
static async Task ResolveOneAsync(string address, CancellationTokenSource cancellationTokenSource)
{
try
{
await Dns.GetHostAddressesAsync(address, cancellationTokenSource.Token);
}
catch (Exception)
{
cancellationTokenSource.Cancel();
}
}
}
}
}
......@@ -309,13 +309,21 @@ public async Task DnsGetHostEntry_PreCancelledToken_Throws()
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(() => Dns.GetHostEntryAsync(TestSettings.LocalHost, cts.Token));
Assert.Equal(cts.Token, oce.CancellationToken);
}
}
// Cancellation tests are sequential to reduce the chance of timing issues.
[Collection(nameof(DisableParallelization))]
public class GetHostEntryTest_Cancellation
{
[OuterLoop]
[ActiveIssue("https://github.com/dotnet/runtime/issues/43816")] // Race condition outlined below.
[ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix.
[Fact]
public async Task DnsGetHostEntry_PostCancelledToken_Throws()
{
// Windows 7 name resolution is synchronous and does not respect cancellation.
if (PlatformDetection.IsWindows7)
return;
using var cts = new CancellationTokenSource();
Task task = Dns.GetHostEntryAsync(TestSettings.UncachedHost, cts.Token);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册