From 65d8132113f900b209dfb319e2cf57f0607b771d Mon Sep 17 00:00:00 2001 From: Andy Gocke Date: Fri, 28 Oct 2016 12:36:01 -0700 Subject: [PATCH] Implement proper 'this' capturing for local functions (#14736) Currently, when a local function is captured inside another local function or lambda it can capture 'this' without generating a frame. This is useful, but when that lambda is itself captured then the capturing closure must also capture the frame pointer, namely 'this'. Currently, local function frame pointers are not correctly captured when the captured local function itself captures something from a "higher" scope than the capturing local function. This change solves this problem by: 1) Considering a local function's captured variables when deciding its scope. If the local function captures variables from a higher scope, that local function will be analyzed as belonging to the "higher" scope, causing that local function to register for frame capturing. 2) Since the proxies for capturing frames are not available at the time of local function reference rewriting, the proxies must be saved. There is a new temporary bound node for this purpose, PartiallyLoweredLocalFunctionReference, that stores the proxies and the underlying node for later use during the rewriting phase. This node should never make it past LocalFunctionReferenceRewriting. When these steps are completed, local functions should act very similarly to all other captured variables with different frames, where the frame pointers are captured and walked in a linked list in order to access the target with the proper receiver/frame pointer. --- .../CSharp/Portable/CSharpCodeAnalysis.csproj | 1 + .../Portable/Compiler/TypeCompilationState.cs | 5 + .../LambdaRewriter/LambdaRewriter.Analysis.cs | 60 ++-- ...Rewriter.LocalFunctionReferenceRewriter.cs | 90 +++++ .../Lowering/LambdaRewriter/LambdaRewriter.cs | 277 +++++++++------ .../PartiallyLoweredLocalFunctionReference.cs | 51 +++ .../Lowering/MethodToClassRewriter.cs | 5 +- .../Emit/CodeGen/CodeGenLocalFunctionTests.cs | 317 ++++++++++++++++++ .../EditAndContinueClosureTests.cs | 8 +- .../EditAndContinue/SymbolMatcherTests.cs | 7 +- .../Core/Portable/CodeAnalysis.csproj | 1 + .../OrderedMultiDictionary.cs | 60 ++++ .../SetWithInsertionOrder.cs | 2 +- 13 files changed, 742 insertions(+), 142 deletions(-) create mode 100644 src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/PartiallyLoweredLocalFunctionReference.cs create mode 100644 src/Compilers/Core/Portable/InternalUtilities/OrderedMultiDictionary.cs diff --git a/src/Compilers/CSharp/Portable/CSharpCodeAnalysis.csproj b/src/Compilers/CSharp/Portable/CSharpCodeAnalysis.csproj index de8183202d0..e5f29410fb3 100644 --- a/src/Compilers/CSharp/Portable/CSharpCodeAnalysis.csproj +++ b/src/Compilers/CSharp/Portable/CSharpCodeAnalysis.csproj @@ -402,6 +402,7 @@ + diff --git a/src/Compilers/CSharp/Portable/Compiler/TypeCompilationState.cs b/src/Compilers/CSharp/Portable/Compiler/TypeCompilationState.cs index e197e732855..97be4e0c93f 100644 --- a/src/Compilers/CSharp/Portable/Compiler/TypeCompilationState.cs +++ b/src/Compilers/CSharp/Portable/Compiler/TypeCompilationState.cs @@ -118,6 +118,11 @@ public bool Emitting public ArrayBuilder SynthesizedMethods { get { return _synthesizedMethods; } + set + { + Debug.Assert(_synthesizedMethods == null); + _synthesizedMethods = value; + } } /// diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs index 5b74eea49f7..7cb3de477d6 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs @@ -7,6 +7,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; using Roslyn.Utilities; +using Microsoft.CodeAnalysis; namespace Microsoft.CodeAnalysis.CSharp { @@ -63,7 +64,7 @@ internal sealed class Analysis : BoundTreeWalkerWithStackGuardWithoutRecursionOn /// /// For each lambda in the code, the set of variables that it captures. /// - public MultiDictionary CapturedVariablesByLambda = new MultiDictionary(); + public OrderedMultiDictionary CapturedVariablesByLambda = new OrderedMultiDictionary(); /// /// If a local function is in the set, at some point in the code it is converted to a delegate and should then not be optimized to a struct closure. @@ -219,21 +220,7 @@ private void RemoveUnneededReferences() } } - var capturedVariablesNew = new MultiDictionary(); - foreach (var old in CapturedVariables) - { - var method = old.Key as MethodSymbol; - // don't add if it's a method that only captures 'this' - if (method == null || capturesVariable.Contains(method)) - { - foreach (var oldValue in old.Value) - { - capturedVariablesNew.Add(old.Key, oldValue); - } - } - } - CapturedVariables = capturedVariablesNew; - var capturedVariablesByLambdaNew = new MultiDictionary(); + var capturedVariablesByLambdaNew = new OrderedMultiDictionary(); foreach (var old in CapturedVariablesByLambda) { if (capturesVariable.Contains(old.Key)) @@ -263,22 +250,43 @@ internal void ComputeLambdaScopesAndFrameCaptures() foreach (var kvp in CapturedVariablesByLambda) { - // get innermost and outermost scopes from which a lambda captures + var lambda = kvp.Key; + var capturedVars = kvp.Value; + + var allCapturedVars = ArrayBuilder.GetInstance(capturedVars.Count); + allCapturedVars.AddRange(capturedVars); + + // If any of the captured variables are local functions we'll need + // to add the captured variables of that local function to the current + // set. This has the effect of ensuring that if the local function + // captures anything "above" the current scope then parent frame + // is itself captured (so that the current lambda can call that + // local function). + foreach (var captured in capturedVars) + { + var capturedLocalFunction = captured as LocalFunctionSymbol; + if (capturedLocalFunction != null) + { + allCapturedVars.AddRange( + CapturedVariablesByLambda[capturedLocalFunction]); + } + } + // get innermost and outermost scopes from which a lambda captures int innermostScopeDepth = -1; BoundNode innermostScope = null; int outermostScopeDepth = int.MaxValue; BoundNode outermostScope = null; - foreach (var variables in kvp.Value) + foreach (var captured in allCapturedVars) { BoundNode curBlock = null; int curBlockDepth; - if (!VariableScope.TryGetValue(variables, out curBlock)) + if (!VariableScope.TryGetValue(captured, out curBlock)) { - // this is something that is not defined in a block, like "Me" + // this is something that is not defined in a block, like "this" // Since it is defined outside of the method, the depth is -1 curBlockDepth = -1; } @@ -300,22 +308,24 @@ internal void ComputeLambdaScopesAndFrameCaptures() } } + allCapturedVars.Free(); + // 1) if there is innermost scope, lambda goes there as we cannot go any higher. // 2) scopes in [innermostScope, outermostScope) chain need to have access to the parent scope. // // Example: - // if a lambda captures a method//s parameter and Me, + // if a lambda captures a method's parameter and `this`, // its innermost scope depth is 0 (method locals and parameters) // and outermost scope is -1 - // Such lambda will be placed in a closure frame that corresponds to the method//s outer block - // and this frame will also lift original Me as a field when created by its parent. + // Such lambda will be placed in a closure frame that corresponds to the method's outer block + // and this frame will also lift original `this` as a field when created by its parent. // Note that it is completely irrelevant how deeply the lexical scope of the lambda was originally nested. if (innermostScope != null) { - LambdaScopes.Add(kvp.Key, innermostScope); + LambdaScopes.Add(lambda, innermostScope); // Disable struct closures on methods converted to delegates, as well as on async and iterator methods. - var markAsNoStruct = MethodsConvertedToDelegates.Contains(kvp.Key) || kvp.Key.IsAsync || kvp.Key.IsIterator; + var markAsNoStruct = MethodsConvertedToDelegates.Contains(lambda) || lambda.IsAsync || lambda.IsIterator; if (markAsNoStruct) { ScopesThatCantBeStructs.Add(innermostScope); diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.LocalFunctionReferenceRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.LocalFunctionReferenceRewriter.cs index aece1bd0035..afece7deab1 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.LocalFunctionReferenceRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.LocalFunctionReferenceRewriter.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.CodeAnalysis.CSharp.Symbols; +using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; @@ -27,6 +28,27 @@ public LocalFunctionReferenceRewriter(LambdaRewriter lambdaRewriter) _lambdaRewriter = lambdaRewriter; } + public override BoundNode Visit(BoundNode node) + { + var partiallyLowered = node as PartiallyLoweredLocalFunctionReference; + if (partiallyLowered != null) + { + var underlying = partiallyLowered.UnderlyingNode; + Debug.Assert(underlying.Kind == BoundKind.Call || + underlying.Kind == BoundKind.DelegateCreationExpression || + underlying.Kind == BoundKind.Conversion); + var oldProxies = _lambdaRewriter.proxies; + _lambdaRewriter.proxies = partiallyLowered.Proxies; + + var result = base.Visit(underlying); + + _lambdaRewriter.proxies = oldProxies; + + return result; + } + return base.Visit(node); + } + public override BoundNode VisitCall(BoundCall node) { if (node.Method.MethodKind == MethodKind.LocalFunction) @@ -76,6 +98,74 @@ public override BoundNode VisitConversion(BoundConversion conversion) } } + /// + /// Visit all references to local functions (calls, delegate + /// conversions, delegate creations) and rewrite them to point + /// to the rewritten local function method instead of the original. + /// + public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody) + { + var rewriter = new LocalFunctionReferenceRewriter(this); + + Debug.Assert(_currentMethod == _topLevelMethod); + + // Visit the body first since the state is already set + // for the top-level method + var newBody = (BoundStatement)rewriter.Visit(loweredBody); + + // Visit all the rewritten methods as well + var synthesizedMethods = _synthesizedMethods; + if (synthesizedMethods != null) + { + var newMethods = ArrayBuilder.GetInstance( + synthesizedMethods.Count); + + foreach (var oldMethod in synthesizedMethods) + { + var synthesizedLambda = oldMethod.Method as SynthesizedLambdaMethod; + if (synthesizedLambda == null) + { + // The only methods synthesized by the rewriter should + // be lowered closures and frame constructors + Debug.Assert(oldMethod.Method.MethodKind == MethodKind.Constructor || + oldMethod.Method.MethodKind == MethodKind.StaticConstructor); + newMethods.Add(oldMethod); + continue; + } + + _currentMethod = synthesizedLambda; + var closureKind = synthesizedLambda.ClosureKind; + if (closureKind == ClosureKind.Static || closureKind == ClosureKind.Singleton) + { + // no link from a static lambda to its container + _innermostFramePointer = _currentFrameThis = null; + } + else + { + _currentFrameThis = synthesizedLambda.ThisParameter; + _innermostFramePointer = null; + _framePointers.TryGetValue(synthesizedLambda.ContainingType, out _innermostFramePointer); + } + + _currentTypeParameters = synthesizedLambda.ContainingType + ?.TypeParameters.Concat(synthesizedLambda.TypeParameters) + ?? synthesizedLambda.TypeParameters; + _currentLambdaBodyTypeMap = synthesizedLambda.TypeMap; + + var rewrittenBody = (BoundStatement)rewriter.Visit(oldMethod.Body); + + var newMethod = new TypeCompilationState.MethodWithBody( + synthesizedLambda, rewrittenBody, oldMethod.ImportChainOpt); + newMethods.Add(newMethod); + } + + _synthesizedMethods = newMethods; + synthesizedMethods.Free(); + } + + return newBody; + } + private void RemapLocalFunction( SyntaxNode syntax, diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs index 522716c6ff3..3840746da36 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs @@ -12,6 +12,8 @@ using Microsoft.CodeAnalysis.CSharp.Symbols; using Microsoft.CodeAnalysis.CSharp.Syntax; using Roslyn.Utilities; +using System.Linq; +using Microsoft.CodeAnalysis.Collections; namespace Microsoft.CodeAnalysis.CSharp { @@ -140,6 +142,12 @@ public MappedLocalFunction(SynthesizedLambdaMethod symbol, ClosureKind closureKi // top-level block initializing those variables to null. private ArrayBuilder _addedStatements; + /// + /// Temporary bag for methods synthesized by the rewriting. Added to + /// at the end of rewriting. + /// + private ArrayBuilder _synthesizedMethods; + private LambdaRewriter( Analysis analysis, NamedTypeSymbol thisType, @@ -257,6 +265,20 @@ protected override bool NeedsProxy(Symbol localOrParameter) body = rewriter.RewriteLocalFunctionReferences(body); } + // Add the completed methods to the compilation state + if (rewriter._synthesizedMethods != null) + { + if (compilationState.SynthesizedMethods == null) + { + compilationState.SynthesizedMethods = rewriter._synthesizedMethods; + } + else + { + compilationState.SynthesizedMethods.AddRange(rewriter._synthesizedMethods); + rewriter._synthesizedMethods.Free(); + } + } + CheckLocalsDefined(body); return body; @@ -305,38 +327,120 @@ protected override NamedTypeSymbol ContainingType /// private void MakeFrames(ArrayBuilder closureDebugInfo) { - NamedTypeSymbol containingType = this.ContainingType; + var closures = _analysis.CapturedVariablesByLambda.Keys; - foreach (var kvp in _analysis.CapturedVariables) + foreach (var closure in closures) { - var captured = kvp.Key; + var capturedVars = _analysis.CapturedVariablesByLambda[closure]; - BoundNode scope; - if (!_analysis.VariableScope.TryGetValue(captured, out scope)) + if (closure.MethodKind == MethodKind.LocalFunction && + OnlyCapturesThis((LocalFunctionSymbol)closure, capturedVars)) { continue; } - LambdaFrame frame = GetFrameForScope(scope, closureDebugInfo); - - if (captured.Kind != SymbolKind.Method) + foreach (var captured in capturedVars) { - var hoistedField = LambdaCapturedVariable.Create(frame, captured, ref _synthesizedFieldNameIdDispenser); - proxies.Add(captured, new CapturedToFrameSymbolReplacement(hoistedField, isReusable: false)); - CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(frame, hoistedField); + BoundNode scope; + if (!_analysis.VariableScope.TryGetValue(captured, out scope)) + { + continue; + } + + LambdaFrame frame = GetFrameForScope(scope, closureDebugInfo); - if (hoistedField.Type.IsRestrictedType()) + if (captured.Kind != SymbolKind.Method && !proxies.ContainsKey(captured)) { - foreach (CSharpSyntaxNode syntax in kvp.Value) + var hoistedField = LambdaCapturedVariable.Create(frame, captured, ref _synthesizedFieldNameIdDispenser); + proxies.Add(captured, new CapturedToFrameSymbolReplacement(hoistedField, isReusable: false)); + CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(frame, hoistedField); + + if (hoistedField.Type.IsRestrictedType()) { - // CS4013: Instance of type '{0}' cannot be used inside an anonymous function, query expression, iterator block or async method - this.Diagnostics.Add(ErrorCode.ERR_SpecialByRefInLambda, syntax.Location, hoistedField.Type); + foreach (CSharpSyntaxNode syntax in _analysis.CapturedVariables[captured]) + { + // CS4013: Instance of type '{0}' cannot be used inside an anonymous function, query expression, iterator block or async method + this.Diagnostics.Add(ErrorCode.ERR_SpecialByRefInLambda, syntax.Location, hoistedField.Type); + } } } } } } + + private SmallDictionary _onlyCapturesThisMemoTable; + /// + /// Helper for determining whether a local function transitively + /// only captures this (only captures this or other local functions + /// which only capture this). + /// + private bool OnlyCapturesThis( + LocalFunctionSymbol closure, + T capturedVars, + PooledHashSet localFuncsInProgress = null) + where T : IEnumerable + { + bool result = false; + if (_onlyCapturesThisMemoTable?.TryGetValue(closure, out result) == true) + { + return result; + } + + result = true; + foreach (var captured in capturedVars) + { + var param = captured as ParameterSymbol; + if (param != null && param.IsThis) + { + continue; + } + + var localFunc = captured as LocalFunctionSymbol; + if (localFunc != null) + { + bool freePool = false; + if (localFuncsInProgress == null) + { + localFuncsInProgress = PooledHashSet.GetInstance(); + freePool = true; + } + else if (localFuncsInProgress.Contains(localFunc)) + { + continue; + } + + localFuncsInProgress.Add(localFunc); + bool transitivelyTrue = OnlyCapturesThis( + localFunc, + _analysis.CapturedVariablesByLambda[localFunc], + localFuncsInProgress); + + if (freePool) + { + localFuncsInProgress.Free(); + localFuncsInProgress = null; + } + + if (transitivelyTrue) + { + continue; + } + } + + result = false; + break; + } + + if (_onlyCapturesThisMemoTable == null) + { + _onlyCapturesThisMemoTable = new SmallDictionary(); + } + + _onlyCapturesThisMemoTable[closure] = result; + return result; + } + private LambdaFrame GetFrameForScope(BoundNode scope, ArrayBuilder closureDebugInfo) { LambdaFrame frame; @@ -361,12 +465,12 @@ private LambdaFrame GetFrameForScope(BoundNode scope, ArrayBuilder FramePointer(syntax, frameType)); } - var localFrame = framePointer as LocalSymbol; + var localFrame = (LocalSymbol)framePointer; return new BoundLocal(syntax, localFrame, null, localFrame.Type); } @@ -517,7 +621,7 @@ private static void InsertAndFreePrologue(ArrayBuilder result, A /// The frame for the translated node /// A function that computes the translation of the node. It receives lists of added statements and added symbols /// The translated statement, as returned from F - private T IntroduceFrame(BoundNode node, LambdaFrame frame, Func, ArrayBuilder, T> F) + private BoundNode IntroduceFrame(BoundNode node, LambdaFrame frame, Func, ArrayBuilder, BoundNode> F) { var frameTypeParameters = ImmutableArray.Create(StaticCast.From(_currentTypeParameters).SelectAsArray(TypeMap.TypeSymbolAsTypeWithModifiers), 0, frame.Arity); NamedTypeSymbol frameType = frame.ConstructIfGeneric(frameTypeParameters); @@ -539,8 +643,8 @@ private T IntroduceFrame(BoundNode node, LambdaFrame frame, Func - /// Visit all references to local functions (calls, delegete - /// conversions, delegate creations) and rewrite them to point - /// to the rewritten local function method instead of the original. - /// - public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody) - { - var rewriter = new LocalFunctionReferenceRewriter(this); - - Debug.Assert(_currentMethod == _topLevelMethod); - - // Visit the body first since the state is already set - // for the top-level method - var newBody = (BoundStatement)rewriter.Visit(loweredBody); - - // Visit all the rewritten methods as well - var synthesizedMethods = CompilationState.SynthesizedMethods; - if (synthesizedMethods != null) - { - // Dump the existing methods for rewriting - var oldMethods = synthesizedMethods.ToImmutable(); - synthesizedMethods.Clear(); - - foreach (var oldMethod in oldMethods) - { - var synthesizedLambda = oldMethod.Method as SynthesizedLambdaMethod; - if (synthesizedLambda == null) - { - // The only methods synthesized by the rewriter should - // be lowered closures and frame constructors - Debug.Assert(oldMethod.Method.MethodKind == MethodKind.Constructor || - oldMethod.Method.MethodKind == MethodKind.StaticConstructor); - CompilationState.AddSynthesizedMethod(oldMethod.Method, oldMethod.Body); - continue; - } - - _currentMethod = synthesizedLambda; - var closureKind = synthesizedLambda.ClosureKind; - if (closureKind == ClosureKind.Static || closureKind == ClosureKind.Singleton) - { - // no link from a static lambda to its container - _innermostFramePointer = _currentFrameThis = null; - } - else - { - _currentFrameThis = synthesizedLambda.ThisParameter; - _innermostFramePointer = null; - _framePointers.TryGetValue(synthesizedLambda.ContainingType, out _innermostFramePointer); - } - - _currentTypeParameters = synthesizedLambda.ContainingType - ?.TypeParameters.Concat(synthesizedLambda.TypeParameters) - ?? synthesizedLambda.TypeParameters; - _currentLambdaBodyTypeMap = synthesizedLambda.TypeMap; - - var rewrittenBody = (BoundStatement)rewriter.Visit(oldMethod.Body); - - CompilationState.AddSynthesizedMethod(synthesizedLambda, rewrittenBody); - } - } - - return newBody; - } - private void RemapLambdaOrLocalFunction( SyntaxNode syntax, MethodSymbol originalMethod, @@ -845,7 +885,7 @@ public override BoundNode VisitCall(BoundCall node) { var rewrittenArguments = this.VisitList(node.Arguments); - return node.Update( + var withArguments = node.Update( node.ReceiverOpt, node.Method, rewrittenArguments, @@ -857,6 +897,8 @@ public override BoundNode VisitCall(BoundCall node) node.ArgsToParamsOpt, node.ResultKind, node.Type); + + return PartiallyLowerLocalFunctionReference(withArguments); } var visited = base.VisitCall(node); @@ -887,6 +929,17 @@ public override BoundNode VisitCall(BoundCall node) return rewritten; } + private PartiallyLoweredLocalFunctionReference PartiallyLowerLocalFunctionReference( + BoundExpression underlyingNode) + { + Debug.Assert(underlyingNode.Kind == BoundKind.Call || + underlyingNode.Kind == BoundKind.DelegateCreationExpression || + underlyingNode.Kind == BoundKind.Conversion); + return new PartiallyLoweredLocalFunctionReference( + underlyingNode, + new Dictionary(proxies)); + } + private BoundSequence RewriteSequence(BoundSequence node, ArrayBuilder prologue, ArrayBuilder newLocals) { RewriteLocals(node.Locals, newLocals); @@ -1084,14 +1137,12 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE { return RewriteLambdaConversion((BoundLambda)node.Argument); } - else + + if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction) { - if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction) - { - return node; - } - return base.VisitDelegateCreationExpression(node); + return PartiallyLowerLocalFunctionReference(node); } + return base.VisitDelegateCreationExpression(node); } public override BoundNode VisitConversion(BoundConversion conversion) @@ -1115,15 +1166,13 @@ public override BoundNode VisitConversion(BoundConversion conversion) return result; } - else + + if (conversion.ConversionKind == ConversionKind.MethodGroup && + conversion.SymbolOpt?.MethodKind == MethodKind.LocalFunction) { - if (conversion.ConversionKind == ConversionKind.MethodGroup && - conversion.SymbolOpt?.MethodKind == MethodKind.LocalFunction) - { - return conversion; - } - return base.VisitConversion(conversion); + return PartiallyLowerLocalFunctionReference(conversion); } + return base.VisitConversion(conversion); } public override BoundNode VisitLocalFunctionStatement(BoundLocalFunctionStatement node) @@ -1353,7 +1402,7 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos var body = AddStatementsIfNeeded((BoundStatement)VisitBlock(node.Body)); CheckLocalsDefined(body); - CompilationState.AddSynthesizedMethod(synthesizedMethod, body); + AddSynthesizedMethod(synthesizedMethod, body); // return to the old method @@ -1368,8 +1417,22 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos return synthesizedMethod; } - private BoundNode RewriteLambdaConversion(BoundLambda node) + private void AddSynthesizedMethod(MethodSymbol method, BoundStatement body) + { + if (_synthesizedMethods == null) { + _synthesizedMethods = ArrayBuilder.GetInstance(); + } + + _synthesizedMethods.Add( + new TypeCompilationState.MethodWithBody( + method, + body, + CompilationState.CurrentImportChain)); + } + + private BoundNode RewriteLambdaConversion(BoundLambda node) + { var wasInExpressionLambda = _inExpressionLambda; _inExpressionLambda = _inExpressionLambda || node.Type.IsExpressionTree(); diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/PartiallyLoweredLocalFunctionReference.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/PartiallyLoweredLocalFunctionReference.cs new file mode 100644 index 00000000000..d7f39860d33 --- /dev/null +++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/PartiallyLoweredLocalFunctionReference.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using Microsoft.CodeAnalysis.Semantics; + +namespace Microsoft.CodeAnalysis.CSharp +{ + /// + /// This represents a partially lowered local function reference (e.g., + /// a local function call or delegate conversion) with relevant proxies + /// attached. It will later be rewritten by the + /// into a + /// proper call. + /// + internal class PartiallyLoweredLocalFunctionReference : BoundExpression + { + private const BoundKind s_privateKind = (BoundKind)byte.MaxValue; + + public BoundExpression UnderlyingNode { get; } + public Dictionary Proxies { get; } + + public PartiallyLoweredLocalFunctionReference( + BoundExpression underlying, + Dictionary proxies) + : base(s_privateKind, underlying.Syntax, underlying.Type) + { + UnderlyingNode = underlying; + Proxies = proxies; + } + + public override BoundNode Accept(BoundTreeVisitor visitor) => + visitor.Visit(this); + + protected override OperationKind ExpressionKind + { + get + { + throw new InvalidOperationException(); + } + } + + public override void Accept(OperationVisitor visitor) + { + throw new InvalidOperationException(); + } + + public override TResult Accept(OperationVisitor visitor, TArgument argument) + { + throw new InvalidOperationException(); + } + } +} diff --git a/src/Compilers/CSharp/Portable/Lowering/MethodToClassRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/MethodToClassRewriter.cs index d717ddce9ab..606d19e2a1f 100644 --- a/src/Compilers/CSharp/Portable/Lowering/MethodToClassRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/MethodToClassRewriter.cs @@ -17,7 +17,10 @@ internal abstract partial class MethodToClassRewriter : BoundTreeRewriterWithSta // For each captured variable, information about its replacement. May be populated lazily (that is, not all // upfront) by subclasses. Specifically, the async rewriter produces captured symbols for temps, including // ref locals, lazily. - protected readonly Dictionary proxies = new Dictionary(); + // The lambda rewriter also saves/restores the proxies across passes, since local function + // reference rewriting is done in a separate pass but still requires the frame proxies + // created in the first pass. + protected Dictionary proxies = new Dictionary(); // A mapping from every local variable to its replacement local variable. Local variables are replaced when // their types change due to being inside of a generic method. Otherwise we reuse the original local (even diff --git a/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenLocalFunctionTests.cs b/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenLocalFunctionTests.cs index 9d5d4f21348..009c6376b4a 100644 --- a/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenLocalFunctionTests.cs +++ b/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenLocalFunctionTests.cs @@ -29,6 +29,323 @@ public static IMethodSymbol FindLocalFunction(this CommonTestBase.CompilationVer [CompilerTrait(CompilerFeature.LocalFunctions)] public class CodeGenLocalFunctionTests : CSharpTestBase { + [Fact] + public void DeepNestedLocalFuncsWithDifferentCaptures() + { + var src = @" +using System; +class C +{ + int P = 100000; + void M() + { + C Local1() => this; + int capture1 = 1; + Func f1 = () => capture1 + Local1().P; + Console.WriteLine(f1()); + { + C Local2() => Local1(); + int capture2 = 10; + Func f2 = () => capture2 + Local2().P; + Console.WriteLine(f2()); + { + C Local3() => Local2(); + + int capture3 = 100; + Func f3 = () => capture1 + capture2 + capture3 + Local3().P; + Console.WriteLine(f3()); + + Console.WriteLine(Local3().P); + } + } + } + public static void Main() => new C().M(); +}"; + VerifyOutput(src, @"100001 +100010 +100111 +100000"); + } + + [Fact] + public void LotsOfMutuallyRecursiveLocalFunctions() + { + var src = @" +class C +{ + int P = 0; + public void M() + { + int Local1() => this.P; + int Local2() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local3() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local4() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local5() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local6() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local7() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local8() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local9() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local10() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local11() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + int Local12() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1(); + + Local1(); + Local2(); + Local3(); + Local4(); + Local5(); + Local6(); + Local7(); + Local8(); + Local9(); + Local10(); + Local11(); + Local12(); + } +} +"; + var comp = CreateCompilationWithMscorlib(src); + comp.VerifyEmitDiagnostics(); + } + + [Fact] + public void LocalFuncAndLambdaWithDifferentThis() + { + var src = @" +using System; +class C +{ + private int P = 1; + public void M() + { + int Local(int x) => x + this.P; + + int y = 10; + var a = new Func(() => Local(y)); + Console.WriteLine(a()); + } + + public static void Main(string[] args) + { + var c = new C(); + c.M(); + } +}"; + VerifyOutput(src, "11"); + } + + [Fact] + public void LocalFuncAndLambdaWithDifferentThis2() + { + var src = @" +using System; +class C +{ + private int P = 1; + public void M() + { + int Local() => 10 + this.P; + int Local2(int x) => x + Local(); + + int y = 100; + var a = new Func(() => Local2(y)); + Console.WriteLine(a()); + } + + public static void Main(string[] args) + { + var c = new C(); + c.M(); + } +}"; + VerifyOutput(src, "111"); + } + + [Fact] + public void LocalFuncAndLambdaWithDifferentThis3() + { + var src = @" +using System; +class C +{ + private int P = 1; + public void M() + { + int Local() + { + if (this.P < 5) + { + return Local2(this.P++); + } + else + { + return 1; + } + } + int Local2(int x) => x + Local(); + + int y = 100; + var a = new Func(() => Local2(y)); + Console.WriteLine(a()); + } + + public static void Main(string[] args) + { + var c = new C(); + c.M(); + } +}"; + VerifyOutput(src, "111"); + + } + + [Fact] + public void LocalFuncAndLambdaWithDifferentThis4() + { + var src = @" +using System; +class C +{ + private int P = 1; + public void M() + { + int Local(int x) => x + this.P; + + int y = 10; + var a = new Func(() => + { + var b = (Func)Local; + return b(y); + }); + Console.WriteLine(a()); + } + + public static void Main(string[] args) + { + var c = new C(); + c.M(); + } +}"; + VerifyOutput(src, "11"); + } + + [Fact] + public void LocalFuncAndLambdaWithDifferentThis5() + { + var src = @" +using System; +class C +{ + private int P = 1; + public void M() + { + int Local(int x) => x + this.P; + + int y = 10; + var a = new Func(() => + { + var b = new Func(Local); + return b(y); + }); + Console.WriteLine(a()); + } + + public static void Main(string[] args) + { + var c = new C(); + c.M(); + } +}"; + VerifyOutput(src, "11"); + } + + [Fact] + public void TwoFrames() + { + var src = @" +using System; +class C +{ + private int P = 0; + public void M() + { + int x = 0; + + var a = new Func(() => + { + int Local() => x + this.P; + int z = 0; + int Local3() => z + Local(); + return Local3(); + }); + Console.WriteLine(a()); + } + + public static void Main(string[] args) + { + var c = new C(); + c.M(); + } +}"; + VerifyOutput(src, "0"); + } + + [Fact] + public void SameFrame() + { + var src = @" +using System; +class C +{ + private int P = 1; + public void M() + { + int x = 10; + int Local() => x + this.P; + + int y = 100; + int Local2() => y + Local(); + Console.WriteLine(Local2()); + } + + public static void Main(string[] args) + { + var c = new C(); + c.M(); + } +}"; + VerifyOutput(src, "111"); + } + + [Fact] + public void MutuallyRecursiveThisCapture() + { + var src = @" +using System; +class C +{ + private int P = 1; + public void M() + { + int Local() + { + if (this.P < 5) + { + return Local2(this.P++); + } + else + { + return 1; + } + } + int Local2(int x) => x + Local(); + Console.WriteLine(Local()); + } + public static void Main() => new C().M(); +}"; + VerifyOutput(src, "11"); + } + [Fact] [CompilerTrait(CompilerFeature.Dynamic)] public void DynamicParameterLocalFunction() diff --git a/src/Compilers/CSharp/Test/Emit/Emit/EditAndContinue/EditAndContinueClosureTests.cs b/src/Compilers/CSharp/Test/Emit/Emit/EditAndContinue/EditAndContinueClosureTests.cs index 66f17370626..fe06bc1461e 100644 --- a/src/Compilers/CSharp/Test/Emit/Emit/EditAndContinue/EditAndContinueClosureTests.cs +++ b/src/Compilers/CSharp/Test/Emit/Emit/EditAndContinue/EditAndContinueClosureTests.cs @@ -714,10 +714,10 @@ public void F() // no new synthesized members generated (with #1 in names): diff1.VerifySynthesizedMembers( - "C.<>c__DisplayClass1_0: {a, b__1}", - "C.<>c__DisplayClass1_1: {b, b__3}", - "C.<>c__DisplayClass1_2: {a, b, b__5}", - "C: {b__1_0, b__1_2, b__1_4, <>c__DisplayClass1_0, <>c__DisplayClass1_1, <>c__DisplayClass1_2}"); + "C.<>c__DisplayClass1_2: {a, b, b__5}", + "C.<>c__DisplayClass1_1: {b, b__3}", + "C: {b__1_0, b__1_2, b__1_4, <>c__DisplayClass1_0, <>c__DisplayClass1_1, <>c__DisplayClass1_2}", + "C.<>c__DisplayClass1_0: {a, b__1}"); var md1 = diff1.GetMetadata(); var reader1 = md1.Reader; diff --git a/src/Compilers/CSharp/Test/Emit/Emit/EditAndContinue/SymbolMatcherTests.cs b/src/Compilers/CSharp/Test/Emit/Emit/EditAndContinue/SymbolMatcherTests.cs index d9c3234b956..373588419ad 100644 --- a/src/Compilers/CSharp/Test/Emit/Emit/EditAndContinue/SymbolMatcherTests.cs +++ b/src/Compilers/CSharp/Test/Emit/Emit/EditAndContinue/SymbolMatcherTests.cs @@ -507,10 +507,9 @@ static void F() var emitContext = new EmitContext(peAssemblyBuilder, null, new DiagnosticBag()); var fields = displayClass.GetFields(emitContext).ToArray(); - var x1 = fields[0]; - var x2 = fields[1]; - Assert.Equal("x1", x1.Name); - Assert.Equal("x2", x2.Name); + AssertEx.SetEqual(fields.Select(f => f.Name), new[] { "x1", "x2" }); + var x1 = fields.Where(f => f.Name == "x1").Single(); + var x2 = fields.Where(f => f.Name == "x2").Single(); var matcher = new CSharpSymbolMatcher(anonymousTypeMap0, compilation1.SourceAssembly, emitContext, peAssemblySymbol0); diff --git a/src/Compilers/Core/Portable/CodeAnalysis.csproj b/src/Compilers/Core/Portable/CodeAnalysis.csproj index a863dc80436..d566780405f 100644 --- a/src/Compilers/Core/Portable/CodeAnalysis.csproj +++ b/src/Compilers/Core/Portable/CodeAnalysis.csproj @@ -46,6 +46,7 @@ PreserveNewest false + diff --git a/src/Compilers/Core/Portable/InternalUtilities/OrderedMultiDictionary.cs b/src/Compilers/Core/Portable/InternalUtilities/OrderedMultiDictionary.cs new file mode 100644 index 00000000000..44be7b79917 --- /dev/null +++ b/src/Compilers/Core/Portable/InternalUtilities/OrderedMultiDictionary.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; + +namespace Roslyn.Utilities +{ + // Note that this is not threadsafe for concurrent reading and writing. + internal sealed class OrderedMultiDictionary : IEnumerable>> + { + private readonly Dictionary> _dictionary; + private readonly List _keys; + + public int Count => _dictionary.Count; + + public IEnumerable Keys => _keys; + + // Returns an empty set if there is no such key in the dictionary. + public SetWithInsertionOrder this[K k] + { + get + { + SetWithInsertionOrder set; + return _dictionary.TryGetValue(k, out set) + ? set : new SetWithInsertionOrder(); + } + } + + public OrderedMultiDictionary() + { + _dictionary = new Dictionary>(); + _keys = new List(); + } + + public void Add(K k, V v) + { + SetWithInsertionOrder set; + if (!_dictionary.TryGetValue(k, out set)) + { + _keys.Add(k); + set = new SetWithInsertionOrder(); + } + set.Add(v); + _dictionary[k] = set; + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public IEnumerator>> GetEnumerator() + { + foreach (var key in _keys) + { + yield return new KeyValuePair>( + key, _dictionary[key]); + } + } + } +} diff --git a/src/Compilers/Core/Portable/InternalUtilities/SetWithInsertionOrder.cs b/src/Compilers/Core/Portable/InternalUtilities/SetWithInsertionOrder.cs index d00af3ec348..9fc12c3d1f6 100644 --- a/src/Compilers/Core/Portable/InternalUtilities/SetWithInsertionOrder.cs +++ b/src/Compilers/Core/Portable/InternalUtilities/SetWithInsertionOrder.cs @@ -14,7 +14,7 @@ namespace Roslyn.Utilities /// A set that returns the inserted values in insertion order. /// The mutation operations are not thread-safe. /// - internal sealed class SetWithInsertionOrder : IEnumerable + internal sealed class SetWithInsertionOrder : IEnumerable, IReadOnlySet { private HashSet _set = new HashSet(); private uint _nextElementValue = 0; -- GitLab