diff --git a/src/EditorFeatures/CSharpTest/Diagnostics/MakeStatementAsynchronous/CSharpMakeStatementAsynchronousCodeFixTests.cs b/src/EditorFeatures/CSharpTest/Diagnostics/MakeStatementAsynchronous/CSharpMakeStatementAsynchronousCodeFixTests.cs new file mode 100644 index 0000000000000000000000000000000000000000..14de112ecf0eda0afe34950b5bdd9a9e33bc0924 --- /dev/null +++ b/src/EditorFeatures/CSharpTest/Diagnostics/MakeStatementAsynchronous/CSharpMakeStatementAsynchronousCodeFixTests.cs @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.CodeFixes.MakeStatementAsynchronous; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Test.Utilities; +using Xunit; + +namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.Diagnostics.MakeStatementAsynchronous +{ + [Trait(Traits.Feature, Traits.Features.CodeActionsMakeStatementAsynchronous)] + public class CSharpMakeStatementAsynchronousCodeFixTests : AbstractCSharpDiagnosticProviderBasedUserDiagnosticTest + { + internal override (DiagnosticAnalyzer, CodeFixProvider) CreateDiagnosticProviderAndFixer(Workspace workspace) + => (null, new CSharpMakeStatementAsynchronousCodeFixProvider()); + + private static readonly TestParameters s_asyncStreamsFeature = new TestParameters(parseOptions: new CSharpParseOptions(LanguageVersion.CSharp8)); + + private readonly string AsyncStreams = @" +namespace System.Collections.Generic +{ + public interface IAsyncEnumerable + { + IAsyncEnumerator GetAsyncEnumerator(); + } + + public interface IAsyncEnumerator : System.IAsyncDisposable + { + System.Threading.Tasks.ValueTask MoveNextAsync(); + T Current { get; } + } +} +namespace System +{ + public interface IAsyncDisposable + { + System.Threading.Tasks.ValueTask DisposeAsync(); + } +} +"; + + [Fact] + public async Task FixAllForeach() + { + await TestInRegularAndScript1Async( +AsyncStreams + @" +class Program +{ + void M(System.Collections.Generic.IAsyncEnumerable collection) + { + foreach (var i in {|FixAllInDocument:collection|}) { } + foreach (var j in collection) { } + } +}", +AsyncStreams + @" +class Program +{ + void M(System.Collections.Generic.IAsyncEnumerable collection) + { + await foreach (var i in collection) { } + await foreach (var j in collection) { } + } +}", parameters: s_asyncStreamsFeature); + } + + [Fact] + public async Task FixAllForeachDeconstruction() + { + await TestInRegularAndScript1Async( +AsyncStreams + @" +class Program +{ + void M(System.Collections.Generic.IAsyncEnumerable<(int, int)> collection) + { + foreach (var (i, j) in {|FixAllInDocument:collection|}) { } + foreach (var (k, l) in collection) { } + } +}", +AsyncStreams + @" +class Program +{ + void M(System.Collections.Generic.IAsyncEnumerable<(int, int)> collection) + { + await foreach (var (i, j) in collection) { } + await foreach (var (k, l) in collection) { } + } +}", parameters: s_asyncStreamsFeature); + } + + [Fact] + public async Task FixAllUsingStatement() + { + await TestInRegularAndScript1Async( +AsyncStreams + @" +class Program +{ + void M(System.IAsyncDisposable disposable) + { + using (var i = {|FixAllInDocument:disposable|}) { } + using (var j = disposable) { } + } +}", +AsyncStreams + @" +class Program +{ + void M(System.IAsyncDisposable disposable) + { + await using (var i = disposable) { } + await using (var j = disposable) { } + } +}", parameters: s_asyncStreamsFeature); + } + + [Fact] + public async Task FixAllUsingDeclaration() + { + await TestInRegularAndScript1Async( +AsyncStreams + @" +class Program +{ + void M(System.IAsyncDisposable disposable) + { + using var i = {|FixAllInDocument:disposable|}; + using var j = disposable; + } +}", +AsyncStreams + @" +class Program +{ + void M(System.IAsyncDisposable disposable) + { + await using var i = disposable; + await using var j = disposable; + } +}", parameters: s_asyncStreamsFeature); + } + + [Fact] + public async Task FixForeach() + { + await TestInRegularAndScript1Async( +AsyncStreams + @" +class Program +{ + void M(System.Collections.Generic.IAsyncEnumerable collection) + { + foreach (var i in [|collection|]) + { + } + } +}", +AsyncStreams + @" +class Program +{ + void M(System.Collections.Generic.IAsyncEnumerable collection) + { + await foreach (var i in collection) + { + } + } +}", parameters: s_asyncStreamsFeature); + } + + [Fact] + public async Task FixForeachDeconstruction() + { + await TestInRegularAndScript1Async( +AsyncStreams + @" +class Program +{ + void M(System.Collections.Generic.IAsyncEnumerable<(int, int)> collection) + { + foreach (var (i, j) in collection[||]) + { + } + } +}", +AsyncStreams + @" +class Program +{ + void M(System.Collections.Generic.IAsyncEnumerable<(int, int)> collection) + { + await foreach (var (i, j) in collection) + { + } + } +}", parameters: s_asyncStreamsFeature); + } + + [Fact] + public async Task FixUsingStatement() + { + await TestInRegularAndScript1Async( +AsyncStreams + @" +class Program +{ + void M(System.IAsyncDisposable disposable) + { + using (var i = disposable[||]) + { + } + } +}", +AsyncStreams + @" +class Program +{ + void M(System.IAsyncDisposable disposable) + { + await using (var i = disposable) + { + } + } +}", parameters: s_asyncStreamsFeature); + } + + [Fact] + public async Task FixUsingDeclaration() + { + await TestInRegularAndScript1Async( +AsyncStreams + @" +class Program +{ + void M(System.IAsyncDisposable disposable) + { + using var i = disposable[||]; + } +}", +AsyncStreams + @" +class Program +{ + void M(System.IAsyncDisposable disposable) + { + await using var i = disposable; + } +}", parameters: s_asyncStreamsFeature); + } + } +} diff --git a/src/Features/CSharp/Portable/CodeFixes/MakeStatementAsynchronous/CSharpMakeStatementAsynchronousCodeFixProvider.cs b/src/Features/CSharp/Portable/CodeFixes/MakeStatementAsynchronous/CSharpMakeStatementAsynchronousCodeFixProvider.cs new file mode 100644 index 0000000000000000000000000000000000000000..8921617b34b9538ac430fb191b19ed44981f507f --- /dev/null +++ b/src/Features/CSharp/Portable/CodeFixes/MakeStatementAsynchronous/CSharpMakeStatementAsynchronousCodeFixProvider.cs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Immutable; +using System.Composition; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CSharp.Extensions; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Editing; +using Microsoft.CodeAnalysis.Shared.Extensions; +using Roslyn.Utilities; + +namespace Microsoft.CodeAnalysis.CSharp.CodeFixes.MakeStatementAsynchronous +{ + [ExportCodeFixProvider(LanguageNames.CSharp, Name = PredefinedCodeFixProviderNames.MakeStatementAsynchronous), Shared] + internal class CSharpMakeStatementAsynchronousCodeFixProvider : SyntaxEditorBasedCodeFixProvider + { + // error CS8414: foreach statement cannot operate on variables of type 'IAsyncEnumerable' because 'IAsyncEnumerable' does not contain a public instance definition for 'GetEnumerator'. Did you mean 'await foreach'? + // error CS8418: 'IAsyncDisposable': type used in a using statement must be implicitly convertible to 'System.IDisposable'. Did you mean 'await using' rather than 'using'? + public sealed override ImmutableArray FixableDiagnosticIds => ImmutableArray.Create("CS8414", "CS8418"); + + public override async Task RegisterCodeFixesAsync(CodeFixContext context) + { + var diagnostic = context.Diagnostics.First(); + var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); + var node = root.FindNode(diagnostic.Location.SourceSpan, getInnermostNodeForTie: true); + + var constructToFix = TryGetStatementToFix(node); + if (constructToFix == null) + { + return; + } + + context.RegisterCodeFix(new MyCodeAction( + c => FixAsync(context.Document, diagnostic, c)), + context.Diagnostics); + } + + protected override Task FixAllAsync( + Document document, ImmutableArray diagnostics, + SyntaxEditor editor, CancellationToken cancellationToken) + { + var root = editor.OriginalRoot; + + foreach (var diagnostic in diagnostics) + { + var node = diagnostic.Location.FindNode(getInnermostNodeForTie: true, cancellationToken); + var statementToFix = TryGetStatementToFix(node); + if (statementToFix != null) + { + MakeStatementAsynchronous(editor, statementToFix); + } + } + + return Task.CompletedTask; + } + + private static void MakeStatementAsynchronous(SyntaxEditor editor, SyntaxNode statementToFix) + { + SyntaxNode newStatement; + switch (statementToFix) + { + case ForEachStatementSyntax forEach: + newStatement = forEach + .WithForEachKeyword(forEach.ForEachKeyword.WithLeadingTrivia()) + .WithAwaitKeyword(SyntaxFactory.Token(SyntaxKind.AwaitKeyword).WithLeadingTrivia(forEach.GetLeadingTrivia())); + break; + case ForEachVariableStatementSyntax forEachDeconstruction: + newStatement = forEachDeconstruction + .WithForEachKeyword(forEachDeconstruction.ForEachKeyword.WithLeadingTrivia()) + .WithAwaitKeyword(SyntaxFactory.Token(SyntaxKind.AwaitKeyword).WithLeadingTrivia(forEachDeconstruction.GetLeadingTrivia())); + break; + case UsingStatementSyntax usingStatement: + newStatement = usingStatement + .WithUsingKeyword(usingStatement.UsingKeyword.WithLeadingTrivia()) + .WithAwaitKeyword(SyntaxFactory.Token(SyntaxKind.AwaitKeyword).WithLeadingTrivia(usingStatement.GetLeadingTrivia())); + break; + case LocalDeclarationStatementSyntax localDeclaration: + newStatement = localDeclaration + .WithUsingKeyword(localDeclaration.UsingKeyword.WithLeadingTrivia()) + .WithAwaitKeyword(SyntaxFactory.Token(SyntaxKind.AwaitKeyword).WithLeadingTrivia(localDeclaration.GetLeadingTrivia())); + break; + default: + return; + } + + editor.ReplaceNode(statementToFix, newStatement); + } + + private static SyntaxNode TryGetStatementToFix(SyntaxNode node) + { + if (node.IsParentKind(SyntaxKind.ForEachStatement, SyntaxKind.ForEachVariableStatement, SyntaxKind.UsingStatement)) + { + return node.Parent; + } + + if (node is LocalDeclarationStatementSyntax localDeclaration && localDeclaration.UsingKeyword != default) + { + return node; + } + + return null; + } + + private class MyCodeAction : CodeAction.DocumentChangeAction + { + public MyCodeAction(Func> createChangedDocument) : + base(CSharpFeaturesResources.Add_await, + createChangedDocument, + CSharpFeaturesResources.Add_await) + { + } + } + } +} diff --git a/src/Features/Core/Portable/CodeFixes/PredefinedCodeFixProviderNames.cs b/src/Features/Core/Portable/CodeFixes/PredefinedCodeFixProviderNames.cs index b25630a5755ef94dcff6362d8f9ad07a02a96997..661b50e2fbf4a5d62a0eb82f59634d07c09cc435 100644 --- a/src/Features/Core/Portable/CodeFixes/PredefinedCodeFixProviderNames.cs +++ b/src/Features/Core/Portable/CodeFixes/PredefinedCodeFixProviderNames.cs @@ -38,6 +38,7 @@ internal static class PredefinedCodeFixProviderNames public const string ImplementInterface = nameof(ImplementInterface); public const string InsertMissingCast = nameof(InsertMissingCast); public const string MakeFieldReadonly = nameof(MakeFieldReadonly); + public const string MakeStatementAsynchronous = nameof(MakeStatementAsynchronous); public const string MakeMethodSynchronous = nameof(MakeMethodSynchronous); public const string MoveToTopOfFile = nameof(MoveToTopOfFile); public const string PopulateSwitch = nameof(PopulateSwitch); diff --git a/src/Test/Utilities/Portable/Traits/Traits.cs b/src/Test/Utilities/Portable/Traits/Traits.cs index 066fc83cf4cfab9496cd035ffcde63a8611e2fc4..1726c951f31f231101a2c0900a14902bbfb668d3 100644 --- a/src/Test/Utilities/Portable/Traits/Traits.cs +++ b/src/Test/Utilities/Portable/Traits/Traits.cs @@ -102,6 +102,7 @@ public static class Features public const string CodeActionsInvokeDelegateWithConditionalAccess = "CodeActions.InvokeDelegateWithConditionalAccess"; public const string CodeActionsLambdaSimplifier = "CodeActions.LambdaSimplifier"; public const string CodeActionsMakeFieldReadonly = "CodeActions.MakeFieldReadonly"; + public const string CodeActionsMakeStatementAsynchronous = "CodeActions.MakeStatementAsynchronous"; public const string CodeActionsMakeStructFieldsWritable = "CodeActions.MakeStructFieldsWritable"; public const string CodeActionsMakeLocalFunctionStatic = "CodeActions.MakeLocalFunctionStatic"; public const string CodeActionsMakeMethodAsynchronous = "CodeActions.MakeMethodAsynchronous";