From b87af9f77f3fa7a169506f7125682371d9e63591 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 28 May 2021 17:34:25 +0800 Subject: [PATCH] feat(dnn/cuda): topk support fp16 GitOrigin-RevId: c6610d4cf04f258fd13f420e5301c273e399d67f --- dnn/src/cuda/argsort/argsort.cu | 2 +- dnn/src/cuda/argsort/argsort.cuh | 3 ++- dnn/src/cuda/argsort/bitonic_sort.cu | 14 ++++++++++++++ dnn/src/cuda/cub/util_type.cuh | 2 +- dnn/src/cuda/topk/opr_impl.cpp | 11 ++++++++++- dnn/src/cuda/topk/topk_radix.cu | 3 ++- dnn/src/cuda/topk/topk_radix.cuh | 25 ++++++++++++++++++++++++- dnn/test/cuda/topk.cpp | 6 +++++- 8 files changed, 59 insertions(+), 7 deletions(-) diff --git a/dnn/src/cuda/argsort/argsort.cu b/dnn/src/cuda/argsort/argsort.cu index 631ad160c..a8781cf52 100644 --- a/dnn/src/cuda/argsort/argsort.cu +++ b/dnn/src/cuda/argsort/argsort.cu @@ -124,7 +124,7 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, ARGSORT_FOREACH_CTYPE(cb) #undef cb default: - megdnn_throw("argsort only supports float and int32"); + megdnn_throw("argsort only supports float, int32 and float16"); } if (!iptr_src_given) { size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); diff --git a/dnn/src/cuda/argsort/argsort.cuh b/dnn/src/cuda/argsort/argsort.cuh index 8814b88ee..d6301d65a 100644 --- a/dnn/src/cuda/argsort/argsort.cuh +++ b/dnn/src/cuda/argsort/argsort.cuh @@ -33,7 +33,8 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, const int* iptr_src = NULL); //! iterate over all supported data types -#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t) +#define ARGSORT_FOREACH_CTYPE(cb) \ + cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16)) } // namespace argsort } // namespace cuda diff --git a/dnn/src/cuda/argsort/bitonic_sort.cu b/dnn/src/cuda/argsort/bitonic_sort.cu index 598ad02d4..386e4a577 100644 --- a/dnn/src/cuda/argsort/bitonic_sort.cu +++ b/dnn/src/cuda/argsort/bitonic_sort.cu @@ -11,6 +11,7 @@ #include "./bitonic_sort.cuh" #include "src/cuda/query_blocksize.cuh" +#include "megdnn/dtype.h" #if __CUDACC_VER_MAJOR__ < 9 #pragma message "warp sync disabled due to insufficient cuda version" @@ -82,6 +83,18 @@ struct NumTrait { static __device__ __forceinline__ int32_t min() { return INT_MIN; } }; +#if !MEGDNN_DISABLE_FLOAT16 +template <> +struct NumTrait { + static __device__ __forceinline__ dt_float16 max() { + return std::numeric_limits::max(); + } + static __device__ __forceinline__ dt_float16 min() { + return std::numeric_limits::lowest(); + } +}; +#endif + struct LessThan { template static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, @@ -295,6 +308,7 @@ namespace cuda { INST(float, int); INST(int32_t, int); +DNN_INC_FLOAT16(INST(dt_float16, int)); #undef INST } // namespace megdnn diff --git a/dnn/src/cuda/cub/util_type.cuh b/dnn/src/cuda/cub/util_type.cuh index 0ba41e1ed..154f03278 100644 --- a/dnn/src/cuda/cub/util_type.cuh +++ b/dnn/src/cuda/cub/util_type.cuh @@ -1146,7 +1146,7 @@ template <> struct NumericTraits : BaseTraits= 9) template <> struct NumericTraits<__half> : BaseTraits {}; #endif - +template <> struct NumericTraits : BaseTraits {}; template <> struct NumericTraits : BaseTraits::VolatileWord, bool> {}; diff --git a/dnn/src/cuda/topk/opr_impl.cpp b/dnn/src/cuda/topk/opr_impl.cpp index 7f9bd2cd8..43e5a86dd 100644 --- a/dnn/src/cuda/topk/opr_impl.cpp +++ b/dnn/src/cuda/topk/opr_impl.cpp @@ -81,9 +81,18 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, values.ptr(), indices, workspace.raw_ptr); return; +#if !MEGDNN_DISABLE_FLOAT16 + case DTypeEnum::Float16: + dispatch_with_ctype(k, data.layout[0], data.layout[1], + data.layout.stride[0], data.ptr(), + values.ptr(), indices, + workspace.raw_ptr); + return; +#endif default: megdnn_throw( - ssprintf("only float32 and int32 supported for cuda topk, got: %s", + ssprintf("only float32, int32 and float16 supported for " + "cuda topk, got: %s", data.layout.dtype.name())); } } diff --git a/dnn/src/cuda/topk/topk_radix.cu b/dnn/src/cuda/topk/topk_radix.cu index 700f0ea09..585deeee8 100644 --- a/dnn/src/cuda/topk/topk_radix.cu +++ b/dnn/src/cuda/topk/topk_radix.cu @@ -489,7 +489,7 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, if (k < 0) { k = length + k + 1; } - if (!(BUCKET_BITS == 8 && sizeof(ctype) == 4)) { + if (!(BUCKET_BITS == 8 && (sizeof(ctype) == 4 || sizeof(ctype) == 2))) { // no c++11 in megdnn cuda; so we just trap instead of using static // assert megdnn_trap(); @@ -668,6 +668,7 @@ namespace topk { int32_t, uint32_t, cudaStream_t) INST(float); INST(int32_t); +DNN_INC_FLOAT16(INST(dt_float16)); #undef INST } // namespace topk diff --git a/dnn/src/cuda/topk/topk_radix.cuh b/dnn/src/cuda/topk/topk_radix.cuh index 94c24d351..375a96265 100644 --- a/dnn/src/cuda/topk/topk_radix.cuh +++ b/dnn/src/cuda/topk/topk_radix.cuh @@ -10,7 +10,7 @@ */ #pragma once - +#include "megdnn/dtype.h" #include #include @@ -60,6 +60,29 @@ struct RadixConverter { } }; +#if !MEGDNN_DISABLE_FLOAT16 +template <> +struct RadixConverter { + union FIunion { + FIunion() {} + dt_float16 fv; + uint16_t iv; + }; + static __forceinline__ __device__ __host__ uint16_t to_radix(dt_float16 val) { + FIunion fi; + fi.fv = val; + return fi.iv ^ (((!(fi.iv >> 15u)) - 1u) | 0x8000u); + } + static __forceinline__ __device__ __host__ dt_float16 from_radix(uint16_t val) { + FIunion fi; + // do not write as to_radix() to work around a compiler bug in cuda-9.0 + uint16_t m = 0x8000u; + fi.iv = val ^ (m | (m - !(val >> 15u))); + return fi.fv; + } +}; +#endif + } // namespace internal /*! diff --git a/dnn/test/cuda/topk.cpp b/dnn/test/cuda/topk.cpp index 33036442b..70249df05 100644 --- a/dnn/test/cuda/topk.cpp +++ b/dnn/test/cuda/topk.cpp @@ -27,6 +27,10 @@ TEST_F(CUDA, TOP_K) { TEST_F(CUDA, TOP_K_I32) { run_topk_test(handle_cuda()); } - +#if !MEGDNN_DISABLE_FLOAT16 +TEST_F(CUDA, TOP_K_F16) { + run_topk_test(handle_cuda()); +} +#endif // vim: syntax=cpp.doxygen -- GitLab