From 83cf4ee64e8295d5f5889e87a905bb98e1b6ed6a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 25 Aug 2021 12:33:12 +0800 Subject: [PATCH] refactor(dnn/rocm): remove some useless includes GitOrigin-RevId: 3d2c315a368f7307a88ba37f0674a36072281578 --- dnn/src/rocm/argsort/argsort.cpp.hip | 2 +- dnn/src/rocm/argsort/argsort.h.hip | 1 + dnn/src/rocm/argsort/backward.cpp.hip | 2 -- dnn/src/rocm/argsort/bitonic_sort.cpp.hip | 30 ++++++++++------------- dnn/src/rocm/topk/topk_radix.cpp.hip | 18 +++++--------- 5 files changed, 21 insertions(+), 32 deletions(-) diff --git a/dnn/src/rocm/argsort/argsort.cpp.hip b/dnn/src/rocm/argsort/argsort.cpp.hip index 6ca67a1b0..65e21bb4d 100644 --- a/dnn/src/rocm/argsort/argsort.cpp.hip +++ b/dnn/src/rocm/argsort/argsort.cpp.hip @@ -174,7 +174,7 @@ template void argsort::forward(const dtype*, dtype*, int*, void*, \ ARGSORT_FOREACH_CTYPE(INST_FORWARD) INST_CUB_SORT(uint32_t) -// INST_CUB_SORT(uint64_t) +INST_CUB_SORT(uint64_t) #undef INST_CUB_SORT #undef INST_FORWARD } diff --git a/dnn/src/rocm/argsort/argsort.h.hip b/dnn/src/rocm/argsort/argsort.h.hip index f9ca27cca..9b3e07c17 100644 --- a/dnn/src/rocm/argsort/argsort.h.hip +++ b/dnn/src/rocm/argsort/argsort.h.hip @@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, const int* iptr_src = NULL); //! iterate over all supported data types +// device_radix_sort does not support dt_float16 dtype(half_float::half in rocm) #define ARGSORT_FOREACH_CTYPE(cb) \ cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) diff --git a/dnn/src/rocm/argsort/backward.cpp.hip b/dnn/src/rocm/argsort/backward.cpp.hip index d7befd584..7428a5ffe 100644 --- a/dnn/src/rocm/argsort/backward.cpp.hip +++ b/dnn/src/rocm/argsort/backward.cpp.hip @@ -14,8 +14,6 @@ #include "./argsort.h.hip" #include "./backward.h.hip" -// #include "src/rocm/utils.h" - using namespace megdnn; using namespace rocm; using namespace argsort; diff --git a/dnn/src/rocm/argsort/bitonic_sort.cpp.hip b/dnn/src/rocm/argsort/bitonic_sort.cpp.hip index 3f93d44ed..095f4b712 100644 --- a/dnn/src/rocm/argsort/bitonic_sort.cpp.hip +++ b/dnn/src/rocm/argsort/bitonic_sort.cpp.hip @@ -11,13 +11,9 @@ #include "hcc_detail/hcc_defs_prologue.h" #include "./bitonic_sort.h.hip" -// #include "src/cuda/query_blocksize.cuh" -// #include "megdnn/dtype.h" +#include "megdnn/dtype.h" -// #if __CUDACC_VER_MAJOR__ < 9 -// #pragma message "warp sync disabled due to insufficient cuda version" #define __syncwarp __syncthreads -// #endif #include #include @@ -84,17 +80,17 @@ struct NumTrait { static __device__ __forceinline__ int32_t min() { return INT_MIN; } }; -// #if !MEGDNN_DISABLE_FLOAT16 -// template <> -// struct NumTrait { -// static __device__ __forceinline__ dt_float16 max() { -// return std::numeric_limits::max(); -// } -// static __device__ __forceinline__ dt_float16 min() { -// return std::numeric_limits::lowest(); -// } -// }; -// #endif +#if !MEGDNN_DISABLE_FLOAT16 +template <> +struct NumTrait { + static __device__ __forceinline__ dt_float16 max() { + return std::numeric_limits::max(); + } + static __device__ __forceinline__ dt_float16 min() { + return std::numeric_limits::lowest(); + } +}; +#endif struct LessThan { template @@ -310,7 +306,7 @@ namespace rocm { INST(float, int); INST(int32_t, int); -// DNN_INC_FLOAT16(INST(dt_float16, int)); +DNN_INC_FLOAT16(INST(dt_float16, int)); #undef INST } // namespace megdnn diff --git a/dnn/src/rocm/topk/topk_radix.cpp.hip b/dnn/src/rocm/topk/topk_radix.cpp.hip index 2cae0bfb8..a26b3351d 100644 --- a/dnn/src/rocm/topk/topk_radix.cpp.hip +++ b/dnn/src/rocm/topk/topk_radix.cpp.hip @@ -18,13 +18,7 @@ #include #include -#if __CUDACC_VER_MAJOR__ < 9 -#pragma message "topk is a little slower on cuda earlier than 9.0" -// on cuda 9.0 and later, due to thread-divergent branches we should use -// __syncwarp; and I am too lazy to implement a correct legacy version, so just -// use __syncthreads instead for older cuda #define __syncwarp __syncthreads -#endif using namespace megdnn; using namespace rocm; @@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt, } } - //if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | - // (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { - // // impossible - // int* bad = 0x0; - // *bad = 23; - //} + if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | + (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { + // impossible + int* bad = 0x0; + *bad = 23; + } } static uint32_t get_grid_dim_x(uint32_t length) { -- GitLab