提交 fea05ade 编写于 作者: E Evan Hauck

Support recursive closures

上级 b0de720f
......@@ -420,13 +420,18 @@ private BoundStatement BindLocalFunctionStatement(LocalFunctionStatementSyntax n
// This occurs through the semantic model. In that case concoct a plausible result.
if (localSymbol == null)
{
localSymbol = new LocalFunctionMethodSymbol(this, this.ContainingType, node, node.Identifier.GetLocation());
localSymbol = new LocalFunctionMethodSymbol(this, this.ContainingType, this.ContainingMemberOrLambda, node, node.Identifier.GetLocation());
}
else
{
hasErrors |= this.ValidateDeclarationNameConflictsInScope(localSymbol, diagnostics);
}
if (localSymbol.IsGenericMethod)
{
diagnostics.Add(ErrorCode.ERR_InvalidMemberDecl, node.TypeParameterListOpt.Location, node.TypeParameterListOpt);
}
var binder = this.GetBinder(node);
// Binder could be null in error scenarios (as above)
......
......@@ -889,6 +889,19 @@ private NamedTypeSymbol ConstructNamedTypeUnlessTypeArgumentOmitted(CSharpSyntax
Debug.Assert(members.Count > 0);
if (!hasErrors && members[0] is LocalFunctionMethodSymbol)
{
Debug.Assert(members.Count == 1 && members[0].Locations.Length == 1);
var localSymbolLocation = members[0].Locations[0];
bool usedBeforeDecl =
syntax.SyntaxTree == localSymbolLocation.SourceTree &&
syntax.SpanStart < localSymbolLocation.SourceSpan.Start;
if (usedBeforeDecl)
{
Error(diagnostics, ErrorCode.ERR_VariableUsedBeforeDeclaration, syntax, syntax);
}
}
switch (members[0].Kind)
{
case SymbolKind.Method:
......
......@@ -230,6 +230,7 @@ protected LocalFunctionMethodSymbol MakeLocalFunction(LocalFunctionStatementSynt
return new LocalFunctionMethodSymbol(
this,
this.ContainingType,
this.ContainingMemberOrLambda,
declaration,
declaration.Identifier.GetLocation());
}
......
......@@ -328,7 +328,7 @@ public override BoundNode VisitCall(BoundCall node)
var localFunction = node.Method as LocalFunctionMethodSymbol;
if (localFunction != null)
{
//ReferenceVariable(node.Syntax, localFunction);
ReferenceVariable(node.Syntax, localFunction.OriginalDefinition);
}
return base.VisitCall(node);
}
......@@ -338,9 +338,18 @@ public override BoundNode VisitLambda(BoundLambda node)
return VisitLambdaOrFunction(node);
}
public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationExpression node)
{
if (node.MethodOpt.MethodKind == MethodKind.LocalFunction)
{
ReferenceVariable(node.Syntax, node.MethodOpt);
}
return base.VisitDelegateCreationExpression(node);
}
public override BoundNode VisitLocalFunctionStatement(BoundLocalFunctionStatement node)
{
//variableScope[node.Symbol] = _currentScope;
variableScope[node.Symbol] = _currentScope;
return VisitLambdaOrFunction(node);
}
......@@ -426,6 +435,10 @@ public override BoundNode VisitConversion(BoundConversion node)
{
if (node.ConversionKind == ConversionKind.MethodGroup)
{
if (node.SymbolOpt?.MethodKind == MethodKind.LocalFunction)
{
ReferenceVariable(node.Syntax, node.SymbolOpt);
}
if (node.IsExtensionMethod || ((object)node.SymbolOpt != null && !node.SymbolOpt.IsStatic))
{
return VisitSyntaxWithReceiver(node.Syntax, ((BoundMethodGroup)node.Operand).ReceiverOpt);
......
......@@ -67,19 +67,9 @@ internal sealed partial class LambdaRewriter : MethodToClassRewriter
// A mapping from every lambda parameter to its corresponding method's parameter.
private readonly Dictionary<ParameterSymbol, ParameterSymbol> _parameterMap = new Dictionary<ParameterSymbol, ParameterSymbol>();
private struct LocalFunctionMapValue
{
public readonly BoundExpression Receiver;
public readonly MethodSymbol Symbol;
public LocalFunctionMapValue(BoundExpression receiver, MethodSymbol symbol)
{
Receiver = receiver;
Symbol = symbol;
}
}
// A mapping from every local function to its lowered method
private readonly Dictionary<LocalFunctionMethodSymbol, LocalFunctionMapValue> _localFunctionMap = new Dictionary<LocalFunctionMethodSymbol, LocalFunctionMapValue>();
private readonly Dictionary<LocalFunctionMethodSymbol, MethodSymbol> _localFunctionMap = new Dictionary<LocalFunctionMethodSymbol, MethodSymbol>();
// for each block with lifted (captured) variables, the corresponding frame type
private readonly Dictionary<BoundNode, LambdaFrame> _frames = new Dictionary<BoundNode, LambdaFrame>();
......@@ -166,7 +156,8 @@ public LocalFunctionMapValue(BoundExpression receiver, MethodSymbol symbol)
protected override bool NeedsProxy(Symbol localOrParameter)
{
Debug.Assert(localOrParameter is LocalSymbol || localOrParameter is ParameterSymbol);
Debug.Assert(localOrParameter is LocalSymbol || localOrParameter is ParameterSymbol ||
(localOrParameter as MethodSymbol)?.MethodKind == MethodKind.LocalFunction);
return _analysis.capturedVariables.ContainsKey(localOrParameter);
}
......@@ -292,17 +283,20 @@ private void MakeFrames(ArrayBuilder<ClosureDebugInfo> closureDebugInfo)
}
LambdaFrame frame = GetFrameForScope(scope, closureDebugInfo);
var hoistedField = LambdaCapturedVariable.Create(frame, captured, ref _synthesizedFieldNameIdDispenser);
proxies.Add(captured, new CapturedToFrameSymbolReplacement(hoistedField, isReusable: false));
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(frame, hoistedField);
if (hoistedField.Type.IsRestrictedType())
if (captured.Kind != SymbolKind.Method)
{
foreach (CSharpSyntaxNode syntax in kvp.Value)
var hoistedField = LambdaCapturedVariable.Create(frame, captured, ref _synthesizedFieldNameIdDispenser);
proxies.Add(captured, new CapturedToFrameSymbolReplacement(hoistedField, isReusable: false));
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(frame, hoistedField);
if (hoistedField.Type.IsRestrictedType())
{
// CS4013: Instance of type '{0}' cannot be used inside an anonymous function, query expression, iterator block or async method
this.Diagnostics.Add(ErrorCode.ERR_SpecialByRefInLambda, syntax.Location, hoistedField.Type);
foreach (CSharpSyntaxNode syntax in kvp.Value)
{
// CS4013: Instance of type '{0}' cannot be used inside an anonymous function, query expression, iterator block or async method
this.Diagnostics.Add(ErrorCode.ERR_SpecialByRefInLambda, syntax.Location, hoistedField.Type);
}
}
}
}
......@@ -354,7 +348,7 @@ private LambdaFrame GetStaticFrame(DiagnosticBag diagnostics, BoundNode lambda)
methodId = GetTopLevelMethodId();
}
DebugId closureId = default(DebugId);
DebugId closureId = default(DebugId);
_lazyStaticLambdaFrame = new LambdaFrame(_topLevelMethod, scopeSyntaxOpt: null, methodId: methodId, closureId: closureId);
// nongeneric static lambdas can share the frame
......@@ -592,17 +586,14 @@ private void InitVariableProxy(CSharpSyntaxNode syntax, Symbol symbol, LocalSymb
value = new BoundLocal(syntax, localToUse, null, localToUse.Type);
break;
default:
throw ExceptionUtilities.UnexpectedValue(symbol.Kind);
}
if (value != null)
{
var left = proxy.Replacement(syntax, frameType1 => new BoundLocal(syntax, framePointer, null, framePointer.Type));
var assignToProxy = new BoundAssignmentOperator(syntax, left, value, value.Type);
prologue.Add(assignToProxy);
}
var left = proxy.Replacement(syntax, frameType1 => new BoundLocal(syntax, framePointer, null, framePointer.Type));
var assignToProxy = new BoundAssignmentOperator(syntax, left, value, value.Type);
prologue.Add(assignToProxy);
}
}
......@@ -642,31 +633,57 @@ public override BoundNode VisitBaseReference(BoundBaseReference node)
? node
: FramePointer(node.Syntax, _topLevelMethod.ContainingType); // technically, not the correct static type
}
private BoundCall RewriteLocalFunctionCall(BoundCall node)
{
var localFunction = node.Method as LocalFunctionMethodSymbol;
if (localFunction == null)
return node;
var mapped = _localFunctionMap[localFunction];
var receiver = mapped.Receiver;
var method = mapped.Symbol;
private void RemapLocalFunction(CSharpSyntaxNode syntax, MethodSymbol symbol, out BoundExpression receiver, out MethodSymbol method)
{
Debug.Assert(symbol.MethodKind == MethodKind.LocalFunction);
if (receiver == null && _currentMethod == method)
var constructed = symbol as ConstructedMethodSymbol;
if (constructed != null)
{
receiver = new BoundParameter(node.Syntax, method.ThisParameter);
RemapLocalFunction(syntax, constructed.ConstructedFrom, out receiver, out method);
//method = method.Construct(constructed.TypeArguments);
return;
}
// fzoo
Debug.Assert(receiver != null, "Mutually recursive local functions with closures are not yet supported");
method = _localFunctionMap[(LocalFunctionMethodSymbol)symbol];
var translatedLambdaContainer = method.ContainingType;
var containerAsFrame = translatedLambdaContainer as LambdaFrame;
// Rewrite the lambda expression (and the enclosing anonymous method conversion) as a delegate creation expression
NamedTypeSymbol constructedFrame = (object)containerAsFrame != null ?
translatedLambdaContainer.ConstructIfGeneric(StaticCast<TypeSymbol>.From(_currentTypeParameters)) :
translatedLambdaContainer;
// for instance lambdas, receiver is the frame
// for static lambdas, get the singleton receiver
if (containerAsFrame?.SingletonCache == null)
{
receiver = FrameOfType(syntax, constructedFrame);
}
else
{
var field = containerAsFrame.SingletonCache.AsMember(constructedFrame);
receiver = new BoundFieldAccess(syntax, null, field, constantValueOpt: null);
}
return node.Update(receiver, method, node.Arguments);
method = method.AsMember(constructedFrame);
if (method.IsGenericMethod)
{
method = method.Construct(StaticCast<TypeSymbol>.From(_currentTypeParameters));
}
}
public override BoundNode VisitCall(BoundCall node)
{
var visited = base.VisitCall(RewriteLocalFunctionCall(node));
if (node.Method.MethodKind == MethodKind.LocalFunction)
{
BoundExpression receiver;
MethodSymbol method;
RemapLocalFunction(node.Syntax, node.Method, out receiver, out method);
node = node.Update(receiver, method, node.Arguments);
}
var visited = base.VisitCall(node);
if (visited.Kind != BoundKind.Call)
{
return visited;
......@@ -914,6 +931,14 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE
}
else
{
if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction)
{
BoundExpression receiver;
MethodSymbol method;
RemapLocalFunction(node.Syntax, node.MethodOpt, out receiver, out method);
var result = new BoundDelegateCreationExpression(node.Syntax, receiver, method, isExtensionMethod: false, type: node.Type);
return result;
}
return base.VisitDelegateCreationExpression(node);
}
}
......@@ -941,6 +966,14 @@ public override BoundNode VisitConversion(BoundConversion conversion)
}
else
{
if (conversion.SymbolOpt?.MethodKind == MethodKind.LocalFunction)
{
BoundExpression receiver;
MethodSymbol method;
RemapLocalFunction(conversion.Syntax, conversion.SymbolOpt, out receiver, out method);
var result = new BoundDelegateCreationExpression(conversion.Syntax, receiver, method, isExtensionMethod: false, type: conversion.Type);
return result;
}
return base.VisitConversion(conversion);
}
}
......@@ -976,7 +1009,7 @@ public override BoundNode VisitLocalFunctionStatement(BoundLocalFunctionStatemen
// Move the body of the lambda to a freshly generated synthetic method on its frame.
DebugId topLevelMethodId = GetTopLevelMethodId();
DebugId lambdaId = GetLambdaId(node.Syntax, closureKind, closureOrdinal);
var synthesizedMethod = new SynthesizedLambdaMethod(translatedLambdaContainer, closureKind, _topLevelMethod, topLevelMethodId, node, lambdaId);
CompilationState.ModuleBuilderOpt.AddSynthesizedDefinition(translatedLambdaContainer, synthesizedMethod);
......@@ -984,33 +1017,8 @@ public override BoundNode VisitLocalFunctionStatement(BoundLocalFunctionStatemen
{
_parameterMap.Add(parameter, synthesizedMethod.Parameters[parameter.Ordinal]);
}
// Rewrite the lambda expression (and the enclosing anonymous method conversion) as a delegate creation expression
NamedTypeSymbol constructedFrame = (object)containerAsFrame != null ?
translatedLambdaContainer.ConstructIfGeneric(StaticCast<TypeSymbol>.From(_currentTypeParameters)) :
translatedLambdaContainer;
// for instance lambdas, receiver is the frame
// for static lambdas, get the singleton receiver
BoundExpression receiver;
if (closureKind != ClosureKind.Static)
{
receiver = FrameOfType(node.Syntax, constructedFrame);
}
else
{
var field = containerAsFrame.SingletonCache.AsMember(constructedFrame);
receiver = new BoundFieldAccess(node.Syntax, null, field, constantValueOpt: null);
}
MethodSymbol referencedMethod = synthesizedMethod.AsMember(constructedFrame);
if (referencedMethod.IsGenericMethod)
{
referencedMethod = referencedMethod.Construct(StaticCast<TypeSymbol>.From(_currentTypeParameters));
}
// null means "use closed value"
_localFunctionMap.Add(node.Symbol, new LocalFunctionMapValue(null, referencedMethod));
_localFunctionMap[node.Symbol] = synthesizedMethod;
// rewrite the lambda body as the generated method's body
var oldMethod = _currentMethod;
......@@ -1063,8 +1071,6 @@ public override BoundNode VisitLocalFunctionStatement(BoundLocalFunctionStatemen
_addedLocals = oldAddedLocals;
_addedStatements = oldAddedStatements;
_localFunctionMap[node.Symbol] = new LocalFunctionMapValue(receiver, _localFunctionMap[node.Symbol].Symbol);
return new BoundNoOpStatement(node.Syntax, NoOpStatementFlavor.Default);
}
......
......@@ -2,7 +2,6 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CSharp.Symbols
{
......@@ -10,6 +9,7 @@ internal class LocalFunctionMethodSymbol : SourceMethodSymbol
{
private readonly Binder _binder;
private readonly LocalFunctionStatementSyntax _syntax;
private readonly Symbol _containingSymbol;
private ImmutableArray<ParameterSymbol> _parameters;
private ImmutableArray<TypeParameterSymbol> _typeParameters;
private TypeSymbol _returnType;
......@@ -18,6 +18,7 @@ internal class LocalFunctionMethodSymbol : SourceMethodSymbol
public LocalFunctionMethodSymbol(
Binder binder,
NamedTypeSymbol containingType,
Symbol containingSymbol,
LocalFunctionStatementSyntax syntax,
Location location) :
base(
......@@ -28,6 +29,7 @@ internal class LocalFunctionMethodSymbol : SourceMethodSymbol
{
_binder = binder;
_syntax = syntax;
_containingSymbol = containingSymbol;
// It is an error to be an extension method, but we need to compute it to report it
var firstParam = syntax.ParameterList.Parameters.FirstOrDefault();
......@@ -37,11 +39,19 @@ internal class LocalFunctionMethodSymbol : SourceMethodSymbol
this.MakeFlags(
MethodKind.LocalFunction,
DeclarationModifiers.Static | syntax.Modifiers.ToDeclarationModifiers(), // TODO: Will change when we allow local captures (also change in LocalFunctionRewriter)
(_containingSymbol.IsStatic ? DeclarationModifiers.Static : 0) | syntax.Modifiers.ToDeclarationModifiers(),
returnsVoid: false, // will be fixed in MethodChecks
isExtensionMethod: isExtensionMethod);
}
public sealed override Symbol ContainingSymbol
{
get
{
return _containingSymbol;
}
}
public override string Name
{
get
......@@ -96,7 +106,7 @@ public override ImmutableArray<TypeParameterSymbol> TypeParameters
return _typeParameters;
}
}
internal override bool GenerateDebugInfo
{
get
......@@ -145,7 +155,7 @@ private void MethodChecks(DiagnosticBag diagnostics, Binder parameterBinder)
diagnostics.Add(ErrorCode.ERR_BadExtensionAgg, Locations[0]);
}
}
private ImmutableArray<TypeParameterSymbol> MakeTypeParameters(DiagnosticBag diagnostics)
{
var result = ArrayBuilder<TypeParameterSymbol>.GetInstance();
......@@ -156,7 +166,7 @@ private ImmutableArray<TypeParameterSymbol> MakeTypeParameters(DiagnosticBag dia
var identifier = parameter.Identifier;
var location = identifier.GetLocation();
var name = identifier.ValueText;
// TODO: Add diagnostic checks for nested local functions (and containing method)
if (name == this.Name)
{
......@@ -178,7 +188,7 @@ private ImmutableArray<TypeParameterSymbol> MakeTypeParameters(DiagnosticBag dia
// Type parameter '{0}' has the same name as the type parameter from outer type '{1}'
diagnostics.Add(ErrorCode.WRN_TypeParameterSameAsOuterTypeParameter, location, name, tpEnclosing.ContainingType);
}
var typeParameter = new LocalFunctionTypeParameterSymbol(
this,
name,
......@@ -206,7 +216,7 @@ public override int GetHashCode()
public sealed override bool Equals(object symbol)
{
if ((object)this == symbol) return true;
var localFunction = symbol as LocalFunctionMethodSymbol;
return (object)localFunction != null
&& localFunction._syntax == _syntax
......
......@@ -334,7 +334,7 @@ protected virtual void LazyAsyncMethodChecks(CancellationToken cancellationToken
state.NotePartComplete(CompletionPart.FinishAsyncMethodChecks);
}
public sealed override Symbol ContainingSymbol
public override Symbol ContainingSymbol
{
get
{
......
......@@ -50,6 +50,76 @@ static void Main(string[] args)
");
}
[Fact]
public void StandardMethodFeatures()
{
var source = @"
using System;
class Program
{
static void Main(string[] args)
{
void Params(params int[] x)
{
Console.WriteLine(string.Join("","", x));
}
void RefOut(ref int x, out int y)
{
y = ++x;
}
void NamedOptional(int x = 2)
{
Console.WriteLine(x);
}
Params(2);
int a = 1;
int b;
RefOut(ref a, out b);
Console.WriteLine(a);
Console.WriteLine(b);
NamedOptional(x: 2);
NamedOptional();
}
}
";
var comp = CompileAndVerify(source, expectedOutput: @"
2
2
2
2
2
");
}
[Fact]
public void Delegate()
{
var source = @"
using System;
class Program
{
static void Main(string[] args)
{
int Local(int x) => x;
Func<int, int> local = Local;
Console.WriteLine(local(2));
void Local2()
{
Console.WriteLine(2);
}
var local2 = new Action(Local2);
local2();
}
}
";
var comp = CompileAndVerify(source, expectedOutput: @"
2
2
");
}
[Fact]
public void Closure()
{
......@@ -81,6 +151,45 @@ static void Main(string[] args)
");
}
[Fact]
public void InstanceClosure()
{
var source = @"
using System;
class Program
{
int x;
void A(int y)
{
void Local()
{
A(x + y);
}
if (y != 0)
{
Console.WriteLine(y);
}
else
{
Local();
}
}
static void Main(string[] args)
{
var prog = new Program();
prog.x = 2;
prog.A(0);
}
}
";
var comp = CompileAndVerify(source, expectedOutput: @"
2
");
}
[Fact]
public void SelfClosure()
{
......@@ -148,7 +257,6 @@ public void MutualRecursion()
{
var source = @"
using System;
using System.Collections.Generic;
class Program
{
......@@ -257,7 +365,6 @@ public void Generic()
{
var source = @"
using System;
using System.Collections.Generic;
class Program
{
......@@ -271,9 +378,13 @@ T Local<T>(T val)
}
}
";
var comp = CompileAndVerify(source, expectedOutput: @"
2
");
// TODO: Eventually support this
var option = TestOptions.ReleaseExe.WithWarningLevel(0);
CreateCompilationWithMscorlibAndSystemCore(source, options: option).VerifyDiagnostics(
// (8,16): error CS1519: Invalid token '<T>' in class, struct, or interface member declaration
// T Local<T>(T val)
Diagnostic(ErrorCode.ERR_InvalidMemberDecl, "<T>").WithArguments("<T>").WithLocation(8, 16)
);
}
[Fact]
......@@ -304,6 +415,35 @@ static void Main(string[] args)
");
}
[Fact]
public void Shadows()
{
var source = @"
using System;
using System.Collections.Generic;
class Program
{
static void Local()
{
Console.WriteLine(""bad"");
}
static void Main(string[] args)
{
void Local()
{
Console.WriteLine(2);
}
Local();
}
}
";
var comp = CompileAndVerify(source, expectedOutput: @"
2
");
}
[Fact]
public void NoBody()
{
......@@ -390,7 +530,7 @@ IEnumerable<int> Local(__arglist)
}
[Fact]
public void PartialExpressionBody()
public void ForwardReference()
{
var source = @"
using System;
......@@ -400,21 +540,16 @@ class Program
{
static void Main(string[] args)
{
void Local() => Console.Writ
Console.WriteLine(Local());
int Local() => 2;
}
}
";
var option = TestOptions.ReleaseExe.WithWarningLevel(0);
CreateCompilationWithMscorlibAndSystemCore(source, options: option).VerifyDiagnostics(
// (9,37): error CS1002: ; expected
// void Local() => Console.Writ
Diagnostic(ErrorCode.ERR_SemicolonExpected, "").WithLocation(9, 37),
// (9,33): error CS0117: 'Console' does not contain a definition for 'Writ'
// void Local() => Console.Writ
Diagnostic(ErrorCode.ERR_NoSuchMember, "Writ").WithArguments("System.Console", "Writ").WithLocation(9, 33),
// (9,25): error CS0201: Only assignment, call, increment, decrement, and new object expressions can be used as a statement
// void Local() => Console.Writ
Diagnostic(ErrorCode.ERR_IllegalStatement, "Console.Writ").WithLocation(9, 25)
// (9,27): error CS0841: Cannot use local variable 'Local' before it is declared
// Console.WriteLine(Local());
Diagnostic(ErrorCode.ERR_VariableUsedBeforeDeclaration, "Local").WithArguments("Local").WithLocation(9, 27)
);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册