未验证 提交 a7294b31 编写于 作者: S Sam Harwell 提交者: GitHub

Merge pull request #30470 from Neme12/useLocalFunctionSemantics

Fix for UseLocalFunction breaking semantics
......@@ -3100,6 +3100,213 @@ public void Caller()
U local(U x) => x;
Callee(local);
}
}");
}
[WorkItem(26526, "https://github.com/dotnet/roslyn/issues/26526")]
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithCastIntroducedIfAssignedToVar()
{
await TestInRegularAndScript1Async(
@"using System;
class C
{
void M()
{
Func<string> [||]f = () => null;
var f2 = f;
}
}",
@"using System;
class C
{
void M()
{
string f() => null;
var f2 = (Func<string>)f;
}
}");
}
[WorkItem(26526, "https://github.com/dotnet/roslyn/issues/26526")]
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithCastIntroducedForGenericTypeInference1()
{
await TestInRegularAndScript1Async(
@"using System;
class C
{
void M()
{
Func<int, string> [||]f = _ => null;
Method(f);
}
void Method<T>(Func<T, string> o)
{
}
}",
@"using System;
class C
{
void M()
{
string f(int _) => null;
Method((Func<int, string>)f);
}
void Method<T>(Func<T, string> o)
{
}
}");
}
[WorkItem(26526, "https://github.com/dotnet/roslyn/issues/26526")]
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithCastIntroducedForGenericTypeInference2()
{
await TestInRegularAndScript1Async(
@"using System;
class C
{
void M()
{
Func<int, string> [||]f = _ => null;
Method(f);
}
void Method<T>(Func<T, string> o)
{
}
void Method(string o)
{
}
}",
@"using System;
class C
{
void M()
{
string f(int _) => null;
Method((Func<int, string>)f);
}
void Method<T>(Func<T, string> o)
{
}
void Method(string o)
{
}
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithCastIntroducedForOverloadResolution()
{
await TestInRegularAndScript1Async(
@"using System;
delegate string CustomDelegate();
class C
{
void M()
{
Func<string> [||]f = () => null;
Method(f);
}
void Method(Func<string> o)
{
}
void Method(CustomDelegate o)
{
}
}",
@"using System;
delegate string CustomDelegate();
class C
{
void M()
{
string f() => null;
Method((Func<string>)f);
}
void Method(Func<string> o)
{
}
void Method(CustomDelegate o)
{
}
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithoutCastIfUnnecessaryForOverloadResolution()
{
await TestInRegularAndScript1Async(
@"using System;
delegate string CustomDelegate(object arg);
class C
{
void M()
{
Func<string> [||]f = () => null;
Method(f);
}
void Method(Func<string> o)
{
}
void Method(CustomDelegate o)
{
}
}",
@"using System;
delegate string CustomDelegate(object arg);
class C
{
void M()
{
string f() => null;
Method(f);
}
void Method(Func<string> o)
{
}
void Method(CustomDelegate o)
{
}
}");
}
}
......
......@@ -49,7 +49,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
var nodesFromDiagnostics = new List<(
LocalDeclarationStatementSyntax declaration,
AnonymousFunctionExpressionSyntax function,
List<InvocationExpressionSyntax> invocations)>(diagnostics.Length);
List<ExpressionSyntax> references)>(diagnostics.Length);
var nodesToTrack = new HashSet<SyntaxNode>();
......@@ -58,18 +58,18 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
var localDeclaration = (LocalDeclarationStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(cancellationToken);
var anonymousFunction = (AnonymousFunctionExpressionSyntax)diagnostic.AdditionalLocations[1].FindNode(cancellationToken);
var invocations = new List<InvocationExpressionSyntax>(diagnostic.AdditionalLocations.Count - 2);
var references = new List<ExpressionSyntax>(diagnostic.AdditionalLocations.Count - 2);
for (var i = 2; i < diagnostic.AdditionalLocations.Count; i++)
{
invocations.Add((InvocationExpressionSyntax)diagnostic.AdditionalLocations[i].FindNode(getInnermostNodeForTie: true, cancellationToken));
references.Add((ExpressionSyntax)diagnostic.AdditionalLocations[i].FindNode(getInnermostNodeForTie: true, cancellationToken));
}
nodesFromDiagnostics.Add((localDeclaration, anonymousFunction, invocations));
nodesFromDiagnostics.Add((localDeclaration, anonymousFunction, references));
nodesToTrack.Add(localDeclaration);
nodesToTrack.Add(anonymousFunction);
nodesToTrack.AddRange(invocations);
nodesToTrack.AddRange(references);
}
var root = editor.OriginalRoot;
......@@ -77,7 +77,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
// Process declarations in reverse order so that we see the effects of nested
// declarations befor processing the outer decls.
foreach (var (localDeclaration, anonymousFunction, invocations) in nodesFromDiagnostics.OrderByDescending(nodes => nodes.function.SpanStart))
foreach (var (localDeclaration, anonymousFunction, references) in nodesFromDiagnostics.OrderByDescending(nodes => nodes.function.SpanStart))
{
var delegateType = (INamedTypeSymbol)semanticModel.GetTypeInfo(anonymousFunction, cancellationToken).ConvertedType;
var parameterList = GenerateParameterList(anonymousFunction, delegateType.DelegateInvokeMethod);
......@@ -91,10 +91,10 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
delegateType.DelegateInvokeMethod, parameterList);
// these invocations might actually be inside the local function! so we have to do this separately
currentRoot = ReplaceInvocations(
document.Project.Solution.Workspace, currentRoot,
delegateType.DelegateInvokeMethod, parameterList,
invocations.Select(node => currentRoot.GetCurrentNode(node)).ToImmutableArray());
currentRoot = ReplaceReferences(
document, currentRoot,
delegateType, parameterList,
references.Select(node => currentRoot.GetCurrentNode(node)).ToImmutableArray());
}
editor.ReplaceNode(root, currentRoot);
......@@ -123,18 +123,25 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
return editor.GetChangedRoot();
}
private static SyntaxNode ReplaceInvocations(
Workspace workspace, SyntaxNode currentRoot,
IMethodSymbol delegateMethod, ParameterListSyntax parameterList,
ImmutableArray<InvocationExpressionSyntax> invocations)
private static SyntaxNode ReplaceReferences(
Document document, SyntaxNode currentRoot,
INamedTypeSymbol delegateType, ParameterListSyntax parameterList,
ImmutableArray<ExpressionSyntax> references)
{
return currentRoot.ReplaceNodes(invocations, (_ /* nested invocations! */, invocation) =>
return currentRoot.ReplaceNodes(references, (_ /* nested invocations! */, reference) =>
{
var directInvocation = invocation.Expression is MemberAccessExpressionSyntax memberAccessExpression // it's a .Invoke call
? invocation.WithExpression(memberAccessExpression.Expression).WithTriviaFrom(invocation) // remove it
: invocation;
if (reference is InvocationExpressionSyntax invocation)
{
var directInvocation = invocation.Expression is MemberAccessExpressionSyntax memberAccess // it's a .Invoke call
? invocation.WithExpression(memberAccess.Expression).WithTriviaFrom(invocation) // remove it
: invocation;
return WithNewParameterNames(directInvocation, delegateType.DelegateInvokeMethod, parameterList);
}
return WithNewParameterNames(directInvocation, delegateMethod, parameterList);
// It's not an invocation. Wrap the identifier in a cast (which will be remove by the simplifier if unnecessary)
// to ensure we preserve semantics in cases like overload resolution or generic type inference.
return SyntaxGenerator.GetGenerator(document).CastExpression(delegateType, reference);
});
}
......
......@@ -120,7 +120,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
return;
}
if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, out var invocationLocations, cancellationToken))
if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, out var referenceLocations, cancellationToken))
{
return;
}
......@@ -130,7 +130,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
localDeclaration.GetLocation(),
anonymousFunction.GetLocation());
additionalLocations = additionalLocations.AddRange(invocationLocations);
additionalLocations = additionalLocations.AddRange(referenceLocations);
if (severity.WithDefaultSeverity(DiagnosticSeverity.Hidden) < ReportDiagnostic.Hidden)
{
......@@ -209,12 +209,12 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
private bool CanReplaceAnonymousWithLocalFunction(
SemanticModel semanticModel, INamedTypeSymbol expressionTypeOpt, ISymbol local, BlockSyntax block,
AnonymousFunctionExpressionSyntax anonymousFunction, out ImmutableArray<Location> invocationLocations, CancellationToken cancellationToken)
AnonymousFunctionExpressionSyntax anonymousFunction, out ImmutableArray<Location> referenceLocations, CancellationToken cancellationToken)
{
// Check all the references to the anonymous function and disallow the conversion if
// they're used in certain ways.
var invocations = ArrayBuilder<Location>.GetInstance();
invocationLocations = ImmutableArray<Location>.Empty;
var references = ArrayBuilder<Location>.GetInstance();
referenceLocations = ImmutableArray<Location>.Empty;
var anonymousFunctionStart = anonymousFunction.SpanStart;
foreach (var descendentNode in block.DescendantNodes())
{
......@@ -248,14 +248,14 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
if (nodeToCheck.Parent is InvocationExpressionSyntax invocationExpression)
{
invocations.Add(invocationExpression.GetLocation());
references.Add(invocationExpression.GetLocation());
}
else if (nodeToCheck.Parent is MemberAccessExpressionSyntax memberAccessExpression)
{
if (memberAccessExpression.Parent is InvocationExpressionSyntax explicitInvocationExpression &&
memberAccessExpression.Name.Identifier.ValueText == WellKnownMemberNames.DelegateInvokeName)
{
invocations.Add(explicitInvocationExpression.GetLocation());
references.Add(explicitInvocationExpression.GetLocation());
}
else
{
......@@ -264,6 +264,10 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
return false;
}
}
else
{
references.Add(nodeToCheck.GetLocation());
}
var convertedType = semanticModel.GetTypeInfo(nodeToCheck, cancellationToken).ConvertedType;
if (!convertedType.IsDelegateType())
......@@ -284,7 +288,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
}
}
invocationLocations = invocations.ToImmutableAndFree();
referenceLocations = references.ToImmutableAndFree();
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册