diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index 69d1f2f9d1088041656816782ec1d6a2f35a9768..9bd117a413eb563dda253726c6a4502b2c7015b0 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -53,6 +53,7 @@ namespace megdnn { MEGDNN_INC_FLOAT16(cb(BFloat16)) \ cb(UintB4) \ cb(Bool) \ + cb(Uint16) \ /*! * \brief iterate through each full byte dtype @@ -67,6 +68,7 @@ namespace megdnn { MEGDNN_INC_FLOAT16(cb(Float16)) \ MEGDNN_INC_FLOAT16(cb(BFloat16)) \ cb(Bool) \ + cb(Uint16) \ /*! * \brief iterate through each fractional byte dtype @@ -353,6 +355,7 @@ typedef int16_t dt_int16; typedef int8_t dt_int8; typedef uint8_t dt_uint8; typedef bool dt_bool; +typedef uint16_t dt_uint16; MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) @@ -381,6 +384,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) BFloat16 = 11, #endif Bool = 12, + Uint16 = 13, #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, #define D(_name) _name, MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) @@ -713,6 +717,7 @@ MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX); MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX); MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); MEGDNN_DEF_DT(Bool, dt_bool, BOOL, UNSIGNED, false, true); +MEGDNN_DEF_DT(Uint16, dt_uint16, INT, UNSIGNED, 0, UINT16_MAX); MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, std::numeric_limits::lowest(), std::numeric_limits::max())); diff --git a/dnn/test/common/dtype.cpp b/dnn/test/common/dtype.cpp index 92f1a8dcda6485b4beb6373674590e913749c50b..cc6ab29633a9fabf3f18ec9919893aa98ea51da5 100644 --- a/dnn/test/common/dtype.cpp +++ b/dnn/test/common/dtype.cpp @@ -26,6 +26,7 @@ TEST(TestDType, SizeCheck) { ASSERT_EQ(static_cast(2), ::megdnn::dtype::IntB4().size(3)); ASSERT_EQ(static_cast(2), ::megdnn::dtype::IntB4().size(4)); ASSERT_EQ(static_cast(3), ::megdnn::dtype::IntB4().size(5)); + ASSERT_EQ(static_cast(2), ::megdnn::dtype::Uint16().size(1)); ASSERT_EQ(static_cast(2), ::megdnn::dtype::Quantized4Asymm(1.0f, static_cast(12)) .size(3)); diff --git a/src/core/include/megbrain/dtype.h b/src/core/include/megbrain/dtype.h index 52fd71db893019d13cb10097d02779da6bfb9999..6e98c36d5f9f332838928d77c0259a799cb78682 100644 --- a/src/core/include/megbrain/dtype.h +++ b/src/core/include/megbrain/dtype.h @@ -28,6 +28,7 @@ using ::megdnn::dt_quint8; using ::megdnn::dt_qint8; using ::megdnn::dt_qint32; using ::megdnn::dt_bool; +using ::megdnn::dt_uint16; using ::megdnn::DType; using ::megdnn::DTypeEnum; using ::megdnn::DTypeTrait; diff --git a/src/opr/impl/loop/forward.cpp b/src/opr/impl/loop/forward.cpp index b8593c739c3509ad0fc748abd4f6d425de6785fb..bce04c774cfd4e6eefb53e6cec5773429f2801a6 100644 --- a/src/opr/impl/loop/forward.cpp +++ b/src/opr/impl/loop/forward.cpp @@ -370,6 +370,12 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { iv += contain_eq; return std::max(iv, 0); } + case DTypeEnum::Uint16: + { + auto iv = val.ptr()[0]; + iv += contain_eq; + return std::max(iv, 0); + } case DTypeEnum::Float32: #if !MEGDNN_DISABLE_FLOAT16 case DTypeEnum::Float16: diff --git a/src/opr/impl/loop/impl.cpp b/src/opr/impl/loop/impl.cpp index 305ca81756c39c0e02030b3ff8595c32cd85fbb5..e34b59713a1660e99411044e663a9434b2635c6c 100644 --- a/src/opr/impl/loop/impl.cpp +++ b/src/opr/impl/loop/impl.cpp @@ -249,6 +249,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr, break; case DTypeEnum::Bool: break; + case DTypeEnum::Uint16: + break; #define cb(_dt) \ case DTypeEnum::_dt: \ break; diff --git a/src/serialization/impl/dtype.fbs b/src/serialization/impl/dtype.fbs index 6fe3ec23ad15665cadc17f0be0fc8b34034fd006..6e239c6a6e28e7432ea9f03bdb0429ab3fbd9e49 100644 --- a/src/serialization/impl/dtype.fbs +++ b/src/serialization/impl/dtype.fbs @@ -22,6 +22,7 @@ enum DTypeEnum : byte { QuantizedS16, BFloat16, Bool, + Uint16, } table LinearQuantizationParam {