diff --git a/dnn/src/rocm/argsort/argsort.cpp.hip b/dnn/src/rocm/argsort/argsort.cpp.hip index 6ca67a1b0aa4ebe8208cccf17df0d9cfc5cb20af..65e21bb4dc2b4094dd300db4849d24fa28a891ca 100644 --- a/dnn/src/rocm/argsort/argsort.cpp.hip +++ b/dnn/src/rocm/argsort/argsort.cpp.hip @@ -174,7 +174,7 @@ template void argsort::forward(const dtype*, dtype*, int*, void*, \ ARGSORT_FOREACH_CTYPE(INST_FORWARD) INST_CUB_SORT(uint32_t) -// INST_CUB_SORT(uint64_t) +INST_CUB_SORT(uint64_t) #undef INST_CUB_SORT #undef INST_FORWARD } diff --git a/dnn/src/rocm/argsort/argsort.h.hip b/dnn/src/rocm/argsort/argsort.h.hip index f9ca27cca214736e986161702fbe1ac7ee58cd25..9b3e07c1732741648091843e72e1cb5035ee9573 100644 --- a/dnn/src/rocm/argsort/argsort.h.hip +++ b/dnn/src/rocm/argsort/argsort.h.hip @@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, const int* iptr_src = NULL); //! iterate over all supported data types +// device_radix_sort does not support dt_float16 dtype(half_float::half in rocm) #define ARGSORT_FOREACH_CTYPE(cb) \ cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) diff --git a/dnn/src/rocm/argsort/backward.cpp.hip b/dnn/src/rocm/argsort/backward.cpp.hip index d7befd58420eca24dcdc1a0e969af02809cc18cc..7428a5ffe602afaae710974e2ff9054e2b1378ca 100644 --- a/dnn/src/rocm/argsort/backward.cpp.hip +++ b/dnn/src/rocm/argsort/backward.cpp.hip @@ -14,8 +14,6 @@ #include "./argsort.h.hip" #include "./backward.h.hip" -// #include "src/rocm/utils.h" - using namespace megdnn; using namespace rocm; using namespace argsort; diff --git a/dnn/src/rocm/argsort/bitonic_sort.cpp.hip b/dnn/src/rocm/argsort/bitonic_sort.cpp.hip index 3f93d44ed5c42777be2191917f60b8eb7b85078a..095f4b712bbfc479fd2852644318ded3ee7cf79c 100644 --- a/dnn/src/rocm/argsort/bitonic_sort.cpp.hip +++ b/dnn/src/rocm/argsort/bitonic_sort.cpp.hip @@ -11,13 +11,9 @@ #include "hcc_detail/hcc_defs_prologue.h" #include "./bitonic_sort.h.hip" -// #include "src/cuda/query_blocksize.cuh" -// #include "megdnn/dtype.h" +#include "megdnn/dtype.h" -// #if __CUDACC_VER_MAJOR__ < 9 -// #pragma message "warp sync disabled due to insufficient cuda version" #define __syncwarp __syncthreads -// #endif #include #include @@ -84,17 +80,17 @@ struct NumTrait { static __device__ __forceinline__ int32_t min() { return INT_MIN; } }; -// #if !MEGDNN_DISABLE_FLOAT16 -// template <> -// struct NumTrait { -// static __device__ __forceinline__ dt_float16 max() { -// return std::numeric_limits::max(); -// } -// static __device__ __forceinline__ dt_float16 min() { -// return std::numeric_limits::lowest(); -// } -// }; -// #endif +#if !MEGDNN_DISABLE_FLOAT16 +template <> +struct NumTrait { + static __device__ __forceinline__ dt_float16 max() { + return std::numeric_limits::max(); + } + static __device__ __forceinline__ dt_float16 min() { + return std::numeric_limits::lowest(); + } +}; +#endif struct LessThan { template @@ -310,7 +306,7 @@ namespace rocm { INST(float, int); INST(int32_t, int); -// DNN_INC_FLOAT16(INST(dt_float16, int)); +DNN_INC_FLOAT16(INST(dt_float16, int)); #undef INST } // namespace megdnn diff --git a/dnn/src/rocm/topk/topk_radix.cpp.hip b/dnn/src/rocm/topk/topk_radix.cpp.hip index 2cae0bfb8b7d21ba116f4e14c1740e67c2630146..a26b3351da6aeb81b75a377b887b641a766f9854 100644 --- a/dnn/src/rocm/topk/topk_radix.cpp.hip +++ b/dnn/src/rocm/topk/topk_radix.cpp.hip @@ -18,13 +18,7 @@ #include #include -#if __CUDACC_VER_MAJOR__ < 9 -#pragma message "topk is a little slower on cuda earlier than 9.0" -// on cuda 9.0 and later, due to thread-divergent branches we should use -// __syncwarp; and I am too lazy to implement a correct legacy version, so just -// use __syncthreads instead for older cuda #define __syncwarp __syncthreads -#endif using namespace megdnn; using namespace rocm; @@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt, } } - //if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | - // (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { - // // impossible - // int* bad = 0x0; - // *bad = 23; - //} + if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | + (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { + // impossible + int* bad = 0x0; + *bad = 23; + } } static uint32_t get_grid_dim_x(uint32_t length) {