提交 8df386bb 编写于 作者: C CyrusNajmabadi 提交者: GitHub

Merge pull request #14225 from CyrusNajmabadi/removeAwaitFromCallers

Remove 'await' from callers when we make methods synchronous.

Fixes #13961
......@@ -352,5 +352,203 @@ int Bar()
}
}", compareTokens: false, fixAllActionEquivalenceKey: AbstractMakeMethodSynchronousCodeFixProvider.EquivalenceKey);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)]
[WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")]
public async Task TestRemoveAwaitFromCaller1()
{
await TestAsync(
@"using System.Threading.Tasks;
public class Class1
{
async Task [|FooAsync|]()
{
}
async void BarAsync()
{
await FooAsync();
}
}",
@"using System.Threading.Tasks;
public class Class1
{
void Foo()
{
}
async void BarAsync()
{
Foo();
}
}", compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)]
[WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")]
public async Task TestRemoveAwaitFromCaller2()
{
await TestAsync(
@"using System.Threading.Tasks;
public class Class1
{
async Task [|FooAsync|]()
{
}
async void BarAsync()
{
await FooAsync().ConfigureAwait(false);
}
}",
@"using System.Threading.Tasks;
public class Class1
{
void Foo()
{
}
async void BarAsync()
{
Foo();
}
}", compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)]
[WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")]
public async Task TestRemoveAwaitFromCaller3()
{
await TestAsync(
@"using System.Threading.Tasks;
public class Class1
{
async Task [|FooAsync|]()
{
}
async void BarAsync()
{
await this.FooAsync();
}
}",
@"using System.Threading.Tasks;
public class Class1
{
void Foo()
{
}
async void BarAsync()
{
this.Foo();
}
}", compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)]
[WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")]
public async Task TestRemoveAwaitFromCaller4()
{
await TestAsync(
@"using System.Threading.Tasks;
public class Class1
{
async Task [|FooAsync|]()
{
}
async void BarAsync()
{
await this.FooAsync().ConfigureAwait(false);
}
}",
@"using System.Threading.Tasks;
public class Class1
{
void Foo()
{
}
async void BarAsync()
{
this.Foo();
}
}", compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)]
[WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")]
public async Task TestRemoveAwaitFromCallerNested1()
{
await TestAsync(
@"using System.Threading.Tasks;
public class Class1
{
async Task<int> [|FooAsync|](int i)
{
}
async void BarAsync()
{
await this.FooAsync(await this.FooAsync(0));
}
}",
@"using System.Threading.Tasks;
public class Class1
{
int Foo(int i)
{
}
async void BarAsync()
{
this.Foo(this.Foo(0));
}
}", compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)]
[WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")]
public async Task TestRemoveAwaitFromCallerNested()
{
await TestAsync(
@"using System.Threading.Tasks;
public class Class1
{
async Task<int> [|FooAsync|](int i)
{
}
async void BarAsync()
{
await this.FooAsync(await this.FooAsync(0).ConfigureAwait(false)).ConfigureAwait(false);
}
}",
@"using System.Threading.Tasks;
public class Class1
{
int Foo(int i)
{
}
async void BarAsync()
{
this.Foo(this.Foo(0));
}
}", compareTokens: false);
}
}
}
\ No newline at end of file
}
......@@ -233,5 +233,161 @@ Class C
End Class",
compareTokens:=False)
End Function
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)>
<WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")>
Public Async Function TestRemoveAwaitFromCaller1() As Task
Await TestAsync(
"Imports System.Threading.Tasks;
Public Class Class1
Async Function [|FooAsync|]() As Task
End Function
Async Sub BarAsync()
Await FooAsync()
End Sub
End Class",
"Imports System.Threading.Tasks;
Public Class Class1
Sub Foo()
End Sub
Async Sub BarAsync()
Foo()
End Sub
End Class", compareTokens:=False)
End Function
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)>
<WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")>
Public Async Function TestRemoveAwaitFromCaller2() As Task
Await TestAsync(
"Imports System.Threading.Tasks;
Public Class Class1
Async Function [|FooAsync|]() As Task
End Function
Async Sub BarAsync()
Await FooAsync().ConfigureAwait(false)
End Sub
End Class",
"Imports System.Threading.Tasks;
Public Class Class1
Sub Foo()
End Sub
Async Sub BarAsync()
Foo()
End Sub
End Class", compareTokens:=False)
End Function
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)>
<WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")>
Public Async Function TestRemoveAwaitFromCaller3() As Task
Await TestAsync(
"Imports System.Threading.Tasks;
Public Class Class1
Async Function [|FooAsync|]() As Task
End Function
Async Sub BarAsync()
Await Me.FooAsync()
End Sub
End Class",
"Imports System.Threading.Tasks;
Public Class Class1
Sub Foo()
End Sub
Async Sub BarAsync()
Me.Foo()
End Sub
End Class", compareTokens:=False)
End Function
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)>
<WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")>
Public Async Function TestRemoveAwaitFromCaller4() As Task
Await TestAsync(
"Imports System.Threading.Tasks;
Public Class Class1
Async Function [|FooAsync|]() As Task
End Function
Async Sub BarAsync()
Await Me.FooAsync().ConfigureAwait(false)
End Sub
End Class",
"Imports System.Threading.Tasks;
Public Class Class1
Sub Foo()
End Sub
Async Sub BarAsync()
Me.Foo()
End Sub
End Class", compareTokens:=False)
End Function
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)>
<WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")>
Public Async Function TestRemoveAwaitFromCallerNested1() As Task
Await TestAsync(
"Imports System.Threading.Tasks;
Public Class Class1
Async Function [|FooAsync|](i As Integer) As Task(Of Integer)
End Function
Async Sub BarAsync()
Await FooAsync(Await FooAsync(0))
End Sub
End Class",
"Imports System.Threading.Tasks;
Public Class Class1
Function Foo(i As Integer) As Integer
End Function
Async Sub BarAsync()
Foo(Foo(0))
End Sub
End Class", compareTokens:=False)
End Function
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)>
<WorkItem(13961, "https://github.com/dotnet/roslyn/issues/13961")>
Public Async Function TestRemoveAwaitFromCallerNested2() As Task
Await TestAsync(
"Imports System.Threading.Tasks;
Public Class Class1
Async Function [|FooAsync|](i As Integer) As Task(Of Integer)
End Function
Async Sub BarAsync()
Await Me.FooAsync(Await Me.FooAsync(0).ConfigureAwait(false)).ConfigureAwait(false)
End Sub
End Class",
"Imports System.Threading.Tasks;
Public Class Class1
Function Foo(i As Integer) As Integer
End Function
Async Sub BarAsync()
Me.Foo(Me.Foo(0))
End Sub
End Class", compareTokens:=False)
End Function
End Class
End Namespace
\ No newline at end of file
......@@ -250,7 +250,7 @@ protected override bool CanAddImport(SyntaxNode node, CancellationToken cancella
return false;
}
if (!syntaxFacts.IsMemberAccessExpressionName(node))
if (!syntaxFacts.IsNameOfMemberAccessExpression(node))
{
return false;
}
......
......@@ -545,7 +545,7 @@ private async Task<IList<SymbolReference>> GetNamespacesForCollectionInitializer
// 'Black' did not bind. We want to find a type called 'Color' that will actually
// allow 'Black' to bind.
var syntaxFacts = this._document.GetLanguageService<ISyntaxFactsService>();
if (!syntaxFacts.IsMemberAccessExpressionName(nameNode))
if (!syntaxFacts.IsNameOfMemberAccessExpression(nameNode))
{
return null;
}
......
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Rename;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
......@@ -71,7 +76,7 @@ private async Task<Solution> RenameThenRemoveAsyncTokenAsync(Document document,
var newRoot = await newDocument.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
SyntaxNode newNode;
if (syntaxPath.TryResolve<SyntaxNode>(newRoot, out newNode))
if (syntaxPath.TryResolve(newRoot, out newNode))
{
var semanticModel = await newDocument.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var newMethod = (IMethodSymbol)semanticModel.GetDeclaredSymbol(newNode, cancellationToken);
......@@ -81,20 +86,167 @@ private async Task<Solution> RenameThenRemoveAsyncTokenAsync(Document document,
return newSolution;
}
private async Task<Solution> RemoveAsyncTokenAsync(Document document, IMethodSymbol methodSymbolOpt, SyntaxNode node, CancellationToken cancellationToken)
private async Task<Solution> RemoveAsyncTokenAsync(
Document document, IMethodSymbol methodSymbolOpt, SyntaxNode node, CancellationToken cancellationToken)
{
var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
var taskType = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
var taskOfTType = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1");
var annotation = new SyntaxAnnotation();
var newNode = RemoveAsyncTokenAndFixReturnType(methodSymbolOpt, node, taskType, taskOfTType)
.WithAdditionalAnnotations(Formatter.Annotation);
.WithAdditionalAnnotations(Formatter.Annotation, annotation);
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var newRoot = root.ReplaceNode(node, newNode);
var newDocument = document.WithSyntaxRoot(newRoot);
return newDocument.Project.Solution;
var newSolution = newDocument.Project.Solution;
if (methodSymbolOpt == null)
{
return newSolution;
}
return await RemoveAwaitFromCallersAsync(
newDocument, annotation, cancellationToken).ConfigureAwait(false) ;
}
private async Task<Solution> RemoveAwaitFromCallersAsync(
Document document, SyntaxAnnotation annotation, CancellationToken cancellationToken)
{
var syntaxRoot = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var methodDeclaration = syntaxRoot.GetAnnotatedNodes(annotation).FirstOrDefault();
if (methodDeclaration != null)
{
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var methodSymbol = semanticModel.GetDeclaredSymbol(methodDeclaration) as IMethodSymbol;
if (methodSymbol != null)
{
var references = await SymbolFinder.FindRenamableReferencesAsync(
new SymbolAndProjectId(methodSymbol, document.Project.Id),
document.Project.Solution, cancellationToken).ConfigureAwait(false);
var referencedSymbol = references.FirstOrDefault(r => Equals(r.Definition, methodSymbol));
if (referencedSymbol != null)
{
return await RemoveAwaitFromCallersAsync(
document.Project.Solution, referencedSymbol.Locations.ToImmutableArray(), cancellationToken).ConfigureAwait(false);
}
}
}
return document.Project.Solution;
}
private async Task<Solution> RemoveAwaitFromCallersAsync(
Solution solution, ImmutableArray<ReferenceLocation> locations, CancellationToken cancellationToken)
{
var currentSolution = solution;
var groupedLocations = locations.GroupBy(loc => loc.Document);
foreach (var group in groupedLocations)
{
currentSolution = await RemoveAwaitFromCallersAsync(
currentSolution, group, cancellationToken).ConfigureAwait(false);
}
return currentSolution;
}
private async Task<Solution> RemoveAwaitFromCallersAsync(
Solution currentSolution, IGrouping<Document, ReferenceLocation> group, CancellationToken cancellationToken)
{
var document = group.Key;
var syntaxFactsService = document.GetLanguageService<ISyntaxFactsService>();
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var editor = new SyntaxEditor(root, currentSolution.Workspace);
foreach (var location in group)
{
RemoveAwaitFromCallerIfPresent(editor, syntaxFactsService, root, location, cancellationToken);
}
var newRoot = editor.GetChangedRoot();
return currentSolution.WithDocumentSyntaxRoot(document.Id, newRoot);
}
private void RemoveAwaitFromCallerIfPresent(
SyntaxEditor editor, ISyntaxFactsService syntaxFacts,
SyntaxNode root, ReferenceLocation referenceLocation,
CancellationToken cancellationToken)
{
if (referenceLocation.IsImplicit)
{
return;
}
var location = referenceLocation.Location;
var token = location.FindToken(cancellationToken);
var nameNode = token.Parent;
if (nameNode == null)
{
return;
}
// Look for the following forms:
// await M(...)
// await <expr>.M(...)
// await M(...).ConfigureAwait(...)
// await <expr>.M(...).ConfigureAwait(...)
var expressionNode = nameNode;
if (syntaxFacts.IsNameOfMemberAccessExpression(nameNode))
{
expressionNode = nameNode.Parent;
}
if (!syntaxFacts.IsExpressionOfInvocationExpression(expressionNode))
{
return;
}
// We now either have M(...) or <expr>.M(...)
var invocationExpression = expressionNode.Parent;
Debug.Assert(syntaxFacts.IsInvocationExpression(invocationExpression));
if (syntaxFacts.IsExpressionOfAwaitExpression(invocationExpression))
{
// Handle the case where we're directly awaited.
var awaitExpression = invocationExpression.Parent;
editor.ReplaceNode(awaitExpression, (currentAwaitExpression, generator) =>
syntaxFacts.GetExpressionOfAwaitExpression(currentAwaitExpression)
.WithTriviaFrom(currentAwaitExpression));
}
else if (syntaxFacts.IsExpressionOfMemberAccessExpression(invocationExpression))
{
// Check for the .ConfigureAwait case.
var parentMemberAccessExpression = invocationExpression.Parent;
var parentMemberAccessExpressionNameNode = syntaxFacts.GetNameOfMemberAccessExpression(
parentMemberAccessExpression);
var parentMemberAccessExpressionName = syntaxFacts.GetIdentifierOfSimpleName(parentMemberAccessExpressionNameNode).ValueText;
if (parentMemberAccessExpressionName == nameof(Task.ConfigureAwait))
{
var parentExpression = parentMemberAccessExpression.Parent;
if (syntaxFacts.IsExpressionOfAwaitExpression(parentExpression))
{
var awaitExpression = parentExpression.Parent;
editor.ReplaceNode(awaitExpression, (currentAwaitExpression, generator) =>
{
var currentConfigureAwaitInvocation = syntaxFacts.GetExpressionOfAwaitExpression(currentAwaitExpression);
var currentMemberAccess = syntaxFacts.GetExpressionOfInvocationExpression(currentConfigureAwaitInvocation);
var currentInvocationExpression = syntaxFacts.GetExpressionOfMemberAccessExpression(currentMemberAccess);
return currentInvocationExpression.WithTriviaFrom(currentAwaitExpression);
});
}
}
}
}
private class MyCodeAction : CodeAction.SolutionChangeAction
......
......@@ -94,7 +94,7 @@ private struct ReferenceReplacer
_identifierName = (TIdentifierNameSyntax)nameToken.Parent;
_expression = _identifierName;
if (_syntaxFacts.IsMemberAccessExpressionName(_expression))
if (_syntaxFacts.IsNameOfMemberAccessExpression(_expression))
{
_expression = _expression.Parent as TExpressionSyntax;
}
......
......@@ -150,7 +150,7 @@ public bool IsRightSideOfQualifiedName(SyntaxNode node)
return name.IsRightSideOfQualifiedName();
}
public bool IsMemberAccessExpressionName(SyntaxNode node)
public bool IsNameOfMemberAccessExpression(SyntaxNode node)
{
var name = node as SimpleNameSyntax;
return name.IsMemberAccessExpressionName();
......@@ -1839,6 +1839,31 @@ public bool AreEquivalent(SyntaxNode node1, SyntaxNode node2)
return SyntaxFactory.AreEquivalent(node1, node2);
}
public bool IsExpressionOfInvocationExpression(SyntaxNode node)
{
return node != null && (node.Parent as InvocationExpressionSyntax)?.Expression == node;
}
public bool IsExpressionOfAwaitExpression(SyntaxNode node)
{
return node != null && (node.Parent as AwaitExpressionSyntax)?.Expression == node;
}
public bool IsExpressionOfMemberAccessExpression(SyntaxNode node)
{
return node != null && (node.Parent as MemberAccessExpressionSyntax)?.Expression == node;
}
public SyntaxNode GetExpressionOfInvocationExpression(SyntaxNode node)
{
return ((InvocationExpressionSyntax)node).Expression;
}
public SyntaxNode GetExpressionOfAwaitExpression(SyntaxNode node)
{
return ((AwaitExpressionSyntax)node).Expression;
}
private class AddFirstMissingCloseBaceRewriter: CSharpSyntaxRewriter
{
private readonly SyntaxNode _contextNode;
......
......@@ -52,6 +52,11 @@ internal interface ISyntaxFactsService : ILanguageService
SyntaxNode GetObjectCreationInitializer(SyntaxNode objectCreationExpression);
bool IsInvocationExpression(SyntaxNode node);
bool IsExpressionOfInvocationExpression(SyntaxNode node);
SyntaxNode GetExpressionOfInvocationExpression(SyntaxNode node);
bool IsExpressionOfAwaitExpression(SyntaxNode node);
SyntaxNode GetExpressionOfAwaitExpression(SyntaxNode node);
// Left side of = assignment.
bool IsLeftSideOfAssignment(SyntaxNode node);
......@@ -72,7 +77,9 @@ internal interface ISyntaxFactsService : ILanguageService
bool IsRightSideOfQualifiedName(SyntaxNode node);
bool IsMemberAccessExpressionName(SyntaxNode node);
bool IsNameOfMemberAccessExpression(SyntaxNode node);
bool IsExpressionOfMemberAccessExpression(SyntaxNode node);
SyntaxNode GetNameOfMemberAccessExpression(SyntaxNode memberAccessExpression);
SyntaxNode GetExpressionOfMemberAccessExpression(SyntaxNode memberAccessExpression);
SyntaxToken GetOperatorTokenOfMemberAccessExpression(SyntaxNode memberAccessExpression);
......
......@@ -142,7 +142,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic
Return vbNode IsNot Nothing AndAlso vbNode.IsRightSideOfQualifiedName()
End Function
Public Function IsMemberAccessExpressionName(node As SyntaxNode) As Boolean Implements ISyntaxFactsService.IsMemberAccessExpressionName
Public Function IsNameOfMemberAccessExpression(node As SyntaxNode) As Boolean Implements ISyntaxFactsService.IsNameOfMemberAccessExpression
Dim vbNode = TryCast(node, SimpleNameSyntax)
Return vbNode IsNot Nothing AndAlso vbNode.IsMemberAccessExpressionName()
End Function
......@@ -1516,5 +1516,25 @@ Namespace Microsoft.CodeAnalysis.VisualBasic
Public Function AreEquivalent(node1 As SyntaxNode, node2 As SyntaxNode) As Boolean Implements ISyntaxFactsService.AreEquivalent
Return SyntaxFactory.AreEquivalent(node1, node2)
End Function
Public Function IsExpressionOfInvocationExpression(node As SyntaxNode) As Boolean Implements ISyntaxFactsService.IsExpressionOfInvocationExpression
Return node IsNot Nothing AndAlso TryCast(node.Parent, InvocationExpressionSyntax)?.Expression Is node
End Function
Public Function IsExpressionOfAwaitExpression(node As SyntaxNode) As Boolean Implements ISyntaxFactsService.IsExpressionOfAwaitExpression
Return node IsNot Nothing AndAlso TryCast(node.Parent, AwaitExpressionSyntax)?.Expression Is node
End Function
Public Function IsExpressionOfMemberAccessExpression(node As SyntaxNode) As Boolean Implements ISyntaxFactsService.IsExpressionOfMemberAccessExpression
Return node IsNot Nothing AndAlso TryCast(node.Parent, MemberAccessExpressionSyntax)?.Expression Is node
End Function
Public Function GetExpressionOfInvocationExpression(node As SyntaxNode) As SyntaxNode Implements ISyntaxFactsService.GetExpressionOfInvocationExpression
Return DirectCast(node, InvocationExpressionSyntax).Expression
End Function
Public Function GetExpressionOfAwaitExpression(node As SyntaxNode) As SyntaxNode Implements ISyntaxFactsService.GetExpressionOfAwaitExpression
Return DirectCast(node, AwaitExpressionSyntax).Expression
End Function
End Class
End Namespace
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册