提交 e1ac8e4d 编写于 作者: A Andy Gocke 提交者: Julien Couvreur

Improve nullable analysis of local functions (#40422)

* Improve nullable analysis of local functions

This design tries to meld better analysis of nullable reference types in
local functions with performance. To keep the common case one pass,
local functions are analyzed using the starting state that is an
intersection of all the states before its usages (calls, delegate
conversions, etc), but the results of variables made nullable or
non-nullable inside the local function do not propagate to the callers.

* Respond to PR comments
上级 e94cd673
......@@ -200,6 +200,7 @@ protected void Unsplit()
this.Diagnostics = DiagnosticBag.GetInstance();
this.compilation = compilation;
_symbol = symbol;
CurrentSymbol = symbol;
this.methodMainNode = node;
this.firstInRegion = firstInRegion;
this.lastInRegion = lastInRegion;
......@@ -1167,11 +1168,10 @@ public override BoundNode VisitCall(BoundCall node)
return null;
}
private void VisitLocalFunctionUse(LocalFunctionSymbol symbol, SyntaxNode syntax, bool isCall)
protected void VisitLocalFunctionUse(LocalFunctionSymbol symbol, SyntaxNode syntax, bool isCall)
{
var localFuncState = GetOrCreateLocalFuncUsages(symbol);
VisitLocalFunctionUse(symbol, localFuncState, syntax, isCall);
localFuncState.Visited = true;
}
protected virtual void VisitLocalFunctionUse(
......@@ -1185,6 +1185,7 @@ private void VisitLocalFunctionUse(LocalFunctionSymbol symbol, SyntaxNode syntax
Join(ref State, ref localFunctionState.StateFromBottom);
Meet(ref State, ref localFunctionState.StateFromTop);
}
localFunctionState.Visited = true;
}
private void VisitReceiverBeforeCall(BoundExpression receiverOpt, MethodSymbol method)
......
......@@ -77,22 +77,6 @@ private void CheckIfAssignedDuringLocalFunctionReplay(Symbol symbol, SyntaxNode
}
}
private int RootSlot(int slot)
{
while (true)
{
var varInfo = variableBySlot[slot];
if (varInfo.ContainingSlot == 0)
{
return slot;
}
else
{
slot = varInfo.ContainingSlot;
}
}
}
private void RecordReadInLocalFunction(int slot)
{
var localFunc = GetNearestLocalFunctionOpt(CurrentSymbol);
......
......@@ -132,7 +132,6 @@ internal partial class DefiniteAssignmentPass : LocalDataFlowPass<
{
this.initiallyAssignedVariables = null;
_sourceAssembly = ((object)member == null) ? null : (SourceAssemblySymbol)member.ContainingAssembly;
this.CurrentSymbol = member;
_unassignedVariableAddressOfSyntaxes = unassignedVariableAddressOfSyntaxes;
_requireOutParamsAssigned = requireOutParamsAssigned;
_trackClassFields = trackClassFields;
......
......@@ -286,5 +286,21 @@ protected int MakeMemberSlot(BoundExpression receiverOpt, Symbol member)
}
return GetOrCreateSlot(member, containingSlot);
}
protected int RootSlot(int slot)
{
while (true)
{
ref var varInfo = ref variableBySlot[slot];
if (varInfo.ContainingSlot == 0)
{
return slot;
}
else
{
slot = varInfo.ContainingSlot;
}
}
}
}
}
......@@ -153,17 +153,21 @@ public VisitArgumentResult(VisitResult visitResult, Optional<LocalState> stateFo
/// </summary>
private static readonly TypeWithState _invalidType = TypeWithState.Create(ErrorTypeSymbol.UnknownResultType, NullableFlowState.NotNull);
#nullable enable
/// <summary>
/// Contains the map of expressions to inferred nullabilities and types used by the optional rewriter phase of the
/// compiler.
/// </summary>
private readonly ImmutableDictionary<BoundExpression, (NullabilityInfo Info, TypeSymbol Type)>.Builder _analyzedNullabilityMapOpt;
private readonly ImmutableDictionary<BoundExpression, (NullabilityInfo Info, TypeSymbol Type)>.Builder? _analyzedNullabilityMapOpt;
/// <summary>
/// Manages creating snapshots of the walker as appropriate. Null if we're not taking snapshots of
/// this walker.
/// </summary>
private readonly SnapshotManager.Builder _snapshotBuilderOpt;
private readonly SnapshotManager.Builder? _snapshotBuilderOpt;
#nullable disable
// https://github.com/dotnet/roslyn/issues/35043: remove this when all expression are supported
private bool _disableNullabilityAnalysis;
......@@ -1379,7 +1383,7 @@ protected override LocalState ReachableBottomState()
private void EnterParameters()
{
if (!(_symbol is MethodSymbol methodSymbol))
if (!(CurrentSymbol is MethodSymbol methodSymbol))
{
return;
}
......@@ -1485,7 +1489,7 @@ private TypeWithState VisitRefExpression(BoundExpression expr, TypeWithAnnotatio
private bool TryGetReturnType(out TypeWithAnnotations type)
{
var method = _symbol as MethodSymbol;
var method = CurrentSymbol as MethodSymbol;
if (method is null)
{
type = default;
......@@ -1540,14 +1544,149 @@ public override BoundNode VisitLocal(BoundLocal node)
public override BoundNode VisitBlock(BoundBlock node)
{
DeclareLocals(node.Locals);
foreach (var statement in node.Statements)
VisitStatementsWithLocalFunctions(node);
return null;
}
#nullable enable
private void VisitStatementsWithLocalFunctions(BoundBlock block)
{
// Since the nullable flow state affects type information, and types can be queried by
// the semantic model, there needs to be a single flow state input to a local function
// that cannot be path-dependent. To decide the local starting state we Meet the state
// of captured variables from all the uses of the local function, computing the
// conservative combination of all potential starting states.
//
// For performance we split the analysis into two phases: the first phase where we
// analyze everything except the local functions, hoping to visit all of the uses of the
// local function, and then a pass where we visit the local functions. If there's no
// recursion or calls between the local functions, the starting state of the local
// function should be stable and we don't need a second pass.
if (!TrackingRegions && !block.LocalFunctions.IsDefaultOrEmpty)
{
// First visit everything else
foreach (var stmt in block.Statements)
{
if (stmt.Kind != BoundKind.LocalFunctionStatement)
{
VisitStatement(stmt);
}
}
// Now visit the local function bodies
foreach (var stmt in block.Statements)
{
if (stmt is BoundLocalFunctionStatement localFunc)
{
VisitLocalFunctionStatement(localFunc);
}
}
}
else
{
foreach (var stmt in block.Statements)
{
VisitStatement(stmt);
}
}
}
public override BoundNode? VisitLocalFunctionStatement(BoundLocalFunctionStatement localFunc)
{
var oldSymbol = this.CurrentSymbol;
var localFuncSymbol = localFunc.Symbol;
this.CurrentSymbol = localFuncSymbol;
var oldPending = SavePending(); // we do not support branches into a lambda
var savedState = this.State;
var localFunctionState = GetOrCreateLocalFuncUsages(localFuncSymbol);
// The starting state is the top state, but with captured
// variables set according to Joining the state at all the
// local function use sites
State = TopState().Clone();
for (int slot = 1; slot < localFunctionState.StartingState.Capacity; slot++)
{
var symbol = variableBySlot[RootSlot(slot)].Symbol;
if (Symbol.IsCaptured(symbol, localFunc.Symbol))
{
State[slot] = localFunctionState.StartingState[slot];
}
}
localFunctionState.Visited = true;
if (!localFunc.WasCompilerGenerated) EnterParameters(localFuncSymbol.Parameters);
// State changes to captured variables are recorded, as calls to local functions
// transition the state of captured variables if the variables have state changes
// across all branches leaving the local function
var oldPending2 = SavePending();
// If this is an iterator, there's an implicit branch before the first statement
// of the function where the enumerable is returned.
if (localFuncSymbol.IsIterator)
{
PendingBranches.Add(new PendingBranch(null, this.State, null));
}
VisitAlways(localFunc.Body);
RestorePending(oldPending2); // process any forward branches within the lambda body
ImmutableArray<PendingBranch> pendingReturns = RemoveReturns();
RestorePending(oldPending);
Location? location = null;
if (!localFuncSymbol.Locations.IsDefaultOrEmpty)
{
VisitStatement(statement);
location = localFuncSymbol.Locations[0];
}
LeaveParameters(localFuncSymbol.Parameters, localFunc.Syntax, location);
// Intersect the state of all branches out of the local function
var stateAtReturn = this.State;
foreach (PendingBranch pending in pendingReturns)
{
this.State = pending.State;
BoundNode branch = pending.Branch;
// Pass the local function identifier as a location if the branch
// is null or compiler generated.
LeaveParameters(localFuncSymbol.Parameters,
branch?.Syntax,
branch?.WasCompilerGenerated == false ? null : location);
Join(ref stateAtReturn, ref this.State);
}
this.State = savedState;
this.CurrentSymbol = oldSymbol;
SetInvalidResult();
return null;
}
protected override void VisitLocalFunctionUse(
LocalFunctionSymbol symbol,
LocalFunctionState localFunctionState,
SyntaxNode syntax,
bool isCall)
{
if (Join(ref localFunctionState.StartingState, ref State) &&
localFunctionState.Visited)
{
// If the starting state of the local function has changed and we've already visited
// the local function, we need another pass
stateChangedAfterUse = true;
}
}
#nullable restore
public override BoundNode VisitDoStatement(BoundDoStatement node)
{
DeclareLocals(node.Locals);
......@@ -3191,6 +3330,10 @@ public override BoundNode VisitCall(BoundCall node)
// Note: we analyze even omitted calls
TypeWithState receiverType = VisitCallReceiver(node);
ReinferMethodAndVisitArguments(node, receiverType);
if (node.Method?.OriginalDefinition is LocalFunctionSymbol localFunc)
{
VisitLocalFunctionUse(localFunc, node.Syntax, isCall: true);
}
return null;
}
......@@ -3213,12 +3356,6 @@ private void ReinferMethodAndVisitArguments(BoundCall node, TypeWithState receiv
LearnFromCompareExchangeMethod(method, node, results);
if (method.MethodKind == MethodKind.LocalFunction)
{
var localFunc = (LocalFunctionSymbol)method.OriginalDefinition;
ReplayReadsAndWrites(localFunc, node.Syntax, writes: true);
}
var returnState = GetReturnTypeWithState(method);
if (returnNotNull)
{
......@@ -4428,13 +4565,6 @@ private void CheckMethodConstraints(SyntaxNode syntax, MethodSymbol method)
diagnosticsBuilder.Free();
}
private void ReplayReadsAndWrites(LocalFunctionSymbol localFunc,
SyntaxNode syntax,
bool writes)
{
// https://github.com/dotnet/roslyn/issues/27233 Support field initializers in local functions.
}
/// <summary>
/// Returns the expression without the top-most conversion plus the conversion.
/// If the expression is not a conversion, returns the original expression plus
......@@ -5149,6 +5279,10 @@ private static BoundConversion GetConversionIfApplicable(BoundExpression convers
var method = conversion.Method;
if (group != null)
{
if (method?.OriginalDefinition is LocalFunctionSymbol localFunc)
{
VisitLocalFunctionUse(localFunc, group.Syntax, isCall: false);
}
method = CheckMethodGroupReceiverNullability(group, delegateType, method, conversion.IsExtensionMethod);
}
if (reportRemainingWarnings)
......@@ -5698,11 +5832,9 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE
{
Debug.Assert(node.Type.IsDelegateType());
if (node.MethodOpt?.MethodKind == MethodKind.LocalFunction)
if (node.MethodOpt?.OriginalDefinition is LocalFunctionSymbol localFunc)
{
var syntax = node.Syntax;
var localFunc = (LocalFunctionSymbol)node.MethodOpt.OriginalDefinition;
ReplayReadsAndWrites(localFunc, syntax, writes: false);
VisitLocalFunctionUse(localFunc, node.Syntax, isCall: true);
}
var delegateType = (NamedTypeSymbol)node.Type;
......@@ -5895,36 +6027,6 @@ public override BoundNode VisitUnboundLambda(UnboundLambda node)
return null;
}
public override BoundNode VisitLocalFunctionStatement(BoundLocalFunctionStatement node)
{
var body = node.Body;
if (body != null)
{
var analyzedNullabilityMap = _analyzedNullabilityMapOpt;
var snapshotBuilder = _snapshotBuilderOpt;
if (_disableNullabilityAnalysis)
{
analyzedNullabilityMap = null;
snapshotBuilder = null;
}
Analyze(compilation,
node.Symbol,
body,
_binder,
_conversions,
Diagnostics,
useMethodSignatureParameterTypes: false,
delegateInvokeMethodOpt: null,
initialState: GetVariableState(this.TopState()),
analyzedNullabilityMap,
snapshotBuilder,
returnTypesOpt: null);
}
SetInvalidResult();
return null;
}
public override BoundNode VisitThisReference(BoundThisReference node)
{
VisitThisOrBaseReference(node);
......@@ -7827,7 +7929,7 @@ public override BoundNode VisitYieldReturnStatement(BoundYieldReturnStatement no
{
return null;
}
var method = _delegateInvokeMethod ?? (MethodSymbol)_symbol;
var method = _delegateInvokeMethod ?? (MethodSymbol)CurrentSymbol;
TypeWithAnnotations elementType = InMethodBinder.GetIteratorElementTypeFromReturnType(compilation, RefKind.None,
method.ReturnType, errorLocation: null, diagnostics: null);
......@@ -8075,9 +8177,15 @@ internal string GetDebuggerDisplay()
internal sealed class LocalFunctionState : AbstractLocalFunctionState
{
/// <summary>
/// Defines the starting state used in the local function body to
/// produce diagnostics and determine types.
/// </summary>
public LocalState StartingState;
public LocalFunctionState(LocalState unreachableState)
: base(unreachableState)
{
StartingState = unreachableState;
}
}
......
......@@ -43286,9 +43286,6 @@ static void F4(object? x4)
// (17,9): warning CS8602: Dereference of a possibly null reference.
// z2.ToString();
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "z2").WithLocation(17, 9),
// (28,18): warning CS8600: Converting null literal or possible null value to non-nullable type.
// z3 = x3;
Diagnostic(ErrorCode.WRN_ConvertingNullableToNonNullable, "x3").WithLocation(28, 18),
// (36,9): warning CS8602: Dereference of a possibly null reference.
// f().ToString(); // warning
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "f()").WithLocation(36, 9),
......@@ -43297,6 +43294,106 @@ static void F4(object? x4)
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "f()").WithLocation(37, 25));
}
[Fact]
[WorkItem(29892, "https://github.com/dotnet/roslyn/issues/29892")]
public void LocalFunction_02()
{
var source =
@"class C
{
static void F1()
{
string? x = """";
f();
x = """";
g();
void f()
{
x.ToString(); // warn
x = null;
f();
}
void g()
{
x.ToString();
x = null;
}
}
}";
var comp = CreateCompilation(new[] { source }, options: WithNonNullTypesTrue());
comp.VerifyDiagnostics(
// (11,13): warning CS8602: Dereference of a possibly null reference.
// x.ToString(); // warn
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "x").WithLocation(11, 13)
);
}
[Fact]
[WorkItem(29892, "https://github.com/dotnet/roslyn/issues/29892")]
public void LocalFunction_03()
{
var source =
@"class C
{
static void F1()
{
string? x = """";
f();
h();
void f()
{
x.ToString();
}
void g()
{
x.ToString(); // warn
}
void h()
{
x = null;
g();
}
}
}";
var comp = CreateCompilation(new[] { source }, options: WithNonNullTypesTrue());
comp.VerifyDiagnostics(
// (14,13): warning CS8602: Dereference of a possibly null reference.
// x.ToString(); // warn
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "x").WithLocation(14, 13)
);
}
[Fact]
[WorkItem(29892, "https://github.com/dotnet/roslyn/issues/29892")]
public void LocalFunction_04()
{
var source =
@"class C
{
static void F1()
{
string? x = """";
f();
void f()
{
x.ToString(); // warn
if (string.Empty == """") // non-constant
{
x = null;
f();
}
}
}
}";
var comp = CreateCompilation(new[] { source }, options: WithNonNullTypesTrue());
comp.VerifyDiagnostics(
// (9,13): warning CS8602: Dereference of a possibly null reference.
// x.ToString(); // warn
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "x").WithLocation(9, 13)
);
}
/// <summary>
/// Should report warnings within unused local functions.
/// </summary>
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册