提交 65d81321 编写于 作者: A Andy Gocke 提交者: GitHub

Implement proper 'this' capturing for local functions (#14736)

Currently, when a local function is captured inside another local
function or lambda it can capture 'this' without generating a frame.
This is useful, but when that lambda is itself captured then the
capturing closure must also capture the frame pointer, namely 'this'.
Currently, local function frame pointers are not correctly captured when
the captured local function itself captures something from a "higher"
scope than the capturing local function.

This change solves this problem by:

1) Considering a local function's captured variables when deciding its
scope. If the local function captures variables from a higher scope,
that local function will be analyzed as belonging to the "higher" scope,
causing that local function to register for frame capturing.

2) Since the proxies for capturing frames are not available at the time
of local function reference rewriting, the proxies must be saved. There
is a new temporary bound node for this purpose,
PartiallyLoweredLocalFunctionReference, that stores the proxies and the
underlying node for later use during the rewriting phase. This node
should never make it past LocalFunctionReferenceRewriting.

When these steps are completed, local functions should act very
similarly to all other captured variables with different frames, where
the frame pointers are captured and walked in a linked list in order to
access the target with the proper receiver/frame pointer.
上级 c88a6718
......@@ -402,6 +402,7 @@
<Compile Include="Lowering\LambdaRewriter\LambdaRewriter.cs" />
<Compile Include="Lowering\LambdaRewriter\LambdaFrame.cs" />
<Compile Include="Lowering\LambdaRewriter\LambdaRewriter.LocalFunctionReferenceRewriter.cs" />
<Compile Include="Lowering\LambdaRewriter\PartiallyLoweredLocalFunctionReference.cs" />
<Compile Include="Lowering\LambdaRewriter\SynthesizedLambdaMethod.cs" />
<Compile Include="Lowering\LocalRewriter\DynamicSiteContainer.cs" />
<Compile Include="Lowering\LocalRewriter\LocalRewriter.cs" />
......
......@@ -118,6 +118,11 @@ public bool Emitting
public ArrayBuilder<MethodWithBody> SynthesizedMethods
{
get { return _synthesizedMethods; }
set
{
Debug.Assert(_synthesizedMethods == null);
_synthesizedMethods = value;
}
}
/// <summary>
......
......@@ -7,6 +7,7 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
using Microsoft.CodeAnalysis;
namespace Microsoft.CodeAnalysis.CSharp
{
......@@ -63,7 +64,7 @@ internal sealed class Analysis : BoundTreeWalkerWithStackGuardWithoutRecursionOn
/// <summary>
/// For each lambda in the code, the set of variables that it captures.
/// </summary>
public MultiDictionary<MethodSymbol, Symbol> CapturedVariablesByLambda = new MultiDictionary<MethodSymbol, Symbol>();
public OrderedMultiDictionary<MethodSymbol, Symbol> CapturedVariablesByLambda = new OrderedMultiDictionary<MethodSymbol, Symbol>();
/// <summary>
/// If a local function is in the set, at some point in the code it is converted to a delegate and should then not be optimized to a struct closure.
......@@ -219,21 +220,7 @@ private void RemoveUnneededReferences()
}
}
var capturedVariablesNew = new MultiDictionary<Symbol, SyntaxNode>();
foreach (var old in CapturedVariables)
{
var method = old.Key as MethodSymbol;
// don't add if it's a method that only captures 'this'
if (method == null || capturesVariable.Contains(method))
{
foreach (var oldValue in old.Value)
{
capturedVariablesNew.Add(old.Key, oldValue);
}
}
}
CapturedVariables = capturedVariablesNew;
var capturedVariablesByLambdaNew = new MultiDictionary<MethodSymbol, Symbol>();
var capturedVariablesByLambdaNew = new OrderedMultiDictionary<MethodSymbol, Symbol>();
foreach (var old in CapturedVariablesByLambda)
{
if (capturesVariable.Contains(old.Key))
......@@ -263,22 +250,43 @@ internal void ComputeLambdaScopesAndFrameCaptures()
foreach (var kvp in CapturedVariablesByLambda)
{
// get innermost and outermost scopes from which a lambda captures
var lambda = kvp.Key;
var capturedVars = kvp.Value;
var allCapturedVars = ArrayBuilder<Symbol>.GetInstance(capturedVars.Count);
allCapturedVars.AddRange(capturedVars);
// If any of the captured variables are local functions we'll need
// to add the captured variables of that local function to the current
// set. This has the effect of ensuring that if the local function
// captures anything "above" the current scope then parent frame
// is itself captured (so that the current lambda can call that
// local function).
foreach (var captured in capturedVars)
{
var capturedLocalFunction = captured as LocalFunctionSymbol;
if (capturedLocalFunction != null)
{
allCapturedVars.AddRange(
CapturedVariablesByLambda[capturedLocalFunction]);
}
}
// get innermost and outermost scopes from which a lambda captures
int innermostScopeDepth = -1;
BoundNode innermostScope = null;
int outermostScopeDepth = int.MaxValue;
BoundNode outermostScope = null;
foreach (var variables in kvp.Value)
foreach (var captured in allCapturedVars)
{
BoundNode curBlock = null;
int curBlockDepth;
if (!VariableScope.TryGetValue(variables, out curBlock))
if (!VariableScope.TryGetValue(captured, out curBlock))
{
// this is something that is not defined in a block, like "Me"
// this is something that is not defined in a block, like "this"
// Since it is defined outside of the method, the depth is -1
curBlockDepth = -1;
}
......@@ -300,22 +308,24 @@ internal void ComputeLambdaScopesAndFrameCaptures()
}
}
allCapturedVars.Free();
// 1) if there is innermost scope, lambda goes there as we cannot go any higher.
// 2) scopes in [innermostScope, outermostScope) chain need to have access to the parent scope.
//
// Example:
// if a lambda captures a method//s parameter and Me,
// if a lambda captures a method's parameter and `this`,
// its innermost scope depth is 0 (method locals and parameters)
// and outermost scope is -1
// Such lambda will be placed in a closure frame that corresponds to the method//s outer block
// and this frame will also lift original Me as a field when created by its parent.
// Such lambda will be placed in a closure frame that corresponds to the method's outer block
// and this frame will also lift original `this` as a field when created by its parent.
// Note that it is completely irrelevant how deeply the lexical scope of the lambda was originally nested.
if (innermostScope != null)
{
LambdaScopes.Add(kvp.Key, innermostScope);
LambdaScopes.Add(lambda, innermostScope);
// Disable struct closures on methods converted to delegates, as well as on async and iterator methods.
var markAsNoStruct = MethodsConvertedToDelegates.Contains(kvp.Key) || kvp.Key.IsAsync || kvp.Key.IsIterator;
var markAsNoStruct = MethodsConvertedToDelegates.Contains(lambda) || lambda.IsAsync || lambda.IsIterator;
if (markAsNoStruct)
{
ScopesThatCantBeStructs.Add(innermostScope);
......
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.CodeAnalysis.CSharp.Symbols;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
......@@ -27,6 +28,27 @@ public LocalFunctionReferenceRewriter(LambdaRewriter lambdaRewriter)
_lambdaRewriter = lambdaRewriter;
}
public override BoundNode Visit(BoundNode node)
{
var partiallyLowered = node as PartiallyLoweredLocalFunctionReference;
if (partiallyLowered != null)
{
var underlying = partiallyLowered.UnderlyingNode;
Debug.Assert(underlying.Kind == BoundKind.Call ||
underlying.Kind == BoundKind.DelegateCreationExpression ||
underlying.Kind == BoundKind.Conversion);
var oldProxies = _lambdaRewriter.proxies;
_lambdaRewriter.proxies = partiallyLowered.Proxies;
var result = base.Visit(underlying);
_lambdaRewriter.proxies = oldProxies;
return result;
}
return base.Visit(node);
}
public override BoundNode VisitCall(BoundCall node)
{
if (node.Method.MethodKind == MethodKind.LocalFunction)
......@@ -76,6 +98,74 @@ public override BoundNode VisitConversion(BoundConversion conversion)
}
}
/// <summary>
/// Visit all references to local functions (calls, delegate
/// conversions, delegate creations) and rewrite them to point
/// to the rewritten local function method instead of the original.
/// </summary>
public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody)
{
var rewriter = new LocalFunctionReferenceRewriter(this);
Debug.Assert(_currentMethod == _topLevelMethod);
// Visit the body first since the state is already set
// for the top-level method
var newBody = (BoundStatement)rewriter.Visit(loweredBody);
// Visit all the rewritten methods as well
var synthesizedMethods = _synthesizedMethods;
if (synthesizedMethods != null)
{
var newMethods = ArrayBuilder<TypeCompilationState.MethodWithBody>.GetInstance(
synthesizedMethods.Count);
foreach (var oldMethod in synthesizedMethods)
{
var synthesizedLambda = oldMethod.Method as SynthesizedLambdaMethod;
if (synthesizedLambda == null)
{
// The only methods synthesized by the rewriter should
// be lowered closures and frame constructors
Debug.Assert(oldMethod.Method.MethodKind == MethodKind.Constructor ||
oldMethod.Method.MethodKind == MethodKind.StaticConstructor);
newMethods.Add(oldMethod);
continue;
}
_currentMethod = synthesizedLambda;
var closureKind = synthesizedLambda.ClosureKind;
if (closureKind == ClosureKind.Static || closureKind == ClosureKind.Singleton)
{
// no link from a static lambda to its container
_innermostFramePointer = _currentFrameThis = null;
}
else
{
_currentFrameThis = synthesizedLambda.ThisParameter;
_innermostFramePointer = null;
_framePointers.TryGetValue(synthesizedLambda.ContainingType, out _innermostFramePointer);
}
_currentTypeParameters = synthesizedLambda.ContainingType
?.TypeParameters.Concat(synthesizedLambda.TypeParameters)
?? synthesizedLambda.TypeParameters;
_currentLambdaBodyTypeMap = synthesizedLambda.TypeMap;
var rewrittenBody = (BoundStatement)rewriter.Visit(oldMethod.Body);
var newMethod = new TypeCompilationState.MethodWithBody(
synthesizedLambda, rewrittenBody, oldMethod.ImportChainOpt);
newMethods.Add(newMethod);
}
_synthesizedMethods = newMethods;
synthesizedMethods.Free();
}
return newBody;
}
private void RemapLocalFunction(
SyntaxNode syntax,
......
......@@ -12,6 +12,8 @@
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Roslyn.Utilities;
using System.Linq;
using Microsoft.CodeAnalysis.Collections;
namespace Microsoft.CodeAnalysis.CSharp
{
......@@ -140,6 +142,12 @@ public MappedLocalFunction(SynthesizedLambdaMethod symbol, ClosureKind closureKi
// top-level block initializing those variables to null.
private ArrayBuilder<BoundStatement> _addedStatements;
/// <summary>
/// Temporary bag for methods synthesized by the rewriting. Added to
/// <see cref="TypeCompilationState.SynthesizedMethods"/> at the end of rewriting.
/// </summary>
private ArrayBuilder<TypeCompilationState.MethodWithBody> _synthesizedMethods;
private LambdaRewriter(
Analysis analysis,
NamedTypeSymbol thisType,
......@@ -257,6 +265,20 @@ protected override bool NeedsProxy(Symbol localOrParameter)
body = rewriter.RewriteLocalFunctionReferences(body);
}
// Add the completed methods to the compilation state
if (rewriter._synthesizedMethods != null)
{
if (compilationState.SynthesizedMethods == null)
{
compilationState.SynthesizedMethods = rewriter._synthesizedMethods;
}
else
{
compilationState.SynthesizedMethods.AddRange(rewriter._synthesizedMethods);
rewriter._synthesizedMethods.Free();
}
}
CheckLocalsDefined(body);
return body;
......@@ -305,38 +327,120 @@ protected override NamedTypeSymbol ContainingType
/// </summary>
private void MakeFrames(ArrayBuilder<ClosureDebugInfo> closureDebugInfo)
{
NamedTypeSymbol containingType = this.ContainingType;
var closures = _analysis.CapturedVariablesByLambda.Keys;
foreach (var kvp in _analysis.CapturedVariables)
foreach (var closure in closures)
{
var captured = kvp.Key;
var capturedVars = _analysis.CapturedVariablesByLambda[closure];
BoundNode scope;
if (!_analysis.VariableScope.TryGetValue(captured, out scope))
if (closure.MethodKind == MethodKind.LocalFunction &&
OnlyCapturesThis((LocalFunctionSymbol)closure, capturedVars))
{
continue;
}
LambdaFrame frame = GetFrameForScope(scope, closureDebugInfo);
if (captured.Kind != SymbolKind.Method)
foreach (var captured in capturedVars)
{
var hoistedField = LambdaCapturedVariable.Create(frame, captured, ref _synthesizedFieldNameIdDispenser);
proxies.Add(captured, new CapturedToFrameSymbolReplacement(hoistedField, isReusable: false));
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(frame, hoistedField);
BoundNode scope;
if (!_analysis.VariableScope.TryGetValue(captured, out scope))
{
continue;
}
LambdaFrame frame = GetFrameForScope(scope, closureDebugInfo);
if (hoistedField.Type.IsRestrictedType())
if (captured.Kind != SymbolKind.Method && !proxies.ContainsKey(captured))
{
foreach (CSharpSyntaxNode syntax in kvp.Value)
var hoistedField = LambdaCapturedVariable.Create(frame, captured, ref _synthesizedFieldNameIdDispenser);
proxies.Add(captured, new CapturedToFrameSymbolReplacement(hoistedField, isReusable: false));
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(frame, hoistedField);
if (hoistedField.Type.IsRestrictedType())
{
// CS4013: Instance of type '{0}' cannot be used inside an anonymous function, query expression, iterator block or async method
this.Diagnostics.Add(ErrorCode.ERR_SpecialByRefInLambda, syntax.Location, hoistedField.Type);
foreach (CSharpSyntaxNode syntax in _analysis.CapturedVariables[captured])
{
// CS4013: Instance of type '{0}' cannot be used inside an anonymous function, query expression, iterator block or async method
this.Diagnostics.Add(ErrorCode.ERR_SpecialByRefInLambda, syntax.Location, hoistedField.Type);
}
}
}
}
}
}
private SmallDictionary<LocalFunctionSymbol, bool> _onlyCapturesThisMemoTable;
/// <summary>
/// Helper for determining whether a local function transitively
/// only captures this (only captures this or other local functions
/// which only capture this).
/// </summary>
private bool OnlyCapturesThis<T>(
LocalFunctionSymbol closure,
T capturedVars,
PooledHashSet<LocalFunctionSymbol> localFuncsInProgress = null)
where T : IEnumerable<Symbol>
{
bool result = false;
if (_onlyCapturesThisMemoTable?.TryGetValue(closure, out result) == true)
{
return result;
}
result = true;
foreach (var captured in capturedVars)
{
var param = captured as ParameterSymbol;
if (param != null && param.IsThis)
{
continue;
}
var localFunc = captured as LocalFunctionSymbol;
if (localFunc != null)
{
bool freePool = false;
if (localFuncsInProgress == null)
{
localFuncsInProgress = PooledHashSet<LocalFunctionSymbol>.GetInstance();
freePool = true;
}
else if (localFuncsInProgress.Contains(localFunc))
{
continue;
}
localFuncsInProgress.Add(localFunc);
bool transitivelyTrue = OnlyCapturesThis(
localFunc,
_analysis.CapturedVariablesByLambda[localFunc],
localFuncsInProgress);
if (freePool)
{
localFuncsInProgress.Free();
localFuncsInProgress = null;
}
if (transitivelyTrue)
{
continue;
}
}
result = false;
break;
}
if (_onlyCapturesThisMemoTable == null)
{
_onlyCapturesThisMemoTable = new SmallDictionary<LocalFunctionSymbol, bool>();
}
_onlyCapturesThisMemoTable[closure] = result;
return result;
}
private LambdaFrame GetFrameForScope(BoundNode scope, ArrayBuilder<ClosureDebugInfo> closureDebugInfo)
{
LambdaFrame frame;
......@@ -361,12 +465,12 @@ private LambdaFrame GetFrameForScope(BoundNode scope, ArrayBuilder<ClosureDebugI
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(this.ContainingType, frame);
if (frame.Constructor != null)
{
CompilationState.AddSynthesizedMethod(
frame.Constructor,
FlowAnalysisPass.AppendImplicitReturn(
MethodCompiler.BindMethodBody(frame.Constructor, CompilationState, null),
frame.Constructor));
}
AddSynthesizedMethod(
frame.Constructor,
FlowAnalysisPass.AppendImplicitReturn(
MethodCompiler.BindMethodBody(frame.Constructor, CompilationState, null),
frame.Constructor));
}
}
return frame;
......@@ -411,7 +515,7 @@ private LambdaFrame GetStaticFrame(DiagnosticBag diagnostics, IBoundLambdaOrFunc
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(this.ContainingType, frame);
// add its ctor (note Constructor can be null if TypeKind.Struct is passed in to LambdaFrame.ctor, but Class is passed in above)
CompilationState.AddSynthesizedMethod(
AddSynthesizedMethod(
frame.Constructor,
FlowAnalysisPass.AppendImplicitReturn(
MethodCompiler.BindMethodBody(frame.Constructor, CompilationState, null),
......@@ -432,7 +536,7 @@ private LambdaFrame GetStaticFrame(DiagnosticBag diagnostics, IBoundLambdaOrFunc
F.New(frame.Constructor)),
new BoundReturnStatement(syntax, RefKind.None, null));
CompilationState.AddSynthesizedMethod(frame.StaticConstructor, body);
AddSynthesizedMethod(frame.StaticConstructor, body);
}
}
......@@ -496,7 +600,7 @@ protected override BoundExpression FramePointer(SyntaxNode syntax, NamedTypeSymb
return proxyField.Replacement(syntax, frameType => FramePointer(syntax, frameType));
}
var localFrame = framePointer as LocalSymbol;
var localFrame = (LocalSymbol)framePointer;
return new BoundLocal(syntax, localFrame, null, localFrame.Type);
}
......@@ -517,7 +621,7 @@ private static void InsertAndFreePrologue(ArrayBuilder<BoundStatement> result, A
/// <param name="frame">The frame for the translated node</param>
/// <param name="F">A function that computes the translation of the node. It receives lists of added statements and added symbols</param>
/// <returns>The translated statement, as returned from F</returns>
private T IntroduceFrame<T>(BoundNode node, LambdaFrame frame, Func<ArrayBuilder<BoundExpression>, ArrayBuilder<LocalSymbol>, T> F)
private BoundNode IntroduceFrame(BoundNode node, LambdaFrame frame, Func<ArrayBuilder<BoundExpression>, ArrayBuilder<LocalSymbol>, BoundNode> F)
{
var frameTypeParameters = ImmutableArray.Create(StaticCast<TypeSymbol>.From(_currentTypeParameters).SelectAsArray(TypeMap.TypeSymbolAsTypeWithModifiers), 0, frame.Arity);
NamedTypeSymbol frameType = frame.ConstructIfGeneric(frameTypeParameters);
......@@ -539,8 +643,8 @@ private T IntroduceFrame<T>(BoundNode node, LambdaFrame frame, Func<ArrayBuilder
}
else
{
MethodSymbol constructor = frame.Constructor.AsMember(frameType);
Debug.Assert(frameType == constructor.ContainingType);
MethodSymbol constructor = frame.Constructor.AsMember(frameType);
Debug.Assert(frameType == constructor.ContainingType);
newFrame = new BoundObjectCreationExpression(
syntax: syntax,
constructor: constructor);
......@@ -708,70 +812,6 @@ public override BoundNode VisitBaseReference(BoundBaseReference node)
: FramePointer(node.Syntax, _topLevelMethod.ContainingType); // technically, not the correct static type
}
/// <summary>
/// Visit all references to local functions (calls, delegete
/// conversions, delegate creations) and rewrite them to point
/// to the rewritten local function method instead of the original.
/// </summary>
public BoundStatement RewriteLocalFunctionReferences(BoundStatement loweredBody)
{
var rewriter = new LocalFunctionReferenceRewriter(this);
Debug.Assert(_currentMethod == _topLevelMethod);
// Visit the body first since the state is already set
// for the top-level method
var newBody = (BoundStatement)rewriter.Visit(loweredBody);
// Visit all the rewritten methods as well
var synthesizedMethods = CompilationState.SynthesizedMethods;
if (synthesizedMethods != null)
{
// Dump the existing methods for rewriting
var oldMethods = synthesizedMethods.ToImmutable();
synthesizedMethods.Clear();
foreach (var oldMethod in oldMethods)
{
var synthesizedLambda = oldMethod.Method as SynthesizedLambdaMethod;
if (synthesizedLambda == null)
{
// The only methods synthesized by the rewriter should
// be lowered closures and frame constructors
Debug.Assert(oldMethod.Method.MethodKind == MethodKind.Constructor ||
oldMethod.Method.MethodKind == MethodKind.StaticConstructor);
CompilationState.AddSynthesizedMethod(oldMethod.Method, oldMethod.Body);
continue;
}
_currentMethod = synthesizedLambda;
var closureKind = synthesizedLambda.ClosureKind;
if (closureKind == ClosureKind.Static || closureKind == ClosureKind.Singleton)
{
// no link from a static lambda to its container
_innermostFramePointer = _currentFrameThis = null;
}
else
{
_currentFrameThis = synthesizedLambda.ThisParameter;
_innermostFramePointer = null;
_framePointers.TryGetValue(synthesizedLambda.ContainingType, out _innermostFramePointer);
}
_currentTypeParameters = synthesizedLambda.ContainingType
?.TypeParameters.Concat(synthesizedLambda.TypeParameters)
?? synthesizedLambda.TypeParameters;
_currentLambdaBodyTypeMap = synthesizedLambda.TypeMap;
var rewrittenBody = (BoundStatement)rewriter.Visit(oldMethod.Body);
CompilationState.AddSynthesizedMethod(synthesizedLambda, rewrittenBody);
}
}
return newBody;
}
private void RemapLambdaOrLocalFunction(
SyntaxNode syntax,
MethodSymbol originalMethod,
......@@ -845,7 +885,7 @@ public override BoundNode VisitCall(BoundCall node)
{
var rewrittenArguments = this.VisitList(node.Arguments);
return node.Update(
var withArguments = node.Update(
node.ReceiverOpt,
node.Method,
rewrittenArguments,
......@@ -857,6 +897,8 @@ public override BoundNode VisitCall(BoundCall node)
node.ArgsToParamsOpt,
node.ResultKind,
node.Type);
return PartiallyLowerLocalFunctionReference(withArguments);
}
var visited = base.VisitCall(node);
......@@ -887,6 +929,17 @@ public override BoundNode VisitCall(BoundCall node)
return rewritten;
}
private PartiallyLoweredLocalFunctionReference PartiallyLowerLocalFunctionReference(
BoundExpression underlyingNode)
{
Debug.Assert(underlyingNode.Kind == BoundKind.Call ||
underlyingNode.Kind == BoundKind.DelegateCreationExpression ||
underlyingNode.Kind == BoundKind.Conversion);
return new PartiallyLoweredLocalFunctionReference(
underlyingNode,
new Dictionary<Symbol, CapturedSymbolReplacement>(proxies));
}
private BoundSequence RewriteSequence(BoundSequence node, ArrayBuilder<BoundExpression> prologue, ArrayBuilder<LocalSymbol> newLocals)
{
RewriteLocals(node.Locals, newLocals);
......@@ -1084,14 +1137,12 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE
{
return RewriteLambdaConversion((BoundLambda)node.Argument);
}
else
if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction)
{
if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction)
{
return node;
}
return base.VisitDelegateCreationExpression(node);
return PartiallyLowerLocalFunctionReference(node);
}
return base.VisitDelegateCreationExpression(node);
}
public override BoundNode VisitConversion(BoundConversion conversion)
......@@ -1115,15 +1166,13 @@ public override BoundNode VisitConversion(BoundConversion conversion)
return result;
}
else
if (conversion.ConversionKind == ConversionKind.MethodGroup &&
conversion.SymbolOpt?.MethodKind == MethodKind.LocalFunction)
{
if (conversion.ConversionKind == ConversionKind.MethodGroup &&
conversion.SymbolOpt?.MethodKind == MethodKind.LocalFunction)
{
return conversion;
}
return base.VisitConversion(conversion);
return PartiallyLowerLocalFunctionReference(conversion);
}
return base.VisitConversion(conversion);
}
public override BoundNode VisitLocalFunctionStatement(BoundLocalFunctionStatement node)
......@@ -1353,7 +1402,7 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos
var body = AddStatementsIfNeeded((BoundStatement)VisitBlock(node.Body));
CheckLocalsDefined(body);
CompilationState.AddSynthesizedMethod(synthesizedMethod, body);
AddSynthesizedMethod(synthesizedMethod, body);
// return to the old method
......@@ -1368,8 +1417,22 @@ private DebugId GetLambdaId(SyntaxNode syntax, ClosureKind closureKind, int clos
return synthesizedMethod;
}
private BoundNode RewriteLambdaConversion(BoundLambda node)
private void AddSynthesizedMethod(MethodSymbol method, BoundStatement body)
{
if (_synthesizedMethods == null)
{
_synthesizedMethods = ArrayBuilder<TypeCompilationState.MethodWithBody>.GetInstance();
}
_synthesizedMethods.Add(
new TypeCompilationState.MethodWithBody(
method,
body,
CompilationState.CurrentImportChain));
}
private BoundNode RewriteLambdaConversion(BoundLambda node)
{
var wasInExpressionLambda = _inExpressionLambda;
_inExpressionLambda = _inExpressionLambda || node.Type.IsExpressionTree();
......
using System;
using System.Collections.Generic;
using Microsoft.CodeAnalysis.Semantics;
namespace Microsoft.CodeAnalysis.CSharp
{
/// <summary>
/// This represents a partially lowered local function reference (e.g.,
/// a local function call or delegate conversion) with relevant proxies
/// attached. It will later be rewritten by the
/// <see cref="LambdaRewriter.LocalFunctionReferenceRewriter"/> into a
/// proper call.
/// </summary>
internal class PartiallyLoweredLocalFunctionReference : BoundExpression
{
private const BoundKind s_privateKind = (BoundKind)byte.MaxValue;
public BoundExpression UnderlyingNode { get; }
public Dictionary<Symbol, CapturedSymbolReplacement> Proxies { get; }
public PartiallyLoweredLocalFunctionReference(
BoundExpression underlying,
Dictionary<Symbol, CapturedSymbolReplacement> proxies)
: base(s_privateKind, underlying.Syntax, underlying.Type)
{
UnderlyingNode = underlying;
Proxies = proxies;
}
public override BoundNode Accept(BoundTreeVisitor visitor) =>
visitor.Visit(this);
protected override OperationKind ExpressionKind
{
get
{
throw new InvalidOperationException();
}
}
public override void Accept(OperationVisitor visitor)
{
throw new InvalidOperationException();
}
public override TResult Accept<TArgument, TResult>(OperationVisitor<TArgument, TResult> visitor, TArgument argument)
{
throw new InvalidOperationException();
}
}
}
......@@ -17,7 +17,10 @@ internal abstract partial class MethodToClassRewriter : BoundTreeRewriterWithSta
// For each captured variable, information about its replacement. May be populated lazily (that is, not all
// upfront) by subclasses. Specifically, the async rewriter produces captured symbols for temps, including
// ref locals, lazily.
protected readonly Dictionary<Symbol, CapturedSymbolReplacement> proxies = new Dictionary<Symbol, CapturedSymbolReplacement>();
// The lambda rewriter also saves/restores the proxies across passes, since local function
// reference rewriting is done in a separate pass but still requires the frame proxies
// created in the first pass.
protected Dictionary<Symbol, CapturedSymbolReplacement> proxies = new Dictionary<Symbol, CapturedSymbolReplacement>();
// A mapping from every local variable to its replacement local variable. Local variables are replaced when
// their types change due to being inside of a generic method. Otherwise we reuse the original local (even
......
......@@ -29,6 +29,323 @@ public static IMethodSymbol FindLocalFunction(this CommonTestBase.CompilationVer
[CompilerTrait(CompilerFeature.LocalFunctions)]
public class CodeGenLocalFunctionTests : CSharpTestBase
{
[Fact]
public void DeepNestedLocalFuncsWithDifferentCaptures()
{
var src = @"
using System;
class C
{
int P = 100000;
void M()
{
C Local1() => this;
int capture1 = 1;
Func<int> f1 = () => capture1 + Local1().P;
Console.WriteLine(f1());
{
C Local2() => Local1();
int capture2 = 10;
Func<int> f2 = () => capture2 + Local2().P;
Console.WriteLine(f2());
{
C Local3() => Local2();
int capture3 = 100;
Func<int> f3 = () => capture1 + capture2 + capture3 + Local3().P;
Console.WriteLine(f3());
Console.WriteLine(Local3().P);
}
}
}
public static void Main() => new C().M();
}";
VerifyOutput(src, @"100001
100010
100111
100000");
}
[Fact]
public void LotsOfMutuallyRecursiveLocalFunctions()
{
var src = @"
class C
{
int P = 0;
public void M()
{
int Local1() => this.P;
int Local2() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local3() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local4() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local5() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local6() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local7() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local8() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local9() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local10() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local11() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
int Local12() => Local12() + Local11() + Local10() + Local9() + Local8() + Local7() + Local6() + Local5() + Local4() + Local3() + Local2() + Local1();
Local1();
Local2();
Local3();
Local4();
Local5();
Local6();
Local7();
Local8();
Local9();
Local10();
Local11();
Local12();
}
}
";
var comp = CreateCompilationWithMscorlib(src);
comp.VerifyEmitDiagnostics();
}
[Fact]
public void LocalFuncAndLambdaWithDifferentThis()
{
var src = @"
using System;
class C
{
private int P = 1;
public void M()
{
int Local(int x) => x + this.P;
int y = 10;
var a = new Func<int>(() => Local(y));
Console.WriteLine(a());
}
public static void Main(string[] args)
{
var c = new C();
c.M();
}
}";
VerifyOutput(src, "11");
}
[Fact]
public void LocalFuncAndLambdaWithDifferentThis2()
{
var src = @"
using System;
class C
{
private int P = 1;
public void M()
{
int Local() => 10 + this.P;
int Local2(int x) => x + Local();
int y = 100;
var a = new Func<int>(() => Local2(y));
Console.WriteLine(a());
}
public static void Main(string[] args)
{
var c = new C();
c.M();
}
}";
VerifyOutput(src, "111");
}
[Fact]
public void LocalFuncAndLambdaWithDifferentThis3()
{
var src = @"
using System;
class C
{
private int P = 1;
public void M()
{
int Local()
{
if (this.P < 5)
{
return Local2(this.P++);
}
else
{
return 1;
}
}
int Local2(int x) => x + Local();
int y = 100;
var a = new Func<int>(() => Local2(y));
Console.WriteLine(a());
}
public static void Main(string[] args)
{
var c = new C();
c.M();
}
}";
VerifyOutput(src, "111");
}
[Fact]
public void LocalFuncAndLambdaWithDifferentThis4()
{
var src = @"
using System;
class C
{
private int P = 1;
public void M()
{
int Local(int x) => x + this.P;
int y = 10;
var a = new Func<int>(() =>
{
var b = (Func<int, int>)Local;
return b(y);
});
Console.WriteLine(a());
}
public static void Main(string[] args)
{
var c = new C();
c.M();
}
}";
VerifyOutput(src, "11");
}
[Fact]
public void LocalFuncAndLambdaWithDifferentThis5()
{
var src = @"
using System;
class C
{
private int P = 1;
public void M()
{
int Local(int x) => x + this.P;
int y = 10;
var a = new Func<int>(() =>
{
var b = new Func<int, int>(Local);
return b(y);
});
Console.WriteLine(a());
}
public static void Main(string[] args)
{
var c = new C();
c.M();
}
}";
VerifyOutput(src, "11");
}
[Fact]
public void TwoFrames()
{
var src = @"
using System;
class C
{
private int P = 0;
public void M()
{
int x = 0;
var a = new Func<int>(() =>
{
int Local() => x + this.P;
int z = 0;
int Local3() => z + Local();
return Local3();
});
Console.WriteLine(a());
}
public static void Main(string[] args)
{
var c = new C();
c.M();
}
}";
VerifyOutput(src, "0");
}
[Fact]
public void SameFrame()
{
var src = @"
using System;
class C
{
private int P = 1;
public void M()
{
int x = 10;
int Local() => x + this.P;
int y = 100;
int Local2() => y + Local();
Console.WriteLine(Local2());
}
public static void Main(string[] args)
{
var c = new C();
c.M();
}
}";
VerifyOutput(src, "111");
}
[Fact]
public void MutuallyRecursiveThisCapture()
{
var src = @"
using System;
class C
{
private int P = 1;
public void M()
{
int Local()
{
if (this.P < 5)
{
return Local2(this.P++);
}
else
{
return 1;
}
}
int Local2(int x) => x + Local();
Console.WriteLine(Local());
}
public static void Main() => new C().M();
}";
VerifyOutput(src, "11");
}
[Fact]
[CompilerTrait(CompilerFeature.Dynamic)]
public void DynamicParameterLocalFunction()
......
......@@ -714,10 +714,10 @@ public void F()
// no new synthesized members generated (with #1 in names):
diff1.VerifySynthesizedMembers(
"C.<>c__DisplayClass1_0: {a, <F>b__1}",
"C.<>c__DisplayClass1_1: {b, <F>b__3}",
"C.<>c__DisplayClass1_2: {a, b, <F>b__5}",
"C: {<F>b__1_0, <F>b__1_2, <F>b__1_4, <>c__DisplayClass1_0, <>c__DisplayClass1_1, <>c__DisplayClass1_2}");
"C.<>c__DisplayClass1_2: {a, b, <F>b__5}",
"C.<>c__DisplayClass1_1: {b, <F>b__3}",
"C: {<F>b__1_0, <F>b__1_2, <F>b__1_4, <>c__DisplayClass1_0, <>c__DisplayClass1_1, <>c__DisplayClass1_2}",
"C.<>c__DisplayClass1_0: {a, <F>b__1}");
var md1 = diff1.GetMetadata();
var reader1 = md1.Reader;
......
......@@ -507,10 +507,9 @@ static void F()
var emitContext = new EmitContext(peAssemblyBuilder, null, new DiagnosticBag());
var fields = displayClass.GetFields(emitContext).ToArray();
var x1 = fields[0];
var x2 = fields[1];
Assert.Equal("x1", x1.Name);
Assert.Equal("x2", x2.Name);
AssertEx.SetEqual(fields.Select(f => f.Name), new[] { "x1", "x2" });
var x1 = fields.Where(f => f.Name == "x1").Single();
var x2 = fields.Where(f => f.Name == "x2").Single();
var matcher = new CSharpSymbolMatcher(anonymousTypeMap0, compilation1.SourceAssembly, emitContext, peAssemblySymbol0);
......
......@@ -46,6 +46,7 @@
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</Content>
<Compile Include="InternalUtilities\OrderedMultiDictionary.cs" />
<Compile Include="Syntax\InternalSyntax\GreenNodeExtensions.cs" />
<Compile Include="Syntax\InternalSyntax\SyntaxListPool.cs" />
<Compile Include="Syntax\SyntaxList.SeparatedWithManyWeakChildren.cs" />
......
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
namespace Roslyn.Utilities
{
// Note that this is not threadsafe for concurrent reading and writing.
internal sealed class OrderedMultiDictionary<K, V> : IEnumerable<KeyValuePair<K, SetWithInsertionOrder<V>>>
{
private readonly Dictionary<K, SetWithInsertionOrder<V>> _dictionary;
private readonly List<K> _keys;
public int Count => _dictionary.Count;
public IEnumerable<K> Keys => _keys;
// Returns an empty set if there is no such key in the dictionary.
public SetWithInsertionOrder<V> this[K k]
{
get
{
SetWithInsertionOrder<V> set;
return _dictionary.TryGetValue(k, out set)
? set : new SetWithInsertionOrder<V>();
}
}
public OrderedMultiDictionary()
{
_dictionary = new Dictionary<K, SetWithInsertionOrder<V>>();
_keys = new List<K>();
}
public void Add(K k, V v)
{
SetWithInsertionOrder<V> set;
if (!_dictionary.TryGetValue(k, out set))
{
_keys.Add(k);
set = new SetWithInsertionOrder<V>();
}
set.Add(v);
_dictionary[k] = set;
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public IEnumerator<KeyValuePair<K, SetWithInsertionOrder<V>>> GetEnumerator()
{
foreach (var key in _keys)
{
yield return new KeyValuePair<K, SetWithInsertionOrder<V>>(
key, _dictionary[key]);
}
}
}
}
......@@ -14,7 +14,7 @@ namespace Roslyn.Utilities
/// A set that returns the inserted values in insertion order.
/// The mutation operations are not thread-safe.
/// </summary>
internal sealed class SetWithInsertionOrder<T> : IEnumerable<T>
internal sealed class SetWithInsertionOrder<T> : IEnumerable<T>, IReadOnlySet<T>
{
private HashSet<T> _set = new HashSet<T>();
private uint _nextElementValue = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册