未验证 提交 a7294b31 编写于 作者: S Sam Harwell 提交者: GitHub

Merge pull request #30470 from Neme12/useLocalFunctionSemantics

Fix for UseLocalFunction breaking semantics
...@@ -3100,6 +3100,213 @@ public void Caller() ...@@ -3100,6 +3100,213 @@ public void Caller()
U local(U x) => x; U local(U x) => x;
Callee(local); Callee(local);
} }
}");
}
[WorkItem(26526, "https://github.com/dotnet/roslyn/issues/26526")]
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithCastIntroducedIfAssignedToVar()
{
await TestInRegularAndScript1Async(
@"using System;
class C
{
void M()
{
Func<string> [||]f = () => null;
var f2 = f;
}
}",
@"using System;
class C
{
void M()
{
string f() => null;
var f2 = (Func<string>)f;
}
}");
}
[WorkItem(26526, "https://github.com/dotnet/roslyn/issues/26526")]
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithCastIntroducedForGenericTypeInference1()
{
await TestInRegularAndScript1Async(
@"using System;
class C
{
void M()
{
Func<int, string> [||]f = _ => null;
Method(f);
}
void Method<T>(Func<T, string> o)
{
}
}",
@"using System;
class C
{
void M()
{
string f(int _) => null;
Method((Func<int, string>)f);
}
void Method<T>(Func<T, string> o)
{
}
}");
}
[WorkItem(26526, "https://github.com/dotnet/roslyn/issues/26526")]
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithCastIntroducedForGenericTypeInference2()
{
await TestInRegularAndScript1Async(
@"using System;
class C
{
void M()
{
Func<int, string> [||]f = _ => null;
Method(f);
}
void Method<T>(Func<T, string> o)
{
}
void Method(string o)
{
}
}",
@"using System;
class C
{
void M()
{
string f(int _) => null;
Method((Func<int, string>)f);
}
void Method<T>(Func<T, string> o)
{
}
void Method(string o)
{
}
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithCastIntroducedForOverloadResolution()
{
await TestInRegularAndScript1Async(
@"using System;
delegate string CustomDelegate();
class C
{
void M()
{
Func<string> [||]f = () => null;
Method(f);
}
void Method(Func<string> o)
{
}
void Method(CustomDelegate o)
{
}
}",
@"using System;
delegate string CustomDelegate();
class C
{
void M()
{
string f() => null;
Method((Func<string>)f);
}
void Method(Func<string> o)
{
}
void Method(CustomDelegate o)
{
}
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseLocalFunction)]
public async Task TestAvailableWithoutCastIfUnnecessaryForOverloadResolution()
{
await TestInRegularAndScript1Async(
@"using System;
delegate string CustomDelegate(object arg);
class C
{
void M()
{
Func<string> [||]f = () => null;
Method(f);
}
void Method(Func<string> o)
{
}
void Method(CustomDelegate o)
{
}
}",
@"using System;
delegate string CustomDelegate(object arg);
class C
{
void M()
{
string f() => null;
Method(f);
}
void Method(Func<string> o)
{
}
void Method(CustomDelegate o)
{
}
}"); }");
} }
} }
......
...@@ -49,7 +49,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) ...@@ -49,7 +49,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
var nodesFromDiagnostics = new List<( var nodesFromDiagnostics = new List<(
LocalDeclarationStatementSyntax declaration, LocalDeclarationStatementSyntax declaration,
AnonymousFunctionExpressionSyntax function, AnonymousFunctionExpressionSyntax function,
List<InvocationExpressionSyntax> invocations)>(diagnostics.Length); List<ExpressionSyntax> references)>(diagnostics.Length);
var nodesToTrack = new HashSet<SyntaxNode>(); var nodesToTrack = new HashSet<SyntaxNode>();
...@@ -58,18 +58,18 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) ...@@ -58,18 +58,18 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
var localDeclaration = (LocalDeclarationStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(cancellationToken); var localDeclaration = (LocalDeclarationStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(cancellationToken);
var anonymousFunction = (AnonymousFunctionExpressionSyntax)diagnostic.AdditionalLocations[1].FindNode(cancellationToken); var anonymousFunction = (AnonymousFunctionExpressionSyntax)diagnostic.AdditionalLocations[1].FindNode(cancellationToken);
var invocations = new List<InvocationExpressionSyntax>(diagnostic.AdditionalLocations.Count - 2); var references = new List<ExpressionSyntax>(diagnostic.AdditionalLocations.Count - 2);
for (var i = 2; i < diagnostic.AdditionalLocations.Count; i++) for (var i = 2; i < diagnostic.AdditionalLocations.Count; i++)
{ {
invocations.Add((InvocationExpressionSyntax)diagnostic.AdditionalLocations[i].FindNode(getInnermostNodeForTie: true, cancellationToken)); references.Add((ExpressionSyntax)diagnostic.AdditionalLocations[i].FindNode(getInnermostNodeForTie: true, cancellationToken));
} }
nodesFromDiagnostics.Add((localDeclaration, anonymousFunction, invocations)); nodesFromDiagnostics.Add((localDeclaration, anonymousFunction, references));
nodesToTrack.Add(localDeclaration); nodesToTrack.Add(localDeclaration);
nodesToTrack.Add(anonymousFunction); nodesToTrack.Add(anonymousFunction);
nodesToTrack.AddRange(invocations); nodesToTrack.AddRange(references);
} }
var root = editor.OriginalRoot; var root = editor.OriginalRoot;
...@@ -77,7 +77,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) ...@@ -77,7 +77,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
// Process declarations in reverse order so that we see the effects of nested // Process declarations in reverse order so that we see the effects of nested
// declarations befor processing the outer decls. // declarations befor processing the outer decls.
foreach (var (localDeclaration, anonymousFunction, invocations) in nodesFromDiagnostics.OrderByDescending(nodes => nodes.function.SpanStart)) foreach (var (localDeclaration, anonymousFunction, references) in nodesFromDiagnostics.OrderByDescending(nodes => nodes.function.SpanStart))
{ {
var delegateType = (INamedTypeSymbol)semanticModel.GetTypeInfo(anonymousFunction, cancellationToken).ConvertedType; var delegateType = (INamedTypeSymbol)semanticModel.GetTypeInfo(anonymousFunction, cancellationToken).ConvertedType;
var parameterList = GenerateParameterList(anonymousFunction, delegateType.DelegateInvokeMethod); var parameterList = GenerateParameterList(anonymousFunction, delegateType.DelegateInvokeMethod);
...@@ -91,10 +91,10 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) ...@@ -91,10 +91,10 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
delegateType.DelegateInvokeMethod, parameterList); delegateType.DelegateInvokeMethod, parameterList);
// these invocations might actually be inside the local function! so we have to do this separately // these invocations might actually be inside the local function! so we have to do this separately
currentRoot = ReplaceInvocations( currentRoot = ReplaceReferences(
document.Project.Solution.Workspace, currentRoot, document, currentRoot,
delegateType.DelegateInvokeMethod, parameterList, delegateType, parameterList,
invocations.Select(node => currentRoot.GetCurrentNode(node)).ToImmutableArray()); references.Select(node => currentRoot.GetCurrentNode(node)).ToImmutableArray());
} }
editor.ReplaceNode(root, currentRoot); editor.ReplaceNode(root, currentRoot);
...@@ -123,18 +123,25 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) ...@@ -123,18 +123,25 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context)
return editor.GetChangedRoot(); return editor.GetChangedRoot();
} }
private static SyntaxNode ReplaceInvocations( private static SyntaxNode ReplaceReferences(
Workspace workspace, SyntaxNode currentRoot, Document document, SyntaxNode currentRoot,
IMethodSymbol delegateMethod, ParameterListSyntax parameterList, INamedTypeSymbol delegateType, ParameterListSyntax parameterList,
ImmutableArray<InvocationExpressionSyntax> invocations) ImmutableArray<ExpressionSyntax> references)
{
return currentRoot.ReplaceNodes(references, (_ /* nested invocations! */, reference) =>
{ {
return currentRoot.ReplaceNodes(invocations, (_ /* nested invocations! */, invocation) => if (reference is InvocationExpressionSyntax invocation)
{ {
var directInvocation = invocation.Expression is MemberAccessExpressionSyntax memberAccessExpression // it's a .Invoke call var directInvocation = invocation.Expression is MemberAccessExpressionSyntax memberAccess // it's a .Invoke call
? invocation.WithExpression(memberAccessExpression.Expression).WithTriviaFrom(invocation) // remove it ? invocation.WithExpression(memberAccess.Expression).WithTriviaFrom(invocation) // remove it
: invocation; : invocation;
return WithNewParameterNames(directInvocation, delegateMethod, parameterList); return WithNewParameterNames(directInvocation, delegateType.DelegateInvokeMethod, parameterList);
}
// It's not an invocation. Wrap the identifier in a cast (which will be remove by the simplifier if unnecessary)
// to ensure we preserve semantics in cases like overload resolution or generic type inference.
return SyntaxGenerator.GetGenerator(document).CastExpression(delegateType, reference);
}); });
} }
......
...@@ -120,7 +120,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp ...@@ -120,7 +120,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
return; return;
} }
if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, out var invocationLocations, cancellationToken)) if (!CanReplaceAnonymousWithLocalFunction(semanticModel, expressionTypeOpt, local, block, anonymousFunction, out var referenceLocations, cancellationToken))
{ {
return; return;
} }
...@@ -130,7 +130,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp ...@@ -130,7 +130,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
localDeclaration.GetLocation(), localDeclaration.GetLocation(),
anonymousFunction.GetLocation()); anonymousFunction.GetLocation());
additionalLocations = additionalLocations.AddRange(invocationLocations); additionalLocations = additionalLocations.AddRange(referenceLocations);
if (severity.WithDefaultSeverity(DiagnosticSeverity.Hidden) < ReportDiagnostic.Hidden) if (severity.WithDefaultSeverity(DiagnosticSeverity.Hidden) < ReportDiagnostic.Hidden)
{ {
...@@ -209,12 +209,12 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp ...@@ -209,12 +209,12 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
private bool CanReplaceAnonymousWithLocalFunction( private bool CanReplaceAnonymousWithLocalFunction(
SemanticModel semanticModel, INamedTypeSymbol expressionTypeOpt, ISymbol local, BlockSyntax block, SemanticModel semanticModel, INamedTypeSymbol expressionTypeOpt, ISymbol local, BlockSyntax block,
AnonymousFunctionExpressionSyntax anonymousFunction, out ImmutableArray<Location> invocationLocations, CancellationToken cancellationToken) AnonymousFunctionExpressionSyntax anonymousFunction, out ImmutableArray<Location> referenceLocations, CancellationToken cancellationToken)
{ {
// Check all the references to the anonymous function and disallow the conversion if // Check all the references to the anonymous function and disallow the conversion if
// they're used in certain ways. // they're used in certain ways.
var invocations = ArrayBuilder<Location>.GetInstance(); var references = ArrayBuilder<Location>.GetInstance();
invocationLocations = ImmutableArray<Location>.Empty; referenceLocations = ImmutableArray<Location>.Empty;
var anonymousFunctionStart = anonymousFunction.SpanStart; var anonymousFunctionStart = anonymousFunction.SpanStart;
foreach (var descendentNode in block.DescendantNodes()) foreach (var descendentNode in block.DescendantNodes())
{ {
...@@ -248,14 +248,14 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp ...@@ -248,14 +248,14 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
if (nodeToCheck.Parent is InvocationExpressionSyntax invocationExpression) if (nodeToCheck.Parent is InvocationExpressionSyntax invocationExpression)
{ {
invocations.Add(invocationExpression.GetLocation()); references.Add(invocationExpression.GetLocation());
} }
else if (nodeToCheck.Parent is MemberAccessExpressionSyntax memberAccessExpression) else if (nodeToCheck.Parent is MemberAccessExpressionSyntax memberAccessExpression)
{ {
if (memberAccessExpression.Parent is InvocationExpressionSyntax explicitInvocationExpression && if (memberAccessExpression.Parent is InvocationExpressionSyntax explicitInvocationExpression &&
memberAccessExpression.Name.Identifier.ValueText == WellKnownMemberNames.DelegateInvokeName) memberAccessExpression.Name.Identifier.ValueText == WellKnownMemberNames.DelegateInvokeName)
{ {
invocations.Add(explicitInvocationExpression.GetLocation()); references.Add(explicitInvocationExpression.GetLocation());
} }
else else
{ {
...@@ -264,6 +264,10 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp ...@@ -264,6 +264,10 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
return false; return false;
} }
} }
else
{
references.Add(nodeToCheck.GetLocation());
}
var convertedType = semanticModel.GetTypeInfo(nodeToCheck, cancellationToken).ConvertedType; var convertedType = semanticModel.GetTypeInfo(nodeToCheck, cancellationToken).ConvertedType;
if (!convertedType.IsDelegateType()) if (!convertedType.IsDelegateType())
...@@ -284,7 +288,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp ...@@ -284,7 +288,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext, INamedTyp
} }
} }
invocationLocations = invocations.ToImmutableAndFree(); referenceLocations = references.ToImmutableAndFree();
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册