diff --git a/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs b/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs index 0b99a3168727e2ed25af52a1ba08005e869b0fd6..56cc4d6a63f12c08725d8a67db51d8f96d64da24 100644 --- a/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs +++ b/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs @@ -23,7 +23,12 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using System.Linq.Expressions; using System.Text; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.EntityFrameworkCore.Migrations.Operations; namespace Microsoft.EntityFrameworkCore.Storage.Internal { @@ -35,11 +40,13 @@ public sealed class NpgsqlArrayTypeMapping : RelationalTypeMapping /// Creates the default array mapping (i.e. for the single-dimensional CLR array type) /// internal NpgsqlArrayTypeMapping(RelationalTypeMapping elementMapping) - : this(elementMapping, elementMapping.ClrType.MakeArrayType()) - {} + : this(elementMapping, elementMapping.ClrType.MakeArrayType()) {} internal NpgsqlArrayTypeMapping(RelationalTypeMapping elementMapping, Type arrayType) - : base(GenerateArrayTypeName(elementMapping.StoreType), arrayType) + : this(elementMapping, arrayType, CreateComparer(elementMapping, arrayType)) {} + + NpgsqlArrayTypeMapping(RelationalTypeMapping elementMapping, Type arrayType, ValueComparer comparer) + : base(GenerateArrayTypeName(elementMapping.StoreType), arrayType, null, comparer) { ElementMapping = elementMapping; } @@ -74,7 +81,7 @@ static string GenerateArrayTypeName(string elementTypeName) } public override RelationalTypeMapping Clone(string storeType, int? size) - => new NpgsqlArrayTypeMapping(ElementMapping); + => new NpgsqlArrayTypeMapping(ElementMapping, ClrType, Comparer); protected override string GenerateNonNullSqlLiteral(object value) { @@ -95,5 +102,145 @@ protected override string GenerateNonNullSqlLiteral(object value) sb.Append("]"); return sb.ToString(); } + + #region Value Comparison + + static ValueComparer CreateComparer(RelationalTypeMapping elementMapping, Type arrayType) + { + Debug.Assert(arrayType.IsArray); + var elementType = arrayType.GetElementType(); + + // In .NET, single-dimensional arrays implement IList and can therefore be accessed generically + // (i.e. efficiently). Multi-dimensional arrays don't, and can only be accessed non-generically via IList. + + if (!typeof(IList<>).MakeGenericType(elementType).IsAssignableFrom(arrayType)) + return null; // TODO: Implement multi-dimensional array support (#314) + + // We usee different comparer implementations based on whether we have a non-null element comparer, + // and if not, whether the element is IEquatable + + if (elementMapping.Comparer != null) + return (ValueComparer)Activator.CreateInstance( + typeof(SingleDimComparerWithComparer<>).MakeGenericType(elementType), elementMapping); + + if (typeof(IEquatable<>).MakeGenericType(elementType).IsAssignableFrom(elementType)) + return (ValueComparer)Activator.CreateInstance(typeof(SingleDimComparerWithIEquatable<>).MakeGenericType(elementType)); + + // There's no custom comparer, and the element type doesn't implement IEquatable. We have + // no choice but to use the non-generic Equals method. + return (ValueComparer)Activator.CreateInstance(typeof(SingleDimComparerWithEquals<>).MakeGenericType(elementType)); + } + + class SingleDimComparerWithComparer : ValueComparer> + { + public SingleDimComparerWithComparer(RelationalTypeMapping elementMapping) : base( + (a, b) => Compare(a, b, elementMapping.Comparer.CompareFunc), + source => Snapshot(source, elementMapping.Comparer.SnapshotFunc)) {} + + static bool Compare(IList a, IList b, Func elementComparer) + { + if (a.Count != b.Count) + return false; + + // Note: the following currently boxes every element access because ValueComparer isn't really + // generic (see https://github.com/aspnet/EntityFrameworkCore/issues/11072) + for (var i = 0; i < a.Count; i++) + if (!elementComparer(a[i], b[i])) + return false; + + return true; + } + + static IList Snapshot(IList source, Func elementSnapshotFunc) + { + var snapshot = new TElem[source.Count]; + // Note: the following currently boxes every element access because ValueComparer isn't really + // generic (see https://github.com/aspnet/EntityFrameworkCore/issues/11072) + for (var i = 0; i < source.Count; i++) + snapshot[i] = (TElem)elementSnapshotFunc(source[i]); + return snapshot; + } + } + + class SingleDimComparerWithIEquatable : ValueComparer> + where TElem : IEquatable + { + public SingleDimComparerWithIEquatable(): base( + (a, b) => Compare(a, b), + source => Snapshot(source)) {} + + static bool Compare(IList a, IList b) + { + if (a.Count != b.Count) + return false; + + for (var i = 0; i < a.Count; i++) + { + var elem1 = a[i]; + var elem2 = b[i]; + if (elem1 == null) + { + if (elem2 == null) + continue; + return false; + } + if (!elem1.Equals(elem2)) + return false; + } + + return true; + } + + static IList Snapshot(IList source) + { + var snapshot = new TElem[source.Count]; + for (var i = 0; i < source.Count; i++) + snapshot[i] = source[i]; + return snapshot; + } + } + + class SingleDimComparerWithEquals : ValueComparer> + { + public SingleDimComparerWithEquals() : base( + (a, b) => Compare(a, b), + source => Snapshot(source)) {} + + static bool Compare(IList a, IList b) + { + if (a.Count != b.Count) + return false; + + // Note: the following currently boxes every element access because ValueComparer isn't really + // generic (see https://github.com/aspnet/EntityFrameworkCore/issues/11072) + for (var i = 0; i < a.Count; i++) + { + var elem1 = a[i]; + var elem2 = b[i]; + if (elem1 == null) + { + if (elem2 == null) + continue; + return false; + } + if (!elem1.Equals(elem2)) + return false; + } + + return true; + } + + static IList Snapshot(IList source) + { + var snapshot = new TElem[source.Count]; + // Note: the following currently boxes every element access because ValueComparer isn't really + // generic (see https://github.com/aspnet/EntityFrameworkCore/issues/11072) + for (var i = 0; i < source.Count; i++) + snapshot[i] = source[i]; + return snapshot; + } + } + + #endregion Value Comparison } } diff --git a/test/EFCore.PG.Tests/Storage/NpgsqlTypeMappingTest.cs b/test/EFCore.PG.Tests/Storage/NpgsqlTypeMappingTest.cs index 5b87cb87d93a3eff201961090a867d60fd39b4dd..bd89145bd82938e940f37522a0e708ce4f713bc7 100644 --- a/test/EFCore.PG.Tests/Storage/NpgsqlTypeMappingTest.cs +++ b/test/EFCore.PG.Tests/Storage/NpgsqlTypeMappingTest.cs @@ -163,6 +163,39 @@ public void GenerateSqlLiteral_returns_bit_literal() public void GenerateSqlLiteral_returns_array_literal() => Assert.Equal("ARRAY[3,4]", GetMapping(typeof(int[])).GenerateSqlLiteral(new[] {3, 4})); + [Fact] + public void ValueComparer_int_array() + { + // This exercises array's comparer when the element doesn't have a comparer, but it implements + // IEquatable + var source = new[] { 2, 3, 4 }; + + var comparer = GetMapping(typeof(int[])).Comparer; + var snapshot = (int[])comparer.SnapshotFunc(source); + Assert.Equal(source, snapshot); + Assert.True(comparer.CompareFunc(source, snapshot)); + snapshot[1] = 8; + Assert.False(comparer.CompareFunc(source, snapshot)); + } + + [Fact] + public void ValueComparer_hstore_array() + { + // This exercises array's comparer when the element has its own non-null comparer + var source = new[] + { + new Dictionary { { "k1", "v1"} }, + new Dictionary { { "k2", "v2"} }, + }; + + var comparer = GetMapping(typeof(Dictionary[])).Comparer; + var snapshot = (Dictionary[])comparer.SnapshotFunc(source); + Assert.Equal(source, snapshot); + Assert.True(comparer.CompareFunc(source, snapshot)); + snapshot[1]["k2"] = "v8"; + Assert.False(comparer.CompareFunc(source, snapshot)); + } + [Fact] public void GenerateSqlLiteral_returns_bytea_literal() => Assert.Equal(@"BYTEA E'\\xDEADBEEF'", GetMapping("bytea").GenerateSqlLiteral(new byte[] { 222, 173, 190, 239 })); @@ -184,7 +217,7 @@ public void ValueComparer_hstore() { "k1", "v1"}, { "k2", "v2"} }; - + var comparer = GetMapping("hstore").Comparer; var snapshot = (Dictionary)comparer.SnapshotFunc(source); Assert.Equal(source, snapshot);