diff --git a/src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs b/src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs index da0d262648d8e83147f5ccca020c2883e81e7742..4024c2ed21eeb090d29178fa8198d7de5ba6dc18 100644 --- a/src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs +++ b/src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs @@ -3019,10 +3019,19 @@ public ILocalSymbol GetDeclaredSymbol(ForEachStatementSyntax forEachStatement, C LocalSymbol local = foreachBinder.GetDeclaredLocalsForScope(forEachStatement).FirstOrDefault(); return ((object)local != null && local.DeclarationKind == LocalDeclarationKind.ForEachIterationVariable) - ? local + ? GetAdjustedLocalSymbol(local, forEachStatement.SpanStart) : null; } + /// + /// Given a local symbol, gets an updated version of that local symbol adjusted for nullability analysis + /// if the analysis affects the local. + /// + /// The original symbol from initial binding. + /// The position the local was declared at. + /// The nullability-adjusted local, or the original symbol if the nullability analysis made no adjustments or was not run. + internal abstract LocalSymbol GetAdjustedLocalSymbol(LocalSymbol originalSymbol, int position); + /// /// Given a catch declaration, get the symbol for the exception variable /// diff --git a/src/Compilers/CSharp/Portable/Compilation/MemberSemanticModel.cs b/src/Compilers/CSharp/Portable/Compilation/MemberSemanticModel.cs index 48adf5b86b3b19a8ac98697c7162351f11027246..b24264f8c75f48af08b9e1a93dc1dd09f9f34b05 100644 --- a/src/Compilers/CSharp/Portable/Compilation/MemberSemanticModel.cs +++ b/src/Compilers/CSharp/Portable/Compilation/MemberSemanticModel.cs @@ -661,33 +661,7 @@ private LocalSymbol GetDeclaredLocal(CSharpSyntaxNode declarationSyntax, SyntaxT { if (local.IdentifierToken == declaredIdentifier) { - Debug.Assert(local is SourceLocalSymbol); - LocalSymbol adjustedLocal; - if (Compilation.NullableSemanticAnalysisEnabled) - { - if (!_analyzedVariableTypesOpt.TryGetValue(local, out adjustedLocal)) - { - var types = GetSnapshotManager().GetVariableTypesForPosition(declarationSyntax.SpanStart); - - // If the local was not inferred, it does not get an entry in this dictionary. Save the local mapped - // to itself to avoid needing to enter this code path in the future. - if (types.TryGetValue(local, out TypeWithAnnotations type)) - { - adjustedLocal = _analyzedVariableTypesOpt.GetOrAdd(local, ((SourceLocalSymbol)local).WithAnalyzedType(type)); - } - else - { - _analyzedVariableTypesOpt.TryAdd(local, local); - adjustedLocal = local; - } - } - } - else - { - adjustedLocal = local; - } - - return adjustedLocal; + return GetAdjustedLocalSymbol(local, declarationSyntax.SpanStart); } } } @@ -695,6 +669,37 @@ private LocalSymbol GetDeclaredLocal(CSharpSyntaxNode declarationSyntax, SyntaxT return null; } + internal override LocalSymbol GetAdjustedLocalSymbol(LocalSymbol local, int position) + { + Debug.Assert(local is SourceLocalSymbol); + LocalSymbol adjustedLocal; + if (Compilation.NullableSemanticAnalysisEnabled) + { + if (!_analyzedVariableTypesOpt.TryGetValue(local, out adjustedLocal)) + { + var types = GetSnapshotManager().GetVariableTypesForPosition(position); + + // If the local was not inferred, it does not get an entry in this dictionary. Save the local mapped + // to itself to avoid needing to enter this code path in the future. + if (types.TryGetValue(local, out TypeWithAnnotations type)) + { + adjustedLocal = _analyzedVariableTypesOpt.GetOrAdd(local, ((SourceLocalSymbol)local).WithAnalyzedType(type)); + } + else + { + _analyzedVariableTypesOpt.TryAdd(local, local); + adjustedLocal = local; + } + } + } + else + { + adjustedLocal = local; + } + + return adjustedLocal; + } + private LocalFunctionSymbol GetDeclaredLocalFunction(LocalFunctionStatementSyntax declarationSyntax, SyntaxToken declaredIdentifier) { return GetDeclaredLocalFunction(this.GetEnclosingBinder(GetAdjustedNodePosition(declarationSyntax)), declaredIdentifier); diff --git a/src/Compilers/CSharp/Portable/Compilation/SyntaxTreeSemanticModel.cs b/src/Compilers/CSharp/Portable/Compilation/SyntaxTreeSemanticModel.cs index 1bf8bfdf8a8aa25adbacba6af27912934636eac6..b3e1115d0734a9359e930637bb1e5264fdf8a031 100644 --- a/src/Compilers/CSharp/Portable/Compilation/SyntaxTreeSemanticModel.cs +++ b/src/Compilers/CSharp/Portable/Compilation/SyntaxTreeSemanticModel.cs @@ -1769,6 +1769,9 @@ public override ISymbol GetDeclaredSymbol(SingleVariableDesignationSyntax declar return binder?.LookupDeclaredField(declarationSyntax); } + internal override LocalSymbol GetAdjustedLocalSymbol(LocalSymbol originalSymbol, int position) => + GetMemberModel(position)?.GetAdjustedLocalSymbol(originalSymbol, position) ?? originalSymbol; + /// /// Given a labeled statement syntax, get the corresponding label symbol. /// diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs index 29428940c09da785251e10e5b9da81e84049cab2..10b3ce74de4a225aebd8d01c2c6b94dd17598ec5 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs @@ -6427,6 +6427,12 @@ public override void VisitForEachIterationVariables(BoundForEachStatement node) // In non-error cases we'll only run this loop a single time. In error cases we'll set the nullability of the VariableType multiple times, but at least end up with something SetAnalyzedNullability(node.IterationVariableType, new VisitResult(result, destinationType), isLvalue: true); state = result.State; + + // If we inferred the type of this variable, then record the inferred type of the variable for later use by the SemanticModel. + if (node.Syntax is ForEachStatementSyntax { Type: { IsVar: true } }) + { + _variableTypes[iterationVariable] = result.ToTypeWithAnnotations(); + } } int slot = GetOrCreateSlot(iterationVariable); diff --git a/src/Compilers/CSharp/Test/Symbol/Symbols/Source/NullablePublicAPITests.cs b/src/Compilers/CSharp/Test/Symbol/Symbols/Source/NullablePublicAPITests.cs index 76fa87bea1e971c9002ad0ed152d1cb147b6fe8f..1d7d4852a574bce046834e2341495739a2db8db0 100644 --- a/src/Compilers/CSharp/Test/Symbol/Symbols/Source/NullablePublicAPITests.cs +++ b/src/Compilers/CSharp/Test/Symbol/Symbols/Source/NullablePublicAPITests.cs @@ -2049,5 +2049,140 @@ void assertAnnotation(VariableDeclaratorSyntax variable, PublicNullableAnnotatio Assert.Equal(expectedAnnotation, symbol.NullableAnnotation); } } + + [Fact] + public void GetDeclaredSymbol_Foreach_Inferred() + { + var source = @" +#pragma warning disable CS8600 +using System.Collections.Generic; +class C +{ + List GetList(T t) => throw null!; + void M(object o1, object? o2) + { + foreach (var o in GetList(o1)) {} + foreach (var o in GetList(o2)) {} + o1 = null; + foreach (var o in GetList(o1)) {} + _ = o2 ?? throw null!; + foreach (var o in GetList(o2)) {} + } +}"; + + var comp = CreateCompilation(source, options: WithNonNullTypesTrue()); + comp.VerifyDiagnostics(); + + var syntaxTree = comp.SyntaxTrees[0]; + var root = syntaxTree.GetRoot(); + var model = comp.GetSemanticModel(syntaxTree); + + var declarations = root.DescendantNodes().OfType().ToList(); + + assertAnnotation(declarations[0], PublicNullableAnnotation.NotAnnotated); + assertAnnotation(declarations[1], PublicNullableAnnotation.Annotated); + assertAnnotation(declarations[2], PublicNullableAnnotation.Annotated); + assertAnnotation(declarations[3], PublicNullableAnnotation.NotAnnotated); + + void assertAnnotation(ForEachStatementSyntax variable, PublicNullableAnnotation expectedAnnotation) + { + var symbol = model.GetDeclaredSymbol(variable); + Assert.Equal(expectedAnnotation, symbol.NullableAnnotation); + } + } + + [Fact] + public void GetDeclaredSymbol_Foreach_NoInference() + { + var source = @" +#pragma warning disable CS8600 +using System.Collections.Generic; +class C +{ + List GetList(T t) => throw null!; + void M(object o1, object? o2) + { + foreach (object? o in GetList(o1)) {} + foreach (object o in GetList(o2)) {} + o1 = null; + foreach (object o in GetList(o1)) {} + _ = o2 ?? throw null!; + foreach (object? o in GetList(o2)) {} + } +}"; + + var comp = CreateCompilation(source, options: WithNonNullTypesTrue()); + comp.VerifyDiagnostics( + // (10,25): warning CS8606: Possible null reference assignment to iteration variable + // foreach (object o in GetList(o2)) {} + Diagnostic(ErrorCode.WRN_NullReferenceIterationVariable, "o").WithLocation(10, 25), + // (12,25): warning CS8606: Possible null reference assignment to iteration variable + // foreach (object o in GetList(o1)) {} + Diagnostic(ErrorCode.WRN_NullReferenceIterationVariable, "o").WithLocation(12, 25)); + + var syntaxTree = comp.SyntaxTrees[0]; + var root = syntaxTree.GetRoot(); + var model = comp.GetSemanticModel(syntaxTree); + + var declarations = root.DescendantNodes().OfType().ToList(); + + assertAnnotation(declarations[0], PublicNullableAnnotation.Annotated); + assertAnnotation(declarations[1], PublicNullableAnnotation.NotAnnotated); + assertAnnotation(declarations[2], PublicNullableAnnotation.NotAnnotated); + assertAnnotation(declarations[3], PublicNullableAnnotation.Annotated); + + void assertAnnotation(ForEachStatementSyntax variable, PublicNullableAnnotation expectedAnnotation) + { + var symbol = model.GetDeclaredSymbol(variable); + Assert.Equal(expectedAnnotation, symbol.NullableAnnotation); + } + } + + [Fact] + public void GetDeclaredSymbol_Foreach_Tuples_MixedInference() + { + var source = @" +#pragma warning disable CS8600 +using System.Collections.Generic; +class C +{ + List<(T, T)> GetList(T t) => throw null!; + void M(object o1, object? o2) + { + foreach ((var o3, object? o4) in GetList(o1)) {} + foreach ((var o3, object o4) in GetList(o2)) { o3.ToString(); } + o1 = null; + foreach ((var o3, object o4) in GetList(o1)) {} + _ = o2 ?? throw null!; + foreach ((var o3, object? o4) in GetList(o2)) {} + } +}"; + + var comp = CreateCompilation(source, options: WithNonNullTypesTrue()); + comp.VerifyDiagnostics(); + + var syntaxTree = comp.SyntaxTrees[0]; + var root = syntaxTree.GetRoot(); + var model = comp.GetSemanticModel(syntaxTree); + + var declarations = root.DescendantNodes().OfType().ToList(); + + // Some annotations are incorrect because of https://github.com/dotnet/roslyn/issues/37491 + + assertAnnotation(declarations[0], PublicNullableAnnotation.NotAnnotated); + assertAnnotation(declarations[1], PublicNullableAnnotation.Annotated); + assertAnnotation(declarations[2], PublicNullableAnnotation.NotAnnotated); // Should be Annotated + assertAnnotation(declarations[3], PublicNullableAnnotation.NotAnnotated); + assertAnnotation(declarations[4], PublicNullableAnnotation.NotAnnotated); // Should be Annotated + assertAnnotation(declarations[5], PublicNullableAnnotation.NotAnnotated); + assertAnnotation(declarations[6], PublicNullableAnnotation.NotAnnotated); + assertAnnotation(declarations[7], PublicNullableAnnotation.Annotated); + + void assertAnnotation(SingleVariableDesignationSyntax variable, PublicNullableAnnotation expectedAnnotation) + { + var symbol = (ILocalSymbol)model.GetDeclaredSymbol(variable); + Assert.Equal(expectedAnnotation, symbol.NullableAnnotation); + } + } } }