From 037f4a6229d7b3a93d144b0b2e882a3b96ac20fc Mon Sep 17 00:00:00 2001 From: Evan Hauck Date: Wed, 1 Jul 2015 07:37:35 -0700 Subject: [PATCH] Optimize local function closures into structs --- .../Lowering/LambdaRewriter/LambdaFrame.cs | 8 +- .../LambdaRewriter/LambdaRewriter.Analysis.cs | 85 ++++++++++++++++++- .../Lowering/LambdaRewriter/LambdaRewriter.cs | 40 ++++++--- .../Semantic/Semantics/LocalFunctionTests.cs | 51 +++++++++++ .../InternalUtilities/MultiDictionary.cs | 24 ++++++ 5 files changed, 191 insertions(+), 17 deletions(-) diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaFrame.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaFrame.cs index b7098ffa81a..e04d5d2f82f 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaFrame.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaFrame.cs @@ -15,6 +15,7 @@ namespace Microsoft.CodeAnalysis.CSharp /// internal sealed class LambdaFrame : SynthesizedContainer, ISynthesizedMethodBodyImplementationSymbol { + private readonly TypeKind _typeKind; private readonly MethodSymbol _topLevelMethod; private readonly MethodSymbol _containingMethod; private readonly MethodSymbol _constructor; @@ -23,12 +24,13 @@ internal sealed class LambdaFrame : SynthesizedContainer, ISynthesizedMethodBody internal readonly CSharpSyntaxNode ScopeSyntaxOpt; internal readonly int ClosureOrdinal; - internal LambdaFrame(MethodSymbol topLevelMethod, MethodSymbol containingMethod, CSharpSyntaxNode scopeSyntaxOpt, DebugId methodId, DebugId closureId) + internal LambdaFrame(MethodSymbol topLevelMethod, MethodSymbol containingMethod, TypeKind typeKind, CSharpSyntaxNode scopeSyntaxOpt, DebugId methodId, DebugId closureId) : base(MakeName(scopeSyntaxOpt, methodId, closureId), containingMethod) { + _typeKind = typeKind; _topLevelMethod = topLevelMethod; _containingMethod = containingMethod; - _constructor = new LambdaFrameConstructor(this); + _constructor = typeKind == TypeKind.Class ? new LambdaFrameConstructor(this) : null; this.ClosureOrdinal = closureId.Ordinal; // static lambdas technically have the class scope so the scope syntax is null @@ -77,7 +79,7 @@ private static void AssertIsClosureScopeSyntax(CSharpSyntaxNode syntaxOpt) public override TypeKind TypeKind { - get { return TypeKind.Class; } + get { return _typeKind; } } internal override MethodSymbol Constructor diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs index 42a578efc70..e39ec5724b2 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs @@ -58,12 +58,23 @@ internal sealed class Analysis : BoundTreeWalker /// /// The syntax nodes associated with each captured variable. /// - public readonly MultiDictionary capturedVariables = new MultiDictionary(); + public MultiDictionary capturedVariables = new MultiDictionary(); /// /// For each lambda in the code, the set of variables that it captures. /// - public readonly MultiDictionary capturedVariablesByLambda = new MultiDictionary(); + public MultiDictionary capturedVariablesByLambda = new MultiDictionary(); + + /// + /// If a local function is in the set, at some point in the code it is converted to a delegate and should then not be optimized to a struct closure. + /// Also contains all lambdas (as they are converted to delegates implicitly). + /// + private readonly HashSet methodsConvertedToDelegates = new HashSet(); + + /// + /// Any scope that a method in closes over. If a scope is in this set, don't use a struct closure. + /// + public readonly HashSet scopesThatCantBeStructs = new HashSet(); /// /// Blocks that are positioned between a block declaring some lifted variables @@ -162,6 +173,61 @@ private static BoundNode FindNodeToAnalyze(BoundNode node) } } + /// + /// Optimizes local functions that reference other local functions (even themselves) to not need closures if they aren't required. + /// + private void RemoveUnneededReferences() + { + var capturedVariablesByLambdaNew = new MultiDictionary(); + var capturedVariablesKeepSet = new HashSet(); + foreach (var methodKvp in capturedVariablesByLambda) + { + var isOnlyThis = false; + var isGeneral = false; + foreach (var reference in capturedVariablesByLambda.TransitiveClosure(methodKvp.Key)) + { + if (reference.Kind != SymbolKind.Method) + { + if (reference == _topLevelMethod.ThisParameter) + { + isOnlyThis = true; + } + else + { + isGeneral = true; + break; + } + } + } + if (isGeneral) + { + foreach (var value in methodKvp.Value) + { + capturedVariablesByLambdaNew.Add(methodKvp.Key, value); + capturedVariablesKeepSet.Add(value); + } + } + else if (isOnlyThis) + { + capturedVariablesByLambdaNew.Add(methodKvp.Key, _topLevelMethod.ThisParameter); + capturedVariablesKeepSet.Add(_topLevelMethod.ThisParameter); + } + } + capturedVariablesByLambda = capturedVariablesByLambdaNew; + var capturedVariablesNew = new MultiDictionary(); + foreach (var oldCaptured in capturedVariables) + { + if (capturedVariablesKeepSet.Contains(oldCaptured.Key)) + { + foreach (var value in oldCaptured.Value) + { + capturedVariablesNew.Add(oldCaptured.Key, value); + } + } + } + capturedVariables = capturedVariablesNew; + } + /// /// Create the optimized plan for the location of lambda methods and whether scopes need access to parent scopes /// @@ -170,6 +236,8 @@ internal void ComputeLambdaScopesAndFrameCaptures() lambdaScopes = new Dictionary(ReferenceEqualityComparer.Instance); needsParentFrame = new HashSet(); + RemoveUnneededReferences(); + foreach (var kvp in capturedVariablesByLambda) { // get innermost and outermost scopes from which a lambda captures @@ -223,10 +291,20 @@ internal void ComputeLambdaScopesAndFrameCaptures() { lambdaScopes.Add(kvp.Key, innermostScope); + var markAsNoStruct = methodsConvertedToDelegates.Contains(kvp.Key); + if (markAsNoStruct) + { + scopesThatCantBeStructs.Add(innermostScope); + } + while (innermostScope != outermostScope) { needsParentFrame.Add(innermostScope); scopeParent.TryGetValue(innermostScope, out innermostScope); + if (markAsNoStruct) + { + scopesThatCantBeStructs.Add(innermostScope); + } } } } @@ -355,6 +433,7 @@ public override BoundNode VisitCall(BoundCall node) public override BoundNode VisitLambda(BoundLambda node) { + methodsConvertedToDelegates.Add(node.Symbol); return VisitLambdaOrFunction(node); } @@ -363,6 +442,7 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction) { ReferenceVariable(node.Syntax, node.MethodOpt); + methodsConvertedToDelegates.Add(node.MethodOpt); } return base.VisitDelegateCreationExpression(node); } @@ -459,6 +539,7 @@ public override BoundNode VisitConversion(BoundConversion node) if (node.SymbolOpt?.MethodKind == MethodKind.LocalFunction) { ReferenceVariable(node.Syntax, node.SymbolOpt); + methodsConvertedToDelegates.Add(node.SymbolOpt); } if (node.IsExtensionMethod || ((object)node.SymbolOpt != null && !node.SymbolOpt.IsStatic)) { diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs index 17200b0b9ee..4dfdeeace20 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs @@ -329,19 +329,24 @@ private LambdaFrame GetFrameForScope(BoundNode scope, ArrayBuilder(BoundNode node, LambdaFrame frame, Func.GetInstance(); - MethodSymbol constructor = frame.Constructor.AsMember(frameType); - Debug.Assert(frameType == constructor.ContainingType); - var newFrame = new BoundObjectCreationExpression( - syntax: syntax, - constructor: constructor); + BoundExpression newFrame; + if (frame.Constructor == null) + { + Debug.Assert(frame.TypeKind == TypeKind.Struct); + newFrame = new BoundDefaultOperator(syntax: syntax, type: frameType); + } + else + { + MethodSymbol constructor = frame.Constructor.AsMember(frameType); + Debug.Assert(frameType == constructor.ContainingType); + newFrame = new BoundObjectCreationExpression( + syntax: syntax, + constructor: constructor); + } prologue.Add(new BoundAssignmentOperator(syntax, new BoundLocal(syntax, framePointer, null, frameType), @@ -1155,6 +1169,8 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos } else if (_analysis.capturedVariablesByLambda[node.Symbol].Count == 0) { + // TODO: Check the following and don't use a static frame if true (just emit a static method) + // _analysis.methodsConvertedToDelegates.Contains(node.Symbol) translatedLambdaContainer = containerAsFrame = GetStaticFrame(Diagnostics, node); closureKind = ClosureKind.Static; closureOrdinal = LambdaDebugInfo.StaticClosureOrdinal; diff --git a/src/Compilers/CSharp/Test/Semantic/Semantics/LocalFunctionTests.cs b/src/Compilers/CSharp/Test/Semantic/Semantics/LocalFunctionTests.cs index fa20f87d25c..37abeac6f78 100644 --- a/src/Compilers/CSharp/Test/Semantic/Semantics/LocalFunctionTests.cs +++ b/src/Compilers/CSharp/Test/Semantic/Semantics/LocalFunctionTests.cs @@ -837,6 +837,22 @@ static void Main(string[] args) VerifyOutput(source, "2"); } + [Fact] + public void StructClosure() + { + var source = @" +int x = 2; +void Foo() +{ + Console.Write(x); + Console.Write(' '); + Console.Write(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType.BaseType); +} +Foo(); +"; + VerifyOutputInMain(source, "2 System.ValueType", "System"); + } + [Fact] public void Recursion() { @@ -877,6 +893,41 @@ void Bar(int depth2) VerifyOutputInMain(source, "2", "System"); } + [Fact] + public void RecursionThisOnlyClosure() + { + var source = @" +using System; + +class Program +{ + int _x; + void Outer() + { + void Inner() + { + if (_x == 0) + { + // Ensure we're in a this-only closure. Should NOT print a display class. + Console.Write(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType); + return; + } + Console.Write(_x); + Console.Write(' '); + _x = 0; + Inner(); + } + Inner(); + } + public static void Main() + { + new Program() { _x = 2 }.Outer(); + } +} +"; + VerifyOutput(source, "2 Program"); + } + [Fact] public void IteratorBasic() { diff --git a/src/Compilers/Core/Portable/InternalUtilities/MultiDictionary.cs b/src/Compilers/Core/Portable/InternalUtilities/MultiDictionary.cs index 2118e7cda93..2602282796f 100644 --- a/src/Compilers/Core/Portable/InternalUtilities/MultiDictionary.cs +++ b/src/Compilers/Core/Portable/InternalUtilities/MultiDictionary.cs @@ -194,4 +194,28 @@ internal void Clear() _dictionary.Clear(); } } + + internal static class MultiDictionaryExtensions + { + // Adapted from FunctionExtensions.TransitiveClosure + public static HashSet TransitiveClosure(this MultiDictionary relation, K item) where K : V + { + var closure = new HashSet(); + var stack = new Stack(); + stack.Push(item); + while (stack.Count > 0) + { + var current = stack.Pop(); + foreach (var newItem in relation[current]) + { + if (closure.Add(newItem) && newItem is K) + { + stack.Push((K)newItem); + } + } + } + + return closure; + } + } } -- GitLab