diff --git a/src/EditorFeatures/CSharpTest/GenerateFromMembers/GenerateEqualsAndGetHashCodeFromMembers/GenerateEqualsAndGetHashCodeFromMembersTests.cs b/src/EditorFeatures/CSharpTest/GenerateFromMembers/GenerateEqualsAndGetHashCodeFromMembers/GenerateEqualsAndGetHashCodeFromMembersTests.cs index 843d9c7e26ffd96f486013560196511b259199dc..087ded38dff6560791e3b0ce7fa79131e165501e 100644 --- a/src/EditorFeatures/CSharpTest/GenerateFromMembers/GenerateEqualsAndGetHashCodeFromMembers/GenerateEqualsAndGetHashCodeFromMembersTests.cs +++ b/src/EditorFeatures/CSharpTest/GenerateFromMembers/GenerateEqualsAndGetHashCodeFromMembers/GenerateEqualsAndGetHashCodeFromMembersTests.cs @@ -22,9 +22,6 @@ public class GenerateEqualsAndGetHashCodeFromMembersTests : AbstractCSharpCodeAc private static readonly TestParameters CSharp6 = new TestParameters(parseOptions: TestOptions.Regular.WithLanguageVersion(LanguageVersion.CSharp6)); - private static readonly TestParameters CSharp8 = - new TestParameters(parseOptions: TestOptions.Regular.WithLanguageVersion(LanguageVersion.CSharp8)); - protected override CodeRefactoringProvider CreateCodeRefactoringProvider(Workspace workspace, TestParameters parameters) => new GenerateEqualsAndGetHashCodeFromMembersCodeRefactoringProvider((IPickMembersService)parameters.fixProviderData); @@ -88,6 +85,45 @@ public override bool Equals(object obj) parameters: CSharp6); } + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsGenerateEqualsAndGetHashCode)] + public async Task TestNullableReferenceIEquatable() + { + await TestInRegularAndScript1Async( +@"#nullable enable + +using System; +using System.Collections.Generic; + +class S : IEquatable { } + +class Program +{ + [|S? a;|] +}", +@"#nullable enable + +using System; +using System.Collections.Generic; + +class S : IEquatable { } + +class Program +{ + S? a; + + public override bool Equals(object? obj) + { + return obj is Program program && + EqualityComparer.Default.Equals(a, program.a); + } + + public override int GetHashCode() + { + return -1757793268 + EqualityComparer.Default.GetHashCode(a); + } +}", index: 1); + } + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsGenerateEqualsAndGetHashCode)] public async Task TestValueIEquatable() { @@ -422,8 +458,7 @@ public override bool Equals(object? obj) return obj is Program program && a == program.a; } -}", -parameters: CSharp8); +}"); } [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsGenerateEqualsAndGetHashCode)] diff --git a/src/Workspaces/Core/Portable/Shared/Extensions/SyntaxGeneratorExtensions_CreateEqualsMethod.cs b/src/Workspaces/Core/Portable/Shared/Extensions/SyntaxGeneratorExtensions_CreateEqualsMethod.cs index f577b7184536d1cd858fa679d9798ed164e3fea6..dac28588ba3c8fe0bfd6f870bd78fa6942c81a99 100644 --- a/src/Workspaces/Core/Portable/Shared/Extensions/SyntaxGeneratorExtensions_CreateEqualsMethod.cs +++ b/src/Workspaces/Core/Portable/Shared/Extensions/SyntaxGeneratorExtensions_CreateEqualsMethod.cs @@ -365,7 +365,7 @@ private static bool IsPrimitiveValueType(ITypeSymbol typeSymbol) ITypeSymbol type) { var equalityComparerType = compilation.EqualityComparerOfTType(); - var constructedType = equalityComparerType.Construct(type); + var constructedType = equalityComparerType.ConstructWithNullability(type); return factory.MemberAccessExpression( factory.TypeExpression(constructedType), factory.IdentifierName(DefaultName)); @@ -375,8 +375,8 @@ private static ITypeSymbol GetType(Compilation compilation, ISymbol symbol) { switch (symbol) { - case IFieldSymbol field: return field.Type; - case IPropertySymbol property: return property.Type; + case IFieldSymbol field: return field.GetTypeWithAnnotatedNullability(); + case IPropertySymbol property: return property.GetTypeWithAnnotatedNullability(); default: return compilation.GetSpecialType(SpecialType.System_Object); } }