diff --git a/dnn/src/cuda/elemwise_helper.cpp b/dnn/src/cuda/elemwise_helper.cpp index 33349679c1c65b28d1ea7e600cc20eab6f8e7677..61bbd90b3ece67a0bad3e25201a4dc575ba20e41 100644 --- a/dnn/src/cuda/elemwise_helper.cpp +++ b/dnn/src/cuda/elemwise_helper.cpp @@ -148,6 +148,9 @@ INST_FOR_CTYPE #define ct dt_int16 INST_FOR_CTYPE #undef ct +#define ct dt_uint16 +INST_FOR_CTYPE +#undef ct #define ct dt_quint8 INST_FOR_CTYPE #undef ct @@ -201,6 +204,9 @@ INST_FOR_CTYPE #define ct dt_int16 INST_FOR_CTYPE #undef ct +#define ct dt_uint16 +INST_FOR_CTYPE +#undef ct #define ct dt_quint8 INST_FOR_CTYPE #undef ct diff --git a/dnn/src/cuda/elemwise_helper.cuh b/dnn/src/cuda/elemwise_helper.cuh index e098d8d50a2be70cb01301b2373f0afd21d75043..4a9c67f5e56934e52224e7a1edc067f3998dd46b 100644 --- a/dnn/src/cuda/elemwise_helper.cuh +++ b/dnn/src/cuda/elemwise_helper.cuh @@ -92,6 +92,7 @@ INST(dt_float16, half4); INST(dt_bfloat16, bhalf4); INST(dt_int32, int4); INST(dt_int16, short4); +INST(dt_uint16, ushort4); INST(dt_bool, uchar4); #undef as_raw #define as_raw(x) x.as_int8() diff --git a/dnn/src/cuda/type_cvt/kern.cu b/dnn/src/cuda/type_cvt/kern.cu index 007fd1bef1d5da4d15b7447e37a9af607aac205e..fd591d20142f0efc682bafc4e618137b5cd6f683 100644 --- a/dnn/src/cuda/type_cvt/kern.cu +++ b/dnn/src/cuda/type_cvt/kern.cu @@ -247,6 +247,19 @@ struct TypeCvtOpFromQuantizedToQuantized4bit< namespace megdnn { namespace cuda { +// currently only typecvt_kern_{n2q,n2q4} respect this. change others typecvt_kern_* if +// needed. +template +struct enable_typecvt_kern { + static constexpr bool value = true; +}; + +#define MEGDNN_DISABLE_CUDA_TYPECVT_KERN(dtype_src, dtype_dest) \ + template <> \ + struct enable_typecvt_kern { \ + static constexpr bool value = false; \ + }; + template void typecvt_kern_q2q( const TensorND& dest, const TensorND& src, @@ -257,12 +270,28 @@ void typecvt_kern_q2q( } template -void typecvt_kern_n2q( +typename std::enable_if::value>::type +typecvt_kern_n2q_impl( const TensorND& dest, const TensorND& src, const CudaDTypeParam& dst_param, cudaStream_t stream) { main_func(TypeCvtOpToQuantized, op.param = dst_param;); } +template +typename std::enable_if::value>::type +typecvt_kern_n2q_impl( + const TensorND& dest, const TensorND& src, + const CudaDTypeParam& dst_param, cudaStream_t stream) { + megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled"); +} + +template +void typecvt_kern_n2q( + const TensorND& dest, const TensorND& src, + const CudaDTypeParam& dst_param, cudaStream_t stream) { + typecvt_kern_n2q_impl(dest, src, dst_param, stream); +} + template void typecvt_kern_q2n( const TensorND& dest, const TensorND& src, @@ -312,12 +341,15 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st cb(dtype_src, dt_qint8) \ cb(dtype_src, dt_qint1) \ +MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN) + #define INST_SRC_QUANTIZED(dtype_src) \ MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \ MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2Q) \ #define INST_SRC_NORMAL(dtype_src) \ MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2N) \ + INST_N2N(dtype_src, dt_uint16) \ MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2Q) \ #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ @@ -340,6 +372,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED) MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL) +INST_SRC_NORMAL(dt_uint16) // clang-format on template void typecvt_kern_n2q( @@ -377,12 +410,28 @@ void typecvt_kern_q2q4( } template -void typecvt_kern_n2q4( +typename std::enable_if::value>::type +typecvt_kern_n2q4_impl( const TensorND& dest, const TensorND& src, const CudaDTypeParam& dst_param, cudaStream_t stream) { main_func_to_q4(TypeCvtOpFromNormalToQuantized4bit, op.dst_param = dst_param;) } +template +typename std::enable_if::value>::type +typecvt_kern_n2q4_impl( + const TensorND& dest, const TensorND& src, + const CudaDTypeParam& dst_param, cudaStream_t stream) { + megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled"); +} + +template +void typecvt_kern_n2q4( + const TensorND& dest, const TensorND& src, + const CudaDTypeParam& dst_param, cudaStream_t stream) { + typecvt_kern_n2q4_impl(dest, src, dst_param, stream); +} + #define INST_Q2Q4(dtype_src, dtype_dest) \ template void typecvt_kern_q2q4( \ const TensorND& dest, const TensorND& src, \ @@ -399,6 +448,8 @@ void typecvt_kern_n2q4( cb(dtype_src, dt_qint4) \ cb(dtype_src, dt_quint4) \ +MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN) + #define INST_SRC_QUANTIZED_LOWBIT(dtype_src) \ MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dtype_src, INST_Q2Q4) \ @@ -407,6 +458,7 @@ void typecvt_kern_n2q4( MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED_LOWBIT) MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL_LOWBIT) +INST_SRC_NORMAL_LOWBIT(dt_uint16) } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/type_cvt/opr_impl.cpp b/dnn/src/cuda/type_cvt/opr_impl.cpp index 31013dc55afa90f66b7cd43246e766ec7a1c05b2..eddbdf74ac2e0762e1f8fc660b4dd8f2ed0745c2 100644 --- a/dnn/src/cuda/type_cvt/opr_impl.cpp +++ b/dnn/src/cuda/type_cvt/opr_impl.cpp @@ -12,6 +12,8 @@ #include "./opr_impl.h" #include "./kern.cuh" +#include "megdnn/dtype.h" +#include "src/common/utils.cuh" #include "src/cuda/utils.cuh" #include "src/cuda/utils.h" @@ -87,10 +89,9 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb); - cb(::megdnn::dtype::Bool); + cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) #undef cb - default: - megdnn_assert_internal(0); + default : megdnn_assert_internal(0); } } else if (!is_dst_lowbit) { switch (dst.layout.dtype.enumv()) { @@ -138,7 +139,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) - cb(::megdnn::dtype::Bool) + cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) #undef cb default : megdnn_assert_internal(0); } diff --git a/dnn/test/cuda/type_cvt.cpp b/dnn/test/cuda/type_cvt.cpp index 14feae907dace40296469c34c7145aae34391ea9..9f6adc5ae4faa766fa6cc79dce8ff9c1e871ed19 100644 --- a/dnn/test/cuda/type_cvt.cpp +++ b/dnn/test/cuda/type_cvt.cpp @@ -19,7 +19,8 @@ using namespace test; TEST_F(CUDA, TYPE_CVT) { UniformFloatRNG init(0, 20); std::vector dtypes = {dtype::Float32(), dtype::Float16(), dtype::Int32(), - dtype::Int16(), dtype::Int8(), dtype::Uint8()}; + dtype::Int16(), dtype::Int8(), dtype::Uint8(), + dtype::Uint16()}; for (auto sdtype : dtypes) for (auto ddtype : dtypes) { TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype); diff --git a/src/core/impl/dtype.cpp b/src/core/impl/dtype.cpp index 494dbcd0bfda1f99c771e2cdea971c71e30e183c..6bb70e03c39f84133a315ea8a55a4481ac7dec74 100644 --- a/src/core/impl/dtype.cpp +++ b/src/core/impl/dtype.cpp @@ -210,6 +210,7 @@ typename ctype_enable_if::type DTypeScalar::set_retain_dtype(ctype val) { MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) cb(dt_bool); + cb(dt_uint16); #undef cb default: mgb_throw(ConversionError, "can not assign to dtype %s", m_dtype.name());