diff --git a/src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs b/src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs
index da0d262648d8e83147f5ccca020c2883e81e7742..4024c2ed21eeb090d29178fa8198d7de5ba6dc18 100644
--- a/src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs
+++ b/src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs
@@ -3019,10 +3019,19 @@ public ILocalSymbol GetDeclaredSymbol(ForEachStatementSyntax forEachStatement, C
LocalSymbol local = foreachBinder.GetDeclaredLocalsForScope(forEachStatement).FirstOrDefault();
return ((object)local != null && local.DeclarationKind == LocalDeclarationKind.ForEachIterationVariable)
- ? local
+ ? GetAdjustedLocalSymbol(local, forEachStatement.SpanStart)
: null;
}
+ ///
+ /// Given a local symbol, gets an updated version of that local symbol adjusted for nullability analysis
+ /// if the analysis affects the local.
+ ///
+ /// The original symbol from initial binding.
+ /// The position the local was declared at.
+ /// The nullability-adjusted local, or the original symbol if the nullability analysis made no adjustments or was not run.
+ internal abstract LocalSymbol GetAdjustedLocalSymbol(LocalSymbol originalSymbol, int position);
+
///
/// Given a catch declaration, get the symbol for the exception variable
///
diff --git a/src/Compilers/CSharp/Portable/Compilation/MemberSemanticModel.cs b/src/Compilers/CSharp/Portable/Compilation/MemberSemanticModel.cs
index 48adf5b86b3b19a8ac98697c7162351f11027246..b24264f8c75f48af08b9e1a93dc1dd09f9f34b05 100644
--- a/src/Compilers/CSharp/Portable/Compilation/MemberSemanticModel.cs
+++ b/src/Compilers/CSharp/Portable/Compilation/MemberSemanticModel.cs
@@ -661,33 +661,7 @@ private LocalSymbol GetDeclaredLocal(CSharpSyntaxNode declarationSyntax, SyntaxT
{
if (local.IdentifierToken == declaredIdentifier)
{
- Debug.Assert(local is SourceLocalSymbol);
- LocalSymbol adjustedLocal;
- if (Compilation.NullableSemanticAnalysisEnabled)
- {
- if (!_analyzedVariableTypesOpt.TryGetValue(local, out adjustedLocal))
- {
- var types = GetSnapshotManager().GetVariableTypesForPosition(declarationSyntax.SpanStart);
-
- // If the local was not inferred, it does not get an entry in this dictionary. Save the local mapped
- // to itself to avoid needing to enter this code path in the future.
- if (types.TryGetValue(local, out TypeWithAnnotations type))
- {
- adjustedLocal = _analyzedVariableTypesOpt.GetOrAdd(local, ((SourceLocalSymbol)local).WithAnalyzedType(type));
- }
- else
- {
- _analyzedVariableTypesOpt.TryAdd(local, local);
- adjustedLocal = local;
- }
- }
- }
- else
- {
- adjustedLocal = local;
- }
-
- return adjustedLocal;
+ return GetAdjustedLocalSymbol(local, declarationSyntax.SpanStart);
}
}
}
@@ -695,6 +669,37 @@ private LocalSymbol GetDeclaredLocal(CSharpSyntaxNode declarationSyntax, SyntaxT
return null;
}
+ internal override LocalSymbol GetAdjustedLocalSymbol(LocalSymbol local, int position)
+ {
+ Debug.Assert(local is SourceLocalSymbol);
+ LocalSymbol adjustedLocal;
+ if (Compilation.NullableSemanticAnalysisEnabled)
+ {
+ if (!_analyzedVariableTypesOpt.TryGetValue(local, out adjustedLocal))
+ {
+ var types = GetSnapshotManager().GetVariableTypesForPosition(position);
+
+ // If the local was not inferred, it does not get an entry in this dictionary. Save the local mapped
+ // to itself to avoid needing to enter this code path in the future.
+ if (types.TryGetValue(local, out TypeWithAnnotations type))
+ {
+ adjustedLocal = _analyzedVariableTypesOpt.GetOrAdd(local, ((SourceLocalSymbol)local).WithAnalyzedType(type));
+ }
+ else
+ {
+ _analyzedVariableTypesOpt.TryAdd(local, local);
+ adjustedLocal = local;
+ }
+ }
+ }
+ else
+ {
+ adjustedLocal = local;
+ }
+
+ return adjustedLocal;
+ }
+
private LocalFunctionSymbol GetDeclaredLocalFunction(LocalFunctionStatementSyntax declarationSyntax, SyntaxToken declaredIdentifier)
{
return GetDeclaredLocalFunction(this.GetEnclosingBinder(GetAdjustedNodePosition(declarationSyntax)), declaredIdentifier);
diff --git a/src/Compilers/CSharp/Portable/Compilation/SyntaxTreeSemanticModel.cs b/src/Compilers/CSharp/Portable/Compilation/SyntaxTreeSemanticModel.cs
index 1bf8bfdf8a8aa25adbacba6af27912934636eac6..b3e1115d0734a9359e930637bb1e5264fdf8a031 100644
--- a/src/Compilers/CSharp/Portable/Compilation/SyntaxTreeSemanticModel.cs
+++ b/src/Compilers/CSharp/Portable/Compilation/SyntaxTreeSemanticModel.cs
@@ -1769,6 +1769,9 @@ public override ISymbol GetDeclaredSymbol(SingleVariableDesignationSyntax declar
return binder?.LookupDeclaredField(declarationSyntax);
}
+ internal override LocalSymbol GetAdjustedLocalSymbol(LocalSymbol originalSymbol, int position) =>
+ GetMemberModel(position)?.GetAdjustedLocalSymbol(originalSymbol, position) ?? originalSymbol;
+
///
/// Given a labeled statement syntax, get the corresponding label symbol.
///
diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs
index 29428940c09da785251e10e5b9da81e84049cab2..10b3ce74de4a225aebd8d01c2c6b94dd17598ec5 100644
--- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs
+++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs
@@ -6427,6 +6427,12 @@ public override void VisitForEachIterationVariables(BoundForEachStatement node)
// In non-error cases we'll only run this loop a single time. In error cases we'll set the nullability of the VariableType multiple times, but at least end up with something
SetAnalyzedNullability(node.IterationVariableType, new VisitResult(result, destinationType), isLvalue: true);
state = result.State;
+
+ // If we inferred the type of this variable, then record the inferred type of the variable for later use by the SemanticModel.
+ if (node.Syntax is ForEachStatementSyntax { Type: { IsVar: true } })
+ {
+ _variableTypes[iterationVariable] = result.ToTypeWithAnnotations();
+ }
}
int slot = GetOrCreateSlot(iterationVariable);
diff --git a/src/Compilers/CSharp/Test/Symbol/Symbols/Source/NullablePublicAPITests.cs b/src/Compilers/CSharp/Test/Symbol/Symbols/Source/NullablePublicAPITests.cs
index 76fa87bea1e971c9002ad0ed152d1cb147b6fe8f..1d7d4852a574bce046834e2341495739a2db8db0 100644
--- a/src/Compilers/CSharp/Test/Symbol/Symbols/Source/NullablePublicAPITests.cs
+++ b/src/Compilers/CSharp/Test/Symbol/Symbols/Source/NullablePublicAPITests.cs
@@ -2049,5 +2049,140 @@ void assertAnnotation(VariableDeclaratorSyntax variable, PublicNullableAnnotatio
Assert.Equal(expectedAnnotation, symbol.NullableAnnotation);
}
}
+
+ [Fact]
+ public void GetDeclaredSymbol_Foreach_Inferred()
+ {
+ var source = @"
+#pragma warning disable CS8600
+using System.Collections.Generic;
+class C
+{
+ List GetList(T t) => throw null!;
+ void M(object o1, object? o2)
+ {
+ foreach (var o in GetList(o1)) {}
+ foreach (var o in GetList(o2)) {}
+ o1 = null;
+ foreach (var o in GetList(o1)) {}
+ _ = o2 ?? throw null!;
+ foreach (var o in GetList(o2)) {}
+ }
+}";
+
+ var comp = CreateCompilation(source, options: WithNonNullTypesTrue());
+ comp.VerifyDiagnostics();
+
+ var syntaxTree = comp.SyntaxTrees[0];
+ var root = syntaxTree.GetRoot();
+ var model = comp.GetSemanticModel(syntaxTree);
+
+ var declarations = root.DescendantNodes().OfType().ToList();
+
+ assertAnnotation(declarations[0], PublicNullableAnnotation.NotAnnotated);
+ assertAnnotation(declarations[1], PublicNullableAnnotation.Annotated);
+ assertAnnotation(declarations[2], PublicNullableAnnotation.Annotated);
+ assertAnnotation(declarations[3], PublicNullableAnnotation.NotAnnotated);
+
+ void assertAnnotation(ForEachStatementSyntax variable, PublicNullableAnnotation expectedAnnotation)
+ {
+ var symbol = model.GetDeclaredSymbol(variable);
+ Assert.Equal(expectedAnnotation, symbol.NullableAnnotation);
+ }
+ }
+
+ [Fact]
+ public void GetDeclaredSymbol_Foreach_NoInference()
+ {
+ var source = @"
+#pragma warning disable CS8600
+using System.Collections.Generic;
+class C
+{
+ List GetList(T t) => throw null!;
+ void M(object o1, object? o2)
+ {
+ foreach (object? o in GetList(o1)) {}
+ foreach (object o in GetList(o2)) {}
+ o1 = null;
+ foreach (object o in GetList(o1)) {}
+ _ = o2 ?? throw null!;
+ foreach (object? o in GetList(o2)) {}
+ }
+}";
+
+ var comp = CreateCompilation(source, options: WithNonNullTypesTrue());
+ comp.VerifyDiagnostics(
+ // (10,25): warning CS8606: Possible null reference assignment to iteration variable
+ // foreach (object o in GetList(o2)) {}
+ Diagnostic(ErrorCode.WRN_NullReferenceIterationVariable, "o").WithLocation(10, 25),
+ // (12,25): warning CS8606: Possible null reference assignment to iteration variable
+ // foreach (object o in GetList(o1)) {}
+ Diagnostic(ErrorCode.WRN_NullReferenceIterationVariable, "o").WithLocation(12, 25));
+
+ var syntaxTree = comp.SyntaxTrees[0];
+ var root = syntaxTree.GetRoot();
+ var model = comp.GetSemanticModel(syntaxTree);
+
+ var declarations = root.DescendantNodes().OfType().ToList();
+
+ assertAnnotation(declarations[0], PublicNullableAnnotation.Annotated);
+ assertAnnotation(declarations[1], PublicNullableAnnotation.NotAnnotated);
+ assertAnnotation(declarations[2], PublicNullableAnnotation.NotAnnotated);
+ assertAnnotation(declarations[3], PublicNullableAnnotation.Annotated);
+
+ void assertAnnotation(ForEachStatementSyntax variable, PublicNullableAnnotation expectedAnnotation)
+ {
+ var symbol = model.GetDeclaredSymbol(variable);
+ Assert.Equal(expectedAnnotation, symbol.NullableAnnotation);
+ }
+ }
+
+ [Fact]
+ public void GetDeclaredSymbol_Foreach_Tuples_MixedInference()
+ {
+ var source = @"
+#pragma warning disable CS8600
+using System.Collections.Generic;
+class C
+{
+ List<(T, T)> GetList(T t) => throw null!;
+ void M(object o1, object? o2)
+ {
+ foreach ((var o3, object? o4) in GetList(o1)) {}
+ foreach ((var o3, object o4) in GetList(o2)) { o3.ToString(); }
+ o1 = null;
+ foreach ((var o3, object o4) in GetList(o1)) {}
+ _ = o2 ?? throw null!;
+ foreach ((var o3, object? o4) in GetList(o2)) {}
+ }
+}";
+
+ var comp = CreateCompilation(source, options: WithNonNullTypesTrue());
+ comp.VerifyDiagnostics();
+
+ var syntaxTree = comp.SyntaxTrees[0];
+ var root = syntaxTree.GetRoot();
+ var model = comp.GetSemanticModel(syntaxTree);
+
+ var declarations = root.DescendantNodes().OfType().ToList();
+
+ // Some annotations are incorrect because of https://github.com/dotnet/roslyn/issues/37491
+
+ assertAnnotation(declarations[0], PublicNullableAnnotation.NotAnnotated);
+ assertAnnotation(declarations[1], PublicNullableAnnotation.Annotated);
+ assertAnnotation(declarations[2], PublicNullableAnnotation.NotAnnotated); // Should be Annotated
+ assertAnnotation(declarations[3], PublicNullableAnnotation.NotAnnotated);
+ assertAnnotation(declarations[4], PublicNullableAnnotation.NotAnnotated); // Should be Annotated
+ assertAnnotation(declarations[5], PublicNullableAnnotation.NotAnnotated);
+ assertAnnotation(declarations[6], PublicNullableAnnotation.NotAnnotated);
+ assertAnnotation(declarations[7], PublicNullableAnnotation.Annotated);
+
+ void assertAnnotation(SingleVariableDesignationSyntax variable, PublicNullableAnnotation expectedAnnotation)
+ {
+ var symbol = (ILocalSymbol)model.GetDeclaredSymbol(variable);
+ Assert.Equal(expectedAnnotation, symbol.NullableAnnotation);
+ }
+ }
}
}