未验证 提交 36f71913 编写于 作者: C CyrusNajmabadi 提交者: GitHub

Merge pull request #41680 from ThadHouse/refstructdisallowIEquatable

Disallow generating IEquatable for Ref Struct
...@@ -1940,6 +1940,58 @@ public bool Equals(Program other) ...@@ -1940,6 +1940,58 @@ public bool Equals(Program other)
parameters: CSharp6Implicit); parameters: CSharp6Implicit);
} }
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsGenerateEqualsAndGetHashCode)]
[WorkItem(25708, "https://github.com/dotnet/roslyn/issues/25708")]
public async Task TestOverrideEqualsOnRefStructReturnsFalse()
{
await TestWithPickMembersDialogAsync(
@"
ref struct Program
{
public string s;
[||]
}",
@"
ref struct Program
{
public string s;
public override bool Equals(object obj)
{
return false;
}
}",
chosenSymbols: null);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsGenerateEqualsAndGetHashCode)]
[WorkItem(25708, "https://github.com/dotnet/roslyn/issues/25708")]
public async Task TestImplementIEquatableOnRefStructSkipsIEquatable()
{
await TestWithPickMembersDialogAsync(
@"
ref struct Program
{
public string s;
[||]
}",
@"
ref struct Program
{
public string s;
public override bool Equals(object obj)
{
return false;
}
}",
chosenSymbols: null,
// We are forcefully enabling the ImplementIEquatable option, as that is our way
// to test that the option does nothing. The VS mode will ensure if the option
// is not available it will not be shown.
optionsCallback: options => EnableOption(options, ImplementIEquatableId));
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsGenerateEqualsAndGetHashCode)] [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsGenerateEqualsAndGetHashCode)]
public async Task TestImplementIEquatableOnStructInNullableContextWithUnannotatedMetadata() public async Task TestImplementIEquatableOnStructInNullableContextWithUnannotatedMetadata()
{ {
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license. // The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
#nullable enable
using System; using System;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Composition; using System.Composition;
...@@ -35,7 +37,7 @@ internal partial class GenerateEqualsAndGetHashCodeFromMembersCodeRefactoringPro ...@@ -35,7 +37,7 @@ internal partial class GenerateEqualsAndGetHashCodeFromMembersCodeRefactoringPro
private const string EqualsName = nameof(object.Equals); private const string EqualsName = nameof(object.Equals);
private const string GetHashCodeName = nameof(object.GetHashCode); private const string GetHashCodeName = nameof(object.GetHashCode);
private readonly IPickMembersService _pickMembersService_forTestingPurposes; private readonly IPickMembersService? _pickMembersService_forTestingPurposes;
[ImportingConstructor] [ImportingConstructor]
[Obsolete(MefConstruction.ImportingConstructorMessage, error: true)] [Obsolete(MefConstruction.ImportingConstructorMessage, error: true)]
...@@ -45,7 +47,7 @@ public GenerateEqualsAndGetHashCodeFromMembersCodeRefactoringProvider() ...@@ -45,7 +47,7 @@ public GenerateEqualsAndGetHashCodeFromMembersCodeRefactoringProvider()
} }
[SuppressMessage("RoslynDiagnosticsReliability", "RS0034:Exported parts should have [ImportingConstructor]", Justification = "Used incorrectly by tests")] [SuppressMessage("RoslynDiagnosticsReliability", "RS0034:Exported parts should have [ImportingConstructor]", Justification = "Used incorrectly by tests")]
public GenerateEqualsAndGetHashCodeFromMembersCodeRefactoringProvider(IPickMembersService pickMembersService) public GenerateEqualsAndGetHashCodeFromMembersCodeRefactoringProvider(IPickMembersService? pickMembersService)
=> _pickMembersService_forTestingPurposes = pickMembersService; => _pickMembersService_forTestingPurposes = pickMembersService;
public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context) public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
...@@ -69,9 +71,9 @@ private async Task HandleNonSelectionAsync(CodeRefactoringContext context) ...@@ -69,9 +71,9 @@ private async Task HandleNonSelectionAsync(CodeRefactoringContext context)
{ {
var (document, textSpan, cancellationToken) = context; var (document, textSpan, cancellationToken) = context;
var syntaxFacts = document.GetLanguageService<ISyntaxFactsService>(); var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
var sourceText = await document.GetTextAsync(cancellationToken).ConfigureAwait(false); var sourceText = await document.GetTextAsync(cancellationToken).ConfigureAwait(false);
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
// We offer the refactoring when the user is either on the header of a class/struct, // We offer the refactoring when the user is either on the header of a class/struct,
// or if they're between any members of a class/struct and are on a blank line. // or if they're between any members of a class/struct and are on a blank line.
...@@ -81,7 +83,7 @@ private async Task HandleNonSelectionAsync(CodeRefactoringContext context) ...@@ -81,7 +83,7 @@ private async Task HandleNonSelectionAsync(CodeRefactoringContext context)
return; return;
} }
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false); var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
// Only supported on classes/structs. // Only supported on classes/structs.
var containingType = semanticModel.GetDeclaredSymbol(typeDeclaration) as INamedTypeSymbol; var containingType = semanticModel.GetDeclaredSymbol(typeDeclaration) as INamedTypeSymbol;
...@@ -134,14 +136,22 @@ private bool HasOperator(INamedTypeSymbol containingType, string operatorName) ...@@ -134,14 +136,22 @@ private bool HasOperator(INamedTypeSymbol containingType, string operatorName)
private bool CanImplementIEquatable( private bool CanImplementIEquatable(
SemanticModel semanticModel, INamedTypeSymbol containingType, SemanticModel semanticModel, INamedTypeSymbol containingType,
[NotNullWhen(true)] out INamedTypeSymbol constructedType) [NotNullWhen(true)] out INamedTypeSymbol? constructedType)
{
// A ref struct can never implement an interface, therefore never add IEquatable to the selection
// options if the type is a ref struct.
if (!containingType.IsRefLikeType)
{ {
var equatableTypeOpt = semanticModel.Compilation.GetTypeByMetadataName(typeof(IEquatable<>).FullName); var equatableTypeOpt = semanticModel.Compilation.GetTypeByMetadataName(typeof(IEquatable<>).FullName!);
if (equatableTypeOpt != null) if (equatableTypeOpt != null)
{ {
constructedType = equatableTypeOpt.Construct(containingType); constructedType = equatableTypeOpt.Construct(containingType);
// A ref struct can never implement an interface, therefore never add IEquatable to the selection
// options if the type is a ref struct.
return !containingType.AllInterfaces.Contains(constructedType); return !containingType.AllInterfaces.Contains(constructedType);
} }
}
constructedType = null; constructedType = null;
return false; return false;
...@@ -174,8 +184,8 @@ private void GetExistingMemberInfo(INamedTypeSymbol containingType, out bool has ...@@ -174,8 +184,8 @@ private void GetExistingMemberInfo(INamedTypeSymbol containingType, out bool has
GetExistingMemberInfo( GetExistingMemberInfo(
info.ContainingType, out var hasEquals, out var hasGetHashCode); info.ContainingType, out var hasEquals, out var hasGetHashCode);
var syntaxFacts = document.GetLanguageService<ISyntaxFactsService>(); var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var typeDeclaration = syntaxFacts.GetContainingTypeDeclaration(root, textSpan.Start); var typeDeclaration = syntaxFacts.GetContainingTypeDeclaration(root, textSpan.Start);
return await CreateActionsAsync( return await CreateActionsAsync(
...@@ -240,15 +250,12 @@ private void GetExistingMemberInfo(INamedTypeSymbol containingType, out bool has ...@@ -240,15 +250,12 @@ private void GetExistingMemberInfo(INamedTypeSymbol containingType, out bool has
Document document, SyntaxNode typeDeclaration, INamedTypeSymbol containingType, ImmutableArray<ISymbol> members, Document document, SyntaxNode typeDeclaration, INamedTypeSymbol containingType, ImmutableArray<ISymbol> members,
bool generateEquals, bool generateGetHashCode, CancellationToken cancellationToken) bool generateEquals, bool generateGetHashCode, CancellationToken cancellationToken)
{ {
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false); var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var options = await document.GetOptionsAsync(cancellationToken).ConfigureAwait(false); var options = await document.GetOptionsAsync(cancellationToken).ConfigureAwait(false);
using var _ = ArrayBuilder<PickMembersOption>.GetInstance(out var pickMembersOptions); using var _ = ArrayBuilder<PickMembersOption>.GetInstance(out var pickMembersOptions);
var canImplementIEquatable = CanImplementIEquatable(semanticModel, containingType, out var equatableTypeOpt); if (CanImplementIEquatable(semanticModel, containingType, out var equatableTypeOpt))
var hasExistingOperators = HasOperators(containingType);
if (canImplementIEquatable)
{ {
var value = options.GetOption(GenerateEqualsAndGetHashCodeFromMembersOptions.ImplementIEquatable); var value = options.GetOption(GenerateEqualsAndGetHashCodeFromMembersOptions.ImplementIEquatable);
...@@ -262,7 +269,7 @@ private void GetExistingMemberInfo(INamedTypeSymbol containingType, out bool has ...@@ -262,7 +269,7 @@ private void GetExistingMemberInfo(INamedTypeSymbol containingType, out bool has
value)); value));
} }
if (!hasExistingOperators) if (!HasOperators(containingType))
{ {
var value = options.GetOption(GenerateEqualsAndGetHashCodeFromMembersOptions.GenerateOperators); var value = options.GetOption(GenerateEqualsAndGetHashCodeFromMembersOptions.GenerateOperators);
pickMembersOptions.Add(new PickMembersOption( pickMembersOptions.Add(new PickMembersOption(
...@@ -287,7 +294,7 @@ private void GetExistingMemberInfo(INamedTypeSymbol containingType, out bool has ...@@ -287,7 +294,7 @@ private void GetExistingMemberInfo(INamedTypeSymbol containingType, out bool has
{ {
// if we're generating equals for a struct, then also add IEquatable<S> support as // if we're generating equals for a struct, then also add IEquatable<S> support as
// well as operators (as long as the struct does not already have them). // well as operators (as long as the struct does not already have them).
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false); var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
implementIEquatable = CanImplementIEquatable(semanticModel, containingType, out _); implementIEquatable = CanImplementIEquatable(semanticModel, containingType, out _);
generateOperators = !HasOperators(containingType); generateOperators = !HasOperators(containingType);
} }
......
...@@ -99,7 +99,16 @@ public static IMethodSymbol CreateEqualsMethod(this Compilation compilation, Imm ...@@ -99,7 +99,16 @@ public static IMethodSymbol CreateEqualsMethod(this Compilation compilation, Imm
ImmutableArray<ISymbol> members, ImmutableArray<ISymbol> members,
string localNameOpt) string localNameOpt)
{ {
var statements = ArrayBuilder<SyntaxNode>.GetInstance(); using var _1 = ArrayBuilder<SyntaxNode>.GetInstance(out var statements);
// A ref like type can not be boxed. Because of this an overloaded Equals taking object in the general case
// can never be true, because an equivalent object can never be boxed into the object itself. Therefore only
// need to return false.
if (containingType.IsRefLikeType)
{
statements.Add(factory.ReturnStatement(factory.FalseLiteralExpression()));
return statements.ToImmutable();
}
// Come up with a good name for the local variable we're going to compare against. // Come up with a good name for the local variable we're going to compare against.
// For example, if the class name is "CustomerOrder" then we'll generate: // For example, if the class name is "CustomerOrder" then we'll generate:
...@@ -113,7 +122,7 @@ public static IMethodSymbol CreateEqualsMethod(this Compilation compilation, Imm ...@@ -113,7 +122,7 @@ public static IMethodSymbol CreateEqualsMethod(this Compilation compilation, Imm
// These will be all the expressions that we'll '&&' together inside the final // These will be all the expressions that we'll '&&' together inside the final
// return statement of 'Equals'. // return statement of 'Equals'.
using var _ = ArrayBuilder<SyntaxNode>.GetInstance(out var expressions); using var _2 = ArrayBuilder<SyntaxNode>.GetInstance(out var expressions);
if (factory.SupportsPatterns(parseOptions)) if (factory.SupportsPatterns(parseOptions))
{ {
...@@ -189,7 +198,7 @@ public static IMethodSymbol CreateEqualsMethod(this Compilation compilation, Imm ...@@ -189,7 +198,7 @@ public static IMethodSymbol CreateEqualsMethod(this Compilation compilation, Imm
statements.Add(factory.ReturnStatement( statements.Add(factory.ReturnStatement(
expressions.Aggregate(factory.LogicalAndExpression))); expressions.Aggregate(factory.LogicalAndExpression)));
return statements.ToImmutableAndFree(); return statements.ToImmutable();
} }
private static void AddMemberChecks( private static void AddMemberChecks(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册