提交 d09d3100 编写于 作者: A Andy Gocke 提交者: GitHub

Fix calling generic local functions recursively (#15968)

There were two root causes here:

    1) The rewriter was treating the difference between the symbol and
       `symbol.ConstructedFrom` as whether or not there were any type
       parameters remaining that may need substitution. This is invalid
       for recursive local functions.

    2) The local function reference rewriter was including type
       parameters from the containing type in the list of parameters to
       substitute. This should happen iff the containing type is a
       lambda frame used to capture variables, which is not always the
       case for local functions (although it is always the case for
       lambdas).

Fixes #15751
上级 71c97dbc
......@@ -56,7 +56,12 @@ public override BoundNode VisitCall(BoundCall node)
BoundExpression receiver;
MethodSymbol method;
var arguments = node.Arguments;
_lambdaRewriter.RemapLocalFunction(node.Syntax, node.Method, out receiver, out method, ref arguments);
_lambdaRewriter.RemapLocalFunction(
node.Syntax,
node.Method,
out receiver,
out method,
ref arguments);
node = node.Update(receiver, method, arguments);
}
......@@ -71,10 +76,16 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE
MethodSymbol method;
var arguments = default(ImmutableArray<BoundExpression>);
_lambdaRewriter.RemapLocalFunction(
node.Syntax, node.MethodOpt, out receiver, out method, ref arguments);
node.Syntax,
node.MethodOpt,
out receiver,
out method,
ref arguments);
var newType = _lambdaRewriter.VisitType(node.Type);
return new BoundDelegateCreationExpression(
node.Syntax, receiver, method, isExtensionMethod: false, type: node.Type);
node.Syntax, receiver, method, isExtensionMethod: false, type: newType);
}
return base.VisitDelegateCreationExpression(node);
......@@ -89,10 +100,15 @@ public override BoundNode VisitConversion(BoundConversion conversion)
MethodSymbol method;
var arguments = default(ImmutableArray<BoundExpression>);
_lambdaRewriter.RemapLocalFunction(
conversion.Syntax, conversion.SymbolOpt, out receiver, out method, ref arguments);
conversion.Syntax,
conversion.SymbolOpt,
out receiver,
out method,
ref arguments);
var newType = _lambdaRewriter.VisitType(conversion.Type);
return new BoundDelegateCreationExpression(
conversion.Syntax, receiver, method, isExtensionMethod: false, type: conversion.Type);
conversion.Syntax, receiver, method, isExtensionMethod: false, type: newType);
}
return base.VisitConversion(conversion);
}
......@@ -147,8 +163,16 @@ public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody)
_framePointers.TryGetValue(synthesizedLambda.ContainingType, out _innermostFramePointer);
}
_currentTypeParameters = synthesizedLambda.ContainingType
?.TypeParameters.Concat(synthesizedLambda.TypeParameters)
var containerAsFrame = synthesizedLambda.ContainingType as LambdaFrame;
// Includes type parameters from the containing type iff
// the containing type is a frame. If it is a frame then
// the type parameters are captured, meaning that the
// type parameters should be included.
// If it is not a frame then the local function is being
// directly lowered into the method's containing type and
// the parameters should never be substituted.
_currentTypeParameters = containerAsFrame?.TypeParameters.Concat(synthesizedLambda.TypeParameters)
?? synthesizedLambda.TypeParameters;
_currentLambdaBodyTypeMap = synthesizedLambda.TypeMap;
......@@ -166,43 +190,39 @@ public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody)
return newBody;
}
/// <summary>
/// Rewrites a reference to an unlowered local function to the newly
/// lowered local function.
/// </summary>
private void RemapLocalFunction(
SyntaxNode syntax,
MethodSymbol symbol,
MethodSymbol localFunc,
out BoundExpression receiver,
out MethodSymbol method,
ref ImmutableArray<BoundExpression> parameters,
ImmutableArray<TypeSymbol> typeArguments = default(ImmutableArray<TypeSymbol>))
ref ImmutableArray<BoundExpression> parameters)
{
Debug.Assert(symbol.MethodKind == MethodKind.LocalFunction);
Debug.Assert(localFunc.MethodKind == MethodKind.LocalFunction);
if ((object)symbol != symbol.ConstructedFrom)
{
RemapLocalFunction(syntax,
symbol.ConstructedFrom,
out receiver,
out method,
ref parameters,
TypeMap.SubstituteTypes(symbol.TypeArguments)
.SelectAsArray(t => t.Type));
return;
}
var mappedLocalFunction = _localFunctionMap[(LocalFunctionSymbol)localFunc.OriginalDefinition];
var loweredSymbol = mappedLocalFunction.Symbol;
var mappedLocalFunction = _localFunctionMap[(LocalFunctionSymbol)symbol];
var lambda = mappedLocalFunction.Symbol;
var frameCount = lambda.ExtraSynthesizedParameterCount;
// If the local function captured variables then they will be stored
// in frames and the frames need to be passed as extra parameters.
var frameCount = loweredSymbol.ExtraSynthesizedParameterCount;
if (frameCount != 0)
{
Debug.Assert(!parameters.IsDefault);
var builder = ArrayBuilder<BoundExpression>.GetInstance();
builder.AddRange(parameters);
var start = lambda.ParameterCount - frameCount;
for (int i = start; i < lambda.ParameterCount; i++)
// Build a new list of parameters to pass to the local function
// call that includes any necessary capture frames
var parametersBuilder = ArrayBuilder<BoundExpression>.GetInstance();
parametersBuilder.AddRange(parameters);
var start = loweredSymbol.ParameterCount - frameCount;
for (int i = start; i < loweredSymbol.ParameterCount; i++)
{
// will always be a LambdaFrame, it's always a closure class
var frameType = (NamedTypeSymbol)lambda.Parameters[i].Type.OriginalDefinition;
// will always be a LambdaFrame, it's always a capture frame
var frameType = (NamedTypeSymbol)loweredSymbol.Parameters[i].Type.OriginalDefinition;
Debug.Assert(frameType is LambdaFrame);
......@@ -213,21 +233,76 @@ public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody)
var subst = this.TypeMap.SubstituteTypeParameters(typeParameters);
frameType = frameType.Construct(subst);
}
var frame = FrameOfType(syntax, frameType);
builder.Add(frame);
parametersBuilder.Add(frame);
}
parameters = builder.ToImmutableAndFree();
parameters = parametersBuilder.ToImmutableAndFree();
}
method = lambda;
method = loweredSymbol;
NamedTypeSymbol constructedFrame;
RemapLambdaOrLocalFunction(syntax,
symbol,
typeArguments,
localFunc,
SubstituteTypeArguments(localFunc.TypeArguments),
mappedLocalFunction.ClosureKind,
ref method,
out receiver,
out constructedFrame);
}
/// <summary>
/// Substitutes references from old type arguments to new type arguments
/// in the lowered methods.
/// </summary>
/// <example>
/// Consider the following method:
/// void M() {
/// void L&lt;T&gt;(T t) => Console.Write(t);
/// L("A");
/// }
///
/// In this example, L&lt;T&gt; is a local function that will be
/// lowered into its own method and the type parameter T will be
/// alpha renamed to something else (let's call it T'). In this case,
/// all references to the original type parameter T in L must be
/// rewritten to the renamed parameter, T'.
/// </example>
private ImmutableArray<TypeSymbol> SubstituteTypeArguments(ImmutableArray<TypeSymbol> typeArguments)
{
Debug.Assert(!typeArguments.IsDefault);
if (typeArguments.IsEmpty)
{
return typeArguments;
}
// We must perform this process repeatedly as local
// functions may nest inside one another and capture type
// parameters from the enclosing local functions. Each
// iteration of nesting will cause alpha-renaming of the captured
// parameters, meaning that we must replace until there are no
// more alpha-rename mappings.
var builder = ArrayBuilder<TypeSymbol>.GetInstance();
foreach (var typeArg in typeArguments)
{
TypeSymbol oldTypeArg;
TypeSymbol newTypeArg = typeArg;
do
{
oldTypeArg = newTypeArg;
newTypeArg = this.TypeMap.SubstituteType(typeArg).Type;
}
while (oldTypeArg != newTypeArg);
Debug.Assert((object)oldTypeArg == newTypeArg);
builder.Add(newTypeArg);
}
return builder.ToImmutableAndFree();
}
}
}
......@@ -3173,6 +3173,148 @@ public class C {
VerifyOutput(src, "7");
}
[Fact]
[WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")]
public void RecursiveGenericLocalFunction()
{
var src = @"
void Local<T>(T t, int count)
{
if (count > 0)
{
Console.Write(t);
Local(t, count - 1);
}
}
Local(""A"", 5);
";
VerifyOutputInMain(src, "AAAAA", "System");
}
[Fact]
[WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")]
public void RecursiveGenericLocalFunction2()
{
var src = @"
void Local<T>(T t, int count)
{
if (count > 0)
{
Console.Write(t);
var action = new Action<T, int>(Local);
action(t, count - 1);
}
}
Local(""A"", 5);
";
VerifyOutputInMain(src, "AAAAA", "System");
}
[Fact]
[WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")]
public void RecursiveGenericLocalFunction3()
{
var src = @"
void Local<T>(T t, int count)
{
if (count > 0)
{
Console.Write(t);
var action = (Action<T, int>)Local;
action(t, count - 1);
}
}
Local(""A"", 5);
";
VerifyOutputInMain(src, "AAAAA", "System");
}
[Fact]
[WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")]
public void RecursiveGenericLocalFunction4()
{
var src = @"
using System;
class C
{
public static void M<T>(T t)
{
void Local<U>(U u, int count)
{
if (count > 0)
{
Console.Write(t);
Console.Write(u);
Local(u, count - 1);
}
}
Local(""A"", 5);
}
public static void Main()
{
C.M(""B"");
}
}";
VerifyOutput(src, "BABABABABA");
}
[Fact]
[WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")]
public void RecursiveGenericLocalFunction5()
{
var src = @"
using System;
class C<T1>
{
T1 t1;
public C(T1 t1)
{
this.t1 = t1;
}
public void M<T2>(T2 t2)
{
void L1<T3>(T3 t3)
{
void L2<T4>(T4 t4)
{
void L3<U>(U u, int count)
{
if (count > 0)
{
Console.Write(t1);
Console.Write(t2);
Console.Write(t3);
Console.Write(t4);
Console.Write(u);
L3(u, count - 1);
}
}
L3(""A"", 5);
}
L2(""B"");
}
L1(""C"");
}
}
class Program
{
public static void Main()
{
var c = new C<string>(""D"");
c.M(""E"");
}
}";
VerifyOutput(src, "DECBADECBADECBADECBADECBA");
}
internal CompilationVerifier VerifyOutput(string source, string output, CSharpCompilationOptions options)
{
var comp = CreateCompilationWithMscorlib45AndCSruntime(source, options: options);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册