提交 fed66000 编写于 作者: E Evan Hauck

Implement by-ref struct closure optimization

上级 7b4098f7
......@@ -4,6 +4,7 @@
using Microsoft.CodeAnalysis.CodeGen;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Roslyn.Utilities;
using System.Collections.Immutable;
namespace Microsoft.CodeAnalysis.CSharp
{
......@@ -79,7 +80,13 @@ private static TypeSymbol GetCapturedVariableFieldType(SynthesizedContainer fram
var lambdaFrame = local.Type.OriginalDefinition as LambdaFrame;
if ((object)lambdaFrame != null)
{
return lambdaFrame.ConstructIfGeneric(frame.TypeArgumentsNoUseSiteDiagnostics);
// lambdaFrame may have less generic type parameters than frame, so trim them down (the first N will always match)
var typeArguments = frame.TypeArgumentsNoUseSiteDiagnostics;
if (typeArguments.Length > lambdaFrame.Arity)
{
typeArguments = ImmutableArray.Create(typeArguments, 0, lambdaFrame.Arity);
}
return lambdaFrame.ConstructIfGeneric(typeArguments);
}
}
......
......@@ -251,6 +251,31 @@ private void RemoveUnneededReferences()
CapturedVariablesByLambda = capturedVariablesByLambdaNew;
}
/// <summary>
/// Finds all generic methods and forces them to be classes. TODO: We should be able to handle this, but for this prototype it's too complicated.
/// </summary>
private void MarkGenericMethodsAsClass()
{
foreach (var scope in this.LambdaScopes)
{
var isGeneric = false;
var method = scope.Key;
while (method != null && !isGeneric)
{
isGeneric = method.Arity != 0;
method = method.ContainingSymbol as MethodSymbol;
}
if (isGeneric)
{
var node = scope.Value;
do
{
ScopesThatCantBeStructs.Add(node);
} while (this.ScopeParent.TryGetValue(node, out node));
}
}
}
/// <summary>
/// Create the optimized plan for the location of lambda methods and whether scopes need access to parent scopes
/// </summary>
......@@ -314,7 +339,8 @@ internal void ComputeLambdaScopesAndFrameCaptures()
{
LambdaScopes.Add(kvp.Key, innermostScope);
var markAsNoStruct = MethodsConvertedToDelegates.Contains(kvp.Key);
// Disable struct closures on methods converted to delegates, as well as on async and iterator methods.
var markAsNoStruct = MethodsConvertedToDelegates.Contains(kvp.Key) || kvp.Key.IsAsync || kvp.Key.IsIterator;
if (markAsNoStruct)
{
ScopesThatCantBeStructs.Add(innermostScope);
......@@ -332,21 +358,7 @@ internal void ComputeLambdaScopesAndFrameCaptures()
}
}
// Note the following is temporary, if we do end up changing the signature of closures to take all
// parent frames as a parameter list (where structs are by ref) instead of a linked list of frames,
// then we no longer need to worry about double-nested struct closures not being passed by reference.
// It is illegal to have any parent of a scope be a struct (due to by-value parent fields).
// So, find the parent of every closure and check if it's a struct. If so, make it a class.
foreach (var kvp in LambdaScopes)
{
var scope = kvp.Value;
while (NeedsParentFrame.Contains(scope) && ScopeParent.TryGetValue(scope, out scope) && !ScopesThatCantBeStructs.Contains(scope))
{
// The parent is a struct. Mark it a class. Keep going along the tree.
ScopesThatCantBeStructs.Add(scope);
}
}
MarkGenericMethodsAsClass();
}
/// <summary>
......
......@@ -71,9 +71,9 @@ internal sealed partial class LambdaRewriter : MethodToClassRewriter
// A mapping from every local function to its lowered method
private struct MappedLocalFunction
{
public readonly MethodSymbol Symbol;
public readonly SynthesizedLambdaMethod Symbol;
public readonly ClosureKind ClosureKind;
public MappedLocalFunction(MethodSymbol symbol, ClosureKind closureKind)
public MappedLocalFunction(SynthesizedLambdaMethod symbol, ClosureKind closureKind)
{
Symbol = symbol;
ClosureKind = closureKind;
......@@ -449,6 +449,21 @@ protected override BoundExpression FramePointer(CSharpSyntaxNode syntax, NamedTy
return new BoundThisReference(syntax, frameClass);
}
// If the current method has by-ref struct closure parameters, and one of them is correct, use it.
var lambda = _currentMethod as SynthesizedLambdaMethod;
if (lambda != null)
{
var start = lambda.ParameterCount - lambda.ExtraSynthesizedParameterCount;
for (var i = start; i < lambda.ParameterCount; i++)
{
var potentialParameter = lambda.Parameters[i];
if (potentialParameter.Type.OriginalDefinition == frameClass)
{
return new BoundParameter(syntax, potentialParameter);
}
}
}
// Otherwise we need to return the value from a frame pointer local variable...
Symbol framePointer = _framePointers[frameClass];
CapturedSymbolReplacement proxyField;
......@@ -668,7 +683,7 @@ public override BoundNode VisitThisReference(BoundThisReference node)
public override BoundNode VisitBaseReference(BoundBaseReference node)
{
return (_currentMethod.ContainingType == _topLevelMethod.ContainingType)
return (!_currentMethod.IsStatic && _currentMethod.ContainingType == _topLevelMethod.ContainingType)
? node
: FramePointer(node.Syntax, _topLevelMethod.ContainingType); // technically, not the correct static type
}
......@@ -735,6 +750,7 @@ public override BoundNode VisitBaseReference(BoundBaseReference node)
private void RemapLocalFunction(
CSharpSyntaxNode syntax, MethodSymbol symbol,
out BoundExpression receiver, out MethodSymbol method,
ref ImmutableArray<BoundExpression> parameters,
ImmutableArray<TypeSymbol> typeArguments = default(ImmutableArray<TypeSymbol>))
{
Debug.Assert(symbol.MethodKind == MethodKind.LocalFunction);
......@@ -742,13 +758,30 @@ public override BoundNode VisitBaseReference(BoundBaseReference node)
var constructed = symbol as ConstructedMethodSymbol;
if (constructed != null)
{
RemapLocalFunction(syntax, constructed.ConstructedFrom, out receiver, out method, this.TypeMap.SubstituteTypes(constructed.TypeArguments));
RemapLocalFunction(syntax, constructed.ConstructedFrom, out receiver, out method, ref parameters, this.TypeMap.SubstituteTypes(constructed.TypeArguments));
return;
}
var mappedLocalFunction = _localFunctionMap[(LocalFunctionSymbol)symbol];
method = mappedLocalFunction.Symbol;
var lambda = mappedLocalFunction.Symbol;
var frameCount = lambda.ExtraSynthesizedParameterCount;
if (frameCount != 0)
{
Debug.Assert(!parameters.IsDefault);
var builder = ArrayBuilder<BoundExpression>.GetInstance();
builder.AddRange(parameters);
var start = lambda.ParameterCount - frameCount;
for (int i = start; i < lambda.ParameterCount; i++)
{
// will always be a NamedTypeSymbol, it's always a closure class
var frame = FrameOfType(syntax, (NamedTypeSymbol)lambda.Parameters[i].Type);
builder.Add(frame);
}
parameters = builder.ToImmutableAndFree();
}
method = lambda;
NamedTypeSymbol constructedFrame;
RemapLambdaOrLocalFunction(syntax, symbol, typeArguments, mappedLocalFunction.ClosureKind, ref method, out receiver, out constructedFrame);
}
......@@ -759,8 +792,9 @@ public override BoundNode VisitCall(BoundCall node)
{
BoundExpression receiver;
MethodSymbol method;
RemapLocalFunction(node.Syntax, node.Method, out receiver, out method);
node = node.Update(receiver, method, node.Arguments);
var arguments = node.Arguments;
RemapLocalFunction(node.Syntax, node.Method, out receiver, out method, ref arguments);
node = node.Update(receiver, method, arguments);
}
var visited = base.VisitCall(node);
if (visited.Kind != BoundKind.Call)
......@@ -1014,7 +1048,8 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE
{
BoundExpression receiver;
MethodSymbol method;
RemapLocalFunction(node.Syntax, node.MethodOpt, out receiver, out method);
var arguments = default(ImmutableArray<BoundExpression>);
RemapLocalFunction(node.Syntax, node.MethodOpt, out receiver, out method, ref arguments);
var result = new BoundDelegateCreationExpression(node.Syntax, receiver, method, isExtensionMethod: false, type: node.Type);
return result;
}
......@@ -1049,7 +1084,8 @@ public override BoundNode VisitConversion(BoundConversion conversion)
{
BoundExpression receiver;
MethodSymbol method;
RemapLocalFunction(conversion.Syntax, conversion.SymbolOpt, out receiver, out method);
var arguments = default(ImmutableArray<BoundExpression>);
RemapLocalFunction(conversion.Syntax, conversion.SymbolOpt, out receiver, out method, ref arguments);
var result = new BoundDelegateCreationExpression(conversion.Syntax, receiver, method, isExtensionMethod: false, type: conversion.Type);
return result;
}
......@@ -1165,12 +1201,41 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos
out DebugId topLevelMethodId,
out DebugId lambdaId)
{
ImmutableArray<TypeSymbol> structClosures;
int closureOrdinal;
if (_analysis.LambdaScopes.TryGetValue(node.Symbol, out lambdaScope))
{
translatedLambdaContainer = containerAsFrame = _frames[lambdaScope];
closureKind = ClosureKind.General;
closureOrdinal = containerAsFrame.ClosureOrdinal;
containerAsFrame = _frames[lambdaScope];
var structClosureParamBuilder = ArrayBuilder<TypeSymbol>.GetInstance();
while (containerAsFrame != null && containerAsFrame.IsValueType)
{
structClosureParamBuilder.Add(containerAsFrame);
if (this._analysis.NeedsParentFrame.Contains(lambdaScope) && this._analysis.ScopeParent.TryGetValue(lambdaScope, out lambdaScope))
{
containerAsFrame = _frames[lambdaScope];
}
else
{
// can happen when scope no longer needs parent frame, or we're at the outermost level and the "parent frame" is top level "this".
lambdaScope = null;
containerAsFrame = null;
}
}
// Reverse it because we're going from inner to outer, and parameters are in order of outer to inner
structClosureParamBuilder.ReverseContents();
structClosures = structClosureParamBuilder.ToImmutableAndFree();
if (containerAsFrame == null)
{
closureKind = ClosureKind.Static; // not exactly... but we've rewritten the receiver to be a by-ref parameter
translatedLambdaContainer = _topLevelMethod.ContainingType;
closureOrdinal = LambdaDebugInfo.StaticClosureOrdinal;
}
else
{
closureKind = ClosureKind.General;
translatedLambdaContainer = containerAsFrame;
closureOrdinal = containerAsFrame.ClosureOrdinal;
}
}
else if (_analysis.CapturedVariablesByLambda[node.Symbol].Count == 0)
{
......@@ -1187,6 +1252,7 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos
closureKind = ClosureKind.Static;
closureOrdinal = LambdaDebugInfo.StaticClosureOrdinal;
}
structClosures = default(ImmutableArray<TypeSymbol>);
}
else
{
......@@ -1194,13 +1260,14 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos
translatedLambdaContainer = _topLevelMethod.ContainingType;
closureKind = ClosureKind.ThisOnly;
closureOrdinal = LambdaDebugInfo.ThisOnlyClosureOrdinal;
structClosures = default(ImmutableArray<TypeSymbol>);
}
// Move the body of the lambda to a freshly generated synthetic method on its frame.
topLevelMethodId = GetTopLevelMethodId();
lambdaId = GetLambdaId(node.Syntax, closureKind, closureOrdinal);
var synthesizedMethod = new SynthesizedLambdaMethod(translatedLambdaContainer, closureKind, _topLevelMethod, topLevelMethodId, node, lambdaId);
var synthesizedMethod = new SynthesizedLambdaMethod(translatedLambdaContainer, structClosures, closureKind, _topLevelMethod, topLevelMethodId, node, lambdaId);
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(translatedLambdaContainer, synthesizedMethod);
foreach (var parameter in node.Symbol.Parameters)
......
......@@ -14,9 +14,11 @@ namespace Microsoft.CodeAnalysis.CSharp
internal sealed class SynthesizedLambdaMethod : SynthesizedMethodBaseSymbol, ISynthesizedMethodBodyImplementationSymbol
{
private readonly MethodSymbol _topLevelMethod;
private readonly ImmutableArray<TypeSymbol> _structClosures;
internal SynthesizedLambdaMethod(
NamedTypeSymbol containingType,
ImmutableArray<TypeSymbol> structClosures,
ClosureKind closureKind,
MethodSymbol topLevelMethod,
DebugId topLevelMethodId,
......@@ -35,6 +37,7 @@ internal sealed class SynthesizedLambdaMethod : SynthesizedMethodBaseSymbol, ISy
| (lambdaNode.Symbol.IsAsync ? DeclarationModifiers.Async : 0))
{
_topLevelMethod = topLevelMethod;
_structClosures = structClosures;
TypeMap typeMap;
ImmutableArray<TypeParameterSymbol> typeParameters;
......@@ -87,8 +90,6 @@ private static string MakeName(string topLevelMethodName, DebugId topLevelMethod
lambdaId.Generation);
}
internal override int ParameterCount => this.BaseMethod.ParameterCount;
// The lambda symbol might have declared no parameters in the case
//
// D d = delegate {};
......@@ -102,6 +103,9 @@ private static string MakeName(string topLevelMethodName, DebugId topLevelMethod
// UNDONE: names from the delegate. Does it really matter?
protected override ImmutableArray<ParameterSymbol> BaseMethodParameters => this.BaseMethod.Parameters;
protected override ImmutableArray<TypeSymbol> ExtraSynthesizedRefParameters => _structClosures;
internal int ExtraSynthesizedParameterCount => this._structClosures.IsDefault ? 0 : this._structClosures.Length;
internal override bool GenerateDebugInfo => !this.IsAsync;
internal override bool IsExpressionBodied => false;
internal MethodSymbol TopLevelMethod => _topLevelMethod;
......
......@@ -81,7 +81,7 @@ public sealed override ImmutableArray<TypeParameterSymbol> TypeParameters
internal override int ParameterCount
{
get { return this.BaseMethod.ParameterCount; }
get { return this.Parameters.Length; }
}
public sealed override ImmutableArray<ParameterSymbol> Parameters
......@@ -96,6 +96,11 @@ public sealed override ImmutableArray<ParameterSymbol> Parameters
}
}
protected virtual ImmutableArray<TypeSymbol> ExtraSynthesizedRefParameters
{
get { return default(ImmutableArray<TypeSymbol>); }
}
protected virtual ImmutableArray<ParameterSymbol> BaseMethodParameters
{
get { return this.BaseMethod.Parameters; }
......@@ -110,6 +115,14 @@ private ImmutableArray<ParameterSymbol> MakeParameters()
{
builder.Add(new SynthesizedParameterSymbol(this, this.TypeMap.SubstituteType(p.OriginalDefinition.Type), ordinal++, p.RefKind, p.Name));
}
var extraSynthed = ExtraSynthesizedRefParameters;
if (!extraSynthed.IsDefaultOrEmpty)
{
foreach (var extra in extraSynthed)
{
builder.Add(new SynthesizedParameterSymbol(this, this.TypeMap.SubstituteType(extra), ordinal++, RefKind.Ref));
}
}
return builder.ToImmutableAndFree();
}
......
......@@ -949,7 +949,30 @@ void Foo()
";
var verify = VerifyOutputInMain(source, "2", "System");
var foo = verify.FindLocalFunction("Foo");
Assert.True(foo.ContainingType.IsValueType);
var program = verify.Compilation.GetTypeByMetadataName("Program");
Assert.Equal(program, foo.ContainingType);
Assert.True(foo.IsStatic);
Assert.Equal(RefKind.Ref, foo.Parameters[0].RefKind);
Assert.True(foo.Parameters[0].Type.IsValueType);
}
[Fact]
public void StructClosureGeneric()
{
var source = @"
int x = 2;
void Foo<T1>()
{
int y = x;
void Bar<T2>()
{
Console.Write(x + y);
}
Bar<T1>();
}
Foo<int>();
";
var verify = VerifyOutputInMain(source, "4", "System");
}
[Fact]
......@@ -979,9 +1002,25 @@ void Inner()
Outer();
";
var verify = VerifyOutputInMain(source, "2", "System");
Assert.True(verify.FindLocalFunction("Inner").ContainingType.IsValueType);
Assert.True(verify.FindLocalFunction("Middle").ContainingType.IsReferenceType);
Assert.True(verify.FindLocalFunction("Outer").ContainingType.IsReferenceType);
var inner = verify.FindLocalFunction("Inner");
var middle = verify.FindLocalFunction("Middle");
var outer = verify.FindLocalFunction("Outer");
var program = verify.Compilation.GetTypeByMetadataName("Program");
Assert.Equal(program, inner.ContainingType);
Assert.Equal(program, middle.ContainingType);
Assert.Equal(program, outer.ContainingType);
Assert.True(inner.IsStatic);
Assert.True(middle.IsStatic);
Assert.True(outer.IsStatic);
Assert.Equal(2, inner.Parameters.Length);
Assert.Equal(1, middle.Parameters.Length);
Assert.Equal(0, outer.Parameters.Length);
Assert.Equal(RefKind.Ref, inner.Parameters[0].RefKind);
Assert.Equal(RefKind.Ref, inner.Parameters[1].RefKind);
Assert.Equal(RefKind.Ref, middle.Parameters[0].RefKind);
Assert.True(inner.Parameters[0].Type.IsValueType);
Assert.True(inner.Parameters[1].Type.IsValueType);
Assert.True(middle.Parameters[0].Type.IsValueType);
}
[Fact]
......@@ -1038,7 +1077,11 @@ void Foo()
";
var verify = VerifyOutputInMain(source, "2", "System");
var foo = verify.FindLocalFunction("Foo");
Assert.True(foo.ContainingType.IsValueType);
var program = verify.Compilation.GetTypeByMetadataName("Program");
Assert.Equal(program, foo.ContainingType);
Assert.True(foo.IsStatic);
Assert.Equal(RefKind.Ref, foo.Parameters[0].RefKind);
Assert.True(foo.Parameters[0].Type.IsValueType);
}
[Fact]
......@@ -1067,9 +1110,21 @@ void Bar(int depth2)
Foo(0);
";
var verify = VerifyOutputInMain(source, "2", "System");
// should be class (due to by-value passing). See bottom of LambdaRewriter.Analysis.ComputeLambdaScopesAndFrameCaptures
Assert.True(verify.FindLocalFunction("Foo").ContainingType.IsReferenceType);
Assert.True(verify.FindLocalFunction("Bar").ContainingType.IsValueType);
var program = verify.Compilation.GetTypeByMetadataName("Program");
var foo = verify.FindLocalFunction("Foo");
var bar = verify.FindLocalFunction("Bar");
Assert.Equal(program, foo.ContainingType);
Assert.Equal(program, bar.ContainingType);
Assert.True(foo.IsStatic);
Assert.True(bar.IsStatic);
Assert.Equal(2, foo.Parameters.Length);
Assert.Equal(3, bar.Parameters.Length);
Assert.Equal(RefKind.Ref, foo.Parameters[1].RefKind);
Assert.Equal(RefKind.Ref, bar.Parameters[1].RefKind);
Assert.Equal(RefKind.Ref, bar.Parameters[2].RefKind);
Assert.True(foo.Parameters[1].Type.IsValueType);
Assert.True(bar.Parameters[2].Type.IsValueType);
Assert.True(bar.Parameters[2].Type.IsValueType);
}
[Fact]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册