/** * \file dnn/src/rocm/argsort/argsort.cpp.hip * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "hcc_detail/hcc_defs_prologue.h" #include "src/rocm/utils.h.hip" #include "./argsort.h.hip" #include "./bitonic_sort.h.hip" #include "megdnn/basic_types.h" #include "hipcub/device/device_radix_sort.hpp" #include "hipcub/device/device_segmented_radix_sort.hpp" using namespace megdnn; using namespace rocm; namespace { struct StridedOffsetIterator { int bias, stride; StridedOffsetIterator(int bias_, int stride_) : bias(bias_), stride(stride_) {} __device__ __forceinline__ int operator[](int i) const { return stride * i + bias; } }; bool use_bitonic(uint32_t /*M*/, uint32_t N) { // bitonic sort is preferred when N is small (alwyas faster than radix sort) return N <= BITONIC_SORT_MAX_LENGTH; } bool use_segmented(uint32_t M, uint32_t /*N*/) { // an empirical value: // sort(1, 1e6): 0.574ms // segsort({1,2,8,16}, 1e6): 7-8ms // sort(1, 1e7): 3.425ms // segsort({1,2,8,16}, 1e7): 71-84ms // // segsort is about 7x-10x slower than sort on small batches, so we can // expect it to be faster than sort when batch is large enough. return M >= 8; } __global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; if (i < n) { dst[i] = i % mod; } } template size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { if (use_bitonic(M, N)) { return 0; } return argsort::cub_sort_pairs(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, M, N, 0, sizeof(float)*8, NULL); } } // anonymous namespace template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( bool is_ascending, void* workspace, size_t workspace_size, const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream){ hipError_t err; if (use_segmented(M, N)) { if (is_ascending) { err = hipcub::DeviceSegmentedRadixSort::SortPairs( workspace, workspace_size, keys_in, keys_out, values_in, values_out, N * M, M, StridedOffsetIterator(0, N), StridedOffsetIterator(N, N), begin_bit, end_bit, stream); hip_check(err); } else { err = hipcub::DeviceSegmentedRadixSort::SortPairsDescending( workspace, workspace_size, keys_in, keys_out, values_in, values_out, N * M, M, StridedOffsetIterator(0, N), StridedOffsetIterator(N, N), begin_bit, end_bit, stream); hip_check(err); } } else { if (is_ascending) { for (size_t i = 0; i < M; ++i) { err = hipcub::DeviceRadixSort::SortPairs( workspace, workspace_size, keys_in + N * i, keys_out + N * i, values_in + N * i, values_out + N * i, N, begin_bit, end_bit, stream); hip_check(err); if (!keys_in) { return workspace_size; } } } else { for (size_t i = 0; i < M; ++i) { err = hipcub::DeviceRadixSort::SortPairsDescending( workspace, workspace_size, keys_in + N * i, keys_out + N * i, values_in + N * i, values_out + N * i, N, begin_bit, end_bit, stream); hip_check(err); if (!keys_in) { return workspace_size; } } } } return workspace_size; } size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, bool is_ascending, bool iptr_src_given) { size_t size = 0; switch (dtype.enumv().ev) { #define cb(ctype) \ case DTypeTrait::enumv: \ size = get_sort_workspace(M, N, is_ascending); \ break; ARGSORT_FOREACH_CTYPE(cb) #undef cb default: megdnn_throw("argsort only supports float, int32 and float16"); } if (!iptr_src_given) { size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); } return size; } template void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, uint32_t M, uint32_t N, bool is_ascending, hipStream_t stream, const int* iptr_src) { size_t wk_size = get_sort_workspace(M, N, is_ascending); if (!iptr_src) { int* ptr = reinterpret_cast(static_cast(workspace) + DIVUP(wk_size, sizeof(float)) * sizeof(float)); kern_arange<<>>(ptr, M * N, N); iptr_src = ptr; } if (use_bitonic(M, N)) { hip_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending, stream)); } else { cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, iptr, M, N, 0, sizeof(float)*8, stream); } } namespace megdnn { namespace rocm { #define INST_CUB_SORT(dtype) \ template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs(bool, \ void*, size_t, const dtype*, dtype*, \ const dtype*, dtype*, uint32_t, uint32_t,\ int, int, hipStream_t); #define INST_FORWARD(dtype) \ template void argsort::forward(const dtype*, dtype*, int*, void*, \ uint32_t, uint32_t, bool, hipStream_t, \ const int*); ARGSORT_FOREACH_CTYPE(INST_FORWARD) INST_CUB_SORT(uint32_t) // INST_CUB_SORT(uint64_t) #undef INST_CUB_SORT #undef INST_FORWARD } } // namespace megdnn // vim: ft=rocm syntax=rocm.doxygen