From d09d31003f6bf477d0dc33de285f860bb92f439e Mon Sep 17 00:00:00 2001 From: Andy Gocke Date: Mon, 19 Dec 2016 16:07:06 -0800 Subject: [PATCH] Fix calling generic local functions recursively (#15968) There were two root causes here: 1) The rewriter was treating the difference between the symbol and `symbol.ConstructedFrom` as whether or not there were any type parameters remaining that may need substitution. This is invalid for recursive local functions. 2) The local function reference rewriter was including type parameters from the containing type in the list of parameters to substitute. This should happen iff the containing type is a lambda frame used to capture variables, which is not always the case for local functions (although it is always the case for lambdas). Fixes #15751 --- ...Rewriter.LocalFunctionReferenceRewriter.cs | 151 +++++++++++++----- .../Emit/CodeGen/CodeGenLocalFunctionTests.cs | 142 ++++++++++++++++ 2 files changed, 255 insertions(+), 38 deletions(-) diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.LocalFunctionReferenceRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.LocalFunctionReferenceRewriter.cs index cc6f19b590f..6df9273c39e 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.LocalFunctionReferenceRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.LocalFunctionReferenceRewriter.cs @@ -56,7 +56,12 @@ public override BoundNode VisitCall(BoundCall node) BoundExpression receiver; MethodSymbol method; var arguments = node.Arguments; - _lambdaRewriter.RemapLocalFunction(node.Syntax, node.Method, out receiver, out method, ref arguments); + _lambdaRewriter.RemapLocalFunction( + node.Syntax, + node.Method, + out receiver, + out method, + ref arguments); node = node.Update(receiver, method, arguments); } @@ -71,10 +76,16 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE MethodSymbol method; var arguments = default(ImmutableArray); _lambdaRewriter.RemapLocalFunction( - node.Syntax, node.MethodOpt, out receiver, out method, ref arguments); + node.Syntax, + node.MethodOpt, + out receiver, + out method, + ref arguments); + + var newType = _lambdaRewriter.VisitType(node.Type); return new BoundDelegateCreationExpression( - node.Syntax, receiver, method, isExtensionMethod: false, type: node.Type); + node.Syntax, receiver, method, isExtensionMethod: false, type: newType); } return base.VisitDelegateCreationExpression(node); @@ -89,10 +100,15 @@ public override BoundNode VisitConversion(BoundConversion conversion) MethodSymbol method; var arguments = default(ImmutableArray); _lambdaRewriter.RemapLocalFunction( - conversion.Syntax, conversion.SymbolOpt, out receiver, out method, ref arguments); + conversion.Syntax, + conversion.SymbolOpt, + out receiver, + out method, + ref arguments); + var newType = _lambdaRewriter.VisitType(conversion.Type); return new BoundDelegateCreationExpression( - conversion.Syntax, receiver, method, isExtensionMethod: false, type: conversion.Type); + conversion.Syntax, receiver, method, isExtensionMethod: false, type: newType); } return base.VisitConversion(conversion); } @@ -147,8 +163,16 @@ public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody) _framePointers.TryGetValue(synthesizedLambda.ContainingType, out _innermostFramePointer); } - _currentTypeParameters = synthesizedLambda.ContainingType - ?.TypeParameters.Concat(synthesizedLambda.TypeParameters) + var containerAsFrame = synthesizedLambda.ContainingType as LambdaFrame; + + // Includes type parameters from the containing type iff + // the containing type is a frame. If it is a frame then + // the type parameters are captured, meaning that the + // type parameters should be included. + // If it is not a frame then the local function is being + // directly lowered into the method's containing type and + // the parameters should never be substituted. + _currentTypeParameters = containerAsFrame?.TypeParameters.Concat(synthesizedLambda.TypeParameters) ?? synthesizedLambda.TypeParameters; _currentLambdaBodyTypeMap = synthesizedLambda.TypeMap; @@ -166,43 +190,39 @@ public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody) return newBody; } - + /// + /// Rewrites a reference to an unlowered local function to the newly + /// lowered local function. + /// private void RemapLocalFunction( SyntaxNode syntax, - MethodSymbol symbol, + MethodSymbol localFunc, out BoundExpression receiver, out MethodSymbol method, - ref ImmutableArray parameters, - ImmutableArray typeArguments = default(ImmutableArray)) + ref ImmutableArray parameters) { - Debug.Assert(symbol.MethodKind == MethodKind.LocalFunction); + Debug.Assert(localFunc.MethodKind == MethodKind.LocalFunction); - if ((object)symbol != symbol.ConstructedFrom) - { - RemapLocalFunction(syntax, - symbol.ConstructedFrom, - out receiver, - out method, - ref parameters, - TypeMap.SubstituteTypes(symbol.TypeArguments) - .SelectAsArray(t => t.Type)); - return; - } + var mappedLocalFunction = _localFunctionMap[(LocalFunctionSymbol)localFunc.OriginalDefinition]; + var loweredSymbol = mappedLocalFunction.Symbol; - var mappedLocalFunction = _localFunctionMap[(LocalFunctionSymbol)symbol]; - - var lambda = mappedLocalFunction.Symbol; - var frameCount = lambda.ExtraSynthesizedParameterCount; + // If the local function captured variables then they will be stored + // in frames and the frames need to be passed as extra parameters. + var frameCount = loweredSymbol.ExtraSynthesizedParameterCount; if (frameCount != 0) { Debug.Assert(!parameters.IsDefault); - var builder = ArrayBuilder.GetInstance(); - builder.AddRange(parameters); - var start = lambda.ParameterCount - frameCount; - for (int i = start; i < lambda.ParameterCount; i++) + + // Build a new list of parameters to pass to the local function + // call that includes any necessary capture frames + var parametersBuilder = ArrayBuilder.GetInstance(); + parametersBuilder.AddRange(parameters); + + var start = loweredSymbol.ParameterCount - frameCount; + for (int i = start; i < loweredSymbol.ParameterCount; i++) { - // will always be a LambdaFrame, it's always a closure class - var frameType = (NamedTypeSymbol)lambda.Parameters[i].Type.OriginalDefinition; + // will always be a LambdaFrame, it's always a capture frame + var frameType = (NamedTypeSymbol)loweredSymbol.Parameters[i].Type.OriginalDefinition; Debug.Assert(frameType is LambdaFrame); @@ -213,21 +233,76 @@ public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody) var subst = this.TypeMap.SubstituteTypeParameters(typeParameters); frameType = frameType.Construct(subst); } + var frame = FrameOfType(syntax, frameType); - builder.Add(frame); + parametersBuilder.Add(frame); } - parameters = builder.ToImmutableAndFree(); + parameters = parametersBuilder.ToImmutableAndFree(); } - method = lambda; + method = loweredSymbol; NamedTypeSymbol constructedFrame; + RemapLambdaOrLocalFunction(syntax, - symbol, - typeArguments, + localFunc, + SubstituteTypeArguments(localFunc.TypeArguments), mappedLocalFunction.ClosureKind, ref method, out receiver, out constructedFrame); } + + /// + /// Substitutes references from old type arguments to new type arguments + /// in the lowered methods. + /// + /// + /// Consider the following method: + /// void M() { + /// void L<T>(T t) => Console.Write(t); + /// L("A"); + /// } + /// + /// In this example, L<T> is a local function that will be + /// lowered into its own method and the type parameter T will be + /// alpha renamed to something else (let's call it T'). In this case, + /// all references to the original type parameter T in L must be + /// rewritten to the renamed parameter, T'. + /// + private ImmutableArray SubstituteTypeArguments(ImmutableArray typeArguments) + { + Debug.Assert(!typeArguments.IsDefault); + + if (typeArguments.IsEmpty) + { + return typeArguments; + } + + // We must perform this process repeatedly as local + // functions may nest inside one another and capture type + // parameters from the enclosing local functions. Each + // iteration of nesting will cause alpha-renaming of the captured + // parameters, meaning that we must replace until there are no + // more alpha-rename mappings. + + var builder = ArrayBuilder.GetInstance(); + foreach (var typeArg in typeArguments) + { + TypeSymbol oldTypeArg; + TypeSymbol newTypeArg = typeArg; + do + { + oldTypeArg = newTypeArg; + newTypeArg = this.TypeMap.SubstituteType(typeArg).Type; + } + while (oldTypeArg != newTypeArg); + + Debug.Assert((object)oldTypeArg == newTypeArg); + + builder.Add(newTypeArg); + } + + return builder.ToImmutableAndFree(); + } } } diff --git a/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenLocalFunctionTests.cs b/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenLocalFunctionTests.cs index 9f56787b47e..2dac388262b 100644 --- a/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenLocalFunctionTests.cs +++ b/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenLocalFunctionTests.cs @@ -3173,6 +3173,148 @@ public class C { VerifyOutput(src, "7"); } + [Fact] + [WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")] + public void RecursiveGenericLocalFunction() + { + var src = @" +void Local(T t, int count) +{ + if (count > 0) + { + Console.Write(t); + Local(t, count - 1); + } +} + +Local(""A"", 5); +"; + VerifyOutputInMain(src, "AAAAA", "System"); + } + + [Fact] + [WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")] + public void RecursiveGenericLocalFunction2() + { + var src = @" +void Local(T t, int count) +{ + if (count > 0) + { + Console.Write(t); + var action = new Action(Local); + action(t, count - 1); + } +} + +Local(""A"", 5); +"; + VerifyOutputInMain(src, "AAAAA", "System"); + } + + [Fact] + [WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")] + public void RecursiveGenericLocalFunction3() + { + var src = @" +void Local(T t, int count) +{ + if (count > 0) + { + Console.Write(t); + var action = (Action)Local; + action(t, count - 1); + } +} + +Local(""A"", 5); +"; + VerifyOutputInMain(src, "AAAAA", "System"); + } + + [Fact] + [WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")] + public void RecursiveGenericLocalFunction4() + { + var src = @" +using System; +class C +{ + public static void M(T t) + { + void Local(U u, int count) + { + if (count > 0) + { + Console.Write(t); + Console.Write(u); + Local(u, count - 1); + } + } + Local(""A"", 5); + } + + public static void Main() + { + C.M(""B""); + } +}"; + VerifyOutput(src, "BABABABABA"); + } + + [Fact] + [WorkItem(15751, "https://github.com/dotnet/roslyn/issues/15751")] + public void RecursiveGenericLocalFunction5() + { + var src = @" +using System; +class C +{ + T1 t1; + + public C(T1 t1) + { + this.t1 = t1; + } + + public void M(T2 t2) + { + void L1(T3 t3) + { + void L2(T4 t4) + { + void L3(U u, int count) + { + if (count > 0) + { + Console.Write(t1); + Console.Write(t2); + Console.Write(t3); + Console.Write(t4); + Console.Write(u); + L3(u, count - 1); + } + } + L3(""A"", 5); + } + L2(""B""); + } + L1(""C""); + } + +} + +class Program +{ + public static void Main() + { + var c = new C(""D""); + c.M(""E""); + } +}"; + VerifyOutput(src, "DECBADECBADECBADECBADECBA"); + } + internal CompilationVerifier VerifyOutput(string source, string output, CSharpCompilationOptions options) { var comp = CreateCompilationWithMscorlib45AndCSruntime(source, options: options); -- GitLab