未验证 提交 31e0bd18 编写于 作者: S Stephen Toub 提交者: GitHub

Improve vectorization of Enumerable.Min/Max (#76144)

- Expand it to all supported types with `Enumerable.Min<T>`/`Max<T>`
- Combine Min/Max into a single method using a static abstract interface with generic specialization to differentiate
- Use `Vector128<T>` and `Vector256<T>` instead of `Vector<T>`
- Improve test coverage
上级 c7c6aa03
......@@ -74,6 +74,7 @@
<Compile Include="System\Linq\Last.cs" />
<Compile Include="System\Linq\Lookup.cs" />
<Compile Include="System\Linq\Max.cs" />
<Compile Include="System\Linq\MaxMin.cs" />
<Compile Include="System\Linq\Min.cs" />
<Compile Include="System\Linq\OrderBy.cs" />
<Compile Include="System\Linq\OrderedEnumerable.cs" />
......@@ -104,5 +105,6 @@
<Reference Include="System.Numerics.Vectors" />
<Reference Include="System.Runtime" />
<Reference Include="System.Runtime.InteropServices" />
<Reference Include="System.Runtime.Intrinsics" />
</ItemGroup>
</Project>
\ No newline at end of file
......@@ -3,94 +3,26 @@
using System.Collections.Generic;
using System.Numerics;
using System.Runtime.Intrinsics;
namespace System.Linq
{
public static partial class Enumerable
{
public static int Max(this IEnumerable<int> source) => MaxInteger(source);
public static int Max(this IEnumerable<int> source) => MinMaxInteger<int, MaxCalc<int>>(source);
public static int? Max(this IEnumerable<int?> source) => MaxInteger(source);
public static long Max(this IEnumerable<long> source) => MaxInteger(source);
public static long? Max(this IEnumerable<long?> source) => MaxInteger(source);
public static long Max(this IEnumerable<long> source) => MinMaxInteger<long, MaxCalc<long>>(source);
private static T MaxInteger<T>(this IEnumerable<T> source) where T : struct, IBinaryInteger<T>
private struct MaxCalc<T> : IMinMaxCalc<T> where T : struct, IBinaryInteger<T>
{
T value;
if (source.TryGetSpan(out ReadOnlySpan<T> span))
{
if (span.IsEmpty)
{
ThrowHelper.ThrowNoElementsException();
}
// Vectorize the search if possible.
int index;
if (Vector.IsHardwareAccelerated && span.Length >= Vector<T>.Count * 2)
{
// The span is at least two vectors long. Create a vector from the first N elements,
// and then repeatedly compare that against the next vector from the span. At the end,
// the resulting vector will contain the maximum values found, and we then need only
// to find the max of those.
var maxes = new Vector<T>(span);
index = Vector<T>.Count;
do
{
maxes = Vector.Max(maxes, new Vector<T>(span.Slice(index)));
index += Vector<T>.Count;
}
while (index + Vector<T>.Count <= span.Length);
value = maxes[0];
for (int i = 1; i < Vector<T>.Count; i++)
{
if (maxes[i] > value)
{
value = maxes[i];
}
}
}
else
{
value = span[0];
index = 1;
}
// Iterate through the remaining elements, comparing against the max.
for (int i = index; (uint)i < (uint)span.Length; i++)
{
if (span[i] > value)
{
value = span[i];
}
}
return value;
}
using (IEnumerator<T> e = source.GetEnumerator())
{
if (!e.MoveNext())
{
ThrowHelper.ThrowNoElementsException();
}
public static bool Compare(T left, T right) => left > right;
public static Vector128<T> Compare(Vector128<T> left, Vector128<T> right) => Vector128.Max(left, right);
public static Vector256<T> Compare(Vector256<T> left, Vector256<T> right) => Vector256.Max(left, right);
}
value = e.Current;
while (e.MoveNext())
{
T x = e.Current;
if (x > value)
{
value = x;
}
}
}
public static int? Max(this IEnumerable<int?> source) => MaxInteger(source);
return value;
}
public static long? Max(this IEnumerable<long?> source) => MaxInteger(source);
private static T? MaxInteger<T>(this IEnumerable<T?> source) where T : struct, IBinaryInteger<T>
{
......@@ -386,6 +318,17 @@ public static decimal Max(this IEnumerable<decimal> source)
comparer ??= Comparer<TSource>.Default;
if (typeof(TSource) == typeof(byte) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<byte, MaxCalc<byte>>((IEnumerable<byte>)source);
if (typeof(TSource) == typeof(sbyte) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<sbyte, MaxCalc<sbyte>>((IEnumerable<sbyte>)source);
if (typeof(TSource) == typeof(ushort) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<ushort, MaxCalc<ushort>>((IEnumerable<ushort>)source);
if (typeof(TSource) == typeof(short) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<short, MaxCalc<short>>((IEnumerable<short>)source);
if (typeof(TSource) == typeof(uint) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<uint, MaxCalc<uint>>((IEnumerable<uint>)source);
if (typeof(TSource) == typeof(int) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<int, MaxCalc<int>>((IEnumerable<int>)source);
if (typeof(TSource) == typeof(ulong) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<ulong, MaxCalc<ulong>>((IEnumerable<ulong>)source);
if (typeof(TSource) == typeof(long) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<long, MaxCalc<long>>((IEnumerable<long>)source);
if (typeof(TSource) == typeof(nuint) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<nuint, MaxCalc<nuint>>((IEnumerable<nuint>)source);
if (typeof(TSource) == typeof(nint) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<nint, MaxCalc<nint>>((IEnumerable<nint>)source);
TSource? value = default;
using (IEnumerator<TSource> e = source.GetEnumerator())
{
......
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
namespace System.Linq
{
public static partial class Enumerable
{
private interface IMinMaxCalc<T> where T : struct, IBinaryInteger<T>
{
public static abstract bool Compare(T left, T right);
public static abstract Vector128<T> Compare(Vector128<T> left, Vector128<T> right);
public static abstract Vector256<T> Compare(Vector256<T> left, Vector256<T> right);
}
private static T MinMaxInteger<T, TMinMax>(this IEnumerable<T> source)
where T : struct, IBinaryInteger<T>
where TMinMax : IMinMaxCalc<T>
{
T value;
if (source.TryGetSpan(out ReadOnlySpan<T> span))
{
if (span.IsEmpty)
{
ThrowHelper.ThrowNoElementsException();
}
if (!Vector128.IsHardwareAccelerated || span.Length < Vector128<T>.Count)
{
value = span[0];
for (int i = 1; i < span.Length; i++)
{
if (TMinMax.Compare(span[i], value))
{
value = span[i];
}
}
}
else if (!Vector256.IsHardwareAccelerated || span.Length < Vector256<T>.Count)
{
ref T current = ref MemoryMarshal.GetReference(span);
ref T lastVectorStart = ref Unsafe.Add(ref current, span.Length - Vector128<T>.Count);
Vector128<T> best = Vector128.LoadUnsafe(ref current);
current = ref Unsafe.Add(ref current, Vector128<T>.Count);
while (Unsafe.IsAddressLessThan(ref current, ref lastVectorStart))
{
best = TMinMax.Compare(best, Vector128.LoadUnsafe(ref current));
current = ref Unsafe.Add(ref current, Vector128<T>.Count);
}
best = TMinMax.Compare(best, Vector128.LoadUnsafe(ref lastVectorStart));
value = best[0];
for (int i = 1; i < Vector128<T>.Count; i++)
{
if (TMinMax.Compare(best[i], value))
{
value = best[i];
}
}
}
else
{
ref T current = ref MemoryMarshal.GetReference(span);
ref T lastVectorStart = ref Unsafe.Add(ref current, span.Length - Vector256<T>.Count);
Vector256<T> best = Vector256.LoadUnsafe(ref current);
current = ref Unsafe.Add(ref current, Vector256<T>.Count);
while (Unsafe.IsAddressLessThan(ref current, ref lastVectorStart))
{
best = TMinMax.Compare(best, Vector256.LoadUnsafe(ref current));
current = ref Unsafe.Add(ref current, Vector256<T>.Count);
}
best = TMinMax.Compare(best, Vector256.LoadUnsafe(ref lastVectorStart));
value = best[0];
for (int i = 1; i < Vector256<T>.Count; i++)
{
if (TMinMax.Compare(best[i], value))
{
value = best[i];
}
}
}
}
else
{
using (IEnumerator<T> e = source.GetEnumerator())
{
if (!e.MoveNext())
{
ThrowHelper.ThrowNoElementsException();
}
value = e.Current;
while (e.MoveNext())
{
T x = e.Current;
if (TMinMax.Compare(x, value))
{
value = x;
}
}
}
}
return value;
}
}
}
......@@ -3,94 +3,26 @@
using System.Collections.Generic;
using System.Numerics;
using System.Runtime.Intrinsics;
namespace System.Linq
{
public static partial class Enumerable
{
public static int Min(this IEnumerable<int> source) => MinInteger(source);
public static int Min(this IEnumerable<int> source) => MinMaxInteger<int, MinCalc<int>>(source);
public static long Min(this IEnumerable<long> source) => MinInteger(source);
public static long Min(this IEnumerable<long> source) => MinMaxInteger<long, MinCalc<long>>(source);
public static int? Min(this IEnumerable<int?> source) => MinInteger(source);
public static long? Min(this IEnumerable<long?> source) => MinInteger(source);
private static T MinInteger<T>(this IEnumerable<T> source) where T : struct, IBinaryInteger<T>
private struct MinCalc<T> : IMinMaxCalc<T> where T : struct, IBinaryInteger<T>
{
T value;
if (source.TryGetSpan(out ReadOnlySpan<T> span))
{
if (span.IsEmpty)
{
ThrowHelper.ThrowNoElementsException();
}
// Vectorize the search if possible.
int index;
if (Vector.IsHardwareAccelerated && span.Length >= Vector<T>.Count * 2)
{
// The span is at least two vectors long. Create a vector from the first N elements,
// and then repeatedly compare that against the next vector from the span. At the end,
// the resulting vector will contain the minimum values found, and we then need only
// to find the min of those.
var mins = new Vector<T>(span);
index = Vector<T>.Count;
do
{
mins = Vector.Min(mins, new Vector<T>(span.Slice(index)));
index += Vector<T>.Count;
}
while (index + Vector<T>.Count <= span.Length);
value = mins[0];
for (int i = 1; i < Vector<T>.Count; i++)
{
if (mins[i] < value)
{
value = mins[i];
}
}
}
else
{
value = span[0];
index = 1;
}
// Iterate through the remaining elements, comparing against the min.
for (int i = index; (uint)i < (uint)span.Length; i++)
{
if (span[i] < value)
{
value = span[i];
}
}
return value;
}
using (IEnumerator<T> e = source.GetEnumerator())
{
if (!e.MoveNext())
{
ThrowHelper.ThrowNoElementsException();
}
public static bool Compare(T left, T right) => left < right;
public static Vector128<T> Compare(Vector128<T> left, Vector128<T> right) => Vector128.Min(left, right);
public static Vector256<T> Compare(Vector256<T> left, Vector256<T> right) => Vector256.Min(left, right);
}
value = e.Current;
while (e.MoveNext())
{
T x = e.Current;
if (x < value)
{
value = x;
}
}
}
public static int? Min(this IEnumerable<int?> source) => MinInteger(source);
return value;
}
public static long? Min(this IEnumerable<long?> source) => MinInteger(source);
private static T? MinInteger<T>(this IEnumerable<T?> source) where T : struct, IBinaryInteger<T>
{
......@@ -351,9 +283,9 @@ public static decimal Min(this IEnumerable<decimal> source)
/// <exception cref="ArgumentNullException"><paramref name="source" /> is <see langword="null" />.</exception>
/// <exception cref="ArgumentException">No object in <paramref name="source" /> implements the <see cref="System.IComparable" /> or <see cref="System.IComparable{T}" /> interface.</exception>
/// <remarks>
/// <para>If type <typeparamref name="TSource" /> implements <see cref="System.IComparable{T}" />, the <see cref="Max{T}(IEnumerable{T})" /> method uses that implementation to compare values. Otherwise, if type <typeparamref name="TSource" /> implements <see cref="System.IComparable" />, that implementation is used to compare values.</para>
/// <para>If type <typeparamref name="TSource" /> implements <see cref="System.IComparable{T}" />, the <see cref="Min{T}(IEnumerable{T})" /> method uses that implementation to compare values. Otherwise, if type <typeparamref name="TSource" /> implements <see cref="System.IComparable" />, that implementation is used to compare values.</para>
/// <para>If <typeparamref name="TSource" /> is a reference type and the source sequence is empty or contains only values that are <see langword="null" />, this method returns <see langword="null" />.</para>
/// <para>In Visual Basic query expression syntax, an `Aggregate Into Max()` clause translates to an invocation of <see cref="O:Enumerable.Max" />.</para>
/// <para>In Visual Basic query expression syntax, an `Aggregate Into Min()` clause translates to an invocation of <see cref="O:Enumerable.Min" />.</para>
/// </remarks>
public static TSource? Min<TSource>(this IEnumerable<TSource> source, IComparer<TSource>? comparer)
{
......@@ -364,6 +296,17 @@ public static decimal Min(this IEnumerable<decimal> source)
comparer ??= Comparer<TSource>.Default;
if (typeof(TSource) == typeof(byte) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<byte, MinCalc<byte>>((IEnumerable<byte>)source);
if (typeof(TSource) == typeof(sbyte) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<sbyte, MinCalc<sbyte>>((IEnumerable<sbyte>)source);
if (typeof(TSource) == typeof(ushort) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<ushort, MinCalc<ushort>>((IEnumerable<ushort>)source);
if (typeof(TSource) == typeof(short) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<short, MinCalc<short>>((IEnumerable<short>)source);
if (typeof(TSource) == typeof(uint) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<uint, MinCalc<uint>>((IEnumerable<uint>)source);
if (typeof(TSource) == typeof(int) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<int, MinCalc<int>>((IEnumerable<int>)source);
if (typeof(TSource) == typeof(ulong) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<ulong, MinCalc<ulong>>((IEnumerable<ulong>)source);
if (typeof(TSource) == typeof(long) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<long, MinCalc<long>>((IEnumerable<long>)source);
if (typeof(TSource) == typeof(nuint) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<nuint, MinCalc<nuint>>((IEnumerable<nuint>)source);
if (typeof(TSource) == typeof(nint) && comparer == Comparer<TSource>.Default) return (TSource)(object)MinMaxInteger<nint, MinCalc<nint>>((IEnumerable<nint>)source);
TSource? value = default;
using (IEnumerator<TSource> e = source.GetEnumerator())
{
......
......@@ -2,12 +2,74 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Numerics;
using Xunit;
namespace System.Linq.Tests
{
public class MaxTests : EnumerableTests
{
public static IEnumerable<object[]> Max_AllTypes_TestData()
{
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i)), (byte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i).ToArray()), (byte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i)), (ushort)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i).ToArray()), (ushort)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (short)i)), (short)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (short)i).ToArray()), (short)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (uint)i)), (uint)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (uint)i).ToArray()), (uint)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (int)i)), (int)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (int)i).ToArray()), (int)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ulong)i)), (ulong)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ulong)i).ToArray()), (ulong)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (long)i)), (long)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (long)i).ToArray()), (long)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (float)i)), (float)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (float)i).ToArray()), (float)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (double)i)), (double)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (double)i).ToArray()), (double)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (decimal)i)), (decimal)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (decimal)i).ToArray()), (decimal)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (nuint)i)), (nuint)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (nuint)i).ToArray()), (nuint)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (nint)i)), (nint)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (nint)i).ToArray()), (nint)(length + length - 1) };
}
}
[Theory]
[MemberData(nameof(Max_AllTypes_TestData))]
public void Max_AllTypes<T>(IEnumerable<T> source, T expected) where T : INumber<T>
{
Assert.Equal(expected, source.Max());
Assert.Equal(expected, source.Max(comparer: null));
Assert.Equal(expected, source.Max(Comparer<T>.Default));
Assert.Equal(expected, source.Max(Comparer<T>.Create(Comparer<T>.Default.Compare)));
T first = source.First();
Assert.Equal(first, source.Max(Comparer<T>.Create((x, y) => x == first ? 1 : -1)));
Assert.Equal(expected + T.One, source.Max(x => x + T.One));
}
[Fact]
public void SameResultsRepeatCallsIntQuery()
{
......@@ -64,12 +126,6 @@ public static IEnumerable<object[]> Max_Int_TestData()
yield return new object[] { new TestEnumerable<int>(array), expected };
yield return new object[] { array, expected };
}
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length)), length + length - 1 };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).ToArray()), length + length - 1 };
}
}
[Theory]
......@@ -100,12 +156,6 @@ public static IEnumerable<object[]> Max_Long_TestData()
yield return new object[] { new TestEnumerable<long>(array), expected };
yield return new object[] { array, expected };
}
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (long)i)), (long)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (long)i).ToArray()), (long)(length + length - 1) };
}
}
[Theory]
......@@ -167,12 +217,6 @@ public static IEnumerable<object[]> Max_Float_TestData()
yield return new object[] { new TestEnumerable<float>(array), expected };
yield return new object[] { array, expected };
}
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (float)i)), (float)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (float)i).ToArray()), (float)(length + length - 1) };
}
}
[Theory]
......@@ -195,6 +239,8 @@ public void Max_Float_EmptySource_ThrowsInvalidOperationException()
{
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<float>().Max());
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<float>().Max(x => x));
Assert.Throws<InvalidOperationException>(() => Array.Empty<float>().Max());
Assert.Throws<InvalidOperationException>(() => new List<float>().Max());
}
[Fact]
......@@ -251,12 +297,6 @@ public static IEnumerable<object[]> Max_Double_TestData()
yield return new object[] { new TestEnumerable<double>(array), expected };
yield return new object[] { array, expected };
}
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (double)i)), (double)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (double)i).ToArray()), (double)(length + length - 1) };
}
}
[Theory]
......@@ -279,6 +319,8 @@ public void Max_Double_EmptySource_ThrowsInvalidOperationException()
{
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<double>().Max());
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<double>().Max(x => x));
Assert.Throws<InvalidOperationException>(() => Array.Empty<double>().Max());
Assert.Throws<InvalidOperationException>(() => new List<double>().Max());
}
[Fact]
......@@ -321,12 +363,6 @@ public static IEnumerable<object[]> Max_Decimal_TestData()
yield return new object[] { new TestEnumerable<decimal>(array), expected };
yield return new object[] { array, expected };
}
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (decimal)i)), (decimal)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (decimal)i).ToArray()), (decimal)(length + length - 1) };
}
}
[Theory]
......
......@@ -2,12 +2,74 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Numerics;
using Xunit;
namespace System.Linq.Tests
{
public class MinTests : EnumerableTests
{
public static IEnumerable<object[]> Min_AllTypes_TestData()
{
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i)), (byte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i).ToArray()), (byte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i)), (ushort)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i).ToArray()), (ushort)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (short)i)), (short)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (short)i).ToArray()), (short)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (uint)i)), (uint)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (uint)i).ToArray()), (uint)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (int)i)), (int)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (int)i).ToArray()), (int)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ulong)i)), (ulong)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ulong)i).ToArray()), (ulong)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (long)i)), (long)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (long)i).ToArray()), (long)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (float)i)), (float)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (float)i).ToArray()), (float)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (double)i)), (double)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (double)i).ToArray()), (double)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (decimal)i)), (decimal)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (decimal)i).ToArray()), (decimal)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (nuint)i)), (nuint)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (nuint)i).ToArray()), (nuint)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (nint)i)), (nint)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (nint)i).ToArray()), (nint)length };
}
}
[Theory]
[MemberData(nameof(Min_AllTypes_TestData))]
public void Min_AllTypes<T>(IEnumerable<T> source, T expected) where T : INumber<T>
{
Assert.Equal(expected, source.Min());
Assert.Equal(expected, source.Min(comparer: null));
Assert.Equal(expected, source.Min(Comparer<T>.Default));
Assert.Equal(expected, source.Min(Comparer<T>.Create(Comparer<T>.Default.Compare)));
T first = source.First();
Assert.Equal(first, source.Min(Comparer<T>.Create((x, y) => x == first ? -1 : 1)));
Assert.Equal(expected + T.One, source.Min(x => x + T.One));
}
[Fact]
public void SameResultsRepeatCallsIntQuery()
{
......@@ -49,12 +111,6 @@ public static IEnumerable<object[]> Min_Int_TestData()
yield return new object[] { new TestEnumerable<int>(array), expected };
yield return new object[] { array, expected };
}
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length)), length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).ToArray()), length };
}
}
[Theory]
......@@ -101,12 +157,6 @@ public static IEnumerable<object[]> Min_Long_TestData()
yield return new object[] { new TestEnumerable<long>(array), expected };
yield return new object[] { array, expected };
}
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (long)i)), (long)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (long)i).ToArray()), (long)length };
}
}
[Theory]
......@@ -175,12 +225,6 @@ public static IEnumerable<object[]> Min_Float_TestData()
// a long time.
yield return new object[] { Enumerable.Repeat(float.NaN, int.MaxValue), float.NaN };
yield return new object[] { Enumerable.Repeat(float.NaN, 3).ToArray(), float.NaN };
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (float)i)), (float)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (float)i).ToArray()), (float)length };
}
}
[Theory]
......@@ -247,12 +291,6 @@ public static IEnumerable<object[]> Min_Double_TestData()
// Without this optimization, we would iterate through int.MaxValue elements, which takes
// a long time.
yield return new object[] { Enumerable.Repeat(double.NaN, int.MaxValue), double.NaN };
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (double)i)), (double)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (double)i).ToArray()), (double)length };
}
}
[Theory]
......@@ -299,12 +337,6 @@ public static IEnumerable<object[]> Min_Decimal_TestData()
yield return new object[] { new TestEnumerable<decimal>(array), expected };
yield return new object[] { array, expected };
}
for (int length = 2; length < 33; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (decimal)i)), (decimal)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (decimal)i).ToArray()), (decimal)length };
}
}
[Theory]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册