提交 b87af9f7 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(dnn/cuda): topk support fp16

GitOrigin-RevId: c6610d4cf04f258fd13f420e5301c273e399d67f
上级 34262d90
......@@ -124,7 +124,7 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
ARGSORT_FOREACH_CTYPE(cb)
#undef cb
default:
megdnn_throw("argsort only supports float and int32");
megdnn_throw("argsort only supports float, int32 and float16");
}
if (!iptr_src_given) {
size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int);
......
......@@ -33,7 +33,8 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
const int* iptr_src = NULL);
//! iterate over all supported data types
#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t)
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16))
} // namespace argsort
} // namespace cuda
......
......@@ -11,6 +11,7 @@
#include "./bitonic_sort.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "megdnn/dtype.h"
#if __CUDACC_VER_MAJOR__ < 9
#pragma message "warp sync disabled due to insufficient cuda version"
......@@ -82,6 +83,18 @@ struct NumTrait<int32_t> {
static __device__ __forceinline__ int32_t min() { return INT_MIN; }
};
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NumTrait<dt_float16> {
static __device__ __forceinline__ dt_float16 max() {
return std::numeric_limits<dt_float16>::max();
}
static __device__ __forceinline__ dt_float16 min() {
return std::numeric_limits<dt_float16>::lowest();
}
};
#endif
struct LessThan {
template <typename Key, typename Value>
static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1,
......@@ -295,6 +308,7 @@ namespace cuda {
INST(float, int);
INST(int32_t, int);
DNN_INC_FLOAT16(INST(dt_float16, int));
#undef INST
} // namespace megdnn
......
......@@ -1146,7 +1146,7 @@ template <> struct NumericTraits<double> : BaseTraits<FLOATING_POIN
#if (__CUDACC_VER_MAJOR__ >= 9)
template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
#endif
template <> struct NumericTraits<half_float::half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, half_float::half> {};
template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {};
......
......@@ -81,9 +81,18 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
values.ptr<int32_t>(), indices,
workspace.raw_ptr);
return;
#if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
dispatch_with_ctype<dt_float16>(k, data.layout[0], data.layout[1],
data.layout.stride[0], data.ptr<dt_float16>(),
values.ptr<dt_float16>(), indices,
workspace.raw_ptr);
return;
#endif
default:
megdnn_throw(
ssprintf("only float32 and int32 supported for cuda topk, got: %s",
ssprintf("only float32, int32 and float16 supported for "
"cuda topk, got: %s",
data.layout.dtype.name()));
}
}
......
......@@ -489,7 +489,7 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
if (k < 0) {
k = length + k + 1;
}
if (!(BUCKET_BITS == 8 && sizeof(ctype) == 4)) {
if (!(BUCKET_BITS == 8 && (sizeof(ctype) == 4 || sizeof(ctype) == 2))) {
// no c++11 in megdnn cuda; so we just trap instead of using static
// assert
megdnn_trap();
......@@ -668,6 +668,7 @@ namespace topk {
int32_t, uint32_t, cudaStream_t)
INST(float);
INST(int32_t);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST
} // namespace topk
......
......@@ -10,7 +10,7 @@
*/
#pragma once
#include "megdnn/dtype.h"
#include <cuda_runtime.h>
#include <stdint.h>
......@@ -60,6 +60,29 @@ struct RadixConverter<int32_t> {
}
};
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct RadixConverter<dt_float16> {
union FIunion {
FIunion() {}
dt_float16 fv;
uint16_t iv;
};
static __forceinline__ __device__ __host__ uint16_t to_radix(dt_float16 val) {
FIunion fi;
fi.fv = val;
return fi.iv ^ (((!(fi.iv >> 15u)) - 1u) | 0x8000u);
}
static __forceinline__ __device__ __host__ dt_float16 from_radix(uint16_t val) {
FIunion fi;
// do not write as to_radix() to work around a compiler bug in cuda-9.0
uint16_t m = 0x8000u;
fi.iv = val ^ (m | (m - !(val >> 15u)));
return fi.fv;
}
};
#endif
} // namespace internal
/*!
......
......@@ -27,6 +27,10 @@ TEST_F(CUDA, TOP_K) {
TEST_F(CUDA, TOP_K_I32) {
run_topk_test<dtype::Int32>(handle_cuda());
}
#if !MEGDNN_DISABLE_FLOAT16
TEST_F(CUDA, TOP_K_F16) {
run_topk_test<dtype::Float16>(handle_cuda());
}
#endif
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册