提交 7dc34769 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add typecvt uint16

GitOrigin-RevId: d1368c414e99e15d6fb93273b5051832d1995dea
上级 b92866d2
......@@ -148,6 +148,9 @@ INST_FOR_CTYPE
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_uint16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct
......@@ -201,6 +204,9 @@ INST_FOR_CTYPE
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_uint16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct
......
......@@ -92,6 +92,7 @@ INST(dt_float16, half4);
INST(dt_bfloat16, bhalf4);
INST(dt_int32, int4);
INST(dt_int16, short4);
INST(dt_uint16, ushort4);
INST(dt_bool, uchar4);
#undef as_raw
#define as_raw(x) x.as_int8()
......
......@@ -247,6 +247,19 @@ struct TypeCvtOpFromQuantizedToQuantized4bit<
namespace megdnn {
namespace cuda {
// currently only typecvt_kern_{n2q,n2q4} respect this. change others typecvt_kern_* if
// needed.
template <typename dtype_src, typename dtype_dest, typename sfinae = void>
struct enable_typecvt_kern {
static constexpr bool value = true;
};
#define MEGDNN_DISABLE_CUDA_TYPECVT_KERN(dtype_src, dtype_dest) \
template <> \
struct enable_typecvt_kern<dtype_src, dtype_dest, void> { \
static constexpr bool value = false; \
};
template <typename dtype_src, typename dtype_dest>
void typecvt_kern_q2q(
const TensorND& dest, const TensorND& src,
......@@ -257,12 +270,28 @@ void typecvt_kern_q2q(
}
template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q(
typename std::enable_if<enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
main_func(TypeCvtOpToQuantized, op.param = dst_param;);
}
template <typename dtype_src, typename dtype_dest>
typename std::enable_if<!enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled");
}
template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
typecvt_kern_n2q_impl<dtype_src, dtype_dest>(dest, src, dst_param, stream);
}
template <typename dtype_src, typename dtype_dest>
void typecvt_kern_q2n(
const TensorND& dest, const TensorND& src,
......@@ -312,12 +341,15 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
cb(dtype_src, dt_qint8) \
cb(dtype_src, dt_qint1) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN)
#define INST_SRC_QUANTIZED(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2Q) \
#define INST_SRC_NORMAL(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2N) \
INST_N2N(dtype_src, dt_uint16) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2Q) \
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
......@@ -340,6 +372,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED)
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL)
INST_SRC_NORMAL(dt_uint16)
// clang-format on
template void typecvt_kern_n2q<dtype::Int8, dtype::QuantizedS8>(
......@@ -377,12 +410,28 @@ void typecvt_kern_q2q4(
}
template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q4(
typename std::enable_if<enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q4_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
main_func_to_q4(TypeCvtOpFromNormalToQuantized4bit, op.dst_param = dst_param;)
}
template <typename dtype_src, typename dtype_dest>
typename std::enable_if<!enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q4_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled");
}
template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q4(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
typecvt_kern_n2q4_impl<dtype_src, dtype_dest>(dest, src, dst_param, stream);
}
#define INST_Q2Q4(dtype_src, dtype_dest) \
template void typecvt_kern_q2q4<dtype_src, dtype_dest>( \
const TensorND& dest, const TensorND& src, \
......@@ -399,6 +448,8 @@ void typecvt_kern_n2q4(
cb(dtype_src, dt_qint4) \
cb(dtype_src, dt_quint4) \
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN)
#define INST_SRC_QUANTIZED_LOWBIT(dtype_src) \
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dtype_src, INST_Q2Q4) \
......@@ -407,6 +458,7 @@ void typecvt_kern_n2q4(
MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED_LOWBIT)
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL_LOWBIT)
INST_SRC_NORMAL_LOWBIT(dt_uint16)
} // namespace cuda
} // namespace megdnn
......
......@@ -12,6 +12,8 @@
#include "./opr_impl.h"
#include "./kern.cuh"
#include "megdnn/dtype.h"
#include "src/common/utils.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"
......@@ -87,10 +89,9 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
cb(::megdnn::dtype::Bool);
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
#undef cb
default:
megdnn_assert_internal(0);
default : megdnn_assert_internal(0);
}
} else if (!is_dst_lowbit) {
switch (dst.layout.dtype.enumv()) {
......@@ -138,7 +139,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
#undef cb
default : megdnn_assert_internal(0);
}
......
......@@ -19,7 +19,8 @@ using namespace test;
TEST_F(CUDA, TYPE_CVT) {
UniformFloatRNG init(0, 20);
std::vector<DType> dtypes = {dtype::Float32(), dtype::Float16(), dtype::Int32(),
dtype::Int16(), dtype::Int8(), dtype::Uint8()};
dtype::Int16(), dtype::Int8(), dtype::Uint8(),
dtype::Uint16()};
for (auto sdtype : dtypes)
for (auto ddtype : dtypes) {
TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype);
......
......@@ -210,6 +210,7 @@ typename ctype_enable_if<ctype>::type DTypeScalar::set_retain_dtype(ctype val) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
cb(dt_bool);
cb(dt_uint16);
#undef cb
default:
mgb_throw(ConversionError, "can not assign to dtype %s", m_dtype.name());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册