提交 b87af9f7 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(dnn/cuda): topk support fp16

GitOrigin-RevId: c6610d4cf04f258fd13f420e5301c273e399d67f
上级 34262d90
...@@ -124,7 +124,7 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, ...@@ -124,7 +124,7 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
ARGSORT_FOREACH_CTYPE(cb) ARGSORT_FOREACH_CTYPE(cb)
#undef cb #undef cb
default: default:
megdnn_throw("argsort only supports float and int32"); megdnn_throw("argsort only supports float, int32 and float16");
} }
if (!iptr_src_given) { if (!iptr_src_given) {
size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int);
......
...@@ -33,7 +33,8 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, ...@@ -33,7 +33,8 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
const int* iptr_src = NULL); const int* iptr_src = NULL);
//! iterate over all supported data types //! iterate over all supported data types
#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t) #define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16))
} // namespace argsort } // namespace argsort
} // namespace cuda } // namespace cuda
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "./bitonic_sort.cuh" #include "./bitonic_sort.cuh"
#include "src/cuda/query_blocksize.cuh" #include "src/cuda/query_blocksize.cuh"
#include "megdnn/dtype.h"
#if __CUDACC_VER_MAJOR__ < 9 #if __CUDACC_VER_MAJOR__ < 9
#pragma message "warp sync disabled due to insufficient cuda version" #pragma message "warp sync disabled due to insufficient cuda version"
...@@ -82,6 +83,18 @@ struct NumTrait<int32_t> { ...@@ -82,6 +83,18 @@ struct NumTrait<int32_t> {
static __device__ __forceinline__ int32_t min() { return INT_MIN; } static __device__ __forceinline__ int32_t min() { return INT_MIN; }
}; };
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NumTrait<dt_float16> {
static __device__ __forceinline__ dt_float16 max() {
return std::numeric_limits<dt_float16>::max();
}
static __device__ __forceinline__ dt_float16 min() {
return std::numeric_limits<dt_float16>::lowest();
}
};
#endif
struct LessThan { struct LessThan {
template <typename Key, typename Value> template <typename Key, typename Value>
static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1,
...@@ -295,6 +308,7 @@ namespace cuda { ...@@ -295,6 +308,7 @@ namespace cuda {
INST(float, int); INST(float, int);
INST(int32_t, int); INST(int32_t, int);
DNN_INC_FLOAT16(INST(dt_float16, int));
#undef INST #undef INST
} // namespace megdnn } // namespace megdnn
......
...@@ -1146,7 +1146,7 @@ template <> struct NumericTraits<double> : BaseTraits<FLOATING_POIN ...@@ -1146,7 +1146,7 @@ template <> struct NumericTraits<double> : BaseTraits<FLOATING_POIN
#if (__CUDACC_VER_MAJOR__ >= 9) #if (__CUDACC_VER_MAJOR__ >= 9)
template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {}; template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
#endif #endif
template <> struct NumericTraits<half_float::half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, half_float::half> {};
template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {}; template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {};
......
...@@ -81,9 +81,18 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, ...@@ -81,9 +81,18 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
values.ptr<int32_t>(), indices, values.ptr<int32_t>(), indices,
workspace.raw_ptr); workspace.raw_ptr);
return; return;
#if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
dispatch_with_ctype<dt_float16>(k, data.layout[0], data.layout[1],
data.layout.stride[0], data.ptr<dt_float16>(),
values.ptr<dt_float16>(), indices,
workspace.raw_ptr);
return;
#endif
default: default:
megdnn_throw( megdnn_throw(
ssprintf("only float32 and int32 supported for cuda topk, got: %s", ssprintf("only float32, int32 and float16 supported for "
"cuda topk, got: %s",
data.layout.dtype.name())); data.layout.dtype.name()));
} }
} }
......
...@@ -489,7 +489,7 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, ...@@ -489,7 +489,7 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
if (k < 0) { if (k < 0) {
k = length + k + 1; k = length + k + 1;
} }
if (!(BUCKET_BITS == 8 && sizeof(ctype) == 4)) { if (!(BUCKET_BITS == 8 && (sizeof(ctype) == 4 || sizeof(ctype) == 2))) {
// no c++11 in megdnn cuda; so we just trap instead of using static // no c++11 in megdnn cuda; so we just trap instead of using static
// assert // assert
megdnn_trap(); megdnn_trap();
...@@ -668,6 +668,7 @@ namespace topk { ...@@ -668,6 +668,7 @@ namespace topk {
int32_t, uint32_t, cudaStream_t) int32_t, uint32_t, cudaStream_t)
INST(float); INST(float);
INST(int32_t); INST(int32_t);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST #undef INST
} // namespace topk } // namespace topk
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
*/ */
#pragma once #pragma once
#include "megdnn/dtype.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stdint.h> #include <stdint.h>
...@@ -60,6 +60,29 @@ struct RadixConverter<int32_t> { ...@@ -60,6 +60,29 @@ struct RadixConverter<int32_t> {
} }
}; };
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct RadixConverter<dt_float16> {
union FIunion {
FIunion() {}
dt_float16 fv;
uint16_t iv;
};
static __forceinline__ __device__ __host__ uint16_t to_radix(dt_float16 val) {
FIunion fi;
fi.fv = val;
return fi.iv ^ (((!(fi.iv >> 15u)) - 1u) | 0x8000u);
}
static __forceinline__ __device__ __host__ dt_float16 from_radix(uint16_t val) {
FIunion fi;
// do not write as to_radix() to work around a compiler bug in cuda-9.0
uint16_t m = 0x8000u;
fi.iv = val ^ (m | (m - !(val >> 15u)));
return fi.fv;
}
};
#endif
} // namespace internal } // namespace internal
/*! /*!
......
...@@ -27,6 +27,10 @@ TEST_F(CUDA, TOP_K) { ...@@ -27,6 +27,10 @@ TEST_F(CUDA, TOP_K) {
TEST_F(CUDA, TOP_K_I32) { TEST_F(CUDA, TOP_K_I32) {
run_topk_test<dtype::Int32>(handle_cuda()); run_topk_test<dtype::Int32>(handle_cuda());
} }
#if !MEGDNN_DISABLE_FLOAT16
TEST_F(CUDA, TOP_K_F16) {
run_topk_test<dtype::Float16>(handle_cuda());
}
#endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册