diff --git a/dnn/src/rocm/argsort/argsort.cpp.hip b/dnn/src/rocm/argsort/argsort.cpp.hip new file mode 100644 index 0000000000000000000000000000000000000000..6ca67a1b0aa4ebe8208cccf17df0d9cfc5cb20af --- /dev/null +++ b/dnn/src/rocm/argsort/argsort.cpp.hip @@ -0,0 +1,183 @@ +/** + * \file dnn/src/rocm/argsort/argsort.cpp.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "src/rocm/utils.h.hip" +#include "./argsort.h.hip" +#include "./bitonic_sort.h.hip" +#include "megdnn/basic_types.h" + +#include "hipcub/device/device_radix_sort.hpp" +#include "hipcub/device/device_segmented_radix_sort.hpp" + +using namespace megdnn; +using namespace rocm; + +namespace { +struct StridedOffsetIterator { + int bias, stride; + + StridedOffsetIterator(int bias_, int stride_) + : bias(bias_), stride(stride_) {} + + __device__ __forceinline__ int operator[](int i) const { + return stride * i + bias; + } +}; + +bool use_bitonic(uint32_t /*M*/, uint32_t N) { + // bitonic sort is preferred when N is small (alwyas faster than radix sort) + return N <= BITONIC_SORT_MAX_LENGTH; +} + +bool use_segmented(uint32_t M, uint32_t /*N*/) { + // an empirical value: + // sort(1, 1e6): 0.574ms + // segsort({1,2,8,16}, 1e6): 7-8ms + // sort(1, 1e7): 3.425ms + // segsort({1,2,8,16}, 1e7): 71-84ms + // + // segsort is about 7x-10x slower than sort on small batches, so we can + // expect it to be faster than sort when batch is large enough. + return M >= 8; +} + +__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { + uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < n) { + dst[i] = i % mod; + } +} + +template +size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { + if (use_bitonic(M, N)) { + return 0; + } + return argsort::cub_sort_pairs(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, + M, N, 0, sizeof(float)*8, NULL); +} +} // anonymous namespace + +template +MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( + bool is_ascending, void* workspace, size_t workspace_size, + const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, + ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream){ + hipError_t err; + if (use_segmented(M, N)) { + if (is_ascending) { + err = hipcub::DeviceSegmentedRadixSort::SortPairs( + workspace, workspace_size, keys_in, keys_out, values_in, + values_out, N * M, M, StridedOffsetIterator(0, N), + StridedOffsetIterator(N, N), begin_bit, end_bit, stream); + hip_check(err); + } else { + err = hipcub::DeviceSegmentedRadixSort::SortPairsDescending( + workspace, workspace_size, keys_in, keys_out, values_in, + values_out, N * M, M, StridedOffsetIterator(0, N), + StridedOffsetIterator(N, N), begin_bit, end_bit, stream); + hip_check(err); + } + } else { + if (is_ascending) { + for (size_t i = 0; i < M; ++i) { + err = hipcub::DeviceRadixSort::SortPairs( + workspace, workspace_size, keys_in + N * i, + keys_out + N * i, values_in + N * i, values_out + N * i, + N, begin_bit, end_bit, stream); + hip_check(err); + if (!keys_in) { + return workspace_size; + } + } + } else { + for (size_t i = 0; i < M; ++i) { + err = hipcub::DeviceRadixSort::SortPairsDescending( + workspace, workspace_size, keys_in + N * i, + keys_out + N * i, values_in + N * i, values_out + N * i, + N, begin_bit, end_bit, stream); + hip_check(err); + if (!keys_in) { + return workspace_size; + } + } + } + } + return workspace_size; +} + +size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, + bool is_ascending, + bool iptr_src_given) { + size_t size = 0; + switch (dtype.enumv().ev) { +#define cb(ctype) \ + case DTypeTrait::enumv: \ + size = get_sort_workspace(M, N, is_ascending); \ + break; + ARGSORT_FOREACH_CTYPE(cb) +#undef cb + default: + megdnn_throw("argsort only supports float, int32 and float16"); + } + if (!iptr_src_given) { + size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); + } + return size; +} + +template +void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr, + void* workspace, uint32_t M, uint32_t N, + bool is_ascending, hipStream_t stream, + const int* iptr_src) { + size_t wk_size = get_sort_workspace(M, N, is_ascending); + if (!iptr_src) { + int* ptr = reinterpret_cast(static_cast(workspace) + + DIVUP(wk_size, sizeof(float)) * + sizeof(float)); + kern_arange<<>>(ptr, M * N, N); + iptr_src = ptr; + } + + if (use_bitonic(M, N)) { + hip_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending, + stream)); + } else { + cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, + iptr, M, N, 0, sizeof(float)*8, stream); + } +} + +namespace megdnn { +namespace rocm { + +#define INST_CUB_SORT(dtype) \ +template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs(bool, \ + void*, size_t, const dtype*, dtype*, \ + const dtype*, dtype*, uint32_t, uint32_t,\ + int, int, hipStream_t); + +#define INST_FORWARD(dtype) \ +template void argsort::forward(const dtype*, dtype*, int*, void*, \ + uint32_t, uint32_t, bool, hipStream_t, \ + const int*); + +ARGSORT_FOREACH_CTYPE(INST_FORWARD) +INST_CUB_SORT(uint32_t) +// INST_CUB_SORT(uint64_t) +#undef INST_CUB_SORT +#undef INST_FORWARD +} +} // namespace megdnn +// vim: ft=rocm syntax=rocm.doxygen + diff --git a/dnn/src/rocm/argsort/argsort.h.hip b/dnn/src/rocm/argsort/argsort.h.hip new file mode 100644 index 0000000000000000000000000000000000000000..f9ca27cca214736e986161702fbe1ac7ee58cd25 --- /dev/null +++ b/dnn/src/rocm/argsort/argsort.h.hip @@ -0,0 +1,50 @@ +/** + * \file dnn/src/rocm/argsort/argsort.h.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "hcc_detail/hcc_defs_prologue.h" +#include "hip_header.h" +#include +#include +#include "megdnn/dtype.h" + +namespace megdnn { +namespace rocm { +namespace argsort { + +size_t get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, + bool is_ascending, + bool iptr_src_given = false); + +template +size_t cub_sort_pairs( + bool is_ascending, void* workspace, size_t workspace_size, + const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, + ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream); + +/*! + * \param iptr_src pointer to indices; a range would be generated if it is null + */ +template +void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, + uint32_t M, uint32_t N, bool is_ascending, hipStream_t stream, + const int* iptr_src = NULL); + +//! iterate over all supported data types +#define ARGSORT_FOREACH_CTYPE(cb) \ + cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) + +} // namespace argsort +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/rocm/argsort/backward.cpp.hip b/dnn/src/rocm/argsort/backward.cpp.hip new file mode 100644 index 0000000000000000000000000000000000000000..d7befd58420eca24dcdc1a0e969af02809cc18cc --- /dev/null +++ b/dnn/src/rocm/argsort/backward.cpp.hip @@ -0,0 +1,67 @@ +/** + * \file dnn/src/rocm/argsort/backward.cpp.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "src/rocm/utils.h.hip" +#include "./argsort.h.hip" +#include "./backward.h.hip" + +// #include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; +using namespace argsort; + +namespace { + +template +__global__ void backward_kernel(uint32_t dst_w, uint32_t src_w, + uint32_t src_size, T* dst, const T* src_data, + const int* src_idx) { + uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < src_size) { + uint32_t r = idx / src_w; + dst[r * dst_w + src_idx[idx]] = src_data[idx]; + } +} + +} // namespace + +template +void argsort::backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, + T* dst, const T* src_data, const int* src_idx, + hipStream_t stream) { + if (dst_w != src_w) { + hipMemsetAsync(dst, 0, dst_h * dst_w * sizeof(T), stream); + } + + uint32_t src_size = dst_h * src_w; + backward_kernel<<>>( + dst_w, src_w, src_size, dst, src_data, src_idx); + after_kernel_launch(); +} + +namespace megdnn { +namespace rocm { +namespace argsort { + +#define INST(T) \ + template void backward_proxy(uint32_t dst_h, uint32_t dst_w, \ + uint32_t src_w, T* dst, const T* src_data, \ + const int* src_idx, hipStream_t stream); +ARGSORT_FOREACH_CTYPE(INST) +#undef INST + +} // namespace argsort +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/argsort/backward.h.hip b/dnn/src/rocm/argsort/backward.h.hip new file mode 100644 index 0000000000000000000000000000000000000000..34403bc95fdc761572680efa399e60d38681ba30 --- /dev/null +++ b/dnn/src/rocm/argsort/backward.h.hip @@ -0,0 +1,29 @@ +/** + * \file dnn/src/rocm/argsort/backward.h.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "hip_header.h" +#include + +namespace megdnn { +namespace rocm { +namespace argsort { + +template +void backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, T* dst, + const T* src_data, const int* src_idx, hipStream_t stream); + +} // namespace argsort +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/argsort/bitonic_sort.cpp.hip b/dnn/src/rocm/argsort/bitonic_sort.cpp.hip new file mode 100644 index 0000000000000000000000000000000000000000..3f93d44ed5c42777be2191917f60b8eb7b85078a --- /dev/null +++ b/dnn/src/rocm/argsort/bitonic_sort.cpp.hip @@ -0,0 +1,320 @@ +/** + * \file dnn/src/rocm/argsort/bitonic_sort.cpp.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./bitonic_sort.h.hip" +// #include "src/cuda/query_blocksize.cuh" +// #include "megdnn/dtype.h" + +// #if __CUDACC_VER_MAJOR__ < 9 +// #pragma message "warp sync disabled due to insufficient cuda version" +#define __syncwarp __syncthreads +// #endif + +#include +#include + +using namespace megdnn; +using namespace rocm; + +namespace bitonic_sort_impl { + +//! load keys and init idx +template +__device__ __forceinline__ void safe_load0(T* dst, uint16_t* idx, const T* src, + uint32_t id, uint32_t size) { + dst[id] = id < size ? src[id] : CompareLess::template max(); + idx[id] = id; +} + +//! load values +template +__device__ __forceinline__ void safe_load1(T* dst, const T* src, uint32_t id, + uint32_t size) { + // broadcast last value to avoid out-of-bound values (for example, when + // input contains NaN) + dst[id] = src[min(id, size - 1)]; +} + +//! write keys +template +__device__ __forceinline__ void safe_write0(T* dst, const T* src, uint32_t id, + uint32_t size) { + if (id < size) { + dst[id] = src[id]; + } +} + +//! write values +template +__device__ __forceinline__ void safe_write1(T* dst, const T* src, + const uint16_t* remap, uint32_t id, + uint32_t size) { + if (id < size) { + dst[id] = src[remap[id]]; + } +} + +struct SyncWarp { + static __device__ __forceinline__ void s() { __syncwarp(); } +}; +struct SyncBlock { + static __device__ __forceinline__ void s() { __syncthreads(); } +}; + +template +struct NumTrait; +template <> +struct NumTrait { + static __device__ __forceinline__ float max() { return INFINITY; } + static __device__ __forceinline__ float min() { return -INFINITY; } +}; + +template <> +struct NumTrait { + static __device__ __forceinline__ int32_t max() { return INT_MAX; } + static __device__ __forceinline__ int32_t min() { return INT_MIN; } +}; + +// #if !MEGDNN_DISABLE_FLOAT16 +// template <> +// struct NumTrait { +// static __device__ __forceinline__ dt_float16 max() { +// return std::numeric_limits::max(); +// } +// static __device__ __forceinline__ dt_float16 min() { +// return std::numeric_limits::lowest(); +// } +// }; +// #endif + +struct LessThan { + template + static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, + Value v1) { + return (k0 < k1) | ((k0 == k1) & (v0 < v1)); + } + + template + static __device__ __forceinline__ T max() { + return NumTrait::max(); + } +}; + +struct GreaterThan { + template + static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, + Value v1) { + return (k0 > k1) | ((k0 == k1) & (v0 < v1)); + } + + template + static __device__ __forceinline__ T max() { + return NumTrait::min(); + } +}; + +template +union KVUnion { + Key key; + Value value; +}; + +template +static int get_shmem(int block_size, void* = NULL) { + return (sizeof(KVUnion) + sizeof(uint16_t)) * block_size * 4; +} + +/*! + * \brief batched bitonic sort (M, N) for small N + * + * launch configuration: + * grid(X) + * block(N/4, Y) + * + * where N / 4 == 1 << nr_th_log2 + */ +template +static __global__ void kern(uint32_t batch, uint32_t length, const Key* key_inp, + const Value* value_inp, Key* key_out, + Value* value_out) { + const uint32_t nr_th = 1 << nr_th_log2; + + // 24KiB shared memory for 4-byte keys for 1024 threads + extern __shared__ uint8_t smem_storage[]; + uint16_t* idx_storage = reinterpret_cast(smem_storage); + KVUnion* keys_storage = reinterpret_cast*>( + idx_storage + blockDim.y * (nr_th * 4)); + + uint32_t cur_batch = blockIdx.x * blockDim.y + threadIdx.y, + off = cur_batch * length; + key_inp += off; + key_out += off; + value_inp += off; + value_out += off; + + uint32_t storage_offset = threadIdx.y * (nr_th * 4); + uint16_t* values = idx_storage + storage_offset; + Key* keys = reinterpret_cast(keys_storage + storage_offset); + uint32_t tid0 = threadIdx.x, tid1 = tid0 + nr_th, + cur_length = cur_batch < batch ? length : 0; + safe_load0(keys, values, key_inp, tid0, cur_length); + safe_load0(keys, values, key_inp, tid0 + nr_th, cur_length); + safe_load0(keys, values, key_inp, tid0 + nr_th * 2, + cur_length); + safe_load0(keys, values, key_inp, tid0 + nr_th * 3, + cur_length); + + Sync::s(); + +#define WORK(_idx, _asc) \ + do { \ + uint32_t _id0 = (_idx), _id1 = _id0 + step; \ + Key _k0 = keys[_id0], _k1 = keys[_id1]; \ + uint16_t _v0 = values[_id0], _v1 = values[_id1]; \ + if (CompareLess::cmp(_k0, _v0, _k1, _v1) != _asc) { \ + keys[_id0] = _k1; \ + keys[_id1] = _k0; \ + values[_id0] = _v1; \ + values[_id1] = _v0; \ + } \ + } while (0) + +#pragma unroll + for (uint32_t slen_log = 0; slen_log <= (nr_th_log2 + 1); ++slen_log) { + // log2 of half of current bitonic sequence (i.e. length of its + // monotonic part) + uint32_t asc0 = !((tid0 >> slen_log) & 1), + asc1 = !((tid1 >> slen_log) & 1); +#pragma unroll + for (uint32_t j = 0; j <= slen_log; ++j) { + uint32_t step = 1 << (slen_log - j), xmask = step - 1, + ymask = ~xmask; + WORK((tid0 & xmask) + ((tid0 & ymask) << 1), asc0); + WORK((tid1 & xmask) + ((tid1 & ymask) << 1), asc1); + Sync::s(); + } + } + +#undef WORK + + if (cur_batch < batch) { + safe_write0(key_out, keys, tid0, length); + safe_write0(key_out, keys, tid0 + nr_th, length); + safe_write0(key_out, keys, tid0 + nr_th * 2, length); + safe_write0(key_out, keys, tid0 + nr_th * 3, length); + + // permute values according to sorted indices + Value* copied_values = reinterpret_cast(keys); + safe_load1(copied_values, value_inp, tid0, cur_length); + safe_load1(copied_values, value_inp, tid0 + nr_th, cur_length); + safe_load1(copied_values, value_inp, tid0 + nr_th * 2, cur_length); + safe_load1(copied_values, value_inp, tid0 + nr_th * 3, cur_length); + Sync::s(); + + safe_write1(value_out, copied_values, values, tid0, length); + safe_write1(value_out, copied_values, values, tid0 + nr_th, length); + safe_write1(value_out, copied_values, values, tid0 + nr_th * 2, length); + safe_write1(value_out, copied_values, values, tid0 + nr_th * 3, length); + } +} + +} // namespace bitonic_sort_impl + +template +hipError_t rocm::bitonic_sort(uint32_t batch, uint32_t length, + const Key* key_inp, const Value* value_inp, + Key* key_out, Value* value_out, bool ascending, + hipStream_t stream) { + using namespace bitonic_sort_impl; + if (length == 1) { + if (key_inp != key_out) { + hipMemcpyAsync(key_out, key_inp, sizeof(Key) * batch, + hipMemcpyDeviceToDevice, stream); + } + if (value_inp != value_out) { + hipMemcpyAsync(value_out, value_inp, sizeof(Value) * batch, + hipMemcpyDeviceToDevice, stream); + } + return hipGetLastError(); + } + + void (*kptr)(uint32_t, uint32_t, const Key*, const Value*, Key*, Value*) = + NULL; + uint32_t l4 = (length + 3) / 4; + dim3 block; + +#define chk(s) \ + do { \ + if (!kptr && l4 <= (1 << s)) { \ + block.x = 1 << s; \ + if ((1 << s) <= 32) { \ + if (ascending) { \ + kptr = kern; \ + } else { \ + kptr = kern; \ + } \ + } else { \ + if (ascending) { \ + kptr = kern; \ + } else { \ + kptr = kern; \ + } \ + } \ + } \ + } while (0) + + chk(0); + chk(1); + chk(2); + chk(3); + chk(4); + chk(5); + chk(6); + chk(7); + chk(8); + chk(9); + + if (!kptr) { + return hipErrorInvalidConfiguration; + } + + // TODO: this is randomly choosed + int suggested_block_size = 128; + // query_launch_config_for_kernel(reinterpret_cast(kptr), + // get_shmem) + // .block_size; + block.y = std::max(suggested_block_size / block.x, 1); + int shmem = get_shmem(block.y * block.x); + kptr<<<(batch - 1) / block.y + 1, block, shmem, stream>>>( + batch, length, key_inp, value_inp, key_out, value_out); + return hipGetLastError(); +} + +namespace megdnn { +namespace rocm { + +#define INST(k, v) \ + template hipError_t bitonic_sort(uint32_t, uint32_t, const k*, \ + const v*, k*, v*, bool, \ + hipStream_t) + +INST(float, int); +INST(int32_t, int); +// DNN_INC_FLOAT16(INST(dt_float16, int)); +#undef INST + +} // namespace megdnn +} // namespace megdnn + +// vim: ft=rocm syntax=rocm.doxygen + diff --git a/dnn/src/rocm/argsort/bitonic_sort.h.hip b/dnn/src/rocm/argsort/bitonic_sort.h.hip new file mode 100644 index 0000000000000000000000000000000000000000..8f1ab3b4fefde30283d53d80d897ac5a0a24aca1 --- /dev/null +++ b/dnn/src/rocm/argsort/bitonic_sort.h.hip @@ -0,0 +1,38 @@ +/** + * \file dnn/src/rocm/argsort/bitonic_sort.h.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "hip_header.h" +#include + +namespace megdnn { +namespace rocm { + +const uint32_t BITONIC_SORT_MAX_LENGTH = 1024; +// cub radix sort seems to be faster with lengths > 1024 + +/*! + * \brief bitonic sort for k/v pairs + * + * Requires \p length no larger than 4 times of cuda thread num. \p key_inp + * and \p key_out can be identical, and so are \p value_inp and \p value_out. + */ +template +hipError_t bitonic_sort(uint32_t batch, uint32_t length, const Key* key_inp, + const Value* value_inp, Key* key_out, Value* value_out, + bool ascending, hipStream_t stream); + +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/src/rocm/argsort/opr_impl.cpp b/dnn/src/rocm/argsort/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..944fc213119325612ebeb7d2378a91eef2c36914 --- /dev/null +++ b/dnn/src/rocm/argsort/opr_impl.cpp @@ -0,0 +1,79 @@ +/** + * \file dnn/src/rocm/argsort/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./opr_impl.h" +#include "./argsort.h.hip" +#include "./backward.h.hip" + +#include "src/common/utils.h" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +void ArgsortForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_tensor_out indices, + _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, indices.layout, workspace.size); + auto M = src.layout.shape[0], N = src.layout.shape[1]; + auto iptr = indices.ptr(); + auto wptr = static_cast(workspace.raw_ptr); + bool is_ascending = (param().order == Order::ASCENDING); + auto stream = hip_stream(handle()); + switch (src.layout.dtype.enumv()) { +#define cb(t) \ + case DTypeTrait::enumv: \ + argsort::forward(src.ptr(), dst.ptr(), iptr, wptr, M, N, \ + is_ascending, stream); \ + break; + ARGSORT_FOREACH_CTYPE(cb); +#undef cb + default: + megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s", + src.layout.dtype.name())); + } +} + +size_t ArgsortForwardImpl::get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout&, + const TensorLayout&) { + megdnn_assert(src.ndim == 2, "invalid src layout: %s", + src.to_string().c_str()); + auto M = src.shape[0], N = src.shape[1]; + auto&& dtype = src.dtype; + megdnn_assert(std::max(M, N) <= + static_cast(std::numeric_limits::max())); + return argsort::get_fwd_workspace_in_bytes( + M, N, dtype, param().order == Param::Order::ASCENDING); +} + +void ArgsortBackwardImpl::exec(_megdnn_tensor_in diff, + _megdnn_tensor_in indices, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + check_exec(diff.layout, indices.layout, grad.layout, workspace.size); + auto stream = hip_stream(handle()); + switch (diff.layout.dtype.enumv()) { +#define cb(t) \ + case DTypeTrait::enumv: \ + argsort::backward_proxy(grad.layout[0], grad.layout[1], \ + diff.layout[1], grad.ptr(), diff.ptr(), \ + indices.ptr(), stream); \ + break; + ARGSORT_FOREACH_CTYPE(cb); +#undef cb + default: + megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s", + diff.layout.dtype.name())); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/argsort/opr_impl.h b/dnn/src/rocm/argsort/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..1a2d87cc7593ca2a0725429bead93018acb10565 --- /dev/null +++ b/dnn/src/rocm/argsort/opr_impl.h @@ -0,0 +1,47 @@ +/** + * \file dnn/src/rocm/argsort/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class ArgsortForwardImpl final: public ArgsortForward { + public: + using ArgsortForward::ArgsortForward; + void exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_tensor_out indices, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &src, + const TensorLayout &dst, + const TensorLayout &indices) override; +}; + +class ArgsortBackwardImpl final: public ArgsortBackward { + public: + using ArgsortBackward::ArgsortBackward; + void exec(_megdnn_tensor_in diff, + _megdnn_tensor_in indices, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &, + const TensorLayout &, + const TensorLayout &) override { + return 0; + } +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/handle.cpp b/dnn/src/rocm/handle.cpp index b3752e76346b263690d71b798d4644bbc8590a2a..d1ec445d6b97e40cacb0bdf16e8832d82d3ae18d 100644 --- a/dnn/src/rocm/handle.cpp +++ b/dnn/src/rocm/handle.cpp @@ -33,6 +33,7 @@ #include "src/rocm/powc/opr_impl.h" #include "src/rocm/indexing_multi_axis_vec/opr_impl.h" #include "src/rocm/linspace/opr_impl.h" +#include "src/rocm/argsort/opr_impl.h" #include "src/rocm/argmxx/opr_impl.h" #include "src/rocm/sleep/opr_impl.h" #include "src/rocm/batch_normalization/opr_impl.h" @@ -148,6 +149,8 @@ bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { return src.is_contiguous() || src.stride[src.ndim - 1] == 1; } +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter); diff --git a/dnn/test/rocm/argsort.cpp b/dnn/test/rocm/argsort.cpp new file mode 100644 index 0000000000000000000000000000000000000000..57da80fa0cb8b6ce962d770791da2298e945bacc --- /dev/null +++ b/dnn/test/rocm/argsort.cpp @@ -0,0 +1,124 @@ +/** + * \file dnn/test/rocm/argsort.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/rocm/fixture.h" + +#include "test/common/checker.h" +#include "test/common/rng.h" +#include "test/common/tensor.h" + +#include "../src/rocm/argsort/opr_impl.h" + +using namespace megdnn; +using namespace test; + +namespace { +class ArgsortRNG final : public RNG { + bool m_rev_order = false; + DType m_dtype; + + template + void fill(T* ptr, int n) { + if (m_rev_order) { + for (int i = 0; i < n; ++i) + ptr[i] = static_cast(n / 2 - i); + } else { + for (int i = 0; i < n; ++i) + ptr[i] = static_cast(i - n / 2); + COMPAT_RANDOM(ptr, ptr + n); + } + } + + void gen(const TensorND& tensor) override { + auto n = tensor.layout.total_nr_elems(); + if (m_dtype == dtype::Float32{}) { + fill(tensor.ptr(), n); + } else { + megdnn_assert(m_dtype == dtype::Int32{}); + fill(tensor.ptr(), n); + } + } + +public: + ArgsortRNG(DType dt) : m_dtype{dt} {} + + void set_rev_order(bool flag) { m_rev_order = flag; } +}; +void run_forward_test(Handle* handle, DType dtype) { + Checker checker(handle); + using Param = Argsort::Param; + using Order = Param::Order; + ArgsortRNG rng{dtype}; + checker.set_dtype(2, dtype::Int32()); + checker.set_dtype(0, dtype).set_rng(0, &rng); + for (size_t i = 3; i < 10240; i *= 2) { + Param param; + + param.order = Order::ASCENDING; + checker.set_param(param).execs({{3, i + 1}, {}, {}}); + param.order = Order::DESCENDING; + checker.set_param(param).execs({{3, i - 1}, {}, {}}); + checker.set_param(param).execs({{13, i + 3}, {}, {}}); + } + { + // reverse sort large array + constexpr size_t N = 200003; + rng.set_rev_order(true); + Param param; + param.order = Order::ASCENDING; + checker.set_param(param).execs({{1, N}, {}, {}}); + } +} + +void run_backward_test(Handle* handle, DType dtype) { + class IdxRng final : public RNG { + void gen(const TensorND& tensor) override { + auto ptr = tensor.ptr(); + auto m = tensor.layout[0], n = tensor.layout[1]; + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + ptr[j] = j; + } + COMPAT_RANDOM(ptr, ptr + n); + ptr += n; + } + } + } rng; + Checker checker(handle); + checker.set_dtype(1, dtype::Int32()).set_rng(1, &rng); + checker.set_dtype(0, dtype); + checker.set_dtype(2, dtype); + for (size_t i = 16; i < 4096; i *= 2) { + checker.execs({{3, i}, {3, i}, {3, i}}); + checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 3}}); + checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 7}}); + } +} + +} // anonymous namespace + +TEST_F(ROCM, ARGSORT_FORWARD_F32) { + run_forward_test(handle_rocm(), dtype::Float32{}); +} + +TEST_F(ROCM, ARGSORT_FORWARD_I32) { + run_forward_test(handle_rocm(), dtype::Int32{}); +} + +TEST_F(ROCM, ARGSORT_BACKWARD_F32) { + run_backward_test(handle_rocm(), dtype::Float32{}); +} + +TEST_F(ROCM, ARGSORT_BACKWARD_I32) { + run_backward_test(handle_rocm(), dtype::Int32{}); +} + +// vim: syntax=cpp.doxygen +