提交 568ef788 编写于 作者: E Evgenii Pravda

Add MSD radix sort

上级 5ebb192f
...@@ -136,7 +136,7 @@ class QuantileTDigest ...@@ -136,7 +136,7 @@ class QuantileTDigest
{ {
if (unmerged > 0) if (unmerged > 0)
{ {
RadixSort<RadixSortTraits>::execute(summary.data(), summary.size()); RadixSort<RadixSortTraits>::executeLsd(summary.data(), summary.size());
if (summary.size() > 3) if (summary.size() > 3)
{ {
......
...@@ -120,7 +120,7 @@ void ColumnVector<T>::getPermutation(bool reverse, size_t limit, int nan_directi ...@@ -120,7 +120,7 @@ void ColumnVector<T>::getPermutation(bool reverse, size_t limit, int nan_directi
for (UInt32 i = 0; i < s; ++i) for (UInt32 i = 0; i < s; ++i)
pairs[i] = {data[i], i}; pairs[i] = {data[i], i};
RadixSort<RadixSortTraits<T>>::execute(pairs.data(), s); RadixSort<RadixSortTraits<T>>::executeLsd(pairs.data(), s);
/// Radix sort treats all NaNs to be greater than all numbers. /// Radix sort treats all NaNs to be greater than all numbers.
/// If the user needs the opposite, we must move them accordingly. /// If the user needs the opposite, we must move them accordingly.
......
...@@ -43,61 +43,36 @@ struct RadixSortMallocAllocator ...@@ -43,61 +43,36 @@ struct RadixSortMallocAllocator
}; };
/** A transformation that transforms the bit representation of a key into an unsigned integer number,
* that the order relation over the keys will match the order relation over the obtained unsigned numbers.
* For floats this conversion does the following:
* if the signed bit is set, it flips all other bits.
* In this case, NaN-s are bigger than all normal numbers.
*/
template <typename KeyBits> template <typename KeyBits>
struct RadixSortFloatTransform struct RadixSortIdentityTransform
{ {
/// Is it worth writing the result in memory, or is it better to do calculation every time again? static constexpr bool transform_is_simple = true;
static constexpr bool transform_is_simple = false;
static KeyBits forward(KeyBits x)
{
return x ^ ((-(x >> (sizeof(KeyBits) * 8 - 1))) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)));
}
static KeyBits backward(KeyBits x) static KeyBits forward(KeyBits x) { return x; }
{ static KeyBits backward(KeyBits x) { return x; }
return x ^ (((x >> (sizeof(KeyBits) * 8 - 1)) - 1) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)));
}
}; };
template <typename TElement> template <typename TElement>
struct RadixSortFloatTraits struct RadixSortUIntTraits
{ {
using Element = TElement; /// The type of the element. It can be a structure with a key and some other payload. Or just a key. using Element = TElement;
using Key = Element; /// The key to sort by. using Key = Element;
using CountType = uint32_t; /// Type for calculating histograms. In the case of a known small number of elements, it can be less than size_t. using CountType = uint32_t;
using KeyBits = Key;
/// The type to which the key is transformed to do bit operations. This UInt is the same size as the key.
using KeyBits = std::conditional_t<sizeof(Key) == 8, uint64_t, uint32_t>;
static constexpr size_t PART_SIZE_BITS = 8; /// With what pieces of the key, in bits, to do one pass - reshuffle of the array.
/// Converting a key into KeyBits is such that the order relation over the key corresponds to the order relation over KeyBits. static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortFloatTransform<KeyBits>;
/// An object with the functions allocate and deallocate. using Transform = RadixSortIdentityTransform<KeyBits>;
/// Can be used, for example, to allocate memory for a temporary array on the stack.
/// To do this, the allocator itself is created on the stack.
using Allocator = RadixSortMallocAllocator; using Allocator = RadixSortMallocAllocator;
/// The function to get the key from an array element.
static Key & extractKey(Element & elem) { return elem; } static Key & extractKey(Element & elem) { return elem; }
};
template <typename KeyBits>
struct RadixSortIdentityTransform
{
static constexpr bool transform_is_simple = true;
static KeyBits forward(KeyBits x) { return x; } static bool compare(TElement x, TElement y)
static KeyBits backward(KeyBits x) { return x; } {
return x < y;
}
}; };
...@@ -110,38 +85,80 @@ struct RadixSortSignedTransform ...@@ -110,38 +85,80 @@ struct RadixSortSignedTransform
static KeyBits backward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); } static KeyBits backward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
}; };
template <typename TElement> template <typename TElement>
struct RadixSortUIntTraits struct RadixSortIntTraits
{ {
using Element = TElement; using Element = TElement;
using Key = Element; using Key = Element;
using CountType = uint32_t; using CountType = uint32_t;
using KeyBits = Key; using KeyBits = std::make_unsigned_t<Key>;
static constexpr size_t PART_SIZE_BITS = 8; static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortIdentityTransform<KeyBits>; using Transform = RadixSortSignedTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator; using Allocator = RadixSortMallocAllocator;
static Key & extractKey(Element & elem) { return elem; } static Key & extractKey(Element & elem) { return elem; }
static bool compare(TElement x, TElement y)
{
return x < y;
}
};
/** A transformation that transforms the bit representation of a key into an unsigned integer number,
* that the order relation over the keys will match the order relation over the obtained unsigned numbers.
* For floats this conversion does the following:
* if the signed bit is set, it flips all other bits.
* In this case, NaN-s are bigger than all normal numbers.
*/
template <typename KeyBits>
struct RadixSortFloatTransform
{
/// Is it worth writing the result in memory, or is it better to do calculation every time again?
static constexpr bool transform_is_simple = false;
static KeyBits forward(KeyBits x)
{
return x ^ ((-(x >> (sizeof(KeyBits) * 8 - 1))) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)));
}
static KeyBits backward(KeyBits x)
{
return x ^ (((x >> (sizeof(KeyBits) * 8 - 1)) - 1) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)));
}
}; };
template <typename TElement> template <typename TElement>
struct RadixSortIntTraits struct RadixSortFloatTraits
{ {
using Element = TElement; using Element = TElement; /// The type of the element. It can be a structure with a key and some other payload. Or just a key.
using Key = Element; using Key = Element; /// The key to sort by.
using CountType = uint32_t; using CountType = uint32_t; /// Type for calculating histograms. In the case of a known small number of elements, it can be less than size_t.
using KeyBits = std::make_unsigned_t<Key>;
static constexpr size_t PART_SIZE_BITS = 8; /// The type to which the key is transformed to do bit operations. This UInt is the same size as the key.
using KeyBits = std::conditional_t<sizeof(Key) == 8, uint64_t, uint32_t>;
using Transform = RadixSortSignedTransform<KeyBits>; static constexpr size_t PART_SIZE_BITS = 8; /// With what pieces of the key, in bits, to do one pass - reshuffle of the array.
/// Converting a key into KeyBits is such that the order relation over the key corresponds to the order relation over KeyBits.
using Transform = RadixSortFloatTransform<KeyBits>;
/// An object with the functions allocate and deallocate.
/// Can be used, for example, to allocate memory for a temporary array on the stack.
/// To do this, the allocator itself is created on the stack.
using Allocator = RadixSortMallocAllocator; using Allocator = RadixSortMallocAllocator;
/// The function to get the key from an array element.
static Key & extractKey(Element & elem) { return elem; } static Key & extractKey(Element & elem) { return elem; }
// TODO: Correct handling of NaNs, NULLs, etc
static bool compare(TElement x, TElement y)
{
return x < y;
}
}; };
...@@ -163,6 +180,8 @@ private: ...@@ -163,6 +180,8 @@ private:
using CountType = typename Traits::CountType; using CountType = typename Traits::CountType;
using KeyBits = typename Traits::KeyBits; using KeyBits = typename Traits::KeyBits;
static constexpr size_t INSERT_SORT_THRESHOLD = 64;
static constexpr size_t HISTOGRAM_SIZE = 1 << Traits::PART_SIZE_BITS; static constexpr size_t HISTOGRAM_SIZE = 1 << Traits::PART_SIZE_BITS;
static constexpr size_t PART_BITMASK = HISTOGRAM_SIZE - 1; static constexpr size_t PART_BITMASK = HISTOGRAM_SIZE - 1;
static constexpr size_t KEY_BITS = sizeof(Key) * 8; static constexpr size_t KEY_BITS = sizeof(Key) * 8;
...@@ -179,8 +198,101 @@ private: ...@@ -179,8 +198,101 @@ private:
static KeyBits keyToBits(Key x) { return ext::bit_cast<KeyBits>(x); } static KeyBits keyToBits(Key x) { return ext::bit_cast<KeyBits>(x); }
static Key bitsToKey(KeyBits x) { return ext::bit_cast<Key>(x); } static Key bitsToKey(KeyBits x) { return ext::bit_cast<Key>(x); }
static inline void insertSortInternal(Element * arr, size_t size)
{
for (Element * i = arr + 1; i < arr + size; ++i)
{
if (Traits::compare(*i, *(i - 1)))
{
Element * j;
Element tmp = *i;
*i = *(i - 1);
for (j = i - 1; j > arr && Traits::compare(tmp, *(j - 1)); --j)
*j = *(j - 1);
*j = tmp;
}
}
}
template <int PASS>
static inline void msdRadixSortInternal(Element * arr, size_t size, size_t limit)
{
Element *last_[HISTOGRAM_SIZE + 1];
Element ** last = last_ + 1;
size_t count[HISTOGRAM_SIZE] = {0};
for (Element * i = arr; i < arr + size; ++i)
++count[getPart(PASS, *i)];
last_[0] = last_[1] = arr;
size_t bucketsForRecursion = HISTOGRAM_SIZE;
Element * finish = arr + size;
for (size_t i = 1; i < HISTOGRAM_SIZE; ++i)
{
last[i] = last[i - 1] + count[i - 1];
if (last[i] >= arr + limit)
{
bucketsForRecursion = i;
finish = last[i];
}
}
for (size_t i = 0; i < bucketsForRecursion; ++i)
{
Element * end = last[i - 1] + count[i];
if (end == finish)
{
last[i] = end;
break;
}
while (last[i] != end)
{
Element swapper = *last[i];
KeyBits tag = getPart(PASS, swapper);
if (tag != i)
{
do
{
std::swap(swapper, *last[tag]++);
} while ((tag = getPart(PASS, swapper)) != i);
*last[i] = swapper;
}
++last[i];
}
}
if constexpr (PASS > 0)
{
for (size_t i = 0; i < bucketsForRecursion - 1; ++i)
{
Element * start = last[i - 1];
size_t size = last[i] - last[i - 1];
msdRadixSortInternalHelper<PASS - 1>(start, size, size);
}
// Sort last necessary bucket with sublimit
Element * start = last[bucketsForRecursion - 2];
size_t size = last[bucketsForRecursion - 1] - last[bucketsForRecursion - 2];
size_t sublimit = limit - (last[bucketsForRecursion - 1] - arr);
msdRadixSortInternalHelper<PASS - 1>(start, size, sublimit);
}
}
template <int PASS>
static inline void msdRadixSortInternalHelper(Element * arr, size_t size, size_t limit)
{
if (size <= INSERT_SORT_THRESHOLD)
insertSortInternal(arr, size);
else
msdRadixSortInternal<PASS>(arr, size, limit);
}
public: public:
static void execute(Element * arr, size_t size) /* Least significant digit radix sort
* The most efficient stable general-purpose sorting algorithm
*/
static void executeLsd(Element * arr, size_t size)
{ {
/// If the array is smaller than 256, then it is better to use another algorithm. /// If the array is smaller than 256, then it is better to use another algorithm.
...@@ -247,6 +359,16 @@ public: ...@@ -247,6 +359,16 @@ public:
allocator.deallocate(swap_buffer, size * sizeof(Element)); allocator.deallocate(swap_buffer, size * sizeof(Element));
} }
/* Most significant digit radix sort
* Usually slower than LSD and is not stable, but allows partial sorting
* Based on https://github.com/voutcn/kxsort
*/
static void executeMsd(Element * arr, size_t size, size_t limit)
{
limit = std::min(limit, size);
msdRadixSortInternalHelper<NUM_PASSES - 1>(arr, size, limit);
}
}; };
...@@ -254,7 +376,13 @@ public: ...@@ -254,7 +376,13 @@ public:
/// Use RadixSort with custom traits for complex types instead. /// Use RadixSort with custom traits for complex types instead.
template <typename T> template <typename T>
void radixSort(T * arr, size_t size) void lsdRadixSort(T *arr, size_t size)
{
RadixSort<RadixSortNumTraits<T>>::executeLsd(arr, size);
}
template <typename T>
void msdRadixSort(T * arr, size_t size, size_t limit)
{ {
RadixSort<RadixSortNumTraits<T>>::execute(arr, size); RadixSort<RadixSortNumTraits<T>>::executeMsd(arr, size, limit);
} }
...@@ -16,7 +16,7 @@ void NO_INLINE sort1(Key * data, size_t size) ...@@ -16,7 +16,7 @@ void NO_INLINE sort1(Key * data, size_t size)
void NO_INLINE sort2(Key * data, size_t size) void NO_INLINE sort2(Key * data, size_t size)
{ {
radixSort(data, size); lsdRadixSort(data, size);
} }
void NO_INLINE sort3(Key * data, size_t size) void NO_INLINE sort3(Key * data, size_t size)
......
...@@ -92,7 +92,7 @@ private: ...@@ -92,7 +92,7 @@ private:
{ {
/// TODO: It has been tested only for UInt32 yet. It needs to check UInt64, Float32/64. /// TODO: It has been tested only for UInt32 yet. It needs to check UInt64, Float32/64.
if constexpr (std::is_same_v<TKey, UInt32>) if constexpr (std::is_same_v<TKey, UInt32>)
RadixSort<RadixSortTraits>::execute(&array[0], array.size()); RadixSort<RadixSortTraits>::executeLsd(&array[0], array.size());
else else
std::sort(array.begin(), array.end()); std::sort(array.begin(), array.end());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册