提交 f021ad14 编写于 作者: J Jonathon Marolf

Merge pull request #5328 from diryboy/FixInsertAwait

Fix InsertAwait where LeadingTrivia exists
Fixes #5240
...@@ -15,7 +15,8 @@ public partial class AddAwaitTests : AbstractCSharpDiagnosticProviderBasedUserDi ...@@ -15,7 +15,8 @@ public partial class AddAwaitTests : AbstractCSharpDiagnosticProviderBasedUserDi
public void BadAsyncReturnOperand1() public void BadAsyncReturnOperand1()
{ {
var initial = var initial =
@"using System.Threading.Tasks; @"using System;
using System.Threading.Tasks;
class Program class Program
{ {
...@@ -31,7 +32,8 @@ async Task<int> Test2() ...@@ -31,7 +32,8 @@ async Task<int> Test2()
}"; }";
var expected = var expected =
@"using System.Threading.Tasks; @"using System;
using System.Threading.Tasks;
class Program class Program
{ {
...@@ -48,11 +50,261 @@ async Task<int> Test2() ...@@ -48,11 +50,261 @@ async Task<int> Test2()
Test(initial, expected); Test(initial, expected);
} }
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_WithLeadingTrivia1()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test()
{
return 3;
}
async Task<int> Test2()
{
return
// Useful comment
[|Test()|];
}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test()
{
return 3;
}
async Task<int> Test2()
{
return
// Useful comment
await Test();
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_ConditionalExpressionWithTrailingTrivia_SingleLine()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return true ? Test() /* true */ : Test() /* false */;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (true ? Test() /* true */ : Test() /* false */);
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_ConditionalExpressionWithTrailingTrivia_Multiline()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return true ? Test() // aaa
: Test() // bbb
;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (true ? Test() // aaa
: Test()) // bbb
;
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_NullCoalescingExpressionWithTrailingTrivia_SingleLine()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return null /* 0 */ ?? Test() /* 1 */;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (null /* 0 */ ?? Test() /* 1 */);
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_NullCoalescingExpressionWithTrailingTrivia_Multiline()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return null // aaa
?? Test() // bbb
;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (null // aaa
?? Test()) // bbb
;
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_AsExpressionWithTrailingTrivia_SingleLine()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test2()
{[|
return null /* 0 */ as Task<int> /* 1 */;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test2()
{
return await (null /* 0 */ as Task<int> /* 1 */);
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void BadAsyncReturnOperand_AsExpressionWithTrailingTrivia_Multiline()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{[|
return null // aaa
as Task<int> // bbb
;
|]}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async Task<int> Test() => 3;
async Task<int> Test2()
{
return await (null // aaa
as Task<int>) // bbb
;
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TaskNotAwaited() public void TaskNotAwaited()
{ {
var initial = var initial =
@"using System.Threading.Tasks; @"using System;
using System.Threading.Tasks;
class Program class Program
{ {
async void Test() async void Test()
...@@ -62,7 +314,8 @@ async void Test() ...@@ -62,7 +314,8 @@ async void Test()
}"; }";
var expected = var expected =
@"using System.Threading.Tasks; @"using System;
using System.Threading.Tasks;
class Program class Program
{ {
async void Test() async void Test()
...@@ -73,11 +326,43 @@ async void Test() ...@@ -73,11 +326,43 @@ async void Test()
Test(initial, expected); Test(initial, expected);
} }
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TaskNotAwaited_WithLeadingTrivia()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
async void Test()
{
// Useful comment
[|Task.Delay(3);|]
}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
async void Test()
{
// Useful comment
await Task.Delay(3);
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void FunctionNotAwaited() public void FunctionNotAwaited()
{ {
var initial = var initial =
@"using System.Threading.Tasks; @"using System;
using System.Threading.Tasks;
class Program class Program
{ {
Task AwaitableFunction() Task AwaitableFunction()
...@@ -92,7 +377,8 @@ async void Test() ...@@ -92,7 +377,8 @@ async void Test()
}"; }";
var expected = var expected =
@"using System.Threading.Tasks; @"using System;
using System.Threading.Tasks;
class Program class Program
{ {
Task AwaitableFunction() Task AwaitableFunction()
...@@ -108,6 +394,88 @@ async void Test() ...@@ -108,6 +394,88 @@ async void Test()
Test(initial, expected); Test(initial, expected);
} }
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void FunctionNotAwaited_WithLeadingTrivia()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
{
return Task.FromResult(true);
}
async void Test()
{
// Useful comment
[|AwaitableFunction();|]
}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
{
return Task.FromResult(true);
}
async void Test()
{
// Useful comment
await AwaitableFunction();
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void FunctionNotAwaited_WithLeadingTrivia1()
{
var initial =
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
{
return Task.FromResult(true);
}
async void Test()
{
var i = 0;
[|AwaitableFunction();|]
}
}";
var expected =
@"using System;
using System.Threading.Tasks;
class Program
{
Task AwaitableFunction()
{
return Task.FromResult(true);
}
async void Test()
{
var i = 0;
await AwaitableFunction();
}
}";
Test(initial, expected, compareTokens: false);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TestAssignmentExpression() public void TestAssignmentExpression()
{ {
...@@ -199,6 +567,30 @@ public void TestAssignmentExpression8() ...@@ -199,6 +567,30 @@ public void TestAssignmentExpression8()
@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > @delegate = delegate { int myInt = MyIntM [||] ethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); @"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > @delegate = delegate { int myInt = MyIntM [||] ethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } ");
} }
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TestTernaryOperator()
{
Test(
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return [|true ? Task . FromResult ( 0 ) : Task . FromResult ( 1 )|] ; } } ",
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return await ( true ? Task . FromResult ( 0 ) : Task . FromResult ( 1 ) ) ; } } ");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TestNullCoalescingOperator()
{
Test(
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return [|null ?? Task . FromResult ( 1 )|] } } ",
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return await ( null ?? Task . FromResult ( 1 ) ) } } ");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)]
public void TestAsExpression()
{
Test(
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return [|null as Task < int >|] } } ",
@"using System ; using System . Threading . Tasks ; class Program { async Task < int > A ( ) { return await ( null as Task < int > ) } } ");
}
internal override Tuple<DiagnosticAnalyzer, CodeFixProvider> CreateDiagnosticProviderAndFixer(Workspace workspace) internal override Tuple<DiagnosticAnalyzer, CodeFixProvider> CreateDiagnosticProviderAndFixer(Workspace workspace)
{ {
return new Tuple<DiagnosticAnalyzer, CodeFixProvider>(null, new CSharpAddAwaitCodeFixProvider()); return new Tuple<DiagnosticAnalyzer, CodeFixProvider>(null, new CSharpAddAwaitCodeFixProvider());
......
...@@ -4,7 +4,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes ...@@ -4,7 +4,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes
Imports Microsoft.CodeAnalysis.Diagnostics Imports Microsoft.CodeAnalysis.Diagnostics
Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsync Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.Async
Public Class AddAsyncTests Public Class AddAsyncTests
Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest
......
...@@ -5,7 +5,7 @@ Imports Microsoft.CodeAnalysis.Diagnostics ...@@ -5,7 +5,7 @@ Imports Microsoft.CodeAnalysis.Diagnostics
Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
Imports Roslyn.Test.Utilities Imports Roslyn.Test.Utilities
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsync Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.Async
Public Class AddAwaitTests Public Class AddAwaitTests
Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest
...@@ -17,6 +17,35 @@ Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsy ...@@ -17,6 +17,35 @@ Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsy
) )
End Sub End Sub
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub TaskNotAwaited_WithLeadingTrivia()
Dim initial =
<File>
Imports System
Imports System.Threading.Tasks
Module Program
Async Sub M()
' Useful comment
[|Task.Delay(3)|]
End Sub
End Module
</File>
Dim expected =
<File>
Imports System
Imports System.Threading.Tasks
Module Program
Async Sub M()
' Useful comment
Await Task.Delay(3)
End Sub
End Module
</File>
Test(initial, expected, compareTokens:=False)
End Sub
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)> <Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub BadAsyncReturnOperand1() Public Sub BadAsyncReturnOperand1()
Dim initial = Dim initial =
...@@ -96,6 +125,50 @@ End Module ...@@ -96,6 +125,50 @@ End Module
Test(initial, expected) Test(initial, expected)
End Sub End Sub
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub FunctionNotAwaited_WithLeadingTrivia()
Dim initial =
<File>
Imports System
Imports System.Collections.Generic
Imports System.Linq
Imports System.Threading.Tasks
Module Program
Function AwaitableFunction() As Task
Return New Task()
End Function
Async Sub MySub()
' Useful comment
[|AwaitableFunction()|]
End Sub
End Module
</File>
Dim expected =
<File>
Imports System
Imports System.Collections.Generic
Imports System.Linq
Imports System.Threading.Tasks
Module Program
Function AwaitableFunction() As Task
Return New Task()
End Function
Async Sub MySub()
' Useful comment
Await AwaitableFunction()
End Sub
End Module
</File>
Test(initial, expected, compareTokens:=True)
End Sub
<Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)> <Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub SubLambdaNotAwaited() Public Sub SubLambdaNotAwaited()
Dim initial = Dim initial =
...@@ -144,6 +217,7 @@ Imports System.Threading.Tasks ...@@ -144,6 +217,7 @@ Imports System.Threading.Tasks
Module Program Module Program
Sub MySub() Sub MySub()
Dim a = Async Function() Dim a = Async Function()
' Useful comment
[|Task.Delay(1)|] [|Task.Delay(1)|]
End Function End Function
End Sub End Sub
...@@ -159,13 +233,14 @@ Imports System.Threading.Tasks ...@@ -159,13 +233,14 @@ Imports System.Threading.Tasks
Module Program Module Program
Sub MySub() Sub MySub()
Dim a = Async Function() Dim a = Async Function()
' Useful comment
Await Task.Delay(1) Await Task.Delay(1)
End Function End Function
End Sub End Sub
End Module End Module
</File> </File>
Test(initial, expected) Test(initial, expected, compareTokens:=True)
End Sub End Sub
<Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)> <Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
...@@ -216,6 +291,27 @@ NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1A ...@@ -216,6 +291,27 @@ NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1A
NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim myInt As Long \n Dim lambda = Async Sub() myInt = Await MyIntMethodAsync() \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module")) NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim myInt As Long \n Dim lambda = Async Sub() myInt = Await MyIntMethodAsync() \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module"))
End Sub End Sub
<Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub TestTernaryOperator()
Test(
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return [|If(True, Task.FromResult(0), Task.FromResult(1))|] \n End Function \n End Module"),
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return Await If(True, Task.FromResult(0), Task.FromResult(1)) \n End Function \n End Module"))
End Sub
<Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub TestTernaryOperator2()
Test(
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return [|If(Nothing, Task.FromResult(1))|] \n End Function \n End Module"),
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return Await If(Nothing, Task.FromResult(1)) \n End Function \n End Module"))
End Sub
<Fact(), Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)>
Public Sub TestCastExpression()
Test(
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return [|TryCast(Nothing, Task(Of Integer)|] \n End Function \n End Module"),
NewLines("Imports System.Threading.Tasks \n Module M \n Async Function A() As Task(Of Integer) \n Return Await TryCast(Nothing, Task(Of Integer) \n End Function \n End Module"))
End Sub
Friend Overrides Function CreateDiagnosticProviderAndFixer(workspace As Workspace) As Tuple(Of DiagnosticAnalyzer, CodeFixProvider) Friend Overrides Function CreateDiagnosticProviderAndFixer(workspace As Workspace) As Tuple(Of DiagnosticAnalyzer, CodeFixProvider)
Return Tuple.Create(Of DiagnosticAnalyzer, CodeFixProvider)( Return Tuple.Create(Of DiagnosticAnalyzer, CodeFixProvider)(
Nothing, Nothing,
......
...@@ -4,7 +4,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes ...@@ -4,7 +4,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes
Imports Microsoft.CodeAnalysis.Diagnostics Imports Microsoft.CodeAnalysis.Diagnostics
Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async Imports Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.AddAsync Namespace Microsoft.CodeAnalysis.Editor.VisualBasic.UnitTests.Diagnostics.Async
Public Class ChangeToAsyncTests Public Class ChangeToAsyncTests
Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest Inherits AbstractVisualBasicDiagnosticProviderBasedUserDiagnosticTest
......
...@@ -3,16 +3,18 @@ ...@@ -3,16 +3,18 @@
using System; using System;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Composition; using System.Composition;
using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeFixes; using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CodeFixes.Async; using Microsoft.CodeAnalysis.CodeFixes.Async;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting; using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Simplification;
using Roslyn.Utilities; using Roslyn.Utilities;
using Resources = Microsoft.CodeAnalysis.CSharp.CSharpFeaturesResources; using Resources = Microsoft.CodeAnalysis.CSharp.CSharpFeaturesResources;
using Microsoft.CodeAnalysis.LanguageServices;
using System.Linq;
namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.Async namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.Async
{ {
...@@ -20,12 +22,12 @@ namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.Async ...@@ -20,12 +22,12 @@ namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.Async
internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvider internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvider
{ {
/// <summary> /// <summary>
/// Since this is an async method, the return expression must be of type 'blah' rather than 'baz' /// Because this call is not awaited, execution of the current method continues before the call is completed.
/// </summary> /// </summary>
private const string CS4014 = "CS4014"; private const string CS4014 = "CS4014";
/// <summary> /// <summary>
/// Because this call is not awaited, execution of the current method continues before the call is completed. /// Since this is an async method, the return expression must be of type 'blah' rather than 'baz'
/// </summary> /// </summary>
private const string CS4016 = "CS4016"; private const string CS4016 = "CS4016";
...@@ -48,15 +50,14 @@ internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvi ...@@ -48,15 +50,14 @@ internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvi
CancellationToken cancellationToken) CancellationToken cancellationToken)
{ {
var expression = oldNode as ExpressionSyntax; var expression = oldNode as ExpressionSyntax;
if (expression == null)
{
return SpecializedTasks.Default<SyntaxNode>();
}
switch (diagnostic.Id) switch (diagnostic.Id)
{ {
case CS4014: case CS4014:
if (expression == null)
{
return Task.FromResult<SyntaxNode>(null);
}
return Task.FromResult(root.ReplaceNode(oldNode, ConvertToAwaitExpression(expression))); return Task.FromResult(root.ReplaceNode(oldNode, ConvertToAwaitExpression(expression)));
case CS4016: case CS4016:
...@@ -74,33 +75,27 @@ internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvi ...@@ -74,33 +75,27 @@ internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvi
} }
return Task.FromResult(root.ReplaceNode(oldNode, ConvertToAwaitExpression(expression))); return Task.FromResult(root.ReplaceNode(oldNode, ConvertToAwaitExpression(expression)));
default: default:
return SpecializedTasks.Default<SyntaxNode>(); return SpecializedTasks.Default<SyntaxNode>();
} }
} }
private static bool DoesExpressionReturnTask(ExpressionSyntax expression, SemanticModel semanticModel) private static bool DoesExpressionReturnTask(ExpressionSyntax expression, SemanticModel semanticModel)
{ {
if (expression == null) INamedTypeSymbol taskType = null;
if (!TryGetTaskType(semanticModel, out taskType))
{ {
return false; return false;
} }
INamedTypeSymbol taskType = null;
INamedTypeSymbol returnType = null; INamedTypeSymbol returnType = null;
return TryGetTaskAndExpressionTypes(expression, semanticModel, out taskType, out returnType) && return TryGetExpressionType(expression, semanticModel, out returnType) &&
semanticModel.Compilation.ClassifyConversion(taskType, returnType).Exists; semanticModel.Compilation.ClassifyConversion(taskType, returnType).Exists;
} }
private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(ExpressionSyntax expression, SemanticModel semanticModel, Project project, CancellationToken cancellationToken) private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(ExpressionSyntax expression, SemanticModel semanticModel, Project project, CancellationToken cancellationToken)
{ {
if (expression == null)
{
return false;
}
if (!IsInAsyncFunction(expression)) if (!IsInAsyncFunction(expression))
{ {
return false; return false;
...@@ -108,7 +103,8 @@ private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(E ...@@ -108,7 +103,8 @@ private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(E
INamedTypeSymbol taskType = null; INamedTypeSymbol taskType = null;
INamedTypeSymbol rightSideType = null; INamedTypeSymbol rightSideType = null;
if (!TryGetTaskAndExpressionTypes(expression, semanticModel, out taskType, out rightSideType)) if (!TryGetTaskType(semanticModel, out taskType) ||
!TryGetExpressionType(expression, semanticModel, out rightSideType))
{ {
return false; return false;
} }
...@@ -119,7 +115,7 @@ private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(E ...@@ -119,7 +115,7 @@ private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(E
return false; return false;
} }
if(!rightSideType.IsGenericType) if (!rightSideType.IsGenericType)
{ {
return false; return false;
} }
...@@ -150,11 +146,23 @@ private static bool IsInAsyncFunction(ExpressionSyntax expression) ...@@ -150,11 +146,23 @@ private static bool IsInAsyncFunction(ExpressionSyntax expression)
return false; return false;
} }
private static ExpressionSyntax ConvertToAwaitExpression(ExpressionSyntax expression) private static SyntaxNode ConvertToAwaitExpression(ExpressionSyntax expression)
{ {
return SyntaxFactory.AwaitExpression(expression) if ((expression is BinaryExpressionSyntax || expression is ConditionalExpressionSyntax) && expression.HasTrailingTrivia)
{
var expWithTrailing = expression.WithoutLeadingTrivia();
var span = expWithTrailing.GetLocation().GetLineSpan().Span;
if (span.Start.Line == span.End.Line && !expWithTrailing.DescendantTrivia().Any(trivia => trivia.IsKind(SyntaxKind.SingleLineCommentTrivia)))
{
return SyntaxFactory.AwaitExpression(SyntaxFactory.ParenthesizedExpression(expWithTrailing))
.WithLeadingTrivia(expression.GetLeadingTrivia())
.WithAdditionalAnnotations(Formatter.Annotation);
}
}
return SyntaxFactory.AwaitExpression(expression.WithoutTrivia().Parenthesize())
.WithTriviaFrom(expression) .WithTriviaFrom(expression)
.WithAdditionalAnnotations(Formatter.Annotation); .WithAdditionalAnnotations(Simplifier.Annotation, Formatter.Annotation);
} }
} }
} }
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions; using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.Internal.Log;
namespace Microsoft.CodeAnalysis.CodeFixes.Async namespace Microsoft.CodeAnalysis.CodeFixes.Async
{ {
...@@ -28,38 +27,21 @@ protected override async Task<CodeAction> GetCodeFix(SyntaxNode root, SyntaxNode ...@@ -28,38 +27,21 @@ protected override async Task<CodeAction> GetCodeFix(SyntaxNode root, SyntaxNode
return null; return null;
} }
protected static bool TryGetTaskAndExpressionTypes( protected static bool TryGetExpressionType(
SyntaxNode expression, SyntaxNode expression,
SemanticModel semanticModel, SemanticModel semanticModel,
out INamedTypeSymbol source, out INamedTypeSymbol returnType)
out INamedTypeSymbol destination)
{ {
source = null; var typeInfo = semanticModel.GetTypeInfo(expression);
destination = null; returnType = typeInfo.Type as INamedTypeSymbol;
return returnType != null;
var info = semanticModel.GetSymbolInfo(expression); }
var methodSymbol = info.Symbol as IMethodSymbol;
if (methodSymbol == null)
{
return false;
}
protected static bool TryGetTaskType(SemanticModel semanticModel, out INamedTypeSymbol taskType)
{
var compilation = semanticModel.Compilation; var compilation = semanticModel.Compilation;
var taskType = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task"); taskType = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
if (taskType == null) return taskType != null;
{
return false;
}
var returnType = methodSymbol.ReturnType as INamedTypeSymbol;
if (returnType == null)
{
return false;
}
source = taskType;
destination = returnType;
return true;
} }
private class MyCodeAction : CodeAction.DocumentChangeAction private class MyCodeAction : CodeAction.DocumentChangeAction
......
...@@ -8,6 +8,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes ...@@ -8,6 +8,7 @@ Imports Microsoft.CodeAnalysis.CodeFixes
Imports Microsoft.CodeAnalysis.CodeFixes.Async Imports Microsoft.CodeAnalysis.CodeFixes.Async
Imports Microsoft.CodeAnalysis.Formatting Imports Microsoft.CodeAnalysis.Formatting
Imports Microsoft.CodeAnalysis.LanguageServices Imports Microsoft.CodeAnalysis.LanguageServices
Imports Microsoft.CodeAnalysis.Simplification
Imports Microsoft.CodeAnalysis.VisualBasic.Syntax Imports Microsoft.CodeAnalysis.VisualBasic.Syntax
Imports Resources = Microsoft.CodeAnalysis.VisualBasic.VBFeaturesResources.VBFeaturesResources Imports Resources = Microsoft.CodeAnalysis.VisualBasic.VBFeaturesResources.VBFeaturesResources
...@@ -35,40 +36,37 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async ...@@ -35,40 +36,37 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
Protected Overrides Function GetNewRoot(root As SyntaxNode, oldNode As SyntaxNode, semanticModel As SemanticModel, diagnostic As Diagnostic, document As Document, cancellationToken As CancellationToken) As Task(Of SyntaxNode) Protected Overrides Function GetNewRoot(root As SyntaxNode, oldNode As SyntaxNode, semanticModel As SemanticModel, diagnostic As Diagnostic, document As Document, cancellationToken As CancellationToken) As Task(Of SyntaxNode)
Dim expression = TryCast(oldNode, ExpressionSyntax) Dim expression = TryCast(oldNode, ExpressionSyntax)
If expression Is Nothing Then
Return SpecializedTasks.Default(Of SyntaxNode)()
End If
Select Case diagnostic.Id Select Case diagnostic.Id
Case BC30311 Case BC30311
If Not DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression, semanticModel, document.Project, cancellationToken) Then If Not DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression, semanticModel, document.Project, cancellationToken) Then
Return Task.FromResult(Of SyntaxNode)(Nothing) Return Task.FromResult(Of SyntaxNode)(Nothing)
End If End If
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression))) Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression, semanticModel, cancellationToken)))
Case BC37055 Case BC37055
If Not DoesExpressionReturnTask(expression, semanticModel) Then If Not DoesExpressionReturnTask(expression, semanticModel) Then
Return Task.FromResult(Of SyntaxNode)(Nothing) Return Task.FromResult(Of SyntaxNode)(Nothing)
End If End If
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression))) Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression, semanticModel, cancellationToken)))
Case BC42358 Case BC42358
If expression Is Nothing Then Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression, semanticModel, cancellationToken)))
Return Task.FromResult(Of SyntaxNode)(Nothing)
End If
Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression)))
Case Else Case Else
Return Task.FromResult(Of SyntaxNode)(Nothing) Return SpecializedTasks.Default(Of SyntaxNode)()
End Select End Select
End Function End Function
Private Function DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression As ExpressionSyntax, semanticModel As SemanticModel, project As Project, cancellationToken As CancellationToken) As Boolean Private Function DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression As ExpressionSyntax, semanticModel As SemanticModel, project As Project, cancellationToken As CancellationToken) As Boolean
If expression Is Nothing Then
Return False
End If
If Not IsInAsyncBlock(expression) Then If Not IsInAsyncBlock(expression) Then
Return False Return False
End If End If
Dim taskType As INamedTypeSymbol = Nothing Dim taskType As INamedTypeSymbol = Nothing
Dim rightSideType As INamedTypeSymbol = Nothing Dim rightSideType As INamedTypeSymbol = Nothing
If Not TryGetTaskAndExpressionTypes(expression, semanticModel, taskType, rightSideType) Then If Not TryGetTaskType(semanticModel, taskType) OrElse
Not TryGetExpressionType(expression, semanticModel, rightSideType) Then
Return False Return False
End If End If
...@@ -109,19 +107,35 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async ...@@ -109,19 +107,35 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async
End Function End Function
Private Function DoesExpressionReturnTask(expression As ExpressionSyntax, semanticModel As SemanticModel) As Boolean Private Function DoesExpressionReturnTask(expression As ExpressionSyntax, semanticModel As SemanticModel) As Boolean
If expression Is Nothing Then
Return Nothing
End If
Dim taskType As INamedTypeSymbol = Nothing Dim taskType As INamedTypeSymbol = Nothing
Dim returnType As INamedTypeSymbol = Nothing Dim returnType As INamedTypeSymbol = Nothing
Return TryGetTaskAndExpressionTypes(expression, semanticModel, taskType, returnType) AndAlso Return TryGetTaskType(semanticModel, taskType) AndAlso
TryGetExpressionType(expression, semanticModel, returnType) AndAlso
semanticModel.Compilation.ClassifyConversion(taskType, returnType).Exists semanticModel.Compilation.ClassifyConversion(taskType, returnType).Exists
End Function End Function
Private Function ConverToAwaitExpression(expression As ExpressionSyntax) As ExpressionSyntax Private Shared Function ConverToAwaitExpression(expression As ExpressionSyntax, semanticModel As SemanticModel, cancellationToken As CancellationToken) As ExpressionSyntax
Return SyntaxFactory.AwaitExpression(expression).WithAdditionalAnnotations(Formatter.Annotation) Dim root = expression.Ancestors().Last()
If Not RequiresParenthesis(expression, root, semanticModel, cancellationToken) Then
expression = expression.Parenthesize()
End If
Return SyntaxFactory.AwaitExpression(expression.WithoutTrivia()) _
.WithTriviaFrom(expression) _
.WithAdditionalAnnotations(Formatter.Annotation, Simplifier.Annotation)
End Function End Function
Private Shared Function RequiresParenthesis(expression As ExpressionSyntax, root As SyntaxNode, semanticModel As SemanticModel, cancellationToken As CancellationToken) As Boolean
Dim parenthesizedExpression = SyntaxFactory.ParenthesizedExpression(expression)
Dim newRoot = root.ReplaceNode(expression, parenthesizedExpression)
Dim newNode = newRoot.FindNode(expression.Span)
Dim result = newNode _
.DescendantNodesAndSelf(Function(n) n.Kind <> SyntaxKind.ParenthesizedExpression) _
.OfType(Of ParenthesizedExpressionSyntax).FirstOrDefault
If result IsNot Nothing Then
Return Not result.CanRemoveParentheses(semanticModel, cancellationToken)
End If
Return False
End Function
End Class End Class
End Namespace End Namespace
...@@ -25,7 +25,14 @@ public static class TokenUtilities ...@@ -25,7 +25,14 @@ public static class TokenUtilities
for (var i = 0; i < Math.Min(expectedTokens.Count, actualTokens.Count); i++) for (var i = 0; i < Math.Min(expectedTokens.Count, actualTokens.Count); i++)
{ {
Assert.Equal(expectedTokens[i].ToString(), actualTokens[i].ToString()); var expectedToken = expectedTokens[i].ToString();
var actualToken = actualTokens[i].ToString();
if (!String.Equals(expectedToken, actualToken))
{
var prev = (i - 1 > -1) ? actualTokens[i - 1].ToString() : "^";
var next = (i + 1 < actualTokens.Count) ? actualTokens[i + 1].ToString() : "$";
AssertEx.Fail($"Unexpected token at index {i} near \"{prev} {actualToken} {next}\". Expected '{expectedToken}', Actual '{actualToken}'");
}
} }
if (expectedTokens.Count != actualTokens.Count) if (expectedTokens.Count != actualTokens.Count)
......
...@@ -43,27 +43,22 @@ public static ExpressionSyntax WalkDownParentheses(this ExpressionSyntax express ...@@ -43,27 +43,22 @@ public static ExpressionSyntax WalkDownParentheses(this ExpressionSyntax express
public static ExpressionSyntax Parenthesize(this ExpressionSyntax expression, bool includeElasticTrivia = true) public static ExpressionSyntax Parenthesize(this ExpressionSyntax expression, bool includeElasticTrivia = true)
{ {
var leadingTrivia = expression.GetLeadingTrivia();
var trailingTrivia = expression.GetTrailingTrivia();
expression = expression.WithoutLeadingTrivia()
.WithoutTrailingTrivia();
if (includeElasticTrivia) if (includeElasticTrivia)
{ {
return SyntaxFactory.ParenthesizedExpression(expression) return SyntaxFactory.ParenthesizedExpression(expression.WithoutTrivia())
.WithLeadingTrivia(leadingTrivia) .WithTriviaFrom(expression)
.WithTrailingTrivia(trailingTrivia) .WithAdditionalAnnotations(Simplifier.Annotation);
.WithAdditionalAnnotations(Simplifier.Annotation);
} }
else else
{ {
return SyntaxFactory.ParenthesizedExpression( return SyntaxFactory.ParenthesizedExpression
(
SyntaxFactory.Token(SyntaxTriviaList.Empty, SyntaxKind.OpenParenToken, SyntaxTriviaList.Empty), SyntaxFactory.Token(SyntaxTriviaList.Empty, SyntaxKind.OpenParenToken, SyntaxTriviaList.Empty),
expression, expression.WithoutTrivia(),
SyntaxFactory.Token(SyntaxTriviaList.Empty, SyntaxKind.CloseParenToken, SyntaxTriviaList.Empty)) SyntaxFactory.Token(SyntaxTriviaList.Empty, SyntaxKind.CloseParenToken, SyntaxTriviaList.Empty)
.WithLeadingTrivia(leadingTrivia) )
.WithTrailingTrivia(trailingTrivia) .WithTriviaFrom(expression)
.WithAdditionalAnnotations(Simplifier.Annotation); .WithAdditionalAnnotations(Simplifier.Annotation);
} }
} }
...@@ -2338,7 +2333,7 @@ public static OperatorPrecedence GetOperatorPrecedence(this ExpressionSyntax exp ...@@ -2338,7 +2333,7 @@ public static OperatorPrecedence GetOperatorPrecedence(this ExpressionSyntax exp
case SyntaxKind.UncheckedExpression: case SyntaxKind.UncheckedExpression:
case SyntaxKind.AnonymousMethodExpression: case SyntaxKind.AnonymousMethodExpression:
// From C# spec, 7.3.1: // From C# spec, 7.3.1:
// Primary: x.y f(x) a[x] x++ x-- new typeof default checked unchecked delegate // Primary: x.y x?.y x?[y] f(x) a[x] x++ x-- new typeof default checked unchecked delegate
return OperatorPrecedence.Primary; return OperatorPrecedence.Primary;
...@@ -2349,8 +2344,9 @@ public static OperatorPrecedence GetOperatorPrecedence(this ExpressionSyntax exp ...@@ -2349,8 +2344,9 @@ public static OperatorPrecedence GetOperatorPrecedence(this ExpressionSyntax exp
case SyntaxKind.PreIncrementExpression: case SyntaxKind.PreIncrementExpression:
case SyntaxKind.PreDecrementExpression: case SyntaxKind.PreDecrementExpression:
case SyntaxKind.CastExpression: case SyntaxKind.CastExpression:
case SyntaxKind.AwaitExpression:
// From C# spec, 7.3.1: // From C# spec, 7.3.1:
// Unary: + - ! ~ ++x --x (T)x // Unary: + - ! ~ ++x --x (T)x await Task
return OperatorPrecedence.Unary; return OperatorPrecedence.Unary;
......
...@@ -52,15 +52,9 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.Extensions ...@@ -52,15 +52,9 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.Extensions
<Extension()> <Extension()>
Public Function Parenthesize(expression As ExpressionSyntax) As ParenthesizedExpressionSyntax Public Function Parenthesize(expression As ExpressionSyntax) As ParenthesizedExpressionSyntax
Dim leadingTrivia = expression.GetLeadingTrivia() Return SyntaxFactory.ParenthesizedExpression(expression.WithoutTrivia()) _
Dim trailingTrivia = expression.GetTrailingTrivia() .WithTriviaFrom(expression) _
.WithAdditionalAnnotations(Simplifier.Annotation)
Dim strippedExpression = expression.WithoutLeadingTrivia().WithoutTrailingTrivia()
Return SyntaxFactory.ParenthesizedExpression(strippedExpression) _
.WithLeadingTrivia(leadingTrivia) _
.WithTrailingTrivia(trailingTrivia) _
.WithAdditionalAnnotations(Simplifier.Annotation)
End Function End Function
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册