diff --git a/src/EditorFeatures/CSharpTest/Diagnostics/Async/AddAwaitTests.cs b/src/EditorFeatures/CSharpTest/Diagnostics/Async/AddAwaitTests.cs index 15c3c365b25566c8d558f50f1b5a19fea60b440b..59d64ba59dffbf38f1b347b476113caa10c50b6f 100644 --- a/src/EditorFeatures/CSharpTest/Diagnostics/Async/AddAwaitTests.cs +++ b/src/EditorFeatures/CSharpTest/Diagnostics/Async/AddAwaitTests.cs @@ -108,6 +108,97 @@ async void Test() Test(initial, expected); } + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpression() + { + Test( +@"using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { int myInt = [|MyIntMethodAsync ( )|] ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } ", +@"using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { int myInt = await MyIntMethodAsync ( ) ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpressionWithConversion() + { + Test( +@"using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { long myInt = [|MyIntMethodAsync ( )|] ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } ", +@"using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { long myInt = await MyIntMethodAsync ( ) ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpressionWithConversionInNonAsyncFunction() + { + TestMissing( +@"using System . Threading . Tasks ; class TestClass { private Task MyTestMethod1Async ( ) { long myInt = [|MyIntMethodAsync ( )|] ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpressionWithConversionInAsyncFunction() + { + Test( +@"using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { long myInt = [|MyIntMethodAsync ( )|] ; } private Task < object > MyIntMethodAsync ( ) { return Task . FromResult ( new object ( ) ) ; } } ", +@"using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { long myInt = await MyIntMethodAsync ( ) ; } private Task < object > MyIntMethodAsync ( ) { return Task . FromResult ( new object ( ) ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void Test() + { + Test( +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Action lambda = async ( ) => { int myInt = [|MyIntMethodAsync ( )|] ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } ", +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Action lambda = async ( ) => { int myInt = await MyIntMethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpression2() + { + Test( +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > lambda = async ( ) => { int myInt = [|MyIntMethodAsync ( )|] ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } ", +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > lambda = async ( ) => { int myInt = await MyIntMethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpression3() + { + TestMissing( +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > lambda = ( ) => { int myInt = MyInt [||] MethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpression4() + { + TestMissing( +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Action lambda = ( ) => { int myInt = MyIntM [||] ethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpression5() + { + Test( +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Action @delegate = async delegate { int myInt = [|MyIntMethodAsync ( )|] ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } ", +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Action @delegate = async delegate { int myInt = await MyIntMethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpression6() + { + Test( +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > @delegate = async delegate { int myInt = [|MyIntMethodAsync ( )|] ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } ", +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > @delegate = async delegate { int myInt = await MyIntMethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpression7() + { + TestMissing( +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Action @delegate = delegate { int myInt = MyInt [||] MethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsAddAwait)] + public void TestAssignmentExpression8() + { + TestMissing( +@"using System ; using System . Threading . Tasks ; class TestClass { private async Task MyTestMethod1Async ( ) { Func < Task > @delegate = delegate { int myInt = MyIntM [||] ethodAsync ( ) ; } ; } private Task < int > MyIntMethodAsync ( ) { return Task . FromResult ( result : 1 ) ; } } "); + } + internal override Tuple CreateDiagnosticProviderAndFixer(Workspace workspace) { return new Tuple(null, new CSharpAddAwaitCodeFixProvider()); diff --git a/src/EditorFeatures/VisualBasicTest/Diagnostics/Async/AddAwaitTests.vb b/src/EditorFeatures/VisualBasicTest/Diagnostics/Async/AddAwaitTests.vb index 5946540664269007e52a008451159eb76cd1c8ae..1283cce8947b07b4157bc4652d36da9a2ab2c495 100644 --- a/src/EditorFeatures/VisualBasicTest/Diagnostics/Async/AddAwaitTests.vb +++ b/src/EditorFeatures/VisualBasicTest/Diagnostics/Async/AddAwaitTests.vb @@ -168,6 +168,54 @@ End Module Test(initial, expected) End Sub + + Public Sub TestAddAwaitOnAssignment() + Test( +NewLines("Imports System.Threading.Tasks \n Module Program \n Async Function MyTestMethod1Async() As Task \n Dim myInt As Integer = [|MyIntMethodAsync()|] \n End Function \n Private Function MyIntMethodAsync() As Task(Of Integer) \n Return Task.FromResult(1) \n End Function \n End Module"), +NewLines("Imports System.Threading.Tasks \n Module Program \n Async Function MyTestMethod1Async() As Task \n Dim myInt As Integer = Await MyIntMethodAsync() \n End Function \n Private Function MyIntMethodAsync() As Task(Of Integer) \n Return Task.FromResult(1) \n End Function \n End Module")) + End Sub + + + Public Sub TestAddAwaitOnAssignment2() + Test( +NewLines("Imports System.Threading.Tasks \n Module Program \n Async Function MyTestMethod1Async() As Task \n Dim myInt As Long = [|MyIntMethodAsync()|] \n End Function \n Private Function MyIntMethodAsync() As Task(Of Integer) \n Return Task.FromResult(1) \n End Function \n End Module"), +NewLines("Imports System.Threading.Tasks \n Module Program \n Async Function MyTestMethod1Async() As Task \n Dim myInt As Long = Await MyIntMethodAsync() \n End Function \n Private Function MyIntMethodAsync() As Task(Of Integer) \n Return Task.FromResult(1) \n End Function \n End Module")) + End Sub + + + Public Sub TestAddAwaitOnAssignment3() + TestMissing( +NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim myInt As Long = MyInt[||]MethodAsync() \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module")) + End Sub + + + Public Sub TestAddAwaitOnAssignment4() + Test( +NewLines("Imports System.Threading.Tasks \n Module Program \n Async Function MyTestMethod1Async() As Task \n Dim myInt As Long = [|MyIntMethodAsync()|] \n End Function \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module"), +NewLines("Imports System.Threading.Tasks \n Module Program \n Async Function MyTestMethod1Async() As Task \n Dim myInt As Long = Await MyIntMethodAsync() \n End Function \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module")) + End Sub + + + Public Sub TestAddAwaitOnAssignment5() + Test( +NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim lambda = Async Sub() \n Dim myInt As Long = [|MyIntMethodAsync()|] \n End Sub \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module"), +NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim lambda = Async Sub() \n Dim myInt As Long = Await MyIntMethodAsync() \n End Sub \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module")) + End Sub + + + Public Sub TestAddAwaitOnAssignment6() + Test( +NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim lambda = Async Function() As Task \n Dim myInt As Long = [|MyIntMethodAsync()|] \n End Function \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module"), +NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim lambda = Async Function() As Task \n Dim myInt As Long = Await MyIntMethodAsync() \n End Function \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module")) + End Sub + + + Public Sub TestAddAwaitOnAssignment7() + Test( +NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim myInt As Long \n Dim lambda = Async Sub() myInt = [|MyIntMethodAsync()|] \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module"), +NewLines("Imports System.Threading.Tasks \n Module Program \n Sub MyTestMethod1Async() \n Dim myInt As Long \n Dim lambda = Async Sub() myInt = Await MyIntMethodAsync() \n End Sub \n Private Function MyIntMethodAsync() As Task(Of Object) \n Return Task.FromResult(New Object()) \n End Function \n End Module")) + End Sub + Friend Overrides Function CreateDiagnosticProviderAndFixer(workspace As Workspace) As Tuple(Of DiagnosticAnalyzer, CodeFixProvider) Return Tuple.Create(Of DiagnosticAnalyzer, CodeFixProvider)( Nothing, diff --git a/src/Features/CSharp/Portable/CodeFixes/Async/CSharpAddAwaitCodeFixProvider.cs b/src/Features/CSharp/Portable/CodeFixes/Async/CSharpAddAwaitCodeFixProvider.cs index bee42e5f99670c2bbf31aebe701b88f67b598e96..723a85356ccc347f0c5993edfeb9849ee9eb2e44 100644 --- a/src/Features/CSharp/Portable/CodeFixes/Async/CSharpAddAwaitCodeFixProvider.cs +++ b/src/Features/CSharp/Portable/CodeFixes/Async/CSharpAddAwaitCodeFixProvider.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Immutable; using System.Composition; using System.Threading; @@ -10,6 +11,8 @@ using Microsoft.CodeAnalysis.Formatting; using Roslyn.Utilities; using Resources = Microsoft.CodeAnalysis.CSharp.CSharpFeaturesResources; +using Microsoft.CodeAnalysis.LanguageServices; +using System.Linq; namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.Async { @@ -26,17 +29,23 @@ internal class CSharpAddAwaitCodeFixProvider : AbstractAddAsyncAwaitCodeFixProvi /// private const string CS4016 = "CS4016"; - public override ImmutableArray FixableDiagnosticIds - { - get { return ImmutableArray.Create(CS4014, CS4016); } - } + /// + /// cannot implicitly convert from 'X' to 'Y'. + /// + private const string CS0029 = "CS0029"; - protected override string GetDescription(Diagnostic diagnostic, SyntaxNode node, SemanticModel semanticModel, CancellationToken cancellationToken) - { - return Resources.InsertAwait; - } + public override ImmutableArray FixableDiagnosticIds => ImmutableArray.Create(CS0029, CS4014, CS4016); - protected override Task GetNewRoot(SyntaxNode root, SyntaxNode oldNode, SemanticModel semanticModel, Diagnostic diagnostic, Document document, CancellationToken cancellationToken) + + protected override string GetDescription(Diagnostic diagnostic, SyntaxNode node, SemanticModel semanticModel, CancellationToken cancellationToken) => Resources.InsertAwait; + + protected override Task GetNewRoot( + SyntaxNode root, + SyntaxNode oldNode, + SemanticModel semanticModel, + Diagnostic diagnostic, + Document document, + CancellationToken cancellationToken) { var expression = oldNode as ExpressionSyntax; @@ -49,13 +58,17 @@ protected override Task GetNewRoot(SyntaxNode root, SyntaxNode oldNo } return Task.FromResult(root.ReplaceNode(oldNode, ConvertToAwaitExpression(expression))); + case CS4016: - if (expression == null) + if (!DoesExpressionReturnTask(expression, semanticModel)) { return SpecializedTasks.Default(); } - if (!IsCorrectReturnType(expression, semanticModel)) + return Task.FromResult(root.ReplaceNode(oldNode, ConvertToAwaitExpression(expression))); + + case CS0029: + if (!DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression, semanticModel, document.Project, cancellationToken)) { return SpecializedTasks.Default(); } @@ -66,17 +79,81 @@ protected override Task GetNewRoot(SyntaxNode root, SyntaxNode oldNo } } - private bool IsCorrectReturnType(ExpressionSyntax expression, SemanticModel semanticModel) + + + private static bool DoesExpressionReturnTask(ExpressionSyntax expression, SemanticModel semanticModel) { + if (expression == null) + { + return false; + } + INamedTypeSymbol taskType = null; INamedTypeSymbol returnType = null; - return TryGetTypes(expression, semanticModel, out taskType, out returnType) && + return TryGetTaskAndExpressionTypes(expression, semanticModel, out taskType, out returnType) && semanticModel.Compilation.ClassifyConversion(taskType, returnType).Exists; } + private static bool DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(ExpressionSyntax expression, SemanticModel semanticModel, Project project, CancellationToken cancellationToken) + { + if (expression == null) + { + return false; + } + + if (!IsInAsyncFunction(expression)) + { + return false; + } + + INamedTypeSymbol taskType = null; + INamedTypeSymbol rightSideType = null; + if (!TryGetTaskAndExpressionTypes(expression, semanticModel, out taskType, out rightSideType)) + { + return false; + } + + var compilation = semanticModel.Compilation; + if (!compilation.ClassifyConversion(taskType, rightSideType).Exists) + { + return false; + } + + if(!rightSideType.IsGenericType) + { + return false; + } + + var typeArguments = rightSideType.TypeArguments; + var typeInferer = project.LanguageServices.GetService(); + var inferredTypes = typeInferer.InferTypes(semanticModel, expression, cancellationToken); + return typeArguments.Any(ta => inferredTypes.Any(it => compilation.ClassifyConversion(it, ta).Exists)); + } + + private static bool IsInAsyncFunction(ExpressionSyntax expression) + { + foreach (var node in expression.Ancestors()) + { + switch (node.Kind()) + { + case SyntaxKind.ParenthesizedLambdaExpression: + case SyntaxKind.SimpleLambdaExpression: + case SyntaxKind.AnonymousMethodExpression: + return (node as AnonymousFunctionExpressionSyntax)?.AsyncKeyword.IsMissing == true; + case SyntaxKind.MethodDeclaration: + return (node as MethodDeclarationSyntax)?.Modifiers.Any(SyntaxKind.AsyncKeyword) == true; + default: + continue; + } + } + + return false; + } + private static ExpressionSyntax ConvertToAwaitExpression(ExpressionSyntax expression) { return SyntaxFactory.AwaitExpression(expression) + .WithTriviaFrom(expression) .WithAdditionalAnnotations(Formatter.Annotation); } } diff --git a/src/Features/CSharp/Portable/CodeFixes/Async/CSharpConvertToAsyncMethodCodeFixProvider.cs b/src/Features/CSharp/Portable/CodeFixes/Async/CSharpConvertToAsyncMethodCodeFixProvider.cs index 5d7554f53708cd5b343fa1c688b7094441e66a50..d81aa2959203792ea0e7fb2a6c42f3ece435ef4b 100644 --- a/src/Features/CSharp/Portable/CodeFixes/Async/CSharpConvertToAsyncMethodCodeFixProvider.cs +++ b/src/Features/CSharp/Portable/CodeFixes/Async/CSharpConvertToAsyncMethodCodeFixProvider.cs @@ -89,8 +89,7 @@ private MethodDeclarationSyntax ConvertToAsyncFunction(MethodDeclarationSyntax m { return methodDeclaration.WithReturnType( SyntaxFactory.ParseTypeName("Task") - .WithLeadingTrivia(methodDeclaration.ReturnType.GetLeadingTrivia()) - .WithTrailingTrivia(methodDeclaration.ReturnType.GetTrailingTrivia())); + .WithTriviaFrom(methodDeclaration)); } } } diff --git a/src/Features/Core/Portable/CodeFixes/Async/AbstractAddAsyncAwaitCodeFixProvider.cs b/src/Features/Core/Portable/CodeFixes/Async/AbstractAddAsyncAwaitCodeFixProvider.cs index 80831e24a2f7d9baae83b7067a53de5ef4871cf3..9f24c0eea5c4cb13ef79b3cbb2658fabd04d9ede 100644 --- a/src/Features/Core/Portable/CodeFixes/Async/AbstractAddAsyncAwaitCodeFixProvider.cs +++ b/src/Features/Core/Portable/CodeFixes/Async/AbstractAddAsyncAwaitCodeFixProvider.cs @@ -28,7 +28,7 @@ protected override async Task GetCodeFix(SyntaxNode root, SyntaxNode return null; } - protected bool TryGetTypes( + protected static bool TryGetTaskAndExpressionTypes( SyntaxNode expression, SemanticModel semanticModel, out INamedTypeSymbol source, diff --git a/src/Features/VisualBasic/Portable/CodeFixes/Async/VisualBasicAddAwaitCodeFixProvider.vb b/src/Features/VisualBasic/Portable/CodeFixes/Async/VisualBasicAddAwaitCodeFixProvider.vb index b04b3f183cc14f681dcf287f72db3115470e3c25..5df152b7e9052858b99c5c5c7dddf299306d07b4 100644 --- a/src/Features/VisualBasic/Portable/CodeFixes/Async/VisualBasicAddAwaitCodeFixProvider.vb +++ b/src/Features/VisualBasic/Portable/CodeFixes/Async/VisualBasicAddAwaitCodeFixProvider.vb @@ -7,6 +7,7 @@ Imports Microsoft.CodeAnalysis Imports Microsoft.CodeAnalysis.CodeFixes Imports Microsoft.CodeAnalysis.CodeFixes.Async Imports Microsoft.CodeAnalysis.Formatting +Imports Microsoft.CodeAnalysis.LanguageServices Imports Microsoft.CodeAnalysis.VisualBasic.Syntax Imports Resources = Microsoft.CodeAnalysis.VisualBasic.VBFeaturesResources.VBFeaturesResources @@ -16,10 +17,11 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async Friend Class VisualBasicAddAwaitCodeFixProvider Inherits AbstractAddAsyncAwaitCodeFixProvider + Friend Const BC30311 As String = "BC30311" ' error BC30311: Value of type 'X' cannot be converted to 'Y'. Friend Const BC37055 As String = "BC37055" ' error BC37055: Since this is an async method, the return expression must be of type 'blah' rather than 'baz' Friend Const BC42358 As String = "BC42358" ' error BC42358: Because this call is not awaited, execution of the current method continues before the call is completed. - Friend ReadOnly Ids As ImmutableArray(Of String) = ImmutableArray.Create(BC37055, BC42358) + Friend ReadOnly Ids As ImmutableArray(Of String) = ImmutableArray.Create(BC30311, BC37055, BC42358) Public Overrides ReadOnly Property FixableDiagnosticIds As ImmutableArray(Of String) Get @@ -35,11 +37,13 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async Dim expression = TryCast(oldNode, ExpressionSyntax) Select Case diagnostic.Id - Case BC37055 - If expression Is Nothing Then + Case BC30311 + If Not DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression, semanticModel, document.Project, cancellationToken) Then Return Task.FromResult(Of SyntaxNode)(Nothing) End If - If Not IsCorrectReturnType(expression, semanticModel) Then + Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression))) + Case BC37055 + If Not DoesExpressionReturnTask(expression, semanticModel) Then Return Task.FromResult(Of SyntaxNode)(Nothing) End If Return Task.FromResult(root.ReplaceNode(oldNode, ConverToAwaitExpression(expression))) @@ -53,10 +57,65 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Async End Select End Function - Private Function IsCorrectReturnType(expression As ExpressionSyntax, semanticModel As SemanticModel) As Boolean + Private Function DoesExpressionReturnGenericTaskWhoseArgumentsMatchLeftSide(expression As ExpressionSyntax, semanticModel As SemanticModel, project As Project, cancellationToken As CancellationToken) As Boolean + If expression Is Nothing Then + Return False + End If + + If Not IsInAsyncBlock(expression) Then + Return False + End If + + Dim taskType As INamedTypeSymbol = Nothing + Dim rightSideType As INamedTypeSymbol = Nothing + If Not TryGetTaskAndExpressionTypes(expression, semanticModel, taskType, rightSideType) Then + Return False + End If + + Dim compilation = semanticModel.Compilation + If Not compilation.ClassifyConversion(taskType, rightSideType).Exists Then + Return False + End If + + If Not rightSideType.IsGenericType Then + Return False + End If + + Dim typeArguments = rightSideType.TypeArguments + Dim typeInferer = project.LanguageServices.GetService(Of ITypeInferenceService) + Dim inferredTypes = typeInferer.InferTypes(semanticModel, expression, cancellationToken) + Return typeArguments.Any(Function(ta) inferredTypes.Any(Function(it) compilation.ClassifyConversion(it, ta).Exists)) + End Function + + Private Function IsInAsyncBlock(expression As ExpressionSyntax) As Boolean + + For Each ancestor In expression.Ancestors + Select Case ancestor.Kind + Case SyntaxKind.MultiLineFunctionLambdaExpression, + SyntaxKind.MultiLineSubLambdaExpression, + SyntaxKind.SingleLineFunctionLambdaExpression, + SyntaxKind.SingleLineSubLambdaExpression + Dim result = TryCast(ancestor, LambdaExpressionSyntax)?.SubOrFunctionHeader?.Modifiers.Any(SyntaxKind.AsyncKeyword) + Return result.HasValue AndAlso result.Value + Case SyntaxKind.SubBlock, + SyntaxKind.FunctionBlock + Dim result = TryCast(ancestor, MethodBlockBaseSyntax)?.BlockStatement?.Modifiers.Any(SyntaxKind.AsyncKeyword) + Return result.HasValue AndAlso result.Value + Case Else + Continue For + End Select + Next + Return False + End Function + + Private Function DoesExpressionReturnTask(expression As ExpressionSyntax, semanticModel As SemanticModel) As Boolean + If expression Is Nothing Then + Return Nothing + End If + Dim taskType As INamedTypeSymbol = Nothing Dim returnType As INamedTypeSymbol = Nothing - Return TryGetTypes(expression, semanticModel, taskType, returnType) AndAlso + Return TryGetTaskAndExpressionTypes(expression, semanticModel, taskType, returnType) AndAlso semanticModel.Compilation.ClassifyConversion(taskType, returnType).Exists End Function