未验证 提交 c63f06cd 编写于 作者: D Dixin 提交者: GitHub

Add LINQ APIs for Index and Range (#28776) (#48559)

* Implement dotnet/runtime#28776: Implement LINQ APIs for index and range.

* Implement dotnet#28776: Implement LINQ APIs for index and range.

* Implement dotnet#28776: LINQ APIs for index and range.

* Implement dotnet#28776: LINQ APIs for index and range.

* Implement dotnet#28776: LINQ APIs for index and range.

* Implement dotnet#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Code review update.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Code review update.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Code review update.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Update ElementAt, keep the original behavior.

* Implement dotnet#28776: LINQ APIs for index and range. Update ElementAt, keep the original behavior.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Update ElementAt, keep the original behavior.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Add unit tests for ElementAt, ElementAtOrDefault.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Update unit tests.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Update unit tests.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Update unit tests.

* Implement dotnet/runtime#28776: LINQ APIs for index and range. Update for merge.

* Implement #28776: LINQ APIs for index and range. Update unit tests.
上级 5bc40f58
......@@ -80,7 +80,9 @@ public static partial class Queryable
public static System.Linq.IQueryable<TSource> Distinct<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static System.Linq.IQueryable<TSource> Distinct<TSource>(this System.Linq.IQueryable<TSource> source, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static TSource? ElementAtOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, int index) { throw null; }
public static TSource? ElementAtOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Index index) { throw null; }
public static TSource ElementAt<TSource>(this System.Linq.IQueryable<TSource> source, int index) { throw null; }
public static TSource ElementAt<TSource>(this System.Linq.IQueryable<TSource> source, System.Index index) { throw null; }
public static System.Linq.IQueryable<TSource> Except<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2) { throw null; }
public static System.Linq.IQueryable<TSource> Except<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static TSource? FirstOrDefault<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
......@@ -158,6 +160,7 @@ public static partial class Queryable
public static System.Linq.IQueryable<TSource> TakeWhile<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static System.Linq.IQueryable<TSource> TakeWhile<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, int, bool>> predicate) { throw null; }
public static System.Linq.IQueryable<TSource> Take<TSource>(this System.Linq.IQueryable<TSource> source, int count) { throw null; }
public static System.Linq.IQueryable<TSource> Take<TSource>(this System.Linq.IQueryable<TSource> source, System.Range range) { throw null; }
public static System.Linq.IOrderedQueryable<TSource> ThenByDescending<TSource, TKey>(this System.Linq.IOrderedQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TKey>> keySelector) { throw null; }
public static System.Linq.IOrderedQueryable<TSource> ThenByDescending<TSource, TKey>(this System.Linq.IOrderedQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TKey>> keySelector, System.Collections.Generic.IComparer<TKey>? comparer) { throw null; }
public static System.Linq.IOrderedQueryable<TSource> ThenBy<TSource, TKey>(this System.Linq.IOrderedQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TKey>> keySelector) { throw null; }
......
......@@ -388,11 +388,32 @@ public static IQueryable<TSource> Take<TSource>(this IQueryable<TSource> source,
return source.Provider.CreateQuery<TSource>(
Expression.Call(
null,
CachedReflectionInfo.Take_TSource_2(typeof(TSource)),
CachedReflectionInfo.Take_Int32_TSource_2(typeof(TSource)),
source.Expression, Expression.Constant(count)
));
}
/// <summary>Returns a specified range of contiguous elements from a sequence.</summary>
/// <param name="source">The sequence to return elements from.</param>
/// <param name="range">The range of elements to return, which has start and end indexes either from the start or the end.</param>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <exception cref="ArgumentNullException">
/// <paramref name="source" /> is <see langword="null" />.
/// </exception>
/// <returns>An <see cref="IQueryable{T}" /> that contains the specified <paramref name="range" /> of elements from the <paramref name="source" /> sequence.</returns>
[DynamicDependency("Take`1", typeof(Enumerable))]
public static IQueryable<TSource> Take<TSource>(this IQueryable<TSource> source, Range range)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
return source.Provider.CreateQuery<TSource>(
Expression.Call(
null,
CachedReflectionInfo.Take_Range_TSource_2(typeof(TSource)),
source.Expression, Expression.Constant(range)
));
}
[DynamicDependency("TakeWhile`1", typeof(Enumerable))]
public static IQueryable<TSource> TakeWhile<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
......@@ -972,7 +993,33 @@ public static TSource ElementAt<TSource>(this IQueryable<TSource> source, int in
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.ElementAt_TSource_2(typeof(TSource)),
CachedReflectionInfo.ElementAt_Int32_TSource_2(typeof(TSource)),
source.Expression, Expression.Constant(index)
));
}
/// <summary>Returns the element at a specified index in a sequence.</summary>
/// <param name="source">An <see cref="IQueryable{T}" /> to return an element from.</param>
/// <param name="index">The index of the element to retrieve, which is either from the start or the end.</param>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <exception cref="ArgumentNullException">
/// <paramref name="source" /> is <see langword="null" />.
/// </exception>
/// <exception cref="ArgumentOutOfRangeException">
/// <paramref name="index" /> is outside the bounds of the <paramref name="source" /> sequence.
/// </exception>
/// <returns>The element at the specified position in the <paramref name="source" /> sequence.</returns>
[DynamicDependency("ElementAt`1", typeof(Enumerable))]
public static TSource ElementAt<TSource>(this IQueryable<TSource> source, Index index)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (index.IsFromEnd && index.Value == 0)
throw Error.ArgumentOutOfRange(nameof(index));
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.ElementAt_Index_TSource_2(typeof(TSource)),
source.Expression, Expression.Constant(index)
));
}
......@@ -985,7 +1032,30 @@ public static TSource ElementAt<TSource>(this IQueryable<TSource> source, int in
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.ElementAtOrDefault_TSource_2(typeof(TSource)),
CachedReflectionInfo.ElementAtOrDefault_Int32_TSource_2(typeof(TSource)),
source.Expression, Expression.Constant(index)
));
}
/// <summary>Returns the element at a specified index in a sequence or a default value if the index is out of range.</summary>
/// <param name="source">An <see cref="IQueryable{T}" /> to return an element from.</param>
/// <param name="index">The index of the element to retrieve, which is either from the start or the end.</param>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <exception cref="ArgumentNullException">
/// <paramref name="source" /> is <see langword="null" />.
/// </exception>
/// <returns>
/// <see langword="default" /> if index is outside the bounds of the <paramref name="source" /> sequence; otherwise, the element at the specified position in the <paramref name="source" /> sequence.
/// </returns>
[DynamicDependency("ElementAtOrDefault`1", typeof(Enumerable))]
public static TSource? ElementAtOrDefault<TSource>(this IQueryable<TSource> source, Index index)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.ElementAtOrDefault_Index_TSource_2(typeof(TSource)),
source.Expression, Expression.Constant(index)
));
}
......
......@@ -8,11 +8,19 @@ namespace System.Linq.Tests
public class ElementAtOrDefaultTests : EnumerableBasedTests
{
[Fact]
public void IndexNegative()
public void IndexInvalid()
{
int?[] source = { 9, 8 };
Assert.Null(source.AsQueryable().ElementAtOrDefault(-1));
Assert.Null(source.AsQueryable().ElementAtOrDefault(int.MinValue));
Assert.Null(source.AsQueryable().ElementAtOrDefault(3));
Assert.Null(source.AsQueryable().ElementAtOrDefault(int.MaxValue));
Assert.Null(source.AsQueryable().ElementAtOrDefault(^3));
Assert.Null(source.AsQueryable().ElementAtOrDefault(^int.MaxValue));
Assert.Null(source.AsQueryable().ElementAtOrDefault(new Index(3)));
Assert.Null(source.AsQueryable().ElementAtOrDefault(new Index(int.MaxValue)));
}
[Fact]
......@@ -20,7 +28,9 @@ public void IndexEqualsCount()
{
int[] source = { 1, 2, 3, 4 };
Assert.Equal(default(int), source.AsQueryable().ElementAtOrDefault(source.Length));
Assert.Equal(default, source.AsQueryable().ElementAtOrDefault(source.Length));
Assert.Equal(default, source.AsQueryable().ElementAtOrDefault(new Index(source.Length)));
Assert.Equal(default, source.AsQueryable().ElementAtOrDefault(^0));
}
[Fact]
......@@ -28,7 +38,9 @@ public void EmptyIndexZero()
{
int[] source = { };
Assert.Equal(default(int), source.AsQueryable().ElementAtOrDefault(0));
Assert.Equal(default, source.AsQueryable().ElementAtOrDefault(0));
Assert.Equal(default, source.AsQueryable().ElementAtOrDefault(new Index(0)));
Assert.Equal(default, source.AsQueryable().ElementAtOrDefault(^0));
}
[Fact]
......@@ -37,6 +49,8 @@ public void SingleElementIndexZero()
int[] source = { -4 };
Assert.Equal(-4, source.ElementAtOrDefault(0));
Assert.Equal(-4, source.ElementAtOrDefault(new Index(0)));
Assert.Equal(-4, source.ElementAtOrDefault(^1));
}
[Fact]
......@@ -45,19 +59,29 @@ public void ManyElementsIndexTargetsLast()
int[] source = { 9, 8, 0, -5, 10 };
Assert.Equal(10, source.AsQueryable().ElementAtOrDefault(source.Length - 1));
Assert.Equal(10, source.AsQueryable().ElementAtOrDefault(new Index(source.Length - 1)));
Assert.Equal(10, source.AsQueryable().ElementAtOrDefault(^1));
}
[Fact]
public void NullSource()
{
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).ElementAtOrDefault(2));
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).ElementAtOrDefault(new Index(2)));
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).ElementAtOrDefault(^2));
}
[Fact]
public void ElementAtOrDefault()
{
var val = (new int[] { 0, 2, 1 }).AsQueryable().ElementAtOrDefault(1);
Assert.Equal(2, val);
var val1 = new[] { 0, 2, 1 }.AsQueryable().ElementAtOrDefault(1);
Assert.Equal(2, val1);
var val2 = new[] { 0, 2, 1 }.AsQueryable().ElementAtOrDefault(new Index(1));
Assert.Equal(2, val2);
var val3 = new[] { 0, 2, 1 }.AsQueryable().ElementAtOrDefault(^2);
Assert.Equal(2, val3);
}
}
}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using Xunit;
namespace System.Linq.Tests
......@@ -14,6 +13,7 @@ public void IndexNegative()
int?[] source = { 9, 8 };
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => source.AsQueryable().ElementAt(-1));
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => source.AsQueryable().ElementAt(^3));
}
[Fact]
......@@ -22,6 +22,8 @@ public void IndexEqualsCount()
int[] source = { 1, 2, 3, 4 };
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => source.AsQueryable().ElementAt(source.Length));
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => source.AsQueryable().ElementAt(new Index(source.Length)));
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => source.AsQueryable().ElementAt(^0));
}
[Fact]
......@@ -30,6 +32,8 @@ public void EmptyIndexZero()
int[] source = { };
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => source.AsQueryable().ElementAt(0));
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => source.AsQueryable().ElementAt(new Index(0)));
AssertExtensions.Throws<ArgumentOutOfRangeException>("index", () => source.AsQueryable().ElementAt(^0));
}
[Fact]
......@@ -38,6 +42,8 @@ public void SingleElementIndexZero()
int[] source = { -4 };
Assert.Equal(-4, source.AsQueryable().ElementAt(0));
Assert.Equal(-4, source.AsQueryable().ElementAt(new Index(0)));
Assert.Equal(-4, source.AsQueryable().ElementAt(^1));
}
[Fact]
......@@ -46,19 +52,29 @@ public void ManyElementsIndexTargetsLast()
int[] source = { 9, 8, 0, -5, 10 };
Assert.Equal(10, source.AsQueryable().ElementAt(source.Length - 1));
Assert.Equal(10, source.AsQueryable().ElementAt(source.Length - 1));
Assert.Equal(10, source.AsQueryable().ElementAt(^1));
}
[Fact]
public void NullSource()
{
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).ElementAt(2));
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).ElementAt(new Index(2)));
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).ElementAt(^2));
}
[Fact]
public void ElementAt()
{
var val = (new int[] { 0, 2, 1 }).AsQueryable().ElementAt(1);
Assert.Equal(2, val);
var val1 = new[] { 0, 2, 1 }.AsQueryable().ElementAt(1);
Assert.Equal(2, val1);
var val2 = new[] { 0, 2, 1 }.AsQueryable().ElementAt(new Index(1));
Assert.Equal(2, val2);
var val3 = new[] { 0, 2, 1 }.AsQueryable().ElementAt(^2);
Assert.Equal(2, val3);
}
}
}
......@@ -14,6 +14,10 @@ public void SourceNonEmptyTakeAllButOne()
int[] expected = { 2, 5, 9 };
Assert.Equal(expected, source.AsQueryable().Take(3));
Assert.Equal(expected, source.AsQueryable().Take(0..3));
Assert.Equal(expected, source.AsQueryable().Take(^4..3));
Assert.Equal(expected, source.AsQueryable().Take(0..^1));
Assert.Equal(expected, source.AsQueryable().Take(^4..^1));
}
[Fact]
......@@ -21,13 +25,41 @@ public void ThrowsOnNullSource()
{
IQueryable<int> source = null;
AssertExtensions.Throws<ArgumentNullException>("source", () => source.Take(5));
AssertExtensions.Throws<ArgumentNullException>("source", () => source.Take(0..5));
AssertExtensions.Throws<ArgumentNullException>("source", () => source.Take(^5..5));
AssertExtensions.Throws<ArgumentNullException>("source", () => source.Take(0..^0));
AssertExtensions.Throws<ArgumentNullException>("source", () => source.Take(^5..^0));
}
[Fact]
public void Take()
{
var count = (new int[] { 0, 1, 2 }).AsQueryable().Take(2).Count();
Assert.Equal(2, count);
var count1 = new[] { 0, 1, 2 }.AsQueryable().Take(2).Count();
Assert.Equal(2, count1);
var count2 = new[] { 0, 1, 2 }.AsQueryable().Take(0..2).Count();
Assert.Equal(2, count2);
var count3 = new[] { 0, 1, 2 }.AsQueryable().Take(^3..2).Count();
Assert.Equal(2, count3);
var count4 = new[] { 0, 1, 2 }.AsQueryable().Take(0..^1).Count();
Assert.Equal(2, count4);
var count5 = new[] { 0, 1, 2 }.AsQueryable().Take(^3..^1).Count();
Assert.Equal(2, count5);
var count6 = new[] { 0, 1, 2 }.AsQueryable().Take(1..3).Count();
Assert.Equal(2, count6);
var count7 = new[] { 0, 1, 2 }.AsQueryable().Take(^2..3).Count();
Assert.Equal(2, count7);
var count8 = new[] { 0, 1, 2 }.AsQueryable().Take(1..^0).Count();
Assert.Equal(2, count8);
var count9 = new[] { 0, 1, 2 }.AsQueryable().Take(^2..^0).Count();
Assert.Equal(2, count9);
}
}
}
......@@ -61,7 +61,7 @@ public static void CachedReflectionInfoMethodsNoAnnotations()
.Where(m => m.GetParameters().Length > 0);
// If you are adding a new method to this class, ensure the method meets these requirements
Assert.Equal(108, methods.Count());
Assert.Equal(111, methods.Count());
foreach (MethodInfo method in methods)
{
ParameterInfo[] parameters = method.GetParameters();
......
......@@ -52,7 +52,9 @@ public static partial class Enumerable
public static System.Collections.Generic.IEnumerable<TSource> Distinct<TSource>(this System.Collections.Generic.IEnumerable<TSource> source) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Distinct<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static TSource? ElementAtOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, int index) { throw null; }
public static TSource? ElementAtOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Index index) { throw null; }
public static TSource ElementAt<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, int index) { throw null; }
public static TSource ElementAt<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Index index) { throw null; }
public static System.Collections.Generic.IEnumerable<TResult> Empty<TResult>() { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Except<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Except<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
......@@ -173,6 +175,7 @@ public static partial class Enumerable
public static System.Collections.Generic.IEnumerable<TSource> TakeWhile<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> TakeWhile<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, int, bool> predicate) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Take<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, int count) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Take<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Range range) { throw null; }
public static System.Linq.IOrderedEnumerable<TSource> ThenByDescending<TSource, TKey>(this System.Linq.IOrderedEnumerable<TSource> source, System.Func<TSource, TKey> keySelector) { throw null; }
public static System.Linq.IOrderedEnumerable<TSource> ThenByDescending<TSource, TKey>(this System.Linq.IOrderedEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Collections.Generic.IComparer<TKey>? comparer) { throw null; }
public static System.Linq.IOrderedEnumerable<TSource> ThenBy<TSource, TKey>(this System.Linq.IOrderedEnumerable<TSource> source, System.Func<TSource, TKey> keySelector) { throw null; }
......
......@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
namespace System.Linq
......@@ -23,34 +24,54 @@ public static TSource ElementAt<TSource>(this IEnumerable<TSource> source, int i
return element!;
}
}
else
else if (source is IList<TSource> list)
{
if (source is IList<TSource> list)
{
return list[index];
}
if (index >= 0)
{
using (IEnumerator<TSource> e = source.GetEnumerator())
{
while (e.MoveNext())
{
if (index == 0)
{
return e.Current;
}
index--;
}
}
}
return list[index];
}
else if (TryGetElement(source, index, out TSource? element))
{
return element;
}
ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index);
return default;
}
/// <summary>Returns the element at a specified index in a sequence.</summary>
/// <param name="source">An <see cref="IEnumerable{T}" /> to return an element from.</param>
/// <param name="index">The index of the element to retrieve, which is either from the start or the end.</param>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <exception cref="ArgumentNullException">
/// <paramref name="source" /> is <see langword="null" />.</exception>
/// <exception cref="ArgumentOutOfRangeException">
/// <paramref name="index" /> is outside the bounds of the <paramref name="source" /> sequence.
/// </exception>
/// <returns>The element at the specified position in the <paramref name="source" /> sequence.</returns>
public static TSource ElementAt<TSource>(this IEnumerable<TSource> source, Index index)
{
if (source == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
}
if (!index.IsFromEnd)
{
return source.ElementAt(index.Value);
}
if (source.TryGetNonEnumeratedCount(out int count))
{
return source.ElementAt(count - index.Value);
}
if (!TryGetElementFromEnd(source, index.Value, out TSource? element))
{
ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.index);
}
return element;
}
public static TSource? ElementAtOrDefault<TSource>(this IEnumerable<TSource> source, int index)
{
if (source == null)
......@@ -63,33 +84,101 @@ public static TSource ElementAt<TSource>(this IEnumerable<TSource> source, int i
return partition.TryGetElementAt(index, out bool _);
}
if (source is IList<TSource> list)
{
return index >= 0 && index < list.Count ? list[index] : default;
}
TryGetElement(source, index, out TSource? element);
return element;
}
/// <summary>Returns the element at a specified index in a sequence or a default value if the index is out of range.</summary>
/// <param name="source">An <see cref="IEnumerable{T}" /> to return an element from.</param>
/// <param name="index">The index of the element to retrieve, which is either from the start or the end.</param>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <exception cref="ArgumentNullException">
/// <paramref name="source" /> is <see langword="null" />.
/// </exception>
/// <returns>
/// <see langword="default" /> if <paramref name="index" /> is outside the bounds of the <paramref name="source" /> sequence; otherwise, the element at the specified position in the <paramref name="source" /> sequence.
/// </returns>
public static TSource? ElementAtOrDefault<TSource>(this IEnumerable<TSource> source, Index index)
{
if (source == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
}
if (!index.IsFromEnd)
{
return source.ElementAtOrDefault(index.Value);
}
if (source.TryGetNonEnumeratedCount(out int count))
{
return source.ElementAtOrDefault(count - index.Value);
}
TryGetElementFromEnd(source, index.Value, out TSource? element);
return element;
}
private static bool TryGetElement<TSource>(IEnumerable<TSource> source, int index, [MaybeNullWhen(false)] out TSource element)
{
Debug.Assert(source != null);
if (index >= 0)
{
if (source is IList<TSource> list)
using IEnumerator<TSource> e = source.GetEnumerator();
while (e.MoveNext())
{
if (index < list.Count)
if (index == 0)
{
return list[index];
element = e.Current;
return true;
}
index--;
}
else
}
element = default;
return false;
}
private static bool TryGetElementFromEnd<TSource>(IEnumerable<TSource> source, int indexFromEnd, [MaybeNullWhen(false)] out TSource element)
{
Debug.Assert(source != null);
if (indexFromEnd > 0)
{
using IEnumerator<TSource> e = source.GetEnumerator();
if (e.MoveNext())
{
using (IEnumerator<TSource> e = source.GetEnumerator())
Queue<TSource> queue = new();
queue.Enqueue(e.Current);
while (e.MoveNext())
{
while (e.MoveNext())
if (queue.Count == indexFromEnd)
{
if (index == 0)
{
return e.Current;
}
index--;
queue.Dequeue();
}
queue.Enqueue(e.Current);
}
if (queue.Count == indexFromEnd)
{
element = queue.Dequeue();
return true;
}
}
}
return default;
element = default;
return false;
}
}
}
......@@ -20,6 +20,158 @@ public static IEnumerable<TSource> Take<TSource>(this IEnumerable<TSource> sourc
TakeIterator<TSource>(source, count);
}
/// <summary>Returns a specified range of contiguous elements from a sequence.</summary>
/// <param name="source">The sequence to return elements from.</param>
/// <param name="range">The range of elements to return, which has start and end indexes either from the start or the end.</param>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <exception cref="ArgumentNullException">
/// <paramref name="source" /> is <see langword="null" />.
/// </exception>
/// <returns>An <see cref="IEnumerable{T}" /> that contains the specified <paramref name="range" /> of elements from the <paramref name="source" /> sequence.</returns>
public static IEnumerable<TSource> Take<TSource>(this IEnumerable<TSource> source, Range range)
{
if (source == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
}
Index start = range.Start;
Index end = range.End;
bool isStartIndexFromEnd = start.IsFromEnd;
bool isEndIndexFromEnd = end.IsFromEnd;
int startIndex = start.Value;
int endIndex = end.Value;
Debug.Assert(startIndex >= 0);
Debug.Assert(endIndex >= 0);
if (isStartIndexFromEnd)
{
if (startIndex == 0 || (isEndIndexFromEnd && endIndex >= startIndex))
{
return Empty<TSource>();
}
}
else if (!isEndIndexFromEnd)
{
return startIndex >= endIndex
? Empty<TSource>()
: source.Skip(startIndex).Take(endIndex - startIndex);
}
return TakeIterator(source, isStartIndexFromEnd, startIndex, isEndIndexFromEnd, endIndex);
}
private static IEnumerable<TSource> TakeIterator<TSource>(
IEnumerable<TSource> source, bool isStartIndexFromEnd, int startIndex, bool isEndIndexFromEnd, int endIndex)
{
Debug.Assert(source != null);
Debug.Assert(isStartIndexFromEnd
? startIndex > 0 && (!isEndIndexFromEnd || startIndex > endIndex)
: startIndex >= 0 && (isEndIndexFromEnd || startIndex < endIndex));
Debug.Assert(endIndex >= 0);
using IEnumerator<TSource> e = source.GetEnumerator();
if (isStartIndexFromEnd)
{
if (!e.MoveNext())
{
yield break;
}
int index = 0;
Queue<TSource> queue = new();
queue.Enqueue(e.Current);
while (e.MoveNext())
{
checked
{
index++;
}
if (queue.Count == startIndex)
{
queue.Dequeue();
}
queue.Enqueue(e.Current);
}
int count = checked(index + 1);
Debug.Assert(queue.Count == Math.Min(count, startIndex));
startIndex = count - startIndex;
if (startIndex < 0)
{
startIndex = 0;
}
if (isEndIndexFromEnd)
{
endIndex = count - endIndex;
}
else if (endIndex > count)
{
endIndex = count;
}
Debug.Assert(endIndex - startIndex <= queue.Count);
for (int rangeIndex = startIndex; rangeIndex < endIndex; rangeIndex++)
{
yield return queue.Dequeue();
}
}
else
{
int index = 0;
while (index <= startIndex)
{
if (!e.MoveNext())
{
yield break;
}
checked
{
index++;
}
}
if (isEndIndexFromEnd)
{
if (endIndex > 0)
{
Queue<TSource> queue = new();
do
{
if (queue.Count == endIndex)
{
yield return queue.Dequeue();
}
queue.Enqueue(e.Current);
} while (e.MoveNext());
}
else
{
do
{
yield return e.Current;
} while (e.MoveNext());
}
}
else
{
Debug.Assert(index < endIndex);
yield return e.Current;
while (checked(++index) < endIndex && e.MoveNext())
{
yield return e.Current;
}
}
}
}
public static IEnumerable<TSource> TakeWhile<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
{
if (source == null)
......
......@@ -3,6 +3,7 @@
using System.Collections;
using System.Collections.Generic;
using Xunit;
namespace System.Linq.Tests
{
......@@ -241,6 +242,28 @@ protected static IEnumerable<T> FlipIsCollection<T>(IEnumerable<T> source)
{
return source is ICollection<T> ? ForceNotCollection(source) : new List<T>(source);
}
protected static T[] Repeat<T>(Func<int, T> factory, int count)
{
T[] results = new T[count];
for (int index = 0; index < results.Length; index++)
{
results[index] = factory(index);
}
return results;
}
protected static IEnumerable<T> ListPartitionOrEmpty<T>(IList<T> source) // Or Empty
{
var listPartition = source.Skip(0);
return listPartition;
}
protected static IEnumerable<T> EnumerablePartitionOrEmpty<T>(IEnumerable<T> source) // Or Empty
{
var enumerablePartition = ForceNotCollection(source).Skip(0);
return enumerablePartition;
}
protected struct StringWithIntArray
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册