提交 037f4a62 编写于 作者: E Evan Hauck

Optimize local function closures into structs

上级 21fe5580
......@@ -15,6 +15,7 @@ namespace Microsoft.CodeAnalysis.CSharp
/// </summary>
internal sealed class LambdaFrame : SynthesizedContainer, ISynthesizedMethodBodyImplementationSymbol
{
private readonly TypeKind _typeKind;
private readonly MethodSymbol _topLevelMethod;
private readonly MethodSymbol _containingMethod;
private readonly MethodSymbol _constructor;
......@@ -23,12 +24,13 @@ internal sealed class LambdaFrame : SynthesizedContainer, ISynthesizedMethodBody
internal readonly CSharpSyntaxNode ScopeSyntaxOpt;
internal readonly int ClosureOrdinal;
internal LambdaFrame(MethodSymbol topLevelMethod, MethodSymbol containingMethod, CSharpSyntaxNode scopeSyntaxOpt, DebugId methodId, DebugId closureId)
internal LambdaFrame(MethodSymbol topLevelMethod, MethodSymbol containingMethod, TypeKind typeKind, CSharpSyntaxNode scopeSyntaxOpt, DebugId methodId, DebugId closureId)
: base(MakeName(scopeSyntaxOpt, methodId, closureId), containingMethod)
{
_typeKind = typeKind;
_topLevelMethod = topLevelMethod;
_containingMethod = containingMethod;
_constructor = new LambdaFrameConstructor(this);
_constructor = typeKind == TypeKind.Class ? new LambdaFrameConstructor(this) : null;
this.ClosureOrdinal = closureId.Ordinal;
// static lambdas technically have the class scope so the scope syntax is null
......@@ -77,7 +79,7 @@ private static void AssertIsClosureScopeSyntax(CSharpSyntaxNode syntaxOpt)
public override TypeKind TypeKind
{
get { return TypeKind.Class; }
get { return _typeKind; }
}
internal override MethodSymbol Constructor
......
......@@ -58,12 +58,23 @@ internal sealed class Analysis : BoundTreeWalker
/// <summary>
/// The syntax nodes associated with each captured variable.
/// </summary>
public readonly MultiDictionary<Symbol, CSharpSyntaxNode> capturedVariables = new MultiDictionary<Symbol, CSharpSyntaxNode>();
public MultiDictionary<Symbol, CSharpSyntaxNode> capturedVariables = new MultiDictionary<Symbol, CSharpSyntaxNode>();
/// <summary>
/// For each lambda in the code, the set of variables that it captures.
/// </summary>
public readonly MultiDictionary<MethodSymbol, Symbol> capturedVariablesByLambda = new MultiDictionary<MethodSymbol, Symbol>();
public MultiDictionary<MethodSymbol, Symbol> capturedVariablesByLambda = new MultiDictionary<MethodSymbol, Symbol>();
/// <summary>
/// If a local function is in the set, at some point in the code it is converted to a delegate and should then not be optimized to a struct closure.
/// Also contains all lambdas (as they are converted to delegates implicitly).
/// </summary>
private readonly HashSet<MethodSymbol> methodsConvertedToDelegates = new HashSet<MethodSymbol>();
/// <summary>
/// Any scope that a method in <see cref="methodsConvertedToDelegates"/> closes over. If a scope is in this set, don't use a struct closure.
/// </summary>
public readonly HashSet<BoundNode> scopesThatCantBeStructs = new HashSet<BoundNode>();
/// <summary>
/// Blocks that are positioned between a block declaring some lifted variables
......@@ -162,6 +173,61 @@ private static BoundNode FindNodeToAnalyze(BoundNode node)
}
}
/// <summary>
/// Optimizes local functions that reference other local functions (even themselves) to not need closures if they aren't required.
/// </summary>
private void RemoveUnneededReferences()
{
var capturedVariablesByLambdaNew = new MultiDictionary<MethodSymbol, Symbol>();
var capturedVariablesKeepSet = new HashSet<Symbol>();
foreach (var methodKvp in capturedVariablesByLambda)
{
var isOnlyThis = false;
var isGeneral = false;
foreach (var reference in capturedVariablesByLambda.TransitiveClosure(methodKvp.Key))
{
if (reference.Kind != SymbolKind.Method)
{
if (reference == _topLevelMethod.ThisParameter)
{
isOnlyThis = true;
}
else
{
isGeneral = true;
break;
}
}
}
if (isGeneral)
{
foreach (var value in methodKvp.Value)
{
capturedVariablesByLambdaNew.Add(methodKvp.Key, value);
capturedVariablesKeepSet.Add(value);
}
}
else if (isOnlyThis)
{
capturedVariablesByLambdaNew.Add(methodKvp.Key, _topLevelMethod.ThisParameter);
capturedVariablesKeepSet.Add(_topLevelMethod.ThisParameter);
}
}
capturedVariablesByLambda = capturedVariablesByLambdaNew;
var capturedVariablesNew = new MultiDictionary<Symbol, CSharpSyntaxNode>();
foreach (var oldCaptured in capturedVariables)
{
if (capturedVariablesKeepSet.Contains(oldCaptured.Key))
{
foreach (var value in oldCaptured.Value)
{
capturedVariablesNew.Add(oldCaptured.Key, value);
}
}
}
capturedVariables = capturedVariablesNew;
}
/// <summary>
/// Create the optimized plan for the location of lambda methods and whether scopes need access to parent scopes
/// </summary>
......@@ -170,6 +236,8 @@ internal void ComputeLambdaScopesAndFrameCaptures()
lambdaScopes = new Dictionary<MethodSymbol, BoundNode>(ReferenceEqualityComparer.Instance);
needsParentFrame = new HashSet<BoundNode>();
RemoveUnneededReferences();
foreach (var kvp in capturedVariablesByLambda)
{
// get innermost and outermost scopes from which a lambda captures
......@@ -223,10 +291,20 @@ internal void ComputeLambdaScopesAndFrameCaptures()
{
lambdaScopes.Add(kvp.Key, innermostScope);
var markAsNoStruct = methodsConvertedToDelegates.Contains(kvp.Key);
if (markAsNoStruct)
{
scopesThatCantBeStructs.Add(innermostScope);
}
while (innermostScope != outermostScope)
{
needsParentFrame.Add(innermostScope);
scopeParent.TryGetValue(innermostScope, out innermostScope);
if (markAsNoStruct)
{
scopesThatCantBeStructs.Add(innermostScope);
}
}
}
}
......@@ -355,6 +433,7 @@ public override BoundNode VisitCall(BoundCall node)
public override BoundNode VisitLambda(BoundLambda node)
{
methodsConvertedToDelegates.Add(node.Symbol);
return VisitLambdaOrFunction(node);
}
......@@ -363,6 +442,7 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE
if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction)
{
ReferenceVariable(node.Syntax, node.MethodOpt);
methodsConvertedToDelegates.Add(node.MethodOpt);
}
return base.VisitDelegateCreationExpression(node);
}
......@@ -459,6 +539,7 @@ public override BoundNode VisitConversion(BoundConversion node)
if (node.SymbolOpt?.MethodKind == MethodKind.LocalFunction)
{
ReferenceVariable(node.Syntax, node.SymbolOpt);
methodsConvertedToDelegates.Add(node.SymbolOpt);
}
if (node.IsExtensionMethod || ((object)node.SymbolOpt != null && !node.SymbolOpt.IsStatic))
{
......
......@@ -329,19 +329,24 @@ private LambdaFrame GetFrameForScope(BoundNode scope, ArrayBuilder<ClosureDebugI
DebugId methodId = GetTopLevelMethodId();
DebugId closureId = GetClosureId(syntax, closureDebugInfo);
var canBeStruct = !_analysis.scopesThatCantBeStructs.Contains(scope);
var containingMethod = _analysis.scopeOwner[scope];
if (_substitutedSourceMethod != null && containingMethod == _topLevelMethod)
{
containingMethod = _substitutedSourceMethod;
}
frame = new LambdaFrame(_topLevelMethod, containingMethod, syntax, methodId, closureId);
frame = new LambdaFrame(_topLevelMethod, containingMethod, canBeStruct ? TypeKind.Struct : TypeKind.Class, syntax, methodId, closureId);
_frames.Add(scope, frame);
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(this.ContainingType, frame);
CompilationState.AddSynthesizedMethod(
frame.Constructor,
FlowAnalysisPass.AppendImplicitReturn(MethodCompiler.BindMethodBody(frame.Constructor, CompilationState, null),
frame.Constructor));
if (frame.Constructor != null)
{
CompilationState.AddSynthesizedMethod(
frame.Constructor,
FlowAnalysisPass.AppendImplicitReturn(MethodCompiler.BindMethodBody(frame.Constructor, CompilationState, null),
frame.Constructor));
}
}
return frame;
......@@ -372,7 +377,7 @@ private LambdaFrame GetStaticFrame(DiagnosticBag diagnostics, IBoundLambdaOrFunc
DebugId closureId = default(DebugId);
// using _topLevelMethod as containing member because the static frame does not have generic parameters, except for the top level method's
var containingMethod = isNonGeneric ? null : (_substitutedSourceMethod ?? _topLevelMethod);
_lazyStaticLambdaFrame = new LambdaFrame(_topLevelMethod, containingMethod, scopeSyntaxOpt: null, methodId: methodId, closureId: closureId);
_lazyStaticLambdaFrame = new LambdaFrame(_topLevelMethod, containingMethod, TypeKind.Class, scopeSyntaxOpt: null, methodId: methodId, closureId: closureId);
// nongeneric static lambdas can share the frame
if (isNonGeneric)
......@@ -385,7 +390,7 @@ private LambdaFrame GetStaticFrame(DiagnosticBag diagnostics, IBoundLambdaOrFunc
// add frame type
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(this.ContainingType, frame);
// add its ctor
// add its ctor (note Constructor can be null if TypeKind.Struct is passed in to LambdaFrame.ctor, but Class is passed in above)
CompilationState.AddSynthesizedMethod(
frame.Constructor,
FlowAnalysisPass.AppendImplicitReturn(MethodCompiler.BindMethodBody(frame.Constructor, CompilationState, null),
......@@ -490,11 +495,20 @@ private T IntroduceFrame<T>(BoundNode node, LambdaFrame frame, Func<ArrayBuilder
var prologue = ArrayBuilder<BoundExpression>.GetInstance();
MethodSymbol constructor = frame.Constructor.AsMember(frameType);
Debug.Assert(frameType == constructor.ContainingType);
var newFrame = new BoundObjectCreationExpression(
syntax: syntax,
constructor: constructor);
BoundExpression newFrame;
if (frame.Constructor == null)
{
Debug.Assert(frame.TypeKind == TypeKind.Struct);
newFrame = new BoundDefaultOperator(syntax: syntax, type: frameType);
}
else
{
MethodSymbol constructor = frame.Constructor.AsMember(frameType);
Debug.Assert(frameType == constructor.ContainingType);
newFrame = new BoundObjectCreationExpression(
syntax: syntax,
constructor: constructor);
}
prologue.Add(new BoundAssignmentOperator(syntax,
new BoundLocal(syntax, framePointer, null, frameType),
......@@ -1155,6 +1169,8 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos
}
else if (_analysis.capturedVariablesByLambda[node.Symbol].Count == 0)
{
// TODO: Check the following and don't use a static frame if true (just emit a static method)
// _analysis.methodsConvertedToDelegates.Contains(node.Symbol)
translatedLambdaContainer = containerAsFrame = GetStaticFrame(Diagnostics, node);
closureKind = ClosureKind.Static;
closureOrdinal = LambdaDebugInfo.StaticClosureOrdinal;
......
......@@ -837,6 +837,22 @@ static void Main(string[] args)
VerifyOutput(source, "2");
}
[Fact]
public void StructClosure()
{
var source = @"
int x = 2;
void Foo()
{
Console.Write(x);
Console.Write(' ');
Console.Write(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType.BaseType);
}
Foo();
";
VerifyOutputInMain(source, "2 System.ValueType", "System");
}
[Fact]
public void Recursion()
{
......@@ -877,6 +893,41 @@ void Bar(int depth2)
VerifyOutputInMain(source, "2", "System");
}
[Fact]
public void RecursionThisOnlyClosure()
{
var source = @"
using System;
class Program
{
int _x;
void Outer()
{
void Inner()
{
if (_x == 0)
{
// Ensure we're in a this-only closure. Should NOT print a display class.
Console.Write(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType);
return;
}
Console.Write(_x);
Console.Write(' ');
_x = 0;
Inner();
}
Inner();
}
public static void Main()
{
new Program() { _x = 2 }.Outer();
}
}
";
VerifyOutput(source, "2 Program");
}
[Fact]
public void IteratorBasic()
{
......
......@@ -194,4 +194,28 @@ internal void Clear()
_dictionary.Clear();
}
}
internal static class MultiDictionaryExtensions
{
// Adapted from FunctionExtensions.TransitiveClosure
public static HashSet<V> TransitiveClosure<K, V>(this MultiDictionary<K, V> relation, K item) where K : V
{
var closure = new HashSet<V>();
var stack = new Stack<K>();
stack.Push(item);
while (stack.Count > 0)
{
var current = stack.Pop();
foreach (var newItem in relation[current])
{
if (closure.Add(newItem) && newItem is K)
{
stack.Push((K)newItem);
}
}
}
return closure;
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册