未验证 提交 b42626f3 编写于 作者: S Stephen Toub 提交者: GitHub

Use CollectionsMarshal.SetCount in LINQ to deduplicate ToArray/ToList implementations (#85288)

* Use CollectionsMarshal.SetCount in LINQ to deduplicate ToArray/ToList implementations

* Fix BinaryFormatter test

The test was implemented using ToList, and `List<T>` serialization serializes out its version field.  As a result, because we're now optimizing creation and not incrementing the version as much in ToList, the blob didn't match.
上级 8e01ad8c
......@@ -11,6 +11,17 @@ public static partial class Enumerable
{
public static IEnumerable<TSource> AsEnumerable<TSource>(this IEnumerable<TSource> source) => source;
/// <summary>
/// Sets the <paramref name="list"/>'s <see cref="List{T}.Count"/> to be <paramref name="count"/>
/// and returns the relevant portion of the list's backing array as a span.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static Span<T> SetCountAndGetSpan<T>(List<T> list, int count)
{
CollectionsMarshal.SetCount(list, count);
return CollectionsMarshal.AsSpan(list);
}
/// <summary>Validates that source is not null and then tries to extract a span from the source.</summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)] // fast type checks that don't add a lot of overhead
private static bool TryGetSpan<TSource>(this IEnumerable<TSource> source, out ReadOnlySpan<TSource> span)
......
......@@ -20,12 +20,7 @@ public virtual TElement[] ToArray()
}
TElement[] array = new TElement[count];
int[] map = SortedMap(buffer);
for (int i = 0; i < array.Length; i++)
{
array[i] = buffer._items[map[i]];
}
Fill(buffer, array);
return array;
}
......@@ -36,16 +31,21 @@ public virtual List<TElement> ToList()
List<TElement> list = new List<TElement>(count);
if (count > 0)
{
int[] map = SortedMap(buffer);
for (int i = 0; i != count; i++)
{
list.Add(buffer._items[map[i]]);
}
Fill(buffer, Enumerable.SetCountAndGetSpan(list, count));
}
return list;
}
private void Fill(Buffer<TElement> buffer, Span<TElement> destination)
{
int[] map = SortedMap(buffer);
for (int i = 0; i < destination.Length; i++)
{
destination[i] = buffer._items[map[i]];
}
}
public int GetCount(bool onlyIfCheap)
{
if (_source is IIListProvider<TElement> listProv)
......@@ -75,15 +75,9 @@ internal TElement[] ToArray(int minIdx, int maxIdx)
return new TElement[] { GetEnumerableSorter().ElementAt(buffer._items, count, minIdx) };
}
int[] map = SortedMap(buffer, minIdx, maxIdx);
TElement[] array = new TElement[maxIdx - minIdx + 1];
int idx = 0;
while (minIdx <= maxIdx)
{
array[idx] = buffer._items[map[minIdx]];
++idx;
++minIdx;
}
Fill(minIdx, maxIdx, buffer, array);
return array;
}
......@@ -107,15 +101,21 @@ internal List<TElement> ToList(int minIdx, int maxIdx)
return new List<TElement>(1) { GetEnumerableSorter().ElementAt(buffer._items, count, minIdx) };
}
int[] map = SortedMap(buffer, minIdx, maxIdx);
List<TElement> list = new List<TElement>(maxIdx - minIdx + 1);
Fill(minIdx, maxIdx, buffer, Enumerable.SetCountAndGetSpan(list, maxIdx - minIdx + 1));
return list;
}
private void Fill(int minIdx, int maxIdx, Buffer<TElement> buffer, Span<TElement> destination)
{
int[] map = SortedMap(buffer, minIdx, maxIdx);
int idx = 0;
while (minIdx <= maxIdx)
{
list.Add(buffer._items[map[minIdx]]);
destination[idx] = buffer._items[map[minIdx]];
++idx;
++minIdx;
}
return list;
}
internal int GetCount(int minIdx, int maxIdx, bool onlyIfCheap)
......
......@@ -254,11 +254,7 @@ public TSource[] ToArray()
}
TSource[] array = new TSource[count];
for (int i = 0, curIdx = _minIndexInclusive; i < array.Length; ++i, ++curIdx)
{
array[i] = _source[curIdx];
}
Fill(_source, array, _minIndexInclusive);
return array;
}
......@@ -271,13 +267,16 @@ public List<TSource> ToList()
}
List<TSource> list = new List<TSource>(count);
int end = _minIndexInclusive + count;
for (int i = _minIndexInclusive; i != end; ++i)
Fill(_source, SetCountAndGetSpan(list, count), _minIndexInclusive);
return list;
}
private static void Fill(IList<TSource> source, Span<TSource> destination, int sourceIndex)
{
for (int i = 0; i < destination.Length; i++, sourceIndex++)
{
list.Add(_source[i]);
destination[i] = source[sourceIndex];
}
return list;
}
public int GetCount(bool onlyIfCheap) => Count;
......
......@@ -17,25 +17,23 @@ public override IEnumerable<TResult> Select<TResult>(Func<int, TResult> selector
public int[] ToArray()
{
int[] array = new int[_end - _start];
int cur = _start;
for (int i = 0; i < array.Length; ++i)
{
array[i] = cur;
++cur;
}
Fill(array, _start);
return array;
}
public List<int> ToList()
{
List<int> list = new List<int>(_end - _start);
for (int cur = _start; cur != _end; cur++)
Fill(SetCountAndGetSpan(list, _end - _start), _start);
return list;
}
private static void Fill(Span<int> destination, int value)
{
for (int i = 0; i < destination.Length; i++, value++)
{
list.Add(cur);
destination[i] = value;
}
return list;
}
public int GetCount(bool onlyIfCheap) => unchecked(_end - _start);
......
......@@ -28,10 +28,7 @@ public TResult[] ToArray()
public List<TResult> ToList()
{
List<TResult> list = new List<TResult>(_count);
for (int i = 0; i != _count; ++i)
{
list.Add(_current);
}
SetCountAndGetSpan(list, _count).Fill(_current);
return list;
}
......
......@@ -3,7 +3,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using static System.Linq.Utilities;
namespace System.Linq
......@@ -75,13 +75,11 @@ public TResult[] ToArray()
{
// See assert in constructor.
// Since _source should never be empty, we don't check for 0/return Array.Empty.
Debug.Assert(_source.Length > 0);
TSource[] source = _source;
Debug.Assert(source.Length > 0);
var results = new TResult[_source.Length];
for (int i = 0; i < results.Length; i++)
{
results[i] = _selector(_source[i]);
}
var results = new TResult[source.Length];
Fill(source, results, _selector);
return results;
}
......@@ -89,15 +87,22 @@ public TResult[] ToArray()
public List<TResult> ToList()
{
TSource[] source = _source;
Debug.Assert(source.Length > 0);
var results = new List<TResult>(source.Length);
for (int i = 0; i < source.Length; i++)
{
results.Add(_selector(source[i]));
}
Fill(source, SetCountAndGetSpan(results, source.Length), _selector);
return results;
}
private static void Fill(ReadOnlySpan<TSource> source, Span<TResult> destination, Func<TSource, TResult> func)
{
for (int i = 0; i < destination.Length; i++)
{
destination[i] = func(source[i]);
}
}
public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
......@@ -202,11 +207,7 @@ public override bool MoveNext()
public TResult[] ToArray()
{
var results = new TResult[_end - _start];
int srcIndex = _start;
for (int i = 0; i < results.Length; i++)
{
results[i] = _selector(srcIndex++);
}
Fill(results, _start, _selector);
return results;
}
......@@ -214,14 +215,19 @@ public TResult[] ToArray()
public List<TResult> ToList()
{
var results = new List<TResult>(_end - _start);
for (int i = _start; i != _end; i++)
{
results.Add(_selector(i));
}
Fill(SetCountAndGetSpan(results, _end - _start), _start, _selector);
return results;
}
private static void Fill(Span<TResult> results, int start, Func<int, TResult> func)
{
for (int i = 0; i < results.Length; i++, start++)
{
results[i] = func(start);
}
}
public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of the selector,
......@@ -292,33 +298,36 @@ private sealed partial class SelectListIterator<TSource, TResult> : IPartition<T
{
public TResult[] ToArray()
{
int count = _source.Count;
if (count == 0)
ReadOnlySpan<TSource> source = CollectionsMarshal.AsSpan(_source);
if (source.Length == 0)
{
return Array.Empty<TResult>();
}
var results = new TResult[count];
for (int i = 0; i < results.Length; i++)
{
results[i] = _selector(_source[i]);
}
var results = new TResult[source.Length];
Fill(source, results, _selector);
return results;
}
public List<TResult> ToList()
{
int count = _source.Count;
var results = new List<TResult>(count);
for (int i = 0; i < count; i++)
{
results.Add(_selector(_source[i]));
}
ReadOnlySpan<TSource> source = CollectionsMarshal.AsSpan(_source);
var results = new List<TResult>(source.Length);
Fill(source, SetCountAndGetSpan(results, source.Length), _selector);
return results;
}
private static void Fill(ReadOnlySpan<TSource> source, Span<TResult> destination, Func<TSource, TResult> func)
{
for (int i = 0; i < destination.Length; i++)
{
destination[i] = func(source[i]);
}
}
public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
......@@ -398,26 +407,30 @@ public TResult[] ToArray()
}
var results = new TResult[count];
for (int i = 0; i < results.Length; i++)
{
results[i] = _selector(_source[i]);
}
Fill(_source, results, _selector);
return results;
}
public List<TResult> ToList()
{
IList<TSource> source = _source;
int count = _source.Count;
var results = new List<TResult>(count);
for (int i = 0; i < count; i++)
{
results.Add(_selector(_source[i]));
}
Fill(source, SetCountAndGetSpan(results, count), _selector);
return results;
}
private static void Fill(IList<TSource> source, Span<TResult> results, Func<TSource, TResult> func)
{
for (int i = 0; i < results.Length; i++)
{
results[i] = func(source[i]);
}
}
public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
......@@ -789,10 +802,7 @@ public TResult[] ToArray()
}
TResult[] array = new TResult[count];
for (int i = 0, curIdx = _minIndexInclusive; i < array.Length; ++i, ++curIdx)
{
array[i] = _selector(_source[curIdx]);
}
Fill(_source, array, _selector, _minIndexInclusive);
return array;
}
......@@ -806,15 +816,19 @@ public List<TResult> ToList()
}
List<TResult> list = new List<TResult>(count);
int end = _minIndexInclusive + count;
for (int i = _minIndexInclusive; i != end; ++i)
{
list.Add(_selector(_source[i]));
}
Fill(_source, SetCountAndGetSpan(list, count), _selector, _minIndexInclusive);
return list;
}
private static void Fill(IList<TSource> source, Span<TResult> destination, Func<TSource, TResult> func, int sourceIndex)
{
for (int i = 0; i < destination.Length; i++, sourceIndex++)
{
destination[i] = func(source[sourceIndex]);
}
}
public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
......
......@@ -509,7 +509,7 @@ private class IntAsObject
[Fact]
public void ListSetCount()
{
List<int> list = null!;
List<int> list = null;
Assert.Throws<NullReferenceException>(() => CollectionsMarshal.SetCount(list, 3));
Assert.Throws<ArgumentOutOfRangeException>(() => CollectionsMarshal.SetCount(list, -1));
......@@ -522,45 +522,38 @@ public void ListSetCount()
list = new() { 1, 2, 3, 4, 5 };
ref int intRef = ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list));
// make sure that size decrease preserves content
CollectionsMarshal.SetCount(list, 3);
Assert.Equal(3, list.Count);
Assert.Throws<ArgumentOutOfRangeException>(() => list[3]);
SequenceEquals<int>(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3 });
AssertExtensions.SequenceEqual(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3 });
Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
// make sure that size increase preserves content and doesn't clear
CollectionsMarshal.SetCount(list, 5);
SequenceEquals<int>(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3, 4, 5 });
AssertExtensions.SequenceEqual(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3, 4, 5 });
Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
// make sure that reallocations preserve content
int newCount = list.Capacity * 2;
CollectionsMarshal.SetCount(list, newCount);
Assert.Equal(newCount, list.Count);
SequenceEquals<int>(CollectionsMarshal.AsSpan(list)[..3], new int[] { 1, 2, 3 });
AssertExtensions.SequenceEqual(CollectionsMarshal.AsSpan(list)[..3], new int[] { 1, 2, 3 });
Assert.True(!Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
List<string> listReference = new() { "a", "b", "c", "d", "e" };
ref string stringRef = ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference));
CollectionsMarshal.SetCount(listReference, 3);
// verify that reference types aren't cleared
SequenceEquals<string>(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c" });
AssertExtensions.SequenceEqual(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c" });
Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference))));
CollectionsMarshal.SetCount(listReference, 5);
// verify that removed reference types are cleared
SequenceEquals<string>(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c", null, null });
AssertExtensions.SequenceEqual(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c", null, null });
Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference))));
static void SequenceEquals<T>(ReadOnlySpan<T> actual, ReadOnlySpan<T> expected)
{
Assert.Equal(actual.Length, expected.Length);
for (int i = 0; i < actual.Length; i++)
{
Assert.Equal(actual[i], expected[i]);
}
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册