提交 f021ad14 编写于 作者: J Jonathon Marolf

Merge pull request #5328 from diryboy/FixInsertAwait

Fix InsertAwait where LeadingTrivia exists
Fixes #5240
......@@ -15,7 +15,8 @@ public partial class AddAwaitTests : AbstractCSharpDiagnosticProviderBasedUserDi
public void BadAsyncReturnOperand1()
{
var initial =
@"using System.Threading.Tasks;
@"using System;
using System.Threading.Tasks;
class Program
{
......@@ -31,7 +32,8 @@ async Task<int> Test2()
}";
var expected =
@"using System.Threading.Tasks;
@"using System;
using System.Threading.Tasks;
class Program
{
......@@ -48,11 +50,261 @@ async Task<int> Test2()
Test(initial, expected);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_WithLeadingTrivia1()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test()
{
return 3;
}
async Task<int> Test2()
{
return
// Useful comment
[|Test()|];
}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test()
{
return 3;
}
async Task<int> Test2()
{
return
// Useful comment
await Test();
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_ConditionalExpressionWithTrailingTrivia_SingleLine()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return true ? Test() /* true */ : Test() /* false */;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (true ? Test() /* true */ : Test() /* false */);
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_ConditionalExpressionWithTrailingTrivia_Multiline()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return true ? Test() // aaa
: Test() // bbb
;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (true ? Test() // aaa
: Test()) // bbb
;
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_NullCoalescingExpressionWithTrailingTrivia_SingleLine()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return null /* 0 */ ?? Test() /* 1 */;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (null /* 0 */ ?? Test() /* 1 */);
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_NullCoalescingExpressionWithTrailingTrivia_Multiline()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return null // aaa
?? Test() // bbb
;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (null // aaa
?? Test()) // bbb
;
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_AsExpressionWithTrailingTrivia_SingleLine()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test2()
{[|
return null /* 0 */ as Task<int> /* 1 */;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test2()
{
return await (null /* 0 */ as Task<int> /* 1 */);
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_AsExpressionWithTrailingTrivia_Multiline()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return null // aaa
as Task<int> // bbb
;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (null // aaa
as Task<int>) // bbb
;
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TaskNotAwaited()
{
var initial =
@"using System.Threading.Tasks;
@"using System;
using System.Threading.Tasks;
class Program
{
async void Test()
......@@ -62,7 +314,8 @@ async void Test()
}";
var expected =
@"using System.Threading.Tasks;
@"using System;
using System.Threading.Tasks;
class Program
{
async void Test()
......@@ -73,11 +326,43 @@ async void Test()
Test(initial, expected);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TaskNotAwaited_WithLeadingTrivia()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async void Test()
{
// Useful comment
[|Task.Delay(3);|]
}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async void Test()
{
// Useful comment
await Task.Delay(3);
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void FunctionNotAwaited()
{
var initial =
@"using System.Threading.Tasks;
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
......@@ -92,7 +377,8 @@ async void Test()
}";
var expected =
@"using System.Threading.Tasks;
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
......@@ -108,6 +394,88 @@ async void Test()
Test(initial, expected);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void FunctionNotAwaited_WithLeadingTrivia()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
{
return Task.FromResult(true);
}
async void Test()
{
// Useful comment
[|AwaitableFunction();|]
}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
{
return Task.FromResult(true);
}
async void Test()
{
// Useful comment
await AwaitableFunction();
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void FunctionNotAwaited_WithLeadingTrivia1()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
{
return Task.FromResult(true);
}
async void Test()
{
var i = 0;
[|AwaitableFunction();|]
}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
{
return Task.FromResult(true);
}
async void Test()
{
var i = 0;
await AwaitableFunction();
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TestAssignmentExpression()
{
......@@ -199,6 +567,30 @@ public void TestAssignmentExpression8()
@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > @delegate = delegate { int myInt = MyIntM [||] ethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } ");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TestTernaryOperator()
{
Test(
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return [|true ? Task . FromResult ( 0 ) : Task . FromResult ( 1 )|] ; } } ",
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return await ( true ? Task . FromResult ( 0 ) : Task . FromResult ( 1 ) ) ; } } ");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TestNullCoalescingOperator()
{
Test(
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return [|null ?? Task . FromResult ( 1 )|] } } ",
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return await ( null ?? Task . FromResult ( 1 ) ) } } ");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TestAsExpression()
{
Test(
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return [|null as Task < int >|] } } ",
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return await ( null as Task < int > ) } } ");
}
internal override Tuple<DiagnosticAnalyzer, CodeFixProvider> CreateDiagnosticProviderAndFixer(Workspace workspace)
{
return new Tuple<DiagnosticAnalyzer, CodeFixProvider>(null, new CSharpAddAwaitCodeFixProvider());
......
......@@ -4,7 +4,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes
Imports Microsoft.CodeAnalysis.Diagnostics
Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsync
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.Async
Public Class AddAsyncTests
Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest
......
......@@ -5,7 +5,7 @@ Imports Microsoft.CodeAnalysis.Diagnostics
Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
Imports Roslyn.Test.Utilities
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsync
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.Async
Public Class AddAwaitTests
Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest
......@@ -17,6 +17,35 @@ Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsy
)
End Sub
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub TaskNotAwaited_WithLeadingTrivia()
Dim initial =
<File>
Imports System
Imports System.Threading.Tasks
Module Program
Async Sub M()
' Useful comment
[|Task.Delay(3)|]
End Sub
End Module
</File>
Dim expected =
<File>
Imports System
Imports System.Threading.Tasks
Module Program
Async Sub M()
' Useful comment
Await Task.Delay(3)
End Sub
End Module
</File>
Test(initial, expected, compareTokens:=False)
End Sub
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub BadAsyncReturnOperand1()
Dim initial =
......@@ -96,6 +125,50 @@ End Module
Test(initial, expected)
End Sub
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub FunctionNotAwaited_WithLeadingTrivia()
Dim initial =
<File>
Imports System
Imports System.Collections.Generic
Imports System.Linq
Imports System.Threading.Tasks
Module Program
Function AwaitableFunction() As Task
Return New Task()
End Function
Async Sub MySub()
' Useful comment
[|AwaitableFunction()|]
End Sub
End Module
</File>
Dim expected =
<File>
Imports System
Imports System.Collections.Generic
Imports System.Linq
Imports System.Threading.Tasks
Module Program
Function AwaitableFunction() As Task
Return New Task()
End Function
Async Sub MySub()
' Useful comment
Await AwaitableFunction()
End Sub
End Module
</File>
Test(initial, expected, compareTokens:=True)
End Sub
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub SubLambdaNotAwaited()
Dim initial =
......@@ -144,6 +217,7 @@ Imports System.Threading.Tasks
Module Program
Sub MySub()
Dim a = Async Function()
' Useful comment
[|Task.Delay(1)|]
End Function
End Sub
......@@ -159,13 +233,14 @@ Imports System.Threading.Tasks
Module Program
Sub MySub()
Dim a = Async Function()
' Useful comment
Await Task.Delay(1)
End Function
End Sub
End Module
</File>
Test(initial, expected)
Test(initial, expected, compareTokens:=True)
End Sub
<Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
......@@ -216,6 +291,27 @@ NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1A
NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim myInt As Long \n Dim lambda = Async Sub() myInt = Await MyIntMethodAsync() \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module"))
End Sub
<Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub TestTernaryOperator()
Test(
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return [|If(True, Task.FromResult(0), Task.FromResult(1))|] \n End Function \n End Module"),
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return Await If(True, Task.FromResult(0), Task.FromResult(1)) \n End Function \n End Module"))
End Sub
<Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub TestTernaryOperator2()
Test(
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return [|If(Nothing, Task.FromResult(1))|] \n End Function \n End Module"),
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return Await If(Nothing, Task.FromResult(1)) \n End Function \n End Module"))
End Sub
<Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub TestCastExpression()
Test(
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return [|TryCast(Nothing, Task(Of Integer)|] \n End Function \n End Module"),
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return Await TryCast(Nothing, Task(Of Integer) \n End Function \n End Module"))
End Sub
Friend Overrides Function CreateDiagnosticProviderAndFixer(workspace As Workspace) As Tuple(Of DiagnosticAnalyzer, CodeFixProvider)
Return Tuple.Create(Of DiagnosticAnalyzer, CodeFixProvider)(
Nothing,
......
......@@ -4,7 +4,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes
Imports Microsoft.CodeAnalysis.Diagnostics
Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsync
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.Async
Public Class ChangeToAsyncTests
Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest
......
......@@ -3,16 +3,18 @@
using System;
using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CodeFixes.Async;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Simplification;
using Roslyn.Utilities;
using Resources = Microsoft.CodeAnalysis.CSharp.CSharpFeaturesResources;
using Microsoft.CodeAnalysis.LanguageServices;
using System.Linq;
namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.Async
{
......@@ -20,12 +22,12 @@ namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.Async
internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvider
{
/// <summary>
/// Since this is an async method, the return expression must be of type 'blah' rather than 'baz'
/// Because this call is not awaited, execution of the current method continues before the call is completed.
/// </summary>
private const string CS4014 = "CS4014";
/// <summary>
/// Because this call is not awaited, execution of the current method continues before the call is completed.
/// Since this is an async method, the return expression must be of type 'blah' rather than 'baz'
/// </summary>
private const string CS4016 = "CS4016";
......@@ -48,15 +50,14 @@ internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvi
CancellationToken cancellationToken)
{
var expression = oldNode as ExpressionSyntax;
if (expression == null)
{
return SpecializedTasks.Default<SyntaxNode>();
}
switch (diagnostic.Id)
{
case CS4014:
if (expression == null)
{
return Task.FromResult<SyntaxNode>(null);
}
return Task.FromResult(root.ReplaceNode(oldNode, ConvertToAwaitExpression(expression)));
case CS4016:
......@@ -74,33 +75,27 @@ internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvi
}
return Task.FromResult(root.ReplaceNode(oldNode, ConvertToAwaitExpression(expression)));
default:
return SpecializedTasks.Default<SyntaxNode>();
}
}
private static bool DoesExpressionReturnTask(ExpressionSyntax expression, SemanticModel semanticModel)
{
if (expression == null)
INamedTypeSymbol taskType = null;
if (!TryGetTaskType(semanticModel, out taskType))
{
return false;
}
INamedTypeSymbol taskType = null;
INamedTypeSymbol returnType = null;
return TryGetTaskAndExpressionTypes(expression, semanticModel, out taskType, out returnType) &&
return TryGetExpressionType(expression, semanticModel, out returnType) &&
semanticModel.Compilation.ClassifyConversion(taskType, returnType).Exists;
}
private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(ExpressionSyntax expression, SemanticModel semanticModel, Project project, CancellationToken cancellationToken)
{
if (expression == null)
{
return false;
}
if (!IsInAsyncFunction(expression))
{
return false;
......@@ -108,7 +103,8 @@ private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(E
INamedTypeSymbol taskType = null;
INamedTypeSymbol rightSideType = null;
if (!TryGetTaskAndExpressionTypes(expression, semanticModel, out taskType, out rightSideType))
if (!TryGetTaskType(semanticModel, out taskType) ||
!TryGetExpressionType(expression, semanticModel, out rightSideType))
{
return false;
}
......@@ -119,7 +115,7 @@ private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(E
return false;
}
if(!rightSideType.IsGenericType)
if (!rightSideType.IsGenericType)
{
return false;
}
......@@ -150,11 +146,23 @@ private static bool IsInAsyncFunction(ExpressionSyntax expression)
return false;
}
private static ExpressionSyntax ConvertToAwaitExpression(ExpressionSyntax expression)
private static SyntaxNode ConvertToAwaitExpression(ExpressionSyntax expression)
{
return SyntaxFactory.AwaitExpression(expression)
if ((expression is BinaryExpressionSyntax || expression is ConditionalExpressionSyntax) && expression.HasTrailingTrivia)
{
var expWithTrailing = expression.WithoutLeadingTrivia();
var span = expWithTrailing.GetLocation().GetLineSpan().Span;
if (span.Start.Line == span.End.Line && !expWithTrailing.DescendantTrivia().Any(trivia => trivia.IsKind(SyntaxKind.SingleLineCommentTrivia)))
{
return SyntaxFactory.AwaitExpression(SyntaxFactory.ParenthesizedExpression(expWithTrailing))
.WithLeadingTrivia(expression.GetLeadingTrivia())
.WithAdditionalAnnotations(Formatter.Annotation);
}
}
return SyntaxFactory.AwaitExpression(expression.WithoutTrivia().Parenthesize())
.WithTriviaFrom(expression)
.WithAdditionalAnnotations(Formatter.Annotation);
.WithAdditionalAnnotations(Simplifier.Annotation, Formatter.Annotation);
}
}
}
......@@ -4,7 +4,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.Internal.Log;
namespace Microsoft.CodeAnalysis.CodeFixes.Async
{
......@@ -28,38 +27,21 @@ protected override async Task<CodeAction> GetCodeFix(SyntaxNode root, SyntaxNode
return null;
}
protected static bool TryGetTaskAndExpressionTypes(
protected static bool TryGetExpressionType(
SyntaxNode expression,
SemanticModel semanticModel,
out INamedTypeSymbol source,
out INamedTypeSymbol destination)
out INamedTypeSymbol returnType)
{
source = null;
destination = null;
var info = semanticModel.GetSymbolInfo(expression);
var methodSymbol = info.Symbol as IMethodSymbol;
if (methodSymbol == null)
{
return false;
}
var typeInfo = semanticModel.GetTypeInfo(expression);
returnType = typeInfo.Type as INamedTypeSymbol;
return returnType != null;
}
protected static bool TryGetTaskType(SemanticModel semanticModel, out INamedTypeSymbol taskType)
{
var compilation = semanticModel.Compilation;
var taskType = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
if (taskType == null)
{
return false;
}
var returnType = methodSymbol.ReturnType as INamedTypeSymbol;
if (returnType == null)
{
return false;
}
source = taskType;
destination = returnType;
return true;
taskType = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
return taskType != null;
}
private class MyCodeAction : CodeAction.DocumentChangeAction
......
......@@ -8,6 +8,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes
Imports Microsoft.CodeAnalysis.CodeFixes.Async
Imports Microsoft.CodeAnalysis.Formatting
Imports Microsoft.CodeAnalysis.LanguageServices
Imports Microsoft.CodeAnalysis.Simplification
Imports Microsoft.CodeAnalysis.VisualBasic.Syntax
Imports Resources = Microsoft.CodeAnalysis.VisualBasic.VBFeaturesResources.VBFeaturesResources
......@@ -35,40 +36,37 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
Protected Overrides Function GetNewRoot(root As SyntaxNode, oldNode As SyntaxNode, semanticModel As SemanticModel, diagnostic As Diagnostic, document As Document, cancellationToken As CancellationToken) As Task(Of SyntaxNode)
Dim expression = TryCast(oldNode, ExpressionSyntax)
If expression Is Nothing Then
Return SpecializedTasks.Default(Of SyntaxNode)()
End If
Select Case diagnostic.Id
Case BC30311
If Not DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression, semanticModel, document.Project, cancellationToken) Then
Return Task.FromResult(Of SyntaxNode)(Nothing)
End If
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression)))
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression, semanticModel, cancellationToken)))
Case BC37055
If Not DoesExpressionReturnTask(expression, semanticModel) Then
Return Task.FromResult(Of SyntaxNode)(Nothing)
End If
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression)))
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression, semanticModel, cancellationToken)))
Case BC42358
If expression Is Nothing Then
Return Task.FromResult(Of SyntaxNode)(Nothing)
End If
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression)))
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression, semanticModel, cancellationToken)))
Case Else
Return Task.FromResult(Of SyntaxNode)(Nothing)
Return SpecializedTasks.Default(Of SyntaxNode)()
End Select
End Function
Private Function DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression As ExpressionSyntax, semanticModel As SemanticModel, project As Project, cancellationToken As CancellationToken) As Boolean
If expression Is Nothing Then
Return False
End If
If Not IsInAsyncBlock(expression) Then
Return False
End If
Dim taskType As INamedTypeSymbol = Nothing
Dim rightSideType As INamedTypeSymbol = Nothing
If Not TryGetTaskAndExpressionTypes(expression, semanticModel, taskType, rightSideType) Then
If Not TryGetTaskType(semanticModel, taskType) OrElse
Not TryGetExpressionType(expression, semanticModel, rightSideType) Then
Return False
End If
......@@ -109,19 +107,35 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
End Function
Private Function DoesExpressionReturnTask(expression As ExpressionSyntax, semanticModel As SemanticModel) As Boolean
If expression Is Nothing Then
Return Nothing
End If
Dim taskType As INamedTypeSymbol = Nothing
Dim returnType As INamedTypeSymbol = Nothing
Return TryGetTaskAndExpressionTypes(expression, semanticModel, taskType, returnType) AndAlso
Return TryGetTaskType(semanticModel, taskType) AndAlso
TryGetExpressionType(expression, semanticModel, returnType) AndAlso
semanticModel.Compilation.ClassifyConversion(taskType, returnType).Exists
End Function
Private Function ConverToAwaitExpression(expression As ExpressionSyntax) As ExpressionSyntax
Return SyntaxFactory.AwaitExpression(expression).WithAdditionalAnnotations(Formatter.Annotation)
Private Shared Function ConverToAwaitExpression(expression As ExpressionSyntax, semanticModel As SemanticModel, cancellationToken As CancellationToken) As ExpressionSyntax
Dim root = expression.Ancestors().Last()
If Not RequiresParenthesis(expression, root, semanticModel, cancellationToken) Then
expression = expression.Parenthesize()
End If
Return SyntaxFactory.AwaitExpression(expression.WithoutTrivia()) _
.WithTriviaFrom(expression) _
.WithAdditionalAnnotations(Formatter.Annotation, Simplifier.Annotation)
End Function
Private Shared Function RequiresParenthesis(expression As ExpressionSyntax, root As SyntaxNode, semanticModel As SemanticModel, cancellationToken As CancellationToken) As Boolean
Dim parenthesizedExpression = SyntaxFactory.ParenthesizedExpression(expression)
Dim newRoot = root.ReplaceNode(expression, parenthesizedExpression)
Dim newNode = newRoot.FindNode(expression.Span)
Dim result = newNode _
.DescendantNodesAndSelf(Function(n) n.Kind <> SyntaxKind.ParenthesizedExpression) _
.OfType(Of ParenthesizedExpressionSyntax).FirstOrDefault
If result IsNot Nothing Then
Return Not result.CanRemoveParentheses(semanticModel, cancellationToken)
End If
Return False
End Function
End Class
End Namespace
......@@ -25,7 +25,14 @@ public static class TokenUtilities
for (var i = 0; i < Math.Min(expectedTokens.Count, actualTokens.Count); i++)
{
Assert.Equal(expectedTokens[i].ToString(), actualTokens[i].ToString());
var expectedToken = expectedTokens[i].ToString();
var actualToken = actualTokens[i].ToString();
if (!String.Equals(expectedToken, actualToken))
{
var prev = (i - 1 > -1) ? actualTokens[i - 1].ToString() : "^";
var next = (i + 1 < actualTokens.Count) ? actualTokens[i + 1].ToString() : "$";
AssertEx.Fail($"Unexpected token at index {i} near \"{prev} {actualToken} {next}\". Expected '{expectedToken}', Actual '{actualToken}'");
}
}
if (expectedTokens.Count != actualTokens.Count)
......
......@@ -43,27 +43,22 @@ public static ExpressionSyntax WalkDownParentheses(this ExpressionSyntax express
public static ExpressionSyntax Parenthesize(this ExpressionSyntax expression, bool includeElasticTrivia = true)
{
var leadingTrivia = expression.GetLeadingTrivia();
var trailingTrivia = expression.GetTrailingTrivia();
expression = expression.WithoutLeadingTrivia()
.WithoutTrailingTrivia();
if (includeElasticTrivia)
{
return SyntaxFactory.ParenthesizedExpression(expression)
.WithLeadingTrivia(leadingTrivia)
.WithTrailingTrivia(trailingTrivia)
.WithAdditionalAnnotations(Simplifier.Annotation);
return SyntaxFactory.ParenthesizedExpression(expression.WithoutTrivia())
.WithTriviaFrom(expression)
.WithAdditionalAnnotations(Simplifier.Annotation);
}
else
{
return SyntaxFactory.ParenthesizedExpression(
return SyntaxFactory.ParenthesizedExpression
(
SyntaxFactory.Token(SyntaxTriviaList.Empty, SyntaxKind.OpenParenToken, SyntaxTriviaList.Empty),
expression,
SyntaxFactory.Token(SyntaxTriviaList.Empty, SyntaxKind.CloseParenToken, SyntaxTriviaList.Empty))
.WithLeadingTrivia(leadingTrivia)
.WithTrailingTrivia(trailingTrivia)
.WithAdditionalAnnotations(Simplifier.Annotation);
expression.WithoutTrivia(),
SyntaxFactory.Token(SyntaxTriviaList.Empty, SyntaxKind.CloseParenToken, SyntaxTriviaList.Empty)
)
.WithTriviaFrom(expression)
.WithAdditionalAnnotations(Simplifier.Annotation);
}
}
......@@ -2338,7 +2333,7 @@ public static OperatorPrecedence GetOperatorPrecedence(this ExpressionSyntax exp
case SyntaxKind.UncheckedExpression:
case SyntaxKind.AnonymousMethodExpression:
// From C# spec, 7.3.1:
// Primary: x.y f(x) a[x] x++ x-- new typeof default checked unchecked delegate
// Primary: x.y x?.y x?[y] f(x) a[x] x++ x-- new typeof default checked unchecked delegate
return OperatorPrecedence.Primary;
......@@ -2349,8 +2344,9 @@ public static OperatorPrecedence GetOperatorPrecedence(this ExpressionSyntax exp
case SyntaxKind.PreIncrementExpression:
case SyntaxKind.PreDecrementExpression:
case SyntaxKind.CastExpression:
case SyntaxKind.AwaitExpression:
// From C# spec, 7.3.1:
// Unary: + - ! ~ ++x --x (T)x
// Unary: + - ! ~ ++x --x (T)x await Task
return OperatorPrecedence.Unary;
......
......@@ -52,15 +52,9 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.Extensions
<Extension()>
Public Function Parenthesize(expression As ExpressionSyntax) As ParenthesizedExpressionSyntax
Dim leadingTrivia = expression.GetLeadingTrivia()
Dim trailingTrivia = expression.GetTrailingTrivia()
Dim strippedExpression = expression.WithoutLeadingTrivia().WithoutTrailingTrivia()
Return SyntaxFactory.ParenthesizedExpression(strippedExpression) _
.WithLeadingTrivia(leadingTrivia) _
.WithTrailingTrivia(trailingTrivia) _
.WithAdditionalAnnotations(Simplifier.Annotation)
Return SyntaxFactory.ParenthesizedExpression(expression.WithoutTrivia()) _
.WithTriviaFrom(expression) _
.WithAdditionalAnnotations(Simplifier.Annotation)
End Function
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册