提交 45cb3c46 编写于 作者: V Victor Z

Enabled 'Use Local Function' refactoring for delegate which is invoked with Invoke

上级 715aa4a6
......@@ -1436,25 +1436,41 @@ public void Caller()
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
[WorkItem(22672, "https://github.com/dotnet/roslyn/issues/22672")]
public async Task TestMissingIfUsedInMemberAccess()
[WorkItem(23150, "https://github.com/dotnet/roslyn/issues/23150")]
public async Task TestWithInvokeMethod()
{
await TestMissingAsync(
await TestInRegularAndScript1Async(
@"using System;
class Enclosing<T> where T : class
{
delegate T MyDelegate(T t = null);
delegate T MyDelegate(T t = null);
public class Class
{
public void Caller()
public class Class
{
MyDelegate [||]local = x => x;
public void Caller()
{
MyDelegate [||]local = x => x;
local.Invoke();
local.Invoke();
}
}
}",
@"using System;
class Enclosing<T> where T : class
{
delegate T MyDelegate(T t = null);
public class Class
{
public void Caller()
{
T local(T x = null) => x;
local();
}
}
}
}");
}
......
......@@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
......@@ -47,6 +48,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
var localDeclarationToLambda = new Dictionary<LocalDeclarationStatementSyntax, LambdaExpressionSyntax>();
var nodesToTrack = new HashSet<SyntaxNode>();
var memberAccess = new List<MemberAccessExpressionSyntax>();
foreach (var diagnostic in diagnostics)
{
var localDeclaration = (LocalDeclarationStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(cancellationToken);
......@@ -56,8 +58,17 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
nodesToTrack.Add(localDeclaration);
nodesToTrack.Add(lambda);
if (diagnostic.AdditionalLocations.Count > 2)
{
for (var i = 2; i < diagnostic.AdditionalLocations.Count; i++)
{
memberAccess.Add((MemberAccessExpressionSyntax)diagnostic.AdditionalLocations[i].FindNode(cancellationToken));
}
}
}
nodesToTrack.AddRange(memberAccess);
var root = editor.OriginalRoot;
var currentRoot = root.TrackNodes(nodesToTrack);
......@@ -66,7 +77,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
foreach (var (originalLocalDeclaration, originalLambda) in localDeclarationToLambda.OrderByDescending(kvp => kvp.Value.SpanStart))
{
var delegateType = (INamedTypeSymbol)semanticModel.GetTypeInfo(originalLambda, cancellationToken).ConvertedType;
var parameterList = GenerateParameterList(semanticModel, originalLambda, cancellationToken);
var parameterList = GenerateParameterList(semanticModel, originalLambda, delegateType, cancellationToken);
var currentLocalDeclaration = currentRoot.GetCurrentNode(originalLocalDeclaration);
var currentLambda = currentRoot.GetCurrentNode(originalLambda);
......@@ -74,7 +85,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
currentRoot = ReplaceAnonymousWithLocalFunction(
document.Project.Solution.Workspace, currentRoot,
currentLocalDeclaration, currentLambda,
delegateType, parameterList,
delegateType, parameterList, memberAccess.Select(node => currentRoot.GetCurrentNode(node)).ToImmutableArray(),
cancellationToken);
}
......@@ -85,6 +96,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
Workspace workspace, SyntaxNode currentRoot,
LocalDeclarationStatementSyntax localDeclaration, LambdaExpressionSyntax lambda,
INamedTypeSymbol delegateType, ParameterListSyntax parameterList,
ImmutableArray<MemberAccessExpressionSyntax> notDirectInvocationUsages,
CancellationToken cancellationToken)
{
var newLocalFunctionStatement = CreateLocalFunctionStatement(
......@@ -104,6 +116,16 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
editor.RemoveNode(lambdaStatement);
}
if (!notDirectInvocationUsages.IsEmpty)
{
foreach (var usage in notDirectInvocationUsages)
{
editor.ReplaceNode(
usage.Parent,
SyntaxFactory.InvocationExpression(usage.Expression, (usage.Parent as InvocationExpressionSyntax).ArgumentList));
}
}
return editor.GetChangedRoot();
}
......@@ -144,33 +166,40 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
}
private ParameterListSyntax GenerateParameterList(
SemanticModel semanticModel, AnonymousFunctionExpressionSyntax anonymousFunction, CancellationToken cancellationToken)
SemanticModel semanticModel, AnonymousFunctionExpressionSyntax anonymousFunction, INamedTypeSymbol namedTypeSymbol, CancellationToken cancellationToken)
{
switch (anonymousFunction)
{
case SimpleLambdaExpressionSyntax simpleLambda:
return GenerateSimpleLambdaParameterList(semanticModel, simpleLambda, cancellationToken);
return GenerateSimpleLambdaParameterList(semanticModel, simpleLambda, namedTypeSymbol, cancellationToken);
case ParenthesizedLambdaExpressionSyntax parenthesizedLambda:
return GenerateParenthesizedLambdaParameterList(semanticModel, parenthesizedLambda, cancellationToken);
return GenerateParenthesizedLambdaParameterList(semanticModel, parenthesizedLambda, namedTypeSymbol, cancellationToken);
default:
throw ExceptionUtilities.UnexpectedValue(anonymousFunction);
}
}
private ParameterListSyntax GenerateSimpleLambdaParameterList(
SemanticModel semanticModel, SimpleLambdaExpressionSyntax lambdaExpression, CancellationToken cancellationToken)
SemanticModel semanticModel, SimpleLambdaExpressionSyntax lambdaExpression, INamedTypeSymbol namedTypeSymbol, CancellationToken cancellationToken)
{
var parameter = semanticModel.GetDeclaredSymbol(lambdaExpression.Parameter, cancellationToken);
var type = parameter?.Type.GenerateTypeSyntax() ?? s_objectType;
var parameterSyntax = SyntaxFactory.Parameter(lambdaExpression.Parameter.Identifier).WithType(type);
var param = namedTypeSymbol.DelegateInvokeMethod.Parameters[0];
if (param.HasExplicitDefaultValue)
{
parameterSyntax = parameterSyntax.WithDefault(GetDefaultValue(param));
}
return SyntaxFactory.ParameterList(
SyntaxFactory.SeparatedList<ParameterSyntax>().Add(
SyntaxFactory.Parameter(lambdaExpression.Parameter.Identifier).WithType(type)));
SyntaxFactory.SeparatedList<ParameterSyntax>().Add(parameterSyntax));
}
private ParameterListSyntax GenerateParenthesizedLambdaParameterList(
SemanticModel semanticModel, ParenthesizedLambdaExpressionSyntax lambdaExpression, CancellationToken cancellationToken)
SemanticModel semanticModel, ParenthesizedLambdaExpressionSyntax lambdaExpression, INamedTypeSymbol namedTypeSymbol, CancellationToken cancellationToken)
{
int i = 0;
return lambdaExpression.ParameterList.ReplaceNodes(
lambdaExpression.ParameterList.Parameters,
(parameterNode, _) =>
......@@ -181,10 +210,22 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
}
var parameter = semanticModel.GetDeclaredSymbol(parameterNode, cancellationToken);
return parameterNode.WithType(parameter?.Type.GenerateTypeSyntax() ?? s_objectType);
parameterNode = parameterNode.WithType(parameter?.Type.GenerateTypeSyntax() ?? s_objectType);
var param = namedTypeSymbol.DelegateInvokeMethod.Parameters[i++];
if (param.HasExplicitDefaultValue)
{
parameterNode = parameterNode.WithDefault(GetDefaultValue(param));
}
return parameterNode;
});
}
private static EqualsValueClauseSyntax GetDefaultValue(IParameterSymbol parameter)
{
return SyntaxFactory.EqualsValueClause(ExpressionGenerator.GenerateExpression(parameter.Type, parameter.ExplicitDefaultValue, canUseFieldReference: true));
}
private class MyCodeAction : CodeAction.DocumentChangeAction
{
public MyCodeAction(Func<CancellationToken, Task<Document>> createChangedDocument)
......
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq.Expressions;
using System.Threading;
......@@ -116,7 +117,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
return;
}
if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, cancellationToken))
if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, out var locations, cancellationToken))
{
return;
}
......@@ -126,6 +127,11 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
localDeclaration.GetLocation(),
anonymousFunction.GetLocation());
if (!locations.IsEmpty)
{
additionalLocations = additionalLocations.AddRange(locations);
}
if (severity != DiagnosticSeverity.Hidden)
{
// If the diagnostic is not hidden, then just place the user visible part
......@@ -197,10 +203,12 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
private bool CanReplaceAnonymousWithLocalFunction(
SemanticModel semanticModel, INamedTypeSymbol expressionTypeOpt, ISymbol local, BlockSyntax block,
AnonymousFunctionExpressionSyntax anonymousFunction, CancellationToken cancellationToken)
AnonymousFunctionExpressionSyntax anonymousFunction, out ImmutableArray<Location> locations, CancellationToken cancellationToken)
{
// Check all the references to the anonymous function and disallow the conversion if
// they're used in certain ways.
List<Location> delegateInvokeMethodLocations = null;
locations = ImmutableArray<Location>.Empty;
var anonymousFunctionStart = anonymousFunction.SpanStart;
foreach (var descendentNode in block.DescendantNodes())
{
......@@ -232,11 +240,19 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
return false;
}
if (nodeToCheck.Parent is MemberAccessExpressionSyntax)
if (nodeToCheck.Parent is MemberAccessExpressionSyntax memberAccessExpression)
{
// They're doing something like "del.ToString()". Can't do this with a
// local function.
return false;
if (memberAccessExpression.Name.Identifier.Text != WellKnownMemberNames.DelegateInvokeName)
{
// They're doing something like "del.ToString()". Can't do this with a
// local function.
return false;
}
else
{
delegateInvokeMethodLocations = delegateInvokeMethodLocations ?? new List<Location>();
delegateInvokeMethodLocations.Add(memberAccessExpression.GetLocation());
}
}
var convertedType = semanticModel.GetTypeInfo(nodeToCheck, cancellationToken).ConvertedType;
......@@ -258,6 +274,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
}
}
locations = delegateInvokeMethodLocations.AsImmutableOrEmpty();
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册