diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaFrame.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaFrame.cs
index b7098ffa81abc01a1091c9e269039ea6a11c0275..e04d5d2f82fc4474748cf690f7862aa69c23690e 100644
--- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaFrame.cs
+++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaFrame.cs
@@ -15,6 +15,7 @@ namespace Microsoft.CodeAnalysis.CSharp
///
internal sealed class LambdaFrame : SynthesizedContainer, ISynthesizedMethodBodyImplementationSymbol
{
+ private readonly TypeKind _typeKind;
private readonly MethodSymbol _topLevelMethod;
private readonly MethodSymbol _containingMethod;
private readonly MethodSymbol _constructor;
@@ -23,12 +24,13 @@ internal sealed class LambdaFrame : SynthesizedContainer, ISynthesizedMethodBody
internal readonly CSharpSyntaxNode ScopeSyntaxOpt;
internal readonly int ClosureOrdinal;
- internal LambdaFrame(MethodSymbol topLevelMethod, MethodSymbol containingMethod, CSharpSyntaxNode scopeSyntaxOpt, DebugId methodId, DebugId closureId)
+ internal LambdaFrame(MethodSymbol topLevelMethod, MethodSymbol containingMethod, TypeKind typeKind, CSharpSyntaxNode scopeSyntaxOpt, DebugId methodId, DebugId closureId)
: base(MakeName(scopeSyntaxOpt, methodId, closureId), containingMethod)
{
+ _typeKind = typeKind;
_topLevelMethod = topLevelMethod;
_containingMethod = containingMethod;
- _constructor = new LambdaFrameConstructor(this);
+ _constructor = typeKind == TypeKind.Class ? new LambdaFrameConstructor(this) : null;
this.ClosureOrdinal = closureId.Ordinal;
// static lambdas technically have the class scope so the scope syntax is null
@@ -77,7 +79,7 @@ private static void AssertIsClosureScopeSyntax(CSharpSyntaxNode syntaxOpt)
public override TypeKind TypeKind
{
- get { return TypeKind.Class; }
+ get { return _typeKind; }
}
internal override MethodSymbol Constructor
diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs
index 42a578efc708b22f68cb665f9f3dc1396ef93e86..e39ec5724b2f78c8ada80526c4e7a11737536e70 100644
--- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs
+++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.Analysis.cs
@@ -58,12 +58,23 @@ internal sealed class Analysis : BoundTreeWalker
///
/// The syntax nodes associated with each captured variable.
///
- public readonly MultiDictionary capturedVariables = new MultiDictionary();
+ public MultiDictionary capturedVariables = new MultiDictionary();
///
/// For each lambda in the code, the set of variables that it captures.
///
- public readonly MultiDictionary capturedVariablesByLambda = new MultiDictionary();
+ public MultiDictionary capturedVariablesByLambda = new MultiDictionary();
+
+ ///
+ /// If a local function is in the set, at some point in the code it is converted to a delegate and should then not be optimized to a struct closure.
+ /// Also contains all lambdas (as they are converted to delegates implicitly).
+ ///
+ private readonly HashSet methodsConvertedToDelegates = new HashSet();
+
+ ///
+ /// Any scope that a method in closes over. If a scope is in this set, don't use a struct closure.
+ ///
+ public readonly HashSet scopesThatCantBeStructs = new HashSet();
///
/// Blocks that are positioned between a block declaring some lifted variables
@@ -162,6 +173,61 @@ private static BoundNode FindNodeToAnalyze(BoundNode node)
}
}
+ ///
+ /// Optimizes local functions that reference other local functions (even themselves) to not need closures if they aren't required.
+ ///
+ private void RemoveUnneededReferences()
+ {
+ var capturedVariablesByLambdaNew = new MultiDictionary();
+ var capturedVariablesKeepSet = new HashSet();
+ foreach (var methodKvp in capturedVariablesByLambda)
+ {
+ var isOnlyThis = false;
+ var isGeneral = false;
+ foreach (var reference in capturedVariablesByLambda.TransitiveClosure(methodKvp.Key))
+ {
+ if (reference.Kind != SymbolKind.Method)
+ {
+ if (reference == _topLevelMethod.ThisParameter)
+ {
+ isOnlyThis = true;
+ }
+ else
+ {
+ isGeneral = true;
+ break;
+ }
+ }
+ }
+ if (isGeneral)
+ {
+ foreach (var value in methodKvp.Value)
+ {
+ capturedVariablesByLambdaNew.Add(methodKvp.Key, value);
+ capturedVariablesKeepSet.Add(value);
+ }
+ }
+ else if (isOnlyThis)
+ {
+ capturedVariablesByLambdaNew.Add(methodKvp.Key, _topLevelMethod.ThisParameter);
+ capturedVariablesKeepSet.Add(_topLevelMethod.ThisParameter);
+ }
+ }
+ capturedVariablesByLambda = capturedVariablesByLambdaNew;
+ var capturedVariablesNew = new MultiDictionary();
+ foreach (var oldCaptured in capturedVariables)
+ {
+ if (capturedVariablesKeepSet.Contains(oldCaptured.Key))
+ {
+ foreach (var value in oldCaptured.Value)
+ {
+ capturedVariablesNew.Add(oldCaptured.Key, value);
+ }
+ }
+ }
+ capturedVariables = capturedVariablesNew;
+ }
+
///
/// Create the optimized plan for the location of lambda methods and whether scopes need access to parent scopes
///
@@ -170,6 +236,8 @@ internal void ComputeLambdaScopesAndFrameCaptures()
lambdaScopes = new Dictionary(ReferenceEqualityComparer.Instance);
needsParentFrame = new HashSet();
+ RemoveUnneededReferences();
+
foreach (var kvp in capturedVariablesByLambda)
{
// get innermost and outermost scopes from which a lambda captures
@@ -223,10 +291,20 @@ internal void ComputeLambdaScopesAndFrameCaptures()
{
lambdaScopes.Add(kvp.Key, innermostScope);
+ var markAsNoStruct = methodsConvertedToDelegates.Contains(kvp.Key);
+ if (markAsNoStruct)
+ {
+ scopesThatCantBeStructs.Add(innermostScope);
+ }
+
while (innermostScope != outermostScope)
{
needsParentFrame.Add(innermostScope);
scopeParent.TryGetValue(innermostScope, out innermostScope);
+ if (markAsNoStruct)
+ {
+ scopesThatCantBeStructs.Add(innermostScope);
+ }
}
}
}
@@ -355,6 +433,7 @@ public override BoundNode VisitCall(BoundCall node)
public override BoundNode VisitLambda(BoundLambda node)
{
+ methodsConvertedToDelegates.Add(node.Symbol);
return VisitLambdaOrFunction(node);
}
@@ -363,6 +442,7 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE
if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction)
{
ReferenceVariable(node.Syntax, node.MethodOpt);
+ methodsConvertedToDelegates.Add(node.MethodOpt);
}
return base.VisitDelegateCreationExpression(node);
}
@@ -459,6 +539,7 @@ public override BoundNode VisitConversion(BoundConversion node)
if (node.SymbolOpt?.MethodKind == MethodKind.LocalFunction)
{
ReferenceVariable(node.Syntax, node.SymbolOpt);
+ methodsConvertedToDelegates.Add(node.SymbolOpt);
}
if (node.IsExtensionMethod || ((object)node.SymbolOpt != null && !node.SymbolOpt.IsStatic))
{
diff --git a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs
index 17200b0b9eea0764dce33cc11fe6bf30cd01002e..4dfdeeace20849c54cce2abfece1ab0332e185d4 100644
--- a/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs
+++ b/src/Compilers/CSharp/Portable/Lowering/LambdaRewriter/LambdaRewriter.cs
@@ -329,19 +329,24 @@ private LambdaFrame GetFrameForScope(BoundNode scope, ArrayBuilder(BoundNode node, LambdaFrame frame, Func.GetInstance();
- MethodSymbol constructor = frame.Constructor.AsMember(frameType);
- Debug.Assert(frameType == constructor.ContainingType);
- var newFrame = new BoundObjectCreationExpression(
- syntax: syntax,
- constructor: constructor);
+ BoundExpression newFrame;
+ if (frame.Constructor == null)
+ {
+ Debug.Assert(frame.TypeKind == TypeKind.Struct);
+ newFrame = new BoundDefaultOperator(syntax: syntax, type: frameType);
+ }
+ else
+ {
+ MethodSymbol constructor = frame.Constructor.AsMember(frameType);
+ Debug.Assert(frameType == constructor.ContainingType);
+ newFrame = new BoundObjectCreationExpression(
+ syntax: syntax,
+ constructor: constructor);
+ }
prologue.Add(new BoundAssignmentOperator(syntax,
new BoundLocal(syntax, framePointer, null, frameType),
@@ -1155,6 +1169,8 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos
}
else if (_analysis.capturedVariablesByLambda[node.Symbol].Count == 0)
{
+ // TODO: Check the following and don't use a static frame if true (just emit a static method)
+ // _analysis.methodsConvertedToDelegates.Contains(node.Symbol)
translatedLambdaContainer = containerAsFrame = GetStaticFrame(Diagnostics, node);
closureKind = ClosureKind.Static;
closureOrdinal = LambdaDebugInfo.StaticClosureOrdinal;
diff --git a/src/Compilers/CSharp/Test/Semantic/Semantics/LocalFunctionTests.cs b/src/Compilers/CSharp/Test/Semantic/Semantics/LocalFunctionTests.cs
index fa20f87d25c3d328077e50b4983bc66322baf9f4..37abeac6f78456da2b9f19f52114b2dfd7a82092 100644
--- a/src/Compilers/CSharp/Test/Semantic/Semantics/LocalFunctionTests.cs
+++ b/src/Compilers/CSharp/Test/Semantic/Semantics/LocalFunctionTests.cs
@@ -837,6 +837,22 @@ static void Main(string[] args)
VerifyOutput(source, "2");
}
+ [Fact]
+ public void StructClosure()
+ {
+ var source = @"
+int x = 2;
+void Foo()
+{
+ Console.Write(x);
+ Console.Write(' ');
+ Console.Write(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType.BaseType);
+}
+Foo();
+";
+ VerifyOutputInMain(source, "2 System.ValueType", "System");
+ }
+
[Fact]
public void Recursion()
{
@@ -877,6 +893,41 @@ void Bar(int depth2)
VerifyOutputInMain(source, "2", "System");
}
+ [Fact]
+ public void RecursionThisOnlyClosure()
+ {
+ var source = @"
+using System;
+
+class Program
+{
+ int _x;
+ void Outer()
+ {
+ void Inner()
+ {
+ if (_x == 0)
+ {
+ // Ensure we're in a this-only closure. Should NOT print a display class.
+ Console.Write(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType);
+ return;
+ }
+ Console.Write(_x);
+ Console.Write(' ');
+ _x = 0;
+ Inner();
+ }
+ Inner();
+ }
+ public static void Main()
+ {
+ new Program() { _x = 2 }.Outer();
+ }
+}
+";
+ VerifyOutput(source, "2 Program");
+ }
+
[Fact]
public void IteratorBasic()
{
diff --git a/src/Compilers/Core/Portable/InternalUtilities/MultiDictionary.cs b/src/Compilers/Core/Portable/InternalUtilities/MultiDictionary.cs
index 2118e7cda9335f06b103375adc1aa00f6f5a0297..2602282796f811d0357bc7935a8cdd355fc93e1a 100644
--- a/src/Compilers/Core/Portable/InternalUtilities/MultiDictionary.cs
+++ b/src/Compilers/Core/Portable/InternalUtilities/MultiDictionary.cs
@@ -194,4 +194,28 @@ internal void Clear()
_dictionary.Clear();
}
}
+
+ internal static class MultiDictionaryExtensions
+ {
+ // Adapted from FunctionExtensions.TransitiveClosure
+ public static HashSet TransitiveClosure(this MultiDictionary relation, K item) where K : V
+ {
+ var closure = new HashSet();
+ var stack = new Stack();
+ stack.Push(item);
+ while (stack.Count > 0)
+ {
+ var current = stack.Pop();
+ foreach (var newItem in relation[current])
+ {
+ if (closure.Add(newItem) && newItem is K)
+ {
+ stack.Push((K)newItem);
+ }
+ }
+ }
+
+ return closure;
+ }
+ }
}