From 7066ad5ba63b551c21ebe659261c538c1cc88a3a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 23 Nov 2020 14:45:41 +0800 Subject: [PATCH] feat(dnn): add uint16 support GitOrigin-RevId: f4c4b1c7b9076dd8c4c13b23753d6f1aa5570ae5 --- dnn/include/megdnn/dtype.h | 5 +++++ dnn/test/common/dtype.cpp | 1 + src/core/include/megbrain/dtype.h | 1 + src/opr/impl/loop/forward.cpp | 6 ++++++ src/opr/impl/loop/impl.cpp | 2 ++ src/serialization/impl/dtype.fbs | 1 + 6 files changed, 16 insertions(+) diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index 69d1f2f9d..9bd117a41 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -53,6 +53,7 @@ namespace megdnn { MEGDNN_INC_FLOAT16(cb(BFloat16)) \ cb(UintB4) \ cb(Bool) \ + cb(Uint16) \ /*! * \brief iterate through each full byte dtype @@ -67,6 +68,7 @@ namespace megdnn { MEGDNN_INC_FLOAT16(cb(Float16)) \ MEGDNN_INC_FLOAT16(cb(BFloat16)) \ cb(Bool) \ + cb(Uint16) \ /*! * \brief iterate through each fractional byte dtype @@ -353,6 +355,7 @@ typedef int16_t dt_int16; typedef int8_t dt_int8; typedef uint8_t dt_uint8; typedef bool dt_bool; +typedef uint16_t dt_uint16; MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) @@ -381,6 +384,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) BFloat16 = 11, #endif Bool = 12, + Uint16 = 13, #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, #define D(_name) _name, MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) @@ -713,6 +717,7 @@ MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX); MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX); MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); MEGDNN_DEF_DT(Bool, dt_bool, BOOL, UNSIGNED, false, true); +MEGDNN_DEF_DT(Uint16, dt_uint16, INT, UNSIGNED, 0, UINT16_MAX); MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, std::numeric_limits::lowest(), std::numeric_limits::max())); diff --git a/dnn/test/common/dtype.cpp b/dnn/test/common/dtype.cpp index 92f1a8dcd..cc6ab2963 100644 --- a/dnn/test/common/dtype.cpp +++ b/dnn/test/common/dtype.cpp @@ -26,6 +26,7 @@ TEST(TestDType, SizeCheck) { ASSERT_EQ(static_cast(2), ::megdnn::dtype::IntB4().size(3)); ASSERT_EQ(static_cast(2), ::megdnn::dtype::IntB4().size(4)); ASSERT_EQ(static_cast(3), ::megdnn::dtype::IntB4().size(5)); + ASSERT_EQ(static_cast(2), ::megdnn::dtype::Uint16().size(1)); ASSERT_EQ(static_cast(2), ::megdnn::dtype::Quantized4Asymm(1.0f, static_cast(12)) .size(3)); diff --git a/src/core/include/megbrain/dtype.h b/src/core/include/megbrain/dtype.h index 52fd71db8..6e98c36d5 100644 --- a/src/core/include/megbrain/dtype.h +++ b/src/core/include/megbrain/dtype.h @@ -28,6 +28,7 @@ using ::megdnn::dt_quint8; using ::megdnn::dt_qint8; using ::megdnn::dt_qint32; using ::megdnn::dt_bool; +using ::megdnn::dt_uint16; using ::megdnn::DType; using ::megdnn::DTypeEnum; using ::megdnn::DTypeTrait; diff --git a/src/opr/impl/loop/forward.cpp b/src/opr/impl/loop/forward.cpp index b8593c739..bce04c774 100644 --- a/src/opr/impl/loop/forward.cpp +++ b/src/opr/impl/loop/forward.cpp @@ -370,6 +370,12 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { iv += contain_eq; return std::max(iv, 0); } + case DTypeEnum::Uint16: + { + auto iv = val.ptr()[0]; + iv += contain_eq; + return std::max(iv, 0); + } case DTypeEnum::Float32: #if !MEGDNN_DISABLE_FLOAT16 case DTypeEnum::Float16: diff --git a/src/opr/impl/loop/impl.cpp b/src/opr/impl/loop/impl.cpp index 305ca8175..e34b59713 100644 --- a/src/opr/impl/loop/impl.cpp +++ b/src/opr/impl/loop/impl.cpp @@ -249,6 +249,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr, break; case DTypeEnum::Bool: break; + case DTypeEnum::Uint16: + break; #define cb(_dt) \ case DTypeEnum::_dt: \ break; diff --git a/src/serialization/impl/dtype.fbs b/src/serialization/impl/dtype.fbs index 6fe3ec23a..6e239c6a6 100644 --- a/src/serialization/impl/dtype.fbs +++ b/src/serialization/impl/dtype.fbs @@ -22,6 +22,7 @@ enum DTypeEnum : byte { QuantizedS16, BFloat16, Bool, + Uint16, } table LinearQuantizationParam { -- GitLab