未验证 提交 f15d8765 编写于 作者: F Fred Silberberg 提交者: GitHub

Update immutable arrays of symbols in the NullabilityRewriter (#38736)

Update immutable arrays of symbols in the NullabilityRewriter
......@@ -143,7 +143,7 @@
<Field Name="ResultKind" PropertyOverrides="true" Type="LookupResultKind"/>
<!-- These symbols will be returned from the GetSemanticInfo API if it examines this bound node. -->
<Field Name="Symbols" Type="ImmutableArray&lt;Symbol&gt;"/>
<Field Name="Symbols" Type="ImmutableArray&lt;Symbol?&gt;" SkipInNullabilityRewriter="true"/>
<!-- Any child bound nodes that we need to preserve are put here. -->
<Field Name="ChildBoundNodes" Type="ImmutableArray&lt;BoundExpression&gt;"/>
......@@ -273,7 +273,7 @@
<!--The set of method symbols from which this operator's method was chosen.
Only kept in the tree if the operator was an error and overload resolution
was unable to choose a best method.-->
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow"/>
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow" SkipInNullabilityRewriter="true"/>
</Node>
<Node Name="BoundIncrementOperator" Base="BoundExpression">
......@@ -290,7 +290,7 @@
<!--The set of method symbols from which this operator's method was chosen.
Only kept in the tree if the operator was an error and overload resolution
was unable to choose a best method.-->
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow"/>
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow" SkipInNullabilityRewriter="true"/>
</Node>
<!-- Not really an operator since overload resolution is never required. -->
......@@ -373,7 +373,7 @@
<!--The set of method symbols from which this operator's method was chosen.
Only kept in the tree if the operator was an error and overload resolution
was unable to choose a best method.-->
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow"/>
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow" SkipInNullabilityRewriter="true"/>
</Node>
<Node Name="BoundTupleBinaryOperator" Base="BoundExpression">
......@@ -394,7 +394,7 @@
<!--The set of method symbols from which this operator's method was chosen.
Only kept in the tree if the operator was an error and overload resolution
was unable to choose a best method.-->
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow"/>
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow" SkipInNullabilityRewriter="true"/>
</Node>
<Node Name="BoundCompoundAssignmentOperator" Base="BoundExpression">
......@@ -436,7 +436,7 @@
<!--The set of method symbols from which this operator's method was chosen.
Only kept in the tree if the operator was an error and overload resolution
was unable to choose a best method.-->
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow"/>
<Field Name="OriginalUserDefinedOperatorsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow" SkipInNullabilityRewriter="true"/>
</Node>
<Node Name="BoundAssignmentOperator" Base="BoundExpression">
......@@ -675,7 +675,7 @@
<!--The set of method symbols from which this conversion's method was chosen.
Only kept in the tree if the conversion was an error and overload resolution
was unable to choose a best method.-->
<Field Name="OriginalUserDefinedConversionsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow"/>
<Field Name="OriginalUserDefinedConversionsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow" SkipInNullabilityRewriter="true"/>
</Node>
<!--
......@@ -1486,7 +1486,7 @@
<!--The set of method symbols from which this call's method was chosen.
Only kept in the tree if the call was an error and overload resolution
was unable to choose a best method.-->
<Field Name="OriginalMethodsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow"/>
<Field Name="OriginalMethodsOpt" Type="ImmutableArray&lt;MethodSymbol&gt;" Null="Allow" SkipInNullabilityRewriter="true"/>
<!-- BinderOpt is added as a temporary solution for IOperation implementation and should probably be removed in the future -->
<Field Name="BinderOpt" Type="Binder?"/>
......@@ -1810,7 +1810,7 @@
<!--The set of indexer symbols from which this call's indexer was chosen.
Only kept in the tree if the call was an error and overload resolution
was unable to choose a best indexer.-->
<Field Name="OriginalIndexersOpt" Type="ImmutableArray&lt;PropertySymbol&gt;" Null="allow"/>
<Field Name="OriginalIndexersOpt" Type="ImmutableArray&lt;PropertySymbol&gt;" Null="allow" SkipInNullabilityRewriter="true"/>
</Node>
<Node Name="BoundIndexOrRangePatternIndexerAccess" Base="BoundExpression">
......
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
#nullable enable
using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
......@@ -83,5 +84,38 @@ private BoundNode VisitBinaryOperatorBase(BoundBinaryOperatorBase binaryOperator
return sym;
}
private ImmutableArray<T> GetUpdatedArray<T>(BoundNode expr, ImmutableArray<T> symbols) where T : Symbol
{
if (symbols.IsDefaultOrEmpty)
{
return symbols;
}
var builder = ArrayBuilder<T>.GetInstance(symbols.Length);
bool foundUpdate = false;
foreach (var originalSymbol in symbols)
{
if (_updatedSymbols.TryGetValue((expr, originalSymbol), out var updatedSymbol))
{
foundUpdate = true;
builder.Add((T)updatedSymbol);
}
else
{
builder.Add(originalSymbol);
}
}
if (foundUpdate)
{
return builder.ToImmutableAndFree();
}
else
{
builder.Free();
return symbols;
}
}
}
}
......@@ -609,7 +609,7 @@ public BoundPassByCopy Update(BoundExpression expression, TypeSymbol? type)
internal sealed partial class BoundBadExpression : BoundExpression
{
public BoundBadExpression(SyntaxNode syntax, LookupResultKind resultKind, ImmutableArray<Symbol> symbols, ImmutableArray<BoundExpression> childBoundNodes, TypeSymbol? type, bool hasErrors = false)
public BoundBadExpression(SyntaxNode syntax, LookupResultKind resultKind, ImmutableArray<Symbol?> symbols, ImmutableArray<BoundExpression> childBoundNodes, TypeSymbol? type, bool hasErrors = false)
: base(BoundKind.BadExpression, syntax, type, hasErrors || childBoundNodes.HasErrors())
{
......@@ -625,13 +625,13 @@ public BoundBadExpression(SyntaxNode syntax, LookupResultKind resultKind, Immuta
private readonly LookupResultKind _ResultKind;
public override LookupResultKind ResultKind { get { return _ResultKind;} }
public ImmutableArray<Symbol> Symbols { get; }
public ImmutableArray<Symbol?> Symbols { get; }
public ImmutableArray<BoundExpression> ChildBoundNodes { get; }
[DebuggerStepThrough]
public override BoundNode? Accept(BoundTreeVisitor visitor) => visitor.VisitBadExpression(this);
public BoundBadExpression Update(LookupResultKind resultKind, ImmutableArray<Symbol> symbols, ImmutableArray<BoundExpression> childBoundNodes, TypeSymbol? type)
public BoundBadExpression Update(LookupResultKind resultKind, ImmutableArray<Symbol?> symbols, ImmutableArray<BoundExpression> childBoundNodes, TypeSymbol? type)
{
if (resultKind != this.ResultKind || symbols != this.Symbols || childBoundNodes != this.ChildBoundNodes || !TypeSymbol.Equals(type, this.Type, TypeCompareKind.ConsiderEverything))
{
......@@ -11092,12 +11092,12 @@ public NullabilityRewriter(ImmutableDictionary<BoundExpression, (NullabilityInfo
if (_updatedNullabilities.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol Type) infoAndType))
{
updatedNode = node.Update(node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, node.ApplicableMethods, expression, arguments, infoAndType.Type);
updatedNode = node.Update(node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, GetUpdatedArray(node, node.ApplicableMethods), expression, arguments, infoAndType.Type);
updatedNode.TopLevelNullability = infoAndType.Info;
}
else
{
updatedNode = node.Update(node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, node.ApplicableMethods, expression, arguments, node.Type);
updatedNode = node.Update(node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, GetUpdatedArray(node, node.ApplicableMethods), expression, arguments, node.Type);
}
return updatedNode;
}
......@@ -11176,12 +11176,12 @@ public NullabilityRewriter(ImmutableDictionary<BoundExpression, (NullabilityInfo
if (_updatedNullabilities.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol Type) infoAndType))
{
updatedNode = node.Update(node.TypeArgumentsOpt, node.Name, node.Methods, GetUpdatedSymbol(node, node.LookupSymbolOpt), node.LookupError, node.Flags, receiverOpt, node.ResultKind);
updatedNode = node.Update(node.TypeArgumentsOpt, node.Name, GetUpdatedArray(node, node.Methods), GetUpdatedSymbol(node, node.LookupSymbolOpt), node.LookupError, node.Flags, receiverOpt, node.ResultKind);
updatedNode.TopLevelNullability = infoAndType.Info;
}
else
{
updatedNode = node.Update(node.TypeArgumentsOpt, node.Name, node.Methods, GetUpdatedSymbol(node, node.LookupSymbolOpt), node.LookupError, node.Flags, receiverOpt, node.ResultKind);
updatedNode = node.Update(node.TypeArgumentsOpt, node.Name, GetUpdatedArray(node, node.Methods), GetUpdatedSymbol(node, node.LookupSymbolOpt), node.LookupError, node.Flags, receiverOpt, node.ResultKind);
}
return updatedNode;
}
......@@ -11193,12 +11193,12 @@ public NullabilityRewriter(ImmutableDictionary<BoundExpression, (NullabilityInfo
if (_updatedNullabilities.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol Type) infoAndType))
{
updatedNode = node.Update(node.Properties, receiverOpt, node.ResultKind);
updatedNode = node.Update(GetUpdatedArray(node, node.Properties), receiverOpt, node.ResultKind);
updatedNode.TopLevelNullability = infoAndType.Info;
}
else
{
updatedNode = node.Update(node.Properties, receiverOpt, node.ResultKind);
updatedNode = node.Update(GetUpdatedArray(node, node.Properties), receiverOpt, node.ResultKind);
}
return updatedNode;
}
......@@ -11265,12 +11265,12 @@ public NullabilityRewriter(ImmutableDictionary<BoundExpression, (NullabilityInfo
if (_updatedNullabilities.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol Type) infoAndType))
{
updatedNode = node.Update(GetUpdatedSymbol(node, node.Constructor), node.ConstructorsGroup, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, node.Expanded, node.ArgsToParamsOpt, node.ConstantValueOpt, initializerExpressionOpt, node.BinderOpt, infoAndType.Type);
updatedNode = node.Update(GetUpdatedSymbol(node, node.Constructor), GetUpdatedArray(node, node.ConstructorsGroup), arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, node.Expanded, node.ArgsToParamsOpt, node.ConstantValueOpt, initializerExpressionOpt, node.BinderOpt, infoAndType.Type);
updatedNode.TopLevelNullability = infoAndType.Info;
}
else
{
updatedNode = node.Update(GetUpdatedSymbol(node, node.Constructor), node.ConstructorsGroup, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, node.Expanded, node.ArgsToParamsOpt, node.ConstantValueOpt, initializerExpressionOpt, node.BinderOpt, node.Type);
updatedNode = node.Update(GetUpdatedSymbol(node, node.Constructor), GetUpdatedArray(node, node.ConstructorsGroup), arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, node.Expanded, node.ArgsToParamsOpt, node.ConstantValueOpt, initializerExpressionOpt, node.BinderOpt, node.Type);
}
return updatedNode;
}
......@@ -11318,12 +11318,12 @@ public NullabilityRewriter(ImmutableDictionary<BoundExpression, (NullabilityInfo
if (_updatedNullabilities.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol Type) infoAndType))
{
updatedNode = node.Update(node.Name, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, initializerExpressionOpt, node.ApplicableMethods, infoAndType.Type);
updatedNode = node.Update(node.Name, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, initializerExpressionOpt, GetUpdatedArray(node, node.ApplicableMethods), infoAndType.Type);
updatedNode.TopLevelNullability = infoAndType.Info;
}
else
{
updatedNode = node.Update(node.Name, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, initializerExpressionOpt, node.ApplicableMethods, node.Type);
updatedNode = node.Update(node.Name, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, initializerExpressionOpt, GetUpdatedArray(node, node.ApplicableMethods), node.Type);
}
return updatedNode;
}
......@@ -11434,12 +11434,12 @@ public NullabilityRewriter(ImmutableDictionary<BoundExpression, (NullabilityInfo
if (_updatedNullabilities.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol Type) infoAndType))
{
updatedNode = node.Update(node.ApplicableMethods, expression, arguments, infoAndType.Type);
updatedNode = node.Update(GetUpdatedArray(node, node.ApplicableMethods), expression, arguments, infoAndType.Type);
updatedNode.TopLevelNullability = infoAndType.Info;
}
else
{
updatedNode = node.Update(node.ApplicableMethods, expression, arguments, node.Type);
updatedNode = node.Update(GetUpdatedArray(node, node.ApplicableMethods), expression, arguments, node.Type);
}
return updatedNode;
}
......@@ -11698,12 +11698,12 @@ public NullabilityRewriter(ImmutableDictionary<BoundExpression, (NullabilityInfo
if (_updatedNullabilities.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol Type) infoAndType))
{
updatedNode = node.Update(receiverOpt, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, node.ApplicableIndexers, infoAndType.Type);
updatedNode = node.Update(receiverOpt, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, GetUpdatedArray(node, node.ApplicableIndexers), infoAndType.Type);
updatedNode.TopLevelNullability = infoAndType.Info;
}
else
{
updatedNode = node.Update(receiverOpt, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, node.ApplicableIndexers, node.Type);
updatedNode = node.Update(receiverOpt, arguments, node.ArgumentNamesOpt, node.ArgumentRefKindsOpt, GetUpdatedArray(node, node.ApplicableIndexers), node.Type);
}
return updatedNode;
}
......
......@@ -991,7 +991,8 @@ string wasUpdatedCheck(Field field)
private static bool TypeIsTypeSymbol(Field field) => field.Type.TrimEnd('?') == "TypeSymbol";
private static bool TypeIsSymbol(Field field) => field.Type.TrimEnd('?').EndsWith("Symbol");
private static bool TypeIsSymbol(Field field) => TypeIsSymbol(field.Type);
private static bool TypeIsSymbol(string type) => type.TrimEnd('?').EndsWith("Symbol");
private string StripBound(string name)
{
......@@ -1200,7 +1201,7 @@ private void WriteTreeDumperNodeProducer()
Write("new TreeDumperNode(\"{0}\", null, new TreeDumperNode[] {{ Visit(node.{1}, null) }})", ToCamelCase(field.Name), field.Name);
else if (IsListOfDerived("BoundNode", field.Type))
{
if (IsImmutableArray(field.Type) && FieldNullHandling(node, field.Name) == NullHandling.Disallow)
if (IsImmutableArray(field.Type, out _) && FieldNullHandling(node, field.Name) == NullHandling.Disallow)
{
Write("new TreeDumperNode(\"{0}\", null, from x in node.{1} select Visit(x, null))", ToCamelCase(field.Name), field.Name);
}
......@@ -1499,7 +1500,11 @@ void writeUpdate(bool updatedType)
allSpecifiableFields,
field =>
{
if (IsDerivedOrListOfDerived("BoundNode", field.Type))
if (SkipInNullabilityRewriter(field))
{
return $"node.{field.Name}";
}
else if (IsDerivedOrListOfDerived("BoundNode", field.Type))
{
return ToCamelCase(field.Name);
}
......@@ -1511,6 +1516,10 @@ void writeUpdate(bool updatedType)
{
return $"GetUpdatedSymbol(node, node.{field.Name})";
}
else if (IsImmutableArray(field.Type, out var elementType) && TypeIsSymbol(elementType) && typeIsUpdated(elementType))
{
return $"GetUpdatedArray(node, node.{field.Name})";
}
else
{
return $"node.{field.Name}";
......@@ -1529,7 +1538,12 @@ static bool symbolIsPotentiallyUpdated(Field f)
if (f.Name == "Type") return false;
switch (f.Type.TrimEnd('?'))
return typeIsUpdated(f.Type);
}
static bool typeIsUpdated(string type)
{
switch (type.TrimEnd('?'))
{
case "LocalSymbol":
case "LabelSymbol":
......@@ -1563,17 +1577,23 @@ private bool IsListOfDerived(string baseType, string derivedType)
return IsNodeList(derivedType) && IsDerivedType(baseType, GetElementType(derivedType));
}
private bool IsImmutableArray(string typeName)
private bool IsImmutableArray(string typeName, out string elementType)
{
switch (_targetLang)
string immutableArrayPrefix = _targetLang switch
{
case TargetLanguage.CSharp:
return typeName.StartsWith("ImmutableArray<", StringComparison.Ordinal);
case TargetLanguage.VB:
return typeName.StartsWith("ImmutableArray(Of", StringComparison.OrdinalIgnoreCase);
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
TargetLanguage.CSharp => "ImmutableArray<",
TargetLanguage.VB => "ImmutableArray(Of ",
_ => throw new InvalidOperationException($"Unknown target language {_targetLang}")
};
if (typeName.StartsWith(immutableArrayPrefix, StringComparison.Ordinal))
{
elementType = typeName[immutableArrayPrefix.Length..^1];
return true;
}
elementType = null;
return false;
}
private bool IsNodeList(string typeName)
......@@ -1714,6 +1734,11 @@ private static bool SkipInNullabilityRewriter(Node n)
return string.Compare(n.SkipInNullabilityRewriter, "true", true) == 0;
}
private static bool SkipInNullabilityRewriter(Field f)
{
return string.Compare(f.SkipInNullabilityRewriter, "true", ignoreCase: true) == 0;
}
private string ToCamelCase(string name)
{
if (char.IsUpper(name[0]))
......
......@@ -9,7 +9,7 @@
<RootNamespace>Roslyn.Compilers.Internal.BoundTreeGenerator</RootNamespace>
<AssemblyName>BoundTreeGenerator</AssemblyName>
<AutoGenerateBindingRedirects>True</AutoGenerateBindingRedirects>
<TargetFramework>netcoreapp2.1</TargetFramework>
<TargetFramework>netcoreapp3.0</TargetFramework>
<IsShipping>false</IsShipping>
</PropertyGroup>
<ItemGroup>
......
......@@ -91,6 +91,9 @@ public class Field
[XmlAttribute]
public string SkipInVisitor;
[XmlAttribute]
public string SkipInNullabilityRewriter;
}
public class EnumType : TreeType
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册