提交 7066ad5b 编写于 作者: M Megvii Engine Team

feat(dnn): add uint16 support

GitOrigin-RevId: f4c4b1c7b9076dd8c4c13b23753d6f1aa5570ae5
上级 a1877ee0
......@@ -53,6 +53,7 @@ namespace megdnn {
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) \
cb(Bool) \
cb(Uint16) \
/*!
* \brief iterate through each full byte dtype
......@@ -67,6 +68,7 @@ namespace megdnn {
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(Bool) \
cb(Uint16) \
/*!
* \brief iterate through each fractional byte dtype
......@@ -353,6 +355,7 @@ typedef int16_t dt_int16;
typedef int8_t dt_int8;
typedef uint8_t dt_uint8;
typedef bool dt_bool;
typedef uint16_t dt_uint16;
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
......@@ -381,6 +384,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
BFloat16 = 11,
#endif
Bool = 12,
Uint16 = 13,
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name,
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D)
......@@ -713,6 +717,7 @@ MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX);
MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX);
MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX);
MEGDNN_DEF_DT(Bool, dt_bool, BOOL, UNSIGNED, false, true);
MEGDNN_DEF_DT(Uint16, dt_uint16, INT, UNSIGNED, 0, UINT16_MAX);
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED,
std::numeric_limits<dt_float16>::lowest(),
std::numeric_limits<dt_float16>::max()));
......
......@@ -26,6 +26,7 @@ TEST(TestDType, SizeCheck) {
ASSERT_EQ(static_cast<size_t>(2), ::megdnn::dtype::IntB4().size(3));
ASSERT_EQ(static_cast<size_t>(2), ::megdnn::dtype::IntB4().size(4));
ASSERT_EQ(static_cast<size_t>(3), ::megdnn::dtype::IntB4().size(5));
ASSERT_EQ(static_cast<size_t>(2), ::megdnn::dtype::Uint16().size(1));
ASSERT_EQ(static_cast<size_t>(2),
::megdnn::dtype::Quantized4Asymm(1.0f, static_cast<uint8_t>(12))
.size(3));
......
......@@ -28,6 +28,7 @@ using ::megdnn::dt_quint8;
using ::megdnn::dt_qint8;
using ::megdnn::dt_qint32;
using ::megdnn::dt_bool;
using ::megdnn::dt_uint16;
using ::megdnn::DType;
using ::megdnn::DTypeEnum;
using ::megdnn::DTypeTrait;
......
......@@ -370,6 +370,12 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const {
iv += contain_eq;
return std::max(iv, 0);
}
case DTypeEnum::Uint16:
{
auto iv = val.ptr<dt_uint16>()[0];
iv += contain_eq;
return std::max<int>(iv, 0);
}
case DTypeEnum::Float32:
#if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
......
......@@ -249,6 +249,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr,
break;
case DTypeEnum::Bool:
break;
case DTypeEnum::Uint16:
break;
#define cb(_dt) \
case DTypeEnum::_dt: \
break;
......
......@@ -22,6 +22,7 @@ enum DTypeEnum : byte {
QuantizedS16,
BFloat16,
Bool,
Uint16,
}
table LinearQuantizationParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册