提交 7dc34769 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add typecvt uint16

GitOrigin-RevId: d1368c414e99e15d6fb93273b5051832d1995dea
上级 b92866d2
...@@ -148,6 +148,9 @@ INST_FOR_CTYPE ...@@ -148,6 +148,9 @@ INST_FOR_CTYPE
#define ct dt_int16 #define ct dt_int16
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
#define ct dt_uint16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8 #define ct dt_quint8
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
...@@ -201,6 +204,9 @@ INST_FOR_CTYPE ...@@ -201,6 +204,9 @@ INST_FOR_CTYPE
#define ct dt_int16 #define ct dt_int16
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
#define ct dt_uint16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8 #define ct dt_quint8
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
......
...@@ -92,6 +92,7 @@ INST(dt_float16, half4); ...@@ -92,6 +92,7 @@ INST(dt_float16, half4);
INST(dt_bfloat16, bhalf4); INST(dt_bfloat16, bhalf4);
INST(dt_int32, int4); INST(dt_int32, int4);
INST(dt_int16, short4); INST(dt_int16, short4);
INST(dt_uint16, ushort4);
INST(dt_bool, uchar4); INST(dt_bool, uchar4);
#undef as_raw #undef as_raw
#define as_raw(x) x.as_int8() #define as_raw(x) x.as_int8()
......
...@@ -247,6 +247,19 @@ struct TypeCvtOpFromQuantizedToQuantized4bit< ...@@ -247,6 +247,19 @@ struct TypeCvtOpFromQuantizedToQuantized4bit<
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
// currently only typecvt_kern_{n2q,n2q4} respect this. change others typecvt_kern_* if
// needed.
template <typename dtype_src, typename dtype_dest, typename sfinae = void>
struct enable_typecvt_kern {
static constexpr bool value = true;
};
#define MEGDNN_DISABLE_CUDA_TYPECVT_KERN(dtype_src, dtype_dest) \
template <> \
struct enable_typecvt_kern<dtype_src, dtype_dest, void> { \
static constexpr bool value = false; \
};
template <typename dtype_src, typename dtype_dest> template <typename dtype_src, typename dtype_dest>
void typecvt_kern_q2q( void typecvt_kern_q2q(
const TensorND& dest, const TensorND& src, const TensorND& dest, const TensorND& src,
...@@ -257,12 +270,28 @@ void typecvt_kern_q2q( ...@@ -257,12 +270,28 @@ void typecvt_kern_q2q(
} }
template <typename dtype_src, typename dtype_dest> template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q( typename std::enable_if<enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q_impl(
const TensorND& dest, const TensorND& src, const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) { const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
main_func(TypeCvtOpToQuantized, op.param = dst_param;); main_func(TypeCvtOpToQuantized, op.param = dst_param;);
} }
template <typename dtype_src, typename dtype_dest>
typename std::enable_if<!enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled");
}
template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
typecvt_kern_n2q_impl<dtype_src, dtype_dest>(dest, src, dst_param, stream);
}
template <typename dtype_src, typename dtype_dest> template <typename dtype_src, typename dtype_dest>
void typecvt_kern_q2n( void typecvt_kern_q2n(
const TensorND& dest, const TensorND& src, const TensorND& dest, const TensorND& src,
...@@ -312,12 +341,15 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st ...@@ -312,12 +341,15 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
cb(dtype_src, dt_qint8) \ cb(dtype_src, dt_qint8) \
cb(dtype_src, dt_qint1) \ cb(dtype_src, dt_qint1) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN)
#define INST_SRC_QUANTIZED(dtype_src) \ #define INST_SRC_QUANTIZED(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \ MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2Q) \ MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2Q) \
#define INST_SRC_NORMAL(dtype_src) \ #define INST_SRC_NORMAL(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2N) \ MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2N) \
INST_N2N(dtype_src, dt_uint16) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2Q) \ MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2Q) \
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
...@@ -340,6 +372,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st ...@@ -340,6 +372,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED) MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED)
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL) MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL)
INST_SRC_NORMAL(dt_uint16)
// clang-format on // clang-format on
template void typecvt_kern_n2q<dtype::Int8, dtype::QuantizedS8>( template void typecvt_kern_n2q<dtype::Int8, dtype::QuantizedS8>(
...@@ -377,12 +410,28 @@ void typecvt_kern_q2q4( ...@@ -377,12 +410,28 @@ void typecvt_kern_q2q4(
} }
template <typename dtype_src, typename dtype_dest> template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q4( typename std::enable_if<enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q4_impl(
const TensorND& dest, const TensorND& src, const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) { const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
main_func_to_q4(TypeCvtOpFromNormalToQuantized4bit, op.dst_param = dst_param;) main_func_to_q4(TypeCvtOpFromNormalToQuantized4bit, op.dst_param = dst_param;)
} }
template <typename dtype_src, typename dtype_dest>
typename std::enable_if<!enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q4_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled");
}
template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q4(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
typecvt_kern_n2q4_impl<dtype_src, dtype_dest>(dest, src, dst_param, stream);
}
#define INST_Q2Q4(dtype_src, dtype_dest) \ #define INST_Q2Q4(dtype_src, dtype_dest) \
template void typecvt_kern_q2q4<dtype_src, dtype_dest>( \ template void typecvt_kern_q2q4<dtype_src, dtype_dest>( \
const TensorND& dest, const TensorND& src, \ const TensorND& dest, const TensorND& src, \
...@@ -399,6 +448,8 @@ void typecvt_kern_n2q4( ...@@ -399,6 +448,8 @@ void typecvt_kern_n2q4(
cb(dtype_src, dt_qint4) \ cb(dtype_src, dt_qint4) \
cb(dtype_src, dt_quint4) \ cb(dtype_src, dt_quint4) \
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN)
#define INST_SRC_QUANTIZED_LOWBIT(dtype_src) \ #define INST_SRC_QUANTIZED_LOWBIT(dtype_src) \
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dtype_src, INST_Q2Q4) \ MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dtype_src, INST_Q2Q4) \
...@@ -407,6 +458,7 @@ void typecvt_kern_n2q4( ...@@ -407,6 +458,7 @@ void typecvt_kern_n2q4(
MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED_LOWBIT) MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED_LOWBIT)
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL_LOWBIT) MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL_LOWBIT)
INST_SRC_NORMAL_LOWBIT(dt_uint16)
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include "./opr_impl.h" #include "./opr_impl.h"
#include "./kern.cuh" #include "./kern.cuh"
#include "megdnn/dtype.h"
#include "src/common/utils.cuh"
#include "src/cuda/utils.cuh" #include "src/cuda/utils.cuh"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
...@@ -87,10 +89,9 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre ...@@ -87,10 +89,9 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre
return; \ return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb); MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
cb(::megdnn::dtype::Bool); cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
#undef cb #undef cb
default: default : megdnn_assert_internal(0);
megdnn_assert_internal(0);
} }
} else if (!is_dst_lowbit) { } else if (!is_dst_lowbit) {
switch (dst.layout.dtype.enumv()) { switch (dst.layout.dtype.enumv()) {
...@@ -138,7 +139,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -138,7 +139,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \ return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
#undef cb #undef cb
default : megdnn_assert_internal(0); default : megdnn_assert_internal(0);
} }
......
...@@ -19,7 +19,8 @@ using namespace test; ...@@ -19,7 +19,8 @@ using namespace test;
TEST_F(CUDA, TYPE_CVT) { TEST_F(CUDA, TYPE_CVT) {
UniformFloatRNG init(0, 20); UniformFloatRNG init(0, 20);
std::vector<DType> dtypes = {dtype::Float32(), dtype::Float16(), dtype::Int32(), std::vector<DType> dtypes = {dtype::Float32(), dtype::Float16(), dtype::Int32(),
dtype::Int16(), dtype::Int8(), dtype::Uint8()}; dtype::Int16(), dtype::Int8(), dtype::Uint8(),
dtype::Uint16()};
for (auto sdtype : dtypes) for (auto sdtype : dtypes)
for (auto ddtype : dtypes) { for (auto ddtype : dtypes) {
TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype); TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype);
......
...@@ -210,6 +210,7 @@ typename ctype_enable_if<ctype>::type DTypeScalar::set_retain_dtype(ctype val) { ...@@ -210,6 +210,7 @@ typename ctype_enable_if<ctype>::type DTypeScalar::set_retain_dtype(ctype val) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
cb(dt_bool); cb(dt_bool);
cb(dt_uint16);
#undef cb #undef cb
default: default:
mgb_throw(ConversionError, "can not assign to dtype %s", m_dtype.name()); mgb_throw(ConversionError, "can not assign to dtype %s", m_dtype.name());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册