diff --git a/dnn/src/cuda/argsort/argsort.cu b/dnn/src/cuda/argsort/argsort.cu index 631ad160c09b4ed35c1f4bab72894c6e2ef1b5b9..a8781cf523115569a9d5f5ded34083b27b6f23d5 100644 --- a/dnn/src/cuda/argsort/argsort.cu +++ b/dnn/src/cuda/argsort/argsort.cu @@ -124,7 +124,7 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, ARGSORT_FOREACH_CTYPE(cb) #undef cb default: - megdnn_throw("argsort only supports float and int32"); + megdnn_throw("argsort only supports float, int32 and float16"); } if (!iptr_src_given) { size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); diff --git a/dnn/src/cuda/argsort/argsort.cuh b/dnn/src/cuda/argsort/argsort.cuh index 8814b88ee7cacb5d89c08c5e05d1f1a903c64e0a..d6301d65ace7db8361bad677ed4724095572cd48 100644 --- a/dnn/src/cuda/argsort/argsort.cuh +++ b/dnn/src/cuda/argsort/argsort.cuh @@ -33,7 +33,8 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, const int* iptr_src = NULL); //! iterate over all supported data types -#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t) +#define ARGSORT_FOREACH_CTYPE(cb) \ + cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16)) } // namespace argsort } // namespace cuda diff --git a/dnn/src/cuda/argsort/bitonic_sort.cu b/dnn/src/cuda/argsort/bitonic_sort.cu index 598ad02d4842f6c41baf49c11b095af04abc0b9c..386e4a5779ea73ef5d3781e5de9464a49c8a5c53 100644 --- a/dnn/src/cuda/argsort/bitonic_sort.cu +++ b/dnn/src/cuda/argsort/bitonic_sort.cu @@ -11,6 +11,7 @@ #include "./bitonic_sort.cuh" #include "src/cuda/query_blocksize.cuh" +#include "megdnn/dtype.h" #if __CUDACC_VER_MAJOR__ < 9 #pragma message "warp sync disabled due to insufficient cuda version" @@ -82,6 +83,18 @@ struct NumTrait { static __device__ __forceinline__ int32_t min() { return INT_MIN; } }; +#if !MEGDNN_DISABLE_FLOAT16 +template <> +struct NumTrait { + static __device__ __forceinline__ dt_float16 max() { + return std::numeric_limits::max(); + } + static __device__ __forceinline__ dt_float16 min() { + return std::numeric_limits::lowest(); + } +}; +#endif + struct LessThan { template static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, @@ -295,6 +308,7 @@ namespace cuda { INST(float, int); INST(int32_t, int); +DNN_INC_FLOAT16(INST(dt_float16, int)); #undef INST } // namespace megdnn diff --git a/dnn/src/cuda/cub/util_type.cuh b/dnn/src/cuda/cub/util_type.cuh index 0ba41e1ed26e56c11f373fd235fc9dee88fd213c..154f03278617b97a5e527e1544a65008305db662 100644 --- a/dnn/src/cuda/cub/util_type.cuh +++ b/dnn/src/cuda/cub/util_type.cuh @@ -1146,7 +1146,7 @@ template <> struct NumericTraits : BaseTraits= 9) template <> struct NumericTraits<__half> : BaseTraits {}; #endif - +template <> struct NumericTraits : BaseTraits {}; template <> struct NumericTraits : BaseTraits::VolatileWord, bool> {}; diff --git a/dnn/src/cuda/topk/opr_impl.cpp b/dnn/src/cuda/topk/opr_impl.cpp index 7f9bd2cd8c360803523e466ea0a1ebb9d3764aca..43e5a86ddd3da366364e4ef64a3f55918125481e 100644 --- a/dnn/src/cuda/topk/opr_impl.cpp +++ b/dnn/src/cuda/topk/opr_impl.cpp @@ -81,9 +81,18 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, values.ptr(), indices, workspace.raw_ptr); return; +#if !MEGDNN_DISABLE_FLOAT16 + case DTypeEnum::Float16: + dispatch_with_ctype(k, data.layout[0], data.layout[1], + data.layout.stride[0], data.ptr(), + values.ptr(), indices, + workspace.raw_ptr); + return; +#endif default: megdnn_throw( - ssprintf("only float32 and int32 supported for cuda topk, got: %s", + ssprintf("only float32, int32 and float16 supported for " + "cuda topk, got: %s", data.layout.dtype.name())); } } diff --git a/dnn/src/cuda/topk/topk_radix.cu b/dnn/src/cuda/topk/topk_radix.cu index 700f0ea09280db7615d4f707328ae4098fccfce3..585deeee8ac74534f031668e0cb9798e47720a05 100644 --- a/dnn/src/cuda/topk/topk_radix.cu +++ b/dnn/src/cuda/topk/topk_radix.cu @@ -489,7 +489,7 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, if (k < 0) { k = length + k + 1; } - if (!(BUCKET_BITS == 8 && sizeof(ctype) == 4)) { + if (!(BUCKET_BITS == 8 && (sizeof(ctype) == 4 || sizeof(ctype) == 2))) { // no c++11 in megdnn cuda; so we just trap instead of using static // assert megdnn_trap(); @@ -668,6 +668,7 @@ namespace topk { int32_t, uint32_t, cudaStream_t) INST(float); INST(int32_t); +DNN_INC_FLOAT16(INST(dt_float16)); #undef INST } // namespace topk diff --git a/dnn/src/cuda/topk/topk_radix.cuh b/dnn/src/cuda/topk/topk_radix.cuh index 94c24d3516b359295d87eb844eb72c289189e5d7..375a96265b1add0a4760df53dcef37bd4a35a2b5 100644 --- a/dnn/src/cuda/topk/topk_radix.cuh +++ b/dnn/src/cuda/topk/topk_radix.cuh @@ -10,7 +10,7 @@ */ #pragma once - +#include "megdnn/dtype.h" #include #include @@ -60,6 +60,29 @@ struct RadixConverter { } }; +#if !MEGDNN_DISABLE_FLOAT16 +template <> +struct RadixConverter { + union FIunion { + FIunion() {} + dt_float16 fv; + uint16_t iv; + }; + static __forceinline__ __device__ __host__ uint16_t to_radix(dt_float16 val) { + FIunion fi; + fi.fv = val; + return fi.iv ^ (((!(fi.iv >> 15u)) - 1u) | 0x8000u); + } + static __forceinline__ __device__ __host__ dt_float16 from_radix(uint16_t val) { + FIunion fi; + // do not write as to_radix() to work around a compiler bug in cuda-9.0 + uint16_t m = 0x8000u; + fi.iv = val ^ (m | (m - !(val >> 15u))); + return fi.fv; + } +}; +#endif + } // namespace internal /*! diff --git a/dnn/test/cuda/topk.cpp b/dnn/test/cuda/topk.cpp index 33036442b0c4f6a03076d3bd368bf3ac6bdad5d0..70249df0534bd0de4f16a617bdfe53ba30896b2a 100644 --- a/dnn/test/cuda/topk.cpp +++ b/dnn/test/cuda/topk.cpp @@ -27,6 +27,10 @@ TEST_F(CUDA, TOP_K) { TEST_F(CUDA, TOP_K_I32) { run_topk_test(handle_cuda()); } - +#if !MEGDNN_DISABLE_FLOAT16 +TEST_F(CUDA, TOP_K_F16) { + run_topk_test(handle_cuda()); +} +#endif // vim: syntax=cpp.doxygen