From 45cb3c4644e380bee23085b6b13703ace48d2b3f Mon Sep 17 00:00:00 2001 From: Victor Z Date: Mon, 20 Nov 2017 17:50:24 +0300 Subject: [PATCH] Enabled 'Use Local Function' refactoring for delegate which is invoked with Invoke --- .../UseLocalFunction/UseLocalFunctionTests.cs | 36 ++++++++--- .../CSharpUseLocalFunctionCodeFixProvider.cs | 61 ++++++++++++++++--- ...SharpUseLocalFunctionDiagnosticAnalyzer.cs | 29 +++++++-- 3 files changed, 100 insertions(+), 26 deletions(-) diff --git a/src/EditorFeatures/CSharpTest/UseLocalFunction/UseLocalFunctionTests.cs b/src/EditorFeatures/CSharpTest/UseLocalFunction/UseLocalFunctionTests.cs index 3b4809b2922..e06bfc9d398 100644 --- a/src/EditorFeatures/CSharpTest/UseLocalFunction/UseLocalFunctionTests.cs +++ b/src/EditorFeatures/CSharpTest/UseLocalFunction/UseLocalFunctionTests.cs @@ -1436,25 +1436,41 @@ public void Caller() } [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)] - [WorkItem(22672, "https://github.com/dotnet/roslyn/issues/22672")] - public async Task TestMissingIfUsedInMemberAccess() + [WorkItem(23150, "https://github.com/dotnet/roslyn/issues/23150")] + public async Task TestWithInvokeMethod() { - await TestMissingAsync( + await TestInRegularAndScript1Async( @"using System; class Enclosing where T : class { - delegate T MyDelegate(T t = null); + delegate T MyDelegate(T t = null); - public class Class - { - public void Caller() + public class Class { - MyDelegate [||]local = x => x; + public void Caller() + { + MyDelegate [||]local = x => x; - local.Invoke(); + local.Invoke(); + } + } +}", +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller() + { + T local(T x = null) => x; + + local(); + } } - } }"); } diff --git a/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionCodeFixProvider.cs b/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionCodeFixProvider.cs index ae3995e7963..4b184e78b4c 100644 --- a/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionCodeFixProvider.cs +++ b/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionCodeFixProvider.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Microsoft.CodeAnalysis.CodeActions; using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CSharp.CodeGeneration; using Microsoft.CodeAnalysis.CSharp.Extensions; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; @@ -47,6 +48,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) var localDeclarationToLambda = new Dictionary(); var nodesToTrack = new HashSet(); + var memberAccess = new List(); foreach (var diagnostic in diagnostics) { var localDeclaration = (LocalDeclarationStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(cancellationToken); @@ -56,8 +58,17 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) nodesToTrack.Add(localDeclaration); nodesToTrack.Add(lambda); + + if (diagnostic.AdditionalLocations.Count > 2) + { + for (var i = 2; i < diagnostic.AdditionalLocations.Count; i++) + { + memberAccess.Add((MemberAccessExpressionSyntax)diagnostic.AdditionalLocations[i].FindNode(cancellationToken)); + } + } } + nodesToTrack.AddRange(memberAccess); var root = editor.OriginalRoot; var currentRoot = root.TrackNodes(nodesToTrack); @@ -66,7 +77,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) foreach (var (originalLocalDeclaration, originalLambda) in localDeclarationToLambda.OrderByDescending(kvp => kvp.Value.SpanStart)) { var delegateType = (INamedTypeSymbol)semanticModel.GetTypeInfo(originalLambda, cancellationToken).ConvertedType; - var parameterList = GenerateParameterList(semanticModel, originalLambda, cancellationToken); + var parameterList = GenerateParameterList(semanticModel, originalLambda, delegateType, cancellationToken); var currentLocalDeclaration = currentRoot.GetCurrentNode(originalLocalDeclaration); var currentLambda = currentRoot.GetCurrentNode(originalLambda); @@ -74,7 +85,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) currentRoot = ReplaceAnonymousWithLocalFunction( document.Project.Solution.Workspace, currentRoot, currentLocalDeclaration, currentLambda, - delegateType, parameterList, + delegateType, parameterList, memberAccess.Select(node => currentRoot.GetCurrentNode(node)).ToImmutableArray(), cancellationToken); } @@ -85,6 +96,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) Workspace workspace, SyntaxNode currentRoot, LocalDeclarationStatementSyntax localDeclaration, LambdaExpressionSyntax lambda, INamedTypeSymbol delegateType, ParameterListSyntax parameterList, + ImmutableArray notDirectInvocationUsages, CancellationToken cancellationToken) { var newLocalFunctionStatement = CreateLocalFunctionStatement( @@ -104,6 +116,16 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) editor.RemoveNode(lambdaStatement); } + if (!notDirectInvocationUsages.IsEmpty) + { + foreach (var usage in notDirectInvocationUsages) + { + editor.ReplaceNode( + usage.Parent, + SyntaxFactory.InvocationExpression(usage.Expression, (usage.Parent as InvocationExpressionSyntax).ArgumentList)); + } + } + return editor.GetChangedRoot(); } @@ -144,33 +166,40 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) } private ParameterListSyntax GenerateParameterList( - SemanticModel semanticModel, AnonymousFunctionExpressionSyntax anonymousFunction, CancellationToken cancellationToken) + SemanticModel semanticModel, AnonymousFunctionExpressionSyntax anonymousFunction, INamedTypeSymbol namedTypeSymbol, CancellationToken cancellationToken) { switch (anonymousFunction) { case SimpleLambdaExpressionSyntax simpleLambda: - return GenerateSimpleLambdaParameterList(semanticModel, simpleLambda, cancellationToken); + return GenerateSimpleLambdaParameterList(semanticModel, simpleLambda, namedTypeSymbol, cancellationToken); case ParenthesizedLambdaExpressionSyntax parenthesizedLambda: - return GenerateParenthesizedLambdaParameterList(semanticModel, parenthesizedLambda, cancellationToken); + return GenerateParenthesizedLambdaParameterList(semanticModel, parenthesizedLambda, namedTypeSymbol, cancellationToken); default: throw ExceptionUtilities.UnexpectedValue(anonymousFunction); } } private ParameterListSyntax GenerateSimpleLambdaParameterList( - SemanticModel semanticModel, SimpleLambdaExpressionSyntax lambdaExpression, CancellationToken cancellationToken) + SemanticModel semanticModel, SimpleLambdaExpressionSyntax lambdaExpression, INamedTypeSymbol namedTypeSymbol, CancellationToken cancellationToken) { var parameter = semanticModel.GetDeclaredSymbol(lambdaExpression.Parameter, cancellationToken); var type = parameter?.Type.GenerateTypeSyntax() ?? s_objectType; + var parameterSyntax = SyntaxFactory.Parameter(lambdaExpression.Parameter.Identifier).WithType(type); + var param = namedTypeSymbol.DelegateInvokeMethod.Parameters[0]; + if (param.HasExplicitDefaultValue) + { + parameterSyntax = parameterSyntax.WithDefault(GetDefaultValue(param)); + } + return SyntaxFactory.ParameterList( - SyntaxFactory.SeparatedList().Add( - SyntaxFactory.Parameter(lambdaExpression.Parameter.Identifier).WithType(type))); + SyntaxFactory.SeparatedList().Add(parameterSyntax)); } private ParameterListSyntax GenerateParenthesizedLambdaParameterList( - SemanticModel semanticModel, ParenthesizedLambdaExpressionSyntax lambdaExpression, CancellationToken cancellationToken) + SemanticModel semanticModel, ParenthesizedLambdaExpressionSyntax lambdaExpression, INamedTypeSymbol namedTypeSymbol, CancellationToken cancellationToken) { + int i = 0; return lambdaExpression.ParameterList.ReplaceNodes( lambdaExpression.ParameterList.Parameters, (parameterNode, _) => @@ -181,10 +210,22 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) } var parameter = semanticModel.GetDeclaredSymbol(parameterNode, cancellationToken); - return parameterNode.WithType(parameter?.Type.GenerateTypeSyntax() ?? s_objectType); + parameterNode = parameterNode.WithType(parameter?.Type.GenerateTypeSyntax() ?? s_objectType); + var param = namedTypeSymbol.DelegateInvokeMethod.Parameters[i++]; + if (param.HasExplicitDefaultValue) + { + parameterNode = parameterNode.WithDefault(GetDefaultValue(param)); + } + + return parameterNode; }); } + private static EqualsValueClauseSyntax GetDefaultValue(IParameterSymbol parameter) + { + return SyntaxFactory.EqualsValueClause(ExpressionGenerator.GenerateExpression(parameter.Type, parameter.ExplicitDefaultValue, canUseFieldReference: true)); + } + private class MyCodeAction : CodeAction.DocumentChangeAction { public MyCodeAction(Func> createChangedDocument) diff --git a/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionDiagnosticAnalyzer.cs b/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionDiagnosticAnalyzer.cs index 2fa84f56c74..e9b80c02578 100644 --- a/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionDiagnosticAnalyzer.cs +++ b/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionDiagnosticAnalyzer.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Collections.Immutable; using System.Linq.Expressions; using System.Threading; @@ -116,7 +117,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp return; } - if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, cancellationToken)) + if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, out var locations, cancellationToken)) { return; } @@ -126,6 +127,11 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp localDeclaration.GetLocation(), anonymousFunction.GetLocation()); + if (!locations.IsEmpty) + { + additionalLocations = additionalLocations.AddRange(locations); + } + if (severity != DiagnosticSeverity.Hidden) { // If the diagnostic is not hidden, then just place the user visible part @@ -197,10 +203,12 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp private bool CanReplaceAnonymousWithLocalFunction( SemanticModel semanticModel, INamedTypeSymbol expressionTypeOpt, ISymbol local, BlockSyntax block, - AnonymousFunctionExpressionSyntax anonymousFunction, CancellationToken cancellationToken) + AnonymousFunctionExpressionSyntax anonymousFunction, out ImmutableArray locations, CancellationToken cancellationToken) { // Check all the references to the anonymous function and disallow the conversion if // they're used in certain ways. + List delegateInvokeMethodLocations = null; + locations = ImmutableArray.Empty; var anonymousFunctionStart = anonymousFunction.SpanStart; foreach (var descendentNode in block.DescendantNodes()) { @@ -232,11 +240,19 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp return false; } - if (nodeToCheck.Parent is MemberAccessExpressionSyntax) + if (nodeToCheck.Parent is MemberAccessExpressionSyntax memberAccessExpression) { - // They're doing something like "del.ToString()". Can't do this with a - // local function. - return false; + if (memberAccessExpression.Name.Identifier.Text != WellKnownMemberNames.DelegateInvokeName) + { + // They're doing something like "del.ToString()". Can't do this with a + // local function. + return false; + } + else + { + delegateInvokeMethodLocations = delegateInvokeMethodLocations ?? new List(); + delegateInvokeMethodLocations.Add(memberAccessExpression.GetLocation()); + } } var convertedType = semanticModel.GetTypeInfo(nodeToCheck, cancellationToken).ConvertedType; @@ -258,6 +274,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp } } + locations = delegateInvokeMethodLocations.AsImmutableOrEmpty(); return true; } -- GitLab