提交 47c9e23e 编写于 作者: A Alireza Habibi

Add a code refactoring to convert local function to method

上级 2715ecbf
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.CodeRefactorings;
using Microsoft.CodeAnalysis.CSharp.CodeRefactorings.ConvertLocalFunctionToMethod;
using Roslyn.Test.Utilities;
using Xunit;
namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.CodeActions.ConvertLocalFunctionToMethod
{
public class ConvertLocalFunctionToMethodTests : AbstractCSharpCodeActionTest
{
protected override CodeRefactoringProvider CreateCodeRefactoringProvider(Workspace workspace, TestParameters parameters)
=> new CSharpConvertLocalFunctionToMethodCodeRefactoringProvider();
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)]
public async Task TestCaptures()
{
await TestInRegularAndScriptAsync(
@"class C
{
static void Use<T>(T a) {}
static void Use<T>(ref T a) {}
static void LocalFunction() {} // trigger rename
void M<T1, T2>(T1 param1, T2 param2)
where T1 : struct
where T2 : struct
{
var local1 = 0;
var local2 = 0;
void [||]LocalFunction()
{
Use(param1);
Use(ref param2);
Use(local1);
Use(ref local2);
Use(this);
}
LocalFunction();
}
}",
@"class C
{
static void Use<T>(T a) {}
static void Use<T>(ref T a) {}
static void LocalFunction() {} // trigger rename
void M<T1, T2>(T1 param1, T2 param2)
where T1 : struct
where T2 : struct
{
var local1 = 0;
var local2 = 0;
LocalFunction1<T1, T2>(param1, ref param2, local1, ref local2);
}
private void LocalFunction1<T1, T2>(T1 param1, ref T2 param2, int local1, ref int local2)
where T1 : struct
where T2 : struct
{
Use(param1);
Use(ref param2);
Use(local1);
Use(ref local2);
Use(this);
}
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)]
public async Task TestTypeParameters()
{
await TestInRegularAndScriptAsync(
@"class C<T0>
{
static void LocalFunction() {} // trigger rename
void M<T1, T2>(int i)
where T1 : struct
{
void Local1<T3, T4>()
where T4 : struct
{
void [||]LocalFunction<T5, T6>(T5 a, T6 b)
where T5 : struct
{
_ = typeof(T2);
_ = typeof(T4);
}
LocalFunction<byte, int>(5, 6);
LocalFunction(5, 6);
}
}
}",
@"class C<T0>
{
static void LocalFunction() {} // trigger rename
void M<T1, T2>(int i)
where T1 : struct
{
void Local1<T3, T4>()
where T4 : struct
{
LocalFunction1<T2, T4, byte, int>(5, 6);
LocalFunction1<T2, T4, int, int>(5, 6);
}
}
private static void LocalFunction1<T2, T4, T5, T6>(T5 a, T6 b)
where T4 : struct
where T5 : struct
{
_ = typeof(T2);
_ = typeof(T4);
}
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)]
public async Task TestNameConflict()
{
await TestInRegularAndScriptAsync(
@"class C
{
void LocalFunction() {} // trigger rename
void M()
{
void [||]LocalFunction() => M();
LocalFunction();
System.Action x = LocalFunction;
}
}",
@"class C
{
void LocalFunction() {} // trigger rename
void M()
{
LocalFunction1();
System.Action x = LocalFunction1;
}
private void LocalFunction1() => M();
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)]
public async Task TestNamedArguments1()
{
await TestAsync(
@"class C
{
void LocalFunction() {} // trigger rename
void M()
{
int var = 2;
int [||]LocalFunction(int i)
{
return var;
}
LocalFunction(i: 0);
}
}",
@"class C
{
void LocalFunction() {} // trigger rename
void M()
{
int var = 2;
LocalFunction1(i: 0, var);
}
private static int LocalFunction1(int i, int var)
{
return var;
}
}", parseOptions: CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.CSharp7_2));
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)]
public async Task TestNamedArguments2()
{
await TestAsync(
@"class C
{
void LocalFunction() {} // trigger rename
void M()
{
int var = 2;
int [||]LocalFunction(int i)
{
return var;
}
LocalFunction(i: 0);
}
}",
@"class C
{
void LocalFunction() {} // trigger rename
void M()
{
int var = 2;
LocalFunction1(i: 0, var: var);
}
private static int LocalFunction1(int i, int var)
{
return var;
}
}", parseOptions: CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.CSharp7));
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)]
public async Task TestCaretPositon()
{
await TestAsync("C [||]LocalFunction(C c)");
await TestAsync("C Local[||]Function(C c)");
await TestAsync("C [|LocalFunction|](C c)");
await TestAsync("C LocalFunction[||](C c)");
await TestMissingAsync("C Local[|Function|](C c)");
await TestMissingAsync("[||]C LocalFunction(C c)");
await TestMissingAsync("[|C|] LocalFunction(C c)");
await TestMissingAsync("C[||] LocalFunction(C c)");
await TestMissingAsync("C LocalFunction([||]C c)");
await TestMissingAsync("C LocalFunction(C [||]c)");
async Task TestAsync(string signature)
{
await TestInRegularAndScriptAsync(
$@"class C
{{
void M()
{{
{signature}
{{
return null;
}}
}}
}}",
@"class C
{
void M()
{
}
private static C LocalFunction(C c)
{
return null;
}
}");
}
async Task TestMissingAsync(string signature)
{
await this.TestMissingAsync(
$@"class C
{{
void M()
{{
{signature}
{{
return null;
}}
}}
}}");
}
}
}
}
......@@ -314,6 +314,15 @@ internal class CSharpFeaturesResources {
}
}
/// <summary>
/// Looks up a localized string similar to Convert to regular method.
/// </summary>
internal static string Convert_to_regular_method {
get {
return ResourceManager.GetString("Convert_to_regular_method", resourceCulture);
}
}
/// <summary>
/// Looks up a localized string similar to deconstruction.
/// </summary>
......
......@@ -527,4 +527,7 @@
<data name="Convert_to_for" xml:space="preserve">
<value>Convert to 'for'</value>
</data>
<data name="Convert_to_regular_method" xml:space="preserve">
<value>Convert to regular method</value>
</data>
</root>
\ No newline at end of file
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.CSharp.CodeGeneration;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CSharp.CodeRefactorings.ConvertLocalFunctionToMethod
{
[ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = nameof(CSharpConvertLocalFunctionToMethodCodeRefactoringProvider)), Shared]
internal sealed class CSharpConvertLocalFunctionToMethodCodeRefactoringProvider : CodeRefactoringProvider
{
public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
{
var document = context.Document;
if (document.Project.Solution.Workspace.Kind == WorkspaceKind.MiscellaneousFiles)
{
return;
}
var cancellationToken = context.CancellationToken;
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var identifier = await root.SyntaxTree.GetTouchingTokenAsync(context.Span.Start,
token => token.Parent is LocalFunctionStatementSyntax, cancellationToken).ConfigureAwait(false);
if (identifier == default)
{
return;
}
var localFunction = (LocalFunctionStatementSyntax)identifier.Parent;
if (localFunction.Identifier != identifier)
{
return;
}
if (context.Span.Length > 0 &&
context.Span != identifier.Span)
{
return;
}
if (localFunction.ContainsDiagnostics)
{
return;
}
if (!localFunction.Parent.IsKind(SyntaxKind.Block, out BlockSyntax parentBlock))
{
return;
}
context.RegisterRefactoring(new MyCodeAction(c => UpdateDocumentAsync(root, document, parentBlock, localFunction, c)));
}
private static async Task<Document> UpdateDocumentAsync(
SyntaxNode root,
Document document,
BlockSyntax parentBlock,
LocalFunctionStatementSyntax localFunction,
CancellationToken cancellationToken)
{
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var declaredSymbol = (IMethodSymbol)semanticModel.GetDeclaredSymbol(localFunction, cancellationToken);
var dataFlow = semanticModel.AnalyzeDataFlow(
localFunction.Body ?? (SyntaxNode)localFunction.ExpressionBody.Expression);
var captures = dataFlow.Captured;
var capturesAsParameters = captures
.Where(capture => !capture.IsThisParameter())
.Select(capture => CodeGenerationSymbolFactory.CreateParameterSymbol(
attributes: default,
refKind: dataFlow.WrittenInside.Contains(capture) ? RefKind.Ref : RefKind.None,
isParams: false,
type: capture.GetSymbolType(),
name: capture.Name)).ToList();
var typeParameters = new List<ITypeParameterSymbol>();
GetCapturedTypeParameters(declaredSymbol, typeParameters);
// We explicitly preserve captures' types, in case they were not spelt out in the function body
var captureTypes = captures.Select(capture => capture.GetSymbolType()).OfType<ITypeParameterSymbol>();
RemoveUnusedTypeParameters(localFunction, semanticModel, typeParameters, reservedTypeParameters: captureTypes);
var container = localFunction.GetAncestor<MemberDeclarationSyntax>();
var containerSymbol = semanticModel.GetDeclaredSymbol(container, cancellationToken);
var isStatic = containerSymbol.IsStatic || captures.All(capture => !capture.IsThisParameter());
var methodName = GenerateUniqueMethodName(declaredSymbol);
var methodSymbol = CodeGenerationSymbolFactory.CreateMethodSymbol(
containingType: declaredSymbol.ContainingType,
attributes: default,
accessibility: Accessibility.Private,
modifiers: isStatic ? DeclarationModifiers.Static : default,
returnType: declaredSymbol.ReturnType,
refKind: default,
explicitInterfaceImplementations: default,
name: methodName,
typeParameters: typeParameters.ToImmutableArray(),
parameters: declaredSymbol.Parameters.AddRange(capturesAsParameters));
var method = MethodGenerator.GenerateMethodDeclaration(methodSymbol, CodeGenerationDestination.Unspecified,
document.Project.Solution.Workspace, CodeGenerationOptions.Default, root.SyntaxTree.Options);
method = WithBodyFrom(method, localFunction);
var generator = CSharpSyntaxGenerator.Instance;
var editor = new SyntaxEditor(root, generator);
editor.InsertAfter(container, method);
editor.RemoveNode(localFunction, SyntaxRemoveOptions.KeepNoTrivia);
var needsRename = methodName != declaredSymbol.Name;
var identifierToken = needsRename ? methodName.ToIdentifierToken() : default;
var supportsNonTrailing = SupportsNonTrailingNamedArguments(root.SyntaxTree.Options);
var hasAdditionalArguments = !capturesAsParameters.IsEmpty();
var hasAdditionalTypeArguments = !typeParameters.IsEmpty();
var additionalTypeArguments = hasAdditionalTypeArguments
? typeParameters.Except(declaredSymbol.TypeParameters)
.Select(p => (TypeSyntax)p.Name.ToIdentifierName()).ToArray()
: null;
// Update callers' name, arguments and type arguments
foreach (var node in parentBlock.DescendantNodes())
{
// A local function reference can only be an identifier or a generic name.
switch (node.Kind())
{
case SyntaxKind.IdentifierName:
case SyntaxKind.GenericName:
break;
default:
continue;
}
// Using symbol to get type arguments, since it could be inferred and not present in the source
var symbol = semanticModel.GetSymbolInfo(node, cancellationToken).Symbol as IMethodSymbol;
if (symbol?.OriginalDefinition != declaredSymbol)
{
continue;
}
var currentNode = node;
if (currentNode.Parent.IsKind(SyntaxKind.InvocationExpression, out InvocationExpressionSyntax invocation))
{
if (hasAdditionalArguments)
{
var shouldUseNamedArguments =
!supportsNonTrailing && invocation.ArgumentList.Arguments.Any(arg => arg.NameColon != null);
var additionalArguments = capturesAsParameters.Select(parameter =>
(ArgumentSyntax)generator.Argument(
name: shouldUseNamedArguments ? parameter.Name : null,
refKind: parameter.RefKind,
expression: parameter.Name.ToIdentifierName())).ToArray();
editor.ReplaceNode(invocation.ArgumentList,
invocation.ArgumentList.AddArguments(additionalArguments));
}
if (hasAdditionalTypeArguments)
{
var existingTypeArguments = symbol.TypeArguments.Select(x => x.GenerateTypeSyntax());
// Prepend additional type arguments to preserve lexical order in which they are defined
var typeArguments = additionalTypeArguments.Concat(existingTypeArguments);
currentNode = generator.WithTypeArguments(currentNode, typeArguments);
}
}
if (needsRename)
{
currentNode = ((SimpleNameSyntax)currentNode).WithIdentifier(identifierToken);
}
editor.ReplaceNode(node, currentNode);
}
return document.WithSyntaxRoot(editor.GetChangedRoot());
}
private static bool SupportsNonTrailingNamedArguments(ParseOptions options)
=> ((CSharpParseOptions)options).LanguageVersion >= LanguageVersion.CSharp7_2;
private static MethodDeclarationSyntax WithBodyFrom(
MethodDeclarationSyntax method, LocalFunctionStatementSyntax localFunction)
{
return method
.WithExpressionBody(localFunction.ExpressionBody)
.WithSemicolonToken(localFunction.SemicolonToken)
.WithBody(localFunction.Body);
}
private static void GetCapturedTypeParameters(ISymbol symbol, List<ITypeParameterSymbol> typeParameters)
{
var containingSymbol = symbol.ContainingSymbol;
if (containingSymbol != null &&
containingSymbol.Kind != SymbolKind.NamedType)
{
GetCapturedTypeParameters(containingSymbol, typeParameters);
}
typeParameters.AddRange(symbol.GetTypeParameters());
}
private static void RemoveUnusedTypeParameters(
SyntaxNode localFunction,
SemanticModel semanticModel,
List<ITypeParameterSymbol> typeParameters,
IEnumerable<ITypeParameterSymbol> reservedTypeParameters)
{
var unusedTypeParameters = typeParameters.ToList();
foreach (var id in localFunction.DescendantNodes().OfType<IdentifierNameSyntax>())
{
var symbol = semanticModel.GetSymbolInfo(id).Symbol;
if (symbol != null && symbol.OriginalDefinition is ITypeParameterSymbol typeParameter)
{
unusedTypeParameters.Remove(typeParameter);
}
}
typeParameters.RemoveRange(unusedTypeParameters.Except(reservedTypeParameters));
}
private static string GenerateUniqueMethodName(ISymbol declaredSymbol)
{
return NameGenerator.EnsureUniqueness(
baseName: declaredSymbol.Name,
reservedNames: declaredSymbol.ContainingType.GetMembers().Select(m => m.Name));
}
private sealed class MyCodeAction : CodeActions.CodeAction.DocumentChangeAction
{
public MyCodeAction(Func<CancellationToken, Task<Document>> createChangedDocument)
: base(CSharpFeaturesResources.Convert_to_regular_method, createChangedDocument)
{
}
}
}
}
......@@ -48,6 +48,7 @@ public static class Features
public const string CodeActionsChangeToIEnumerable = "CodeActions.ChangeToIEnumerable";
public const string CodeActionsChangeToYield = "CodeActions.ChangeToYield";
public const string CodeActionsConvertNumericLiteral = "CodeActions.ConvertNumericLiteral";
public const string CodeActionsConvertLocalFunctionToMethod = "CodeActions.ConvertLocalFunctionToMethod";
public const string CodeActionsConvertToInterpolatedString = "CodeActions.ConvertToInterpolatedString";
public const string CodeActionsConvertToIterator = "CodeActions.ConvertToIterator";
public const string CodeActionsConvertForToForEach = "CodeActions.ConvertForToForEach";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册