diff --git a/src/EditorFeatures/CSharpTest/UseLocalFunction/UseLocalFunctionTests.cs b/src/EditorFeatures/CSharpTest/UseLocalFunction/UseLocalFunctionTests.cs index 3b4809b292290c3f5dae1fa220b747829693685a..0d1be07e7e451507af47afacd395c843d7747307 100644 --- a/src/EditorFeatures/CSharpTest/UseLocalFunction/UseLocalFunctionTests.cs +++ b/src/EditorFeatures/CSharpTest/UseLocalFunction/UseLocalFunctionTests.cs @@ -1437,24 +1437,50 @@ public void Caller() [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)] [WorkItem(22672, "https://github.com/dotnet/roslyn/issues/22672")] - public async Task TestMissingIfUsedInMemberAccess() + public async Task TestMissingIfUsedInMemberAccess1() { await TestMissingAsync( @"using System; class Enclosing where T : class { - delegate T MyDelegate(T t = null); + delegate T MyDelegate(T t = null); - public class Class - { - public void Caller() + public class Class { - MyDelegate [||]local = x => x; + public void Caller() + { + MyDelegate [||]local = x => x; - local.Invoke(); + var str = local.ToString(); + } + } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)] + [WorkItem(23150, "https://github.com/dotnet/roslyn/issues/23150")] + public async Task TestMissingIfUsedInMemberAccess2() + { + await TestMissingAsync( +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller(T t) + { + MyDelegate[||]local = x => x; + + Console.Write(local.Invoke(t)); + + var str = local.ToString(); + local.Invoke(t); + } } - } }"); } @@ -1479,6 +1505,174 @@ public void Caller() Expression expression = () => local(null); } } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)] + [WorkItem(23150, "https://github.com/dotnet/roslyn/issues/23150")] + public async Task TestWithInvokeMethod1() + { + await TestInRegularAndScript1Async( +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller() + { + MyDelegate [||]local = x => x; + + local.Invoke(); + } + } +}", +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller() + { + T local(T x = null) => x; + + local(); + } + } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)] + [WorkItem(23150, "https://github.com/dotnet/roslyn/issues/23150")] + public async Task TestWithInvokeMethod2() + { + await TestInRegularAndScript1Async( +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller(T t) + { + MyDelegate [||]local = x => x; + + Console.Write(local.Invoke(t)); + } + } +}", +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller(T t) + { + T local(T x = null) => x; + + Console.Write(local(t)); + } + } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)] + [WorkItem(23150, "https://github.com/dotnet/roslyn/issues/23150")] + public async Task TestWithInvokeMethod3() + { + await TestInRegularAndScript1Async( +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller(T t) + { + MyDelegate [||]local = x => x; + + Console.Write(local.Invoke(t)); + + var val = local.Invoke(t); + local.Invoke(t); + } + } +}", +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller(T t) + { + T local(T x = null) => x; + + Console.Write(local(t)); + + var val = local(t); + local(t); + } + } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)] + [WorkItem(23150, "https://github.com/dotnet/roslyn/issues/23150")] + public async Task TestWithInvokeMethod4() + { + await TestInRegularAndScript1Async( +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller(T t) + { + MyDelegate [||]local = x => x; + + Console.Write(local.Invoke(t)); + + var val = local.Invoke(t); + local(t); + } + } +}", +@"using System; + +class Enclosing where T : class +{ + delegate T MyDelegate(T t = null); + + public class Class + { + public void Caller(T t) + { + T local(T x = null) => x; + + Console.Write(local(t)); + + var val = local(t); + local(t); + } + } }"); } } diff --git a/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionCodeFixProvider.cs b/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionCodeFixProvider.cs index ae3995e796339db30bf51abfbc681e2fd038978b..9a7299e73fe6cc0f1b003781f077e9a56fcf857e 100644 --- a/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionCodeFixProvider.cs +++ b/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionCodeFixProvider.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Microsoft.CodeAnalysis.CodeActions; using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CSharp.CodeGeneration; using Microsoft.CodeAnalysis.CSharp.Extensions; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; @@ -47,6 +48,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) var localDeclarationToLambda = new Dictionary(); var nodesToTrack = new HashSet(); + var explicitInvokeCalls = new List(); foreach (var diagnostic in diagnostics) { var localDeclaration = (LocalDeclarationStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(cancellationToken); @@ -56,8 +58,14 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) nodesToTrack.Add(localDeclaration); nodesToTrack.Add(lambda); + + for (var i = 2; i < diagnostic.AdditionalLocations.Count; i++) + { + explicitInvokeCalls.Add((MemberAccessExpressionSyntax)diagnostic.AdditionalLocations[i].FindNode(getInnermostNodeForTie: true, cancellationToken)); + } } + nodesToTrack.AddRange(explicitInvokeCalls); var root = editor.OriginalRoot; var currentRoot = root.TrackNodes(nodesToTrack); @@ -66,7 +74,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) foreach (var (originalLocalDeclaration, originalLambda) in localDeclarationToLambda.OrderByDescending(kvp => kvp.Value.SpanStart)) { var delegateType = (INamedTypeSymbol)semanticModel.GetTypeInfo(originalLambda, cancellationToken).ConvertedType; - var parameterList = GenerateParameterList(semanticModel, originalLambda, cancellationToken); + var parameterList = GenerateParameterList(semanticModel, originalLambda, delegateType, cancellationToken); var currentLocalDeclaration = currentRoot.GetCurrentNode(originalLocalDeclaration); var currentLambda = currentRoot.GetCurrentNode(originalLambda); @@ -74,7 +82,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) currentRoot = ReplaceAnonymousWithLocalFunction( document.Project.Solution.Workspace, currentRoot, currentLocalDeclaration, currentLambda, - delegateType, parameterList, + delegateType, parameterList, explicitInvokeCalls.Select(node => currentRoot.GetCurrentNode(node)).ToImmutableArray(), cancellationToken); } @@ -85,6 +93,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) Workspace workspace, SyntaxNode currentRoot, LocalDeclarationStatementSyntax localDeclaration, LambdaExpressionSyntax lambda, INamedTypeSymbol delegateType, ParameterListSyntax parameterList, + ImmutableArray explicitInvokeCalls, CancellationToken cancellationToken) { var newLocalFunctionStatement = CreateLocalFunctionStatement( @@ -104,6 +113,13 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) editor.RemoveNode(lambdaStatement); } + foreach (var usage in explicitInvokeCalls) + { + editor.ReplaceNode( + usage.Parent, + (usage.Parent as InvocationExpressionSyntax).WithExpression(usage.Expression).WithTriviaFrom(usage.Parent)); + } + return editor.GetChangedRoot(); } @@ -144,33 +160,40 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) } private ParameterListSyntax GenerateParameterList( - SemanticModel semanticModel, AnonymousFunctionExpressionSyntax anonymousFunction, CancellationToken cancellationToken) + SemanticModel semanticModel, AnonymousFunctionExpressionSyntax anonymousFunction, INamedTypeSymbol delegateType, CancellationToken cancellationToken) { switch (anonymousFunction) { case SimpleLambdaExpressionSyntax simpleLambda: - return GenerateSimpleLambdaParameterList(semanticModel, simpleLambda, cancellationToken); + return GenerateSimpleLambdaParameterList(semanticModel, simpleLambda, delegateType.DelegateInvokeMethod, cancellationToken); case ParenthesizedLambdaExpressionSyntax parenthesizedLambda: - return GenerateParenthesizedLambdaParameterList(semanticModel, parenthesizedLambda, cancellationToken); + return GenerateParenthesizedLambdaParameterList(semanticModel, parenthesizedLambda, delegateType.DelegateInvokeMethod, cancellationToken); default: throw ExceptionUtilities.UnexpectedValue(anonymousFunction); } } private ParameterListSyntax GenerateSimpleLambdaParameterList( - SemanticModel semanticModel, SimpleLambdaExpressionSyntax lambdaExpression, CancellationToken cancellationToken) + SemanticModel semanticModel, SimpleLambdaExpressionSyntax lambdaExpression, IMethodSymbol delegateInvokeMethod, CancellationToken cancellationToken) { var parameter = semanticModel.GetDeclaredSymbol(lambdaExpression.Parameter, cancellationToken); var type = parameter?.Type.GenerateTypeSyntax() ?? s_objectType; + var parameterSyntax = SyntaxFactory.Parameter(lambdaExpression.Parameter.Identifier).WithType(type); + var param = delegateInvokeMethod.Parameters[0]; + if (param.HasExplicitDefaultValue) + { + parameterSyntax = parameterSyntax.WithDefault(GetDefaultValue(param)); + } + return SyntaxFactory.ParameterList( - SyntaxFactory.SeparatedList().Add( - SyntaxFactory.Parameter(lambdaExpression.Parameter.Identifier).WithType(type))); + SyntaxFactory.SeparatedList().Add(parameterSyntax)); } private ParameterListSyntax GenerateParenthesizedLambdaParameterList( - SemanticModel semanticModel, ParenthesizedLambdaExpressionSyntax lambdaExpression, CancellationToken cancellationToken) + SemanticModel semanticModel, ParenthesizedLambdaExpressionSyntax lambdaExpression, IMethodSymbol delegateInvokeMethod, CancellationToken cancellationToken) { + int i = 0; return lambdaExpression.ParameterList.ReplaceNodes( lambdaExpression.ParameterList.Parameters, (parameterNode, _) => @@ -181,10 +204,20 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) } var parameter = semanticModel.GetDeclaredSymbol(parameterNode, cancellationToken); - return parameterNode.WithType(parameter?.Type.GenerateTypeSyntax() ?? s_objectType); + parameterNode = parameterNode.WithType(parameter?.Type.GenerateTypeSyntax() ?? s_objectType); + var param = delegateInvokeMethod.Parameters[i++]; + if (param.HasExplicitDefaultValue) + { + parameterNode = parameterNode.WithDefault(GetDefaultValue(param)); + } + + return parameterNode; }); } + private static EqualsValueClauseSyntax GetDefaultValue(IParameterSymbol parameter) + => SyntaxFactory.EqualsValueClause(ExpressionGenerator.GenerateExpression(parameter.Type, parameter.ExplicitDefaultValue, canUseFieldReference: true)); + private class MyCodeAction : CodeAction.DocumentChangeAction { public MyCodeAction(Func> createChangedDocument) diff --git a/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionDiagnosticAnalyzer.cs b/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionDiagnosticAnalyzer.cs index 2fa84f56c749e1b21435064b5f013e48201cb6be..726fd7d45173b2ca37ccd48a9bc162ae15746761 100644 --- a/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionDiagnosticAnalyzer.cs +++ b/src/Features/CSharp/Portable/UseLocalFunction/CSharpUseLocalFunctionDiagnosticAnalyzer.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Collections.Immutable; using System.Linq.Expressions; using System.Threading; @@ -9,6 +10,7 @@ using Microsoft.CodeAnalysis.CSharp.Extensions; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.PooledObjects; using Microsoft.CodeAnalysis.Shared.Extensions; namespace Microsoft.CodeAnalysis.CSharp.UseLocalFunction @@ -116,7 +118,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp return; } - if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, cancellationToken)) + if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, out var explicitInvokeCallLocations, cancellationToken)) { return; } @@ -126,6 +128,8 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp localDeclaration.GetLocation(), anonymousFunction.GetLocation()); + additionalLocations = additionalLocations.AddRange(explicitInvokeCallLocations); + if (severity != DiagnosticSeverity.Hidden) { // If the diagnostic is not hidden, then just place the user visible part @@ -197,10 +201,12 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp private bool CanReplaceAnonymousWithLocalFunction( SemanticModel semanticModel, INamedTypeSymbol expressionTypeOpt, ISymbol local, BlockSyntax block, - AnonymousFunctionExpressionSyntax anonymousFunction, CancellationToken cancellationToken) + AnonymousFunctionExpressionSyntax anonymousFunction, out ImmutableArray explicitInvokeCallLocations, CancellationToken cancellationToken) { // Check all the references to the anonymous function and disallow the conversion if // they're used in certain ways. + var explicitInvokeCalls = ArrayBuilder.GetInstance(); + explicitInvokeCallLocations = ImmutableArray.Empty; var anonymousFunctionStart = anonymousFunction.SpanStart; foreach (var descendentNode in block.DescendantNodes()) { @@ -232,11 +238,18 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp return false; } - if (nodeToCheck.Parent is MemberAccessExpressionSyntax) + if (nodeToCheck.Parent is MemberAccessExpressionSyntax memberAccessExpression) { - // They're doing something like "del.ToString()". Can't do this with a - // local function. - return false; + if (memberAccessExpression.Name.Identifier.Text != WellKnownMemberNames.DelegateInvokeName) + { + // They're doing something like "del.ToString()". Can't do this with a + // local function. + return false; + } + else + { + explicitInvokeCalls.Add(memberAccessExpression.GetLocation()); + } } var convertedType = semanticModel.GetTypeInfo(nodeToCheck, cancellationToken).ConvertedType; @@ -258,6 +271,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp } } + explicitInvokeCallLocations = explicitInvokeCalls.ToImmutableAndFree(); return true; }