提交 21151a0e 编写于 作者: J Jason Malinowski

Merge pull request #706 from jasonmalinowski/fix-asynclazy-cancellationtoken

Fix AsyncLazy CancellationToken handling
......@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.ErrorReporting;
......@@ -80,13 +81,6 @@ public AsyncLazy(T value)
_cachedResult = Task.FromResult(value);
}
/// <summary>
/// Important: callers of this constructor should ensure that the compute function returns
/// a task in a non-blocking fashion. i.e. the function should *not* synchronously compute
/// a value and then return it using Task.FromResult. Instead, it should return an actual
/// task that operates asynchronously. If this function synchronously computes a value
/// then that will cause locks to be held in this type for excessive periods of time.
/// </summary>
public AsyncLazy(Func<CancellationToken, Task<T>> asynchronousComputeFunction, bool cacheResult)
: this(asynchronousComputeFunction, synchronousComputeFunction: null, cacheResult: cacheResult)
{
......@@ -226,7 +220,7 @@ public override T GetValue(CancellationToken cancellationToken)
// cancelled this new computation if we were the only requestor.
if (newAsynchronousComputation != null)
{
StartAsynchronousComputation(newAsynchronousComputation.Value, requestToCompleteSynchronously: request);
StartAsynchronousComputation(newAsynchronousComputation.Value, requestToCompleteSynchronously: request, requestToCompleteSynchronouslyCancellationToken: cancellationToken);
}
return request.Task.WaitAndGetResult(cancellationToken);
......@@ -260,7 +254,7 @@ public override T GetValue(CancellationToken cancellationToken)
if (newAsynchronousComputation != null)
{
StartAsynchronousComputation(newAsynchronousComputation.Value, requestToCompleteSynchronously: null);
StartAsynchronousComputation(newAsynchronousComputation.Value, requestToCompleteSynchronously: null, requestToCompleteSynchronouslyCancellationToken: cancellationToken);
}
throw;
......@@ -330,7 +324,7 @@ public override Task<T> GetValueAsync(CancellationToken cancellationToken)
if (newAsynchronousComputation != null)
{
StartAsynchronousComputation(newAsynchronousComputation.Value, requestToCompleteSynchronously: request);
StartAsynchronousComputation(newAsynchronousComputation.Value, requestToCompleteSynchronously: request, requestToCompleteSynchronouslyCancellationToken: cancellationToken);
}
return request.Task;
......@@ -358,7 +352,7 @@ public AsynchronousComputationToStart(Func<CancellationToken, Task<T>> asynchron
}
}
private void StartAsynchronousComputation(AsynchronousComputationToStart computationToStart, Request requestToCompleteSynchronously)
private void StartAsynchronousComputation(AsynchronousComputationToStart computationToStart, Request requestToCompleteSynchronously, CancellationToken requestToCompleteSynchronouslyCancellationToken)
{
var cancellationToken = computationToStart.CancellationTokenSource.Token;
......@@ -396,19 +390,24 @@ private void StartAsynchronousComputation(AsynchronousComputationToStart computa
requestToCompleteSynchronously.CompleteFromTaskSynchronously(task);
}
}
catch (Exception e) when(FatalError.ReportUnlessCanceled(e))
{
throw ExceptionUtilities.Unreachable;
}
}
catch (OperationCanceledException oce) when(CrashIfCanceledWithDifferentToken(oce, cancellationToken))
catch (Exception e) when (FatalError.ReportUnlessCanceled(e))
{
// As long as it's the right token, this means that our thread was the first thread
// to start an asynchronous computation, but the requestor cancelled as we were starting up
// the computation.
throw ExceptionUtilities.Unreachable;
}
}
}
catch (OperationCanceledException oce) when (CrashIfCanceledWithDifferentToken(oce, cancellationToken))
{
// The underlying computation cancelled with the correct token, but we must ourselves ensure that the caller
// on our stack gets an OperationCanceledException thrown with the right token
requestToCompleteSynchronouslyCancellationToken.ThrowIfCancellationRequested();
// We can only be here if the computation was cancelled, which means all requests for the value
// must have been cancelled. Therefore, the ThrowIfCancellationRequested above must have thrown
// because that token from the requestor was cancelled.
throw ExceptionUtilities.Unreachable;
}
}
private static bool CrashIfCanceledWithDifferentToken(OperationCanceledException exception, CancellationToken cancellationToken)
{
if (exception.CancellationToken != cancellationToken)
......@@ -416,7 +415,7 @@ private static bool CrashIfCanceledWithDifferentToken(OperationCanceledException
FatalError.Report(exception);
}
return false;
return true;
}
private void CompleteWithTask(Task<T> task, CancellationToken cancellationToken)
......@@ -508,13 +507,38 @@ private void OnAsynchronousRequestCancelled(object state)
}
}
// Using inheritance instead of wrapping a TaskCompletionSource to avoid a second allocation
private class Request : TaskCompletionSource<T>
private sealed class Request
{
/// <summary>
/// The <see cref="CancellationToken"/> associated with this request. This field will be initialized before
/// any cancellation is observed from the token.
/// </summary>
private CancellationToken _cancellationToken;
private CancellationTokenRegistration _cancellationTokenRegistration;
// We use a AsyncTaskMethodBuilder so we have the ability to cancel the task with a given cancellation token
// TODO: remove this once we're on .NET 4.6 and can move back to using TaskCompletionSource.
// WARNING: this is a mutable struct, and thus cannot be made readonly
private AsyncTaskMethodBuilder<T> _taskBuilder;
public Request()
{
// .Task on AsyncTaskMethodBuilder is lazily created in a non-synchronized way, so we must request it
// once before we start doing fancy stuff
var ignored = _taskBuilder.Task;
}
public Task<T> Task
{
get
{
return _taskBuilder.Task;
}
}
public void RegisterForCancellation(Action<object> callback, CancellationToken cancellationToken)
{
_cancellationToken = cancellationToken;
_cancellationTokenRegistration = cancellationToken.Register(callback, this);
}
......@@ -530,24 +554,36 @@ private void CompleteFromTaskSynchronouslyStub(object task)
public void CompleteFromTaskSynchronously(Task<T> task)
{
if (task.Status == TaskStatus.RanToCompletion)
// AsyncTaskMethodBuilder doesn't give us Try* methods, and the Set methods may throw if the task
// is already completed. The belief is that the race is somewhere between rare to impossible, and
// so we'll do a quick check to see if the task is already completed or otherwise just give it a shot
// and catch it if it fails
if (_taskBuilder.Task.IsCompleted)
{
if (TrySetResult(task.Result))
{
_cancellationTokenRegistration.Dispose();
}
return;
}
else if (task.Status == TaskStatus.Faulted)
try
{
if (TrySetException(task.Exception))
if (task.Status == TaskStatus.RanToCompletion)
{
_cancellationTokenRegistration.Dispose();
_taskBuilder.SetResult(task.Result);
}
else if (task.Status == TaskStatus.Faulted)
{
_taskBuilder.SetException(task.Exception);
}
else
{
CancelSynchronously();
}
}
else
catch (InvalidOperationException)
{
CancelSynchronously();
// Something else beat us to setting the state, so bail
}
_cancellationTokenRegistration.Dispose();
}
public void CancelAsynchronously()
......@@ -559,11 +595,22 @@ public void CancelAsynchronously()
private void CancelSynchronously()
{
if (TrySetCanceled())
// AsyncTaskMethodBuilder doesn't give us Try* methods, and the Set methods may throw if the task
// is already completed. The belief is that the race is somewhere between rare to impossible, and
// so we'll do a quick check to see if the task is already completed or otherwise just give it a shot
// and catch it if it fails
if (_taskBuilder.Task.IsCompleted)
{
return;
}
try
{
_taskBuilder.SetException(new OperationCanceledException(_cancellationToken));
}
catch (InvalidOperationException)
{
// Paranoia: the only reason we should ever get here is if the CancellationToken that
// we registered against was cancelled, but just in case, dispose the registration
_cancellationTokenRegistration.Dispose();
// Something else beat us to setting the state, so bail
}
}
}
......
......@@ -4,6 +4,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.Test.Utilities;
using Roslyn.Test.Utilities;
using Roslyn.Utilities;
using Xunit;
......@@ -216,5 +217,114 @@ public void SynchronousRequestShouldCacheValueWithAsynchronousComputeFunction()
Assert.Same(secondRequestResult, firstRequestResult);
}
[Fact]
[Trait(Traits.Feature, Traits.Features.AsyncLazy)]
public void GetValueThrowsCorrectExceptionDuringCancellation()
{
GetValueOrGetValueAsyncThrowsCorrectExceptionDuringCancellation((lazy, ct) => lazy.GetValue(ct), includeSynchronousComputation: false);
}
[Fact]
[Trait(Traits.Feature, Traits.Features.AsyncLazy)]
public void GetValueThrowsCorrectExceptionDuringCancellationWithSynchronousComputation()
{
GetValueOrGetValueAsyncThrowsCorrectExceptionDuringCancellation((lazy, ct) => lazy.GetValue(ct), includeSynchronousComputation: true);
}
[Fact]
[Trait(Traits.Feature, Traits.Features.AsyncLazy)]
public void GetValueAsyncThrowsCorrectExceptionDuringCancellation()
{
// NOTE: since GetValueAsync inlines the call to the async computation, the GetValueAsync call will throw
// immediately instead of returning a task that transitions to the cancelled state
GetValueOrGetValueAsyncThrowsCorrectExceptionDuringCancellation((lazy, ct) => lazy.GetValueAsync(ct), includeSynchronousComputation: false);
}
[Fact]
[Trait(Traits.Feature, Traits.Features.AsyncLazy)]
public void GetValueAsyncThrowsCorrectExceptionDuringCancellationWithSynchronousComputation()
{
// In theory the synchronous computation isn't used during GetValueAsync, but just in case...
GetValueOrGetValueAsyncThrowsCorrectExceptionDuringCancellation((lazy, ct) => lazy.GetValueAsync(ct), includeSynchronousComputation: true);
}
private static void GetValueOrGetValueAsyncThrowsCorrectExceptionDuringCancellation(Action<AsyncLazy<object>, CancellationToken> doGetValue, bool includeSynchronousComputation)
{
// A call to GetValue/GetValueAsync with a token that is cancelled should throw an OperationCancelledException, but it's
// important to make sure the correct token is cancelled. It should be cancelled with the token passed
// to GetValue, not the cancellation that was thrown by the computation function
var computeFunctionRunning = new ManualResetEvent(initialState: false);
AsyncLazy<object> lazy;
Func<CancellationToken, object> synchronousComputation = null;
if (includeSynchronousComputation)
{
synchronousComputation = c =>
{
computeFunctionRunning.Set();
while (true)
{
c.ThrowIfCancellationRequested();
}
};
}
lazy = new AsyncLazy<object>(c =>
{
computeFunctionRunning.Set();
while (true)
{
c.ThrowIfCancellationRequested();
}
}, synchronousComputeFunction: synchronousComputation, cacheResult: false);
var cancellationTokenSource = new CancellationTokenSource();
// Create a task that will cancel the request once it's started
Task.Run(() => { computeFunctionRunning.WaitOne(); cancellationTokenSource.Cancel(); });
try
{
doGetValue(lazy, cancellationTokenSource.Token);
AssertEx.Fail(nameof(AsyncLazy<object>.GetValue) + " did not throw an exception.");
}
catch (OperationCanceledException oce)
{
Assert.Equal(cancellationTokenSource.Token, oce.CancellationToken);
}
}
[Fact]
[Trait(Traits.Feature, Traits.Features.AsyncLazy)]
public void GetValueAsyncThatIsCancelledReturnsTaskCancelledWithCorrectToken()
{
var cancellationTokenSource = new CancellationTokenSource();
var lazy = new AsyncLazy<object>(c => Task.Run((Func<object>)(() =>
{
cancellationTokenSource.Cancel();
while (true)
{
c.ThrowIfCancellationRequested();
}
}), c), cacheResult: true);
var task = lazy.GetValueAsync(cancellationTokenSource.Token);
// Now wait until the task completes
try
{
task.Wait();
AssertEx.Fail(nameof(AsyncLazy<object>.GetValueAsync) + " did not throw an exception.");
}
catch (AggregateException ex)
{
var operationCancelledException = (OperationCanceledException)ex.Flatten().InnerException;
Assert.Equal(cancellationTokenSource.Token, operationCancelledException.CancellationToken);
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册