diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index 5bd360b81b322f805d1576359fc9df78d24c7090..7a8c876bdc998866b79bf660cbae3a119f0e266f 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -62,7 +62,7 @@ namespace megdnn { #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \ cb(QuantizedS32) cb(QuantizedS8) cb(Quantized4Asymm) cb(QuantizedS4) \ - cb(QuantizedS16) + cb(QuantizedS16) cb(QuantizedS1) #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \ MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \ @@ -112,7 +112,7 @@ namespace megdnn { #define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ cb(::megdnn::dtype::QuantizedS32) cb(::megdnn::dtype::QuantizedS8) \ - cb(::megdnn::dtype::QuantizedS4) + cb(::megdnn::dtype::QuantizedS4) cb(::megdnn::dtype::QuantizedS1) #define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::Quantized4Asymm) @@ -292,10 +292,27 @@ public: }; using dt_qint4 = dt_qlowbit<4>; +class dt_qint1 { + int8_t _; + +public: + MEGDNN_DEVICE int8_t as_int8() const { return _; } + + MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint1(int8_t val) : _(val) {} +#ifdef MEGDNN_CC_HOST + explicit operator int8_t() { return _; } +#endif + bool operator<(const dt_qint1& b) const { return _ < b._; } + bool operator>(const dt_qint1& b) const { return _ > b._; } + bool operator==(const dt_qint1& b) const { return _ == b._; } + bool operator!=(const dt_qint1& b) const { return _ != b._; } +} MEGDNN_PACKED; + #ifdef __clang__ #pragma clang diagnostic pop #endif MEGDNN_STATIC_ASSERT(sizeof(dt_byte) == 1, "bad dt_byte size"); +MEGDNN_STATIC_ASSERT(sizeof(dt_qint1) == 1, "bad dt_qint1 size"); MEGDNN_STATIC_ASSERT(sizeof(dt_quint8) == 1, "bad dt_quint8 size"); MEGDNN_STATIC_ASSERT(sizeof(dt_qint16) == 2, "bad dt_qint16 size"); MEGDNN_STATIC_ASSERT(sizeof(dt_qint32) == 4, "bad dt_qint32 size"); @@ -677,7 +694,7 @@ MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT) return static_cast<_itype>(_maxval); \ } \ }; - +MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS1, dt_qint1, int8_t, QUANTIZED, SIGNED, 0, 1, 0); MEGDNN_DEF_PARAMETERIZED_DT( Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, SIGNED, 0, 15, 4); MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, SIGNED, -8, 7, 4); @@ -876,6 +893,26 @@ struct DTypeParamImpl { } }; +template <> +struct DTypeParamImpl { + float scale; + + DTypeParamImpl() = default; + MGE_WIN_DECLSPEC_FUC DTypeParamImpl(float scale); +#ifdef MEGDNN_CC_HOST + std::size_t hash() const; +#endif + bool operator==(const DTypeParam& rhs) const; + MEGDNN_DEVICE dt_qint1 quantize(float in) const { + float v = in / scale; + v = roundf(v); + v = fmin(fmax(0.f, v), 1.f); + return static_cast(v); + } + MEGDNN_DEVICE float dequantize(int8_t in) const { return in * scale; } + MEGDNN_DEVICE float dequantize(dt_qint1 in) const { return in.as_int8() * scale; } +}; + template <> struct DTypeParamImpl { float scale; diff --git a/dnn/src/common/dtype.cpp b/dnn/src/common/dtype.cpp index 417e5613984e7922ad30e7b1a855fedd3999c4dc..431346d6607e898973fc83afa39f08ba964278c1 100644 --- a/dnn/src/common/dtype.cpp +++ b/dnn/src/common/dtype.cpp @@ -142,6 +142,19 @@ inline bool DTypeParam::operator==(const DTypeParam& rhs) return scale == rhs.scale; } +DTypeParam::DTypeParamImpl(float scale) : scale{scale} { + //! As the nan is not equal to any value + megdnn_assert(!std::isnan(scale), "nan number compare is not support"); +} + +inline std::size_t DTypeParam::hash() const { + return std::hash()(scale); +} + +inline bool DTypeParam::operator==(const DTypeParam& rhs) const { + return scale == rhs.scale; +} + DTypeParam::DTypeParamImpl(float scale, uint8_t zero_point) : scale{scale}, zero_point{zero_point} { //! As the nan is not equal to any value diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index 28501839cc72bebc76b277bbbc0e000c2456eb38..bd1565a612d51a27505790017ba5685e7c421fe8 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -241,6 +241,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) { return lhs.param
().scale * rhs.param
().scale; MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) + cb(::megdnn::dtype::QuantizedS1) #undef cb megdnn_assert_internal(0); } @@ -253,8 +254,9 @@ float megdnn::get_scale(DType dt) { return dt.param<_dt>().scale; MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) + cb(::megdnn::dtype::QuantizedS1) #undef cb - megdnn_assert_internal(0); + megdnn_assert_internal(0); } bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { diff --git a/dnn/src/cuda/elemwise_helper.cpp b/dnn/src/cuda/elemwise_helper.cpp index 90ad483f217544ab84a7caaa1193d467cb67e30d..33349679c1c65b28d1ea7e600cc20eab6f8e7677 100644 --- a/dnn/src/cuda/elemwise_helper.cpp +++ b/dnn/src/cuda/elemwise_helper.cpp @@ -160,6 +160,9 @@ INST_FOR_CTYPE #define ct dt_bool INST_FOR_CTYPE #undef ct +#define ct dt_qint1 +INST_FOR_CTYPE +#undef ct #undef INST_FOR_CTYPE #undef INST @@ -210,6 +213,9 @@ INST_FOR_CTYPE #define ct dt_bool INST_FOR_CTYPE #undef ct +#define ct dt_qint1 +INST_FOR_CTYPE +#undef ct #undef ndim_cb @@ -221,6 +227,7 @@ INST(dt_int8); INST(dt_uint8); INST(dt_bool); INST(dt_qint8); +INST(dt_qint1); INST(dt_quint8); #undef dt_ibyte diff --git a/dnn/src/cuda/elemwise_helper.cuh b/dnn/src/cuda/elemwise_helper.cuh index c71b5f2e0b57f62a8f480dd22c9038bdbfdd7174..e098d8d50a2be70cb01301b2373f0afd21d75043 100644 --- a/dnn/src/cuda/elemwise_helper.cuh +++ b/dnn/src/cuda/elemwise_helper.cuh @@ -96,6 +96,7 @@ INST(dt_bool, uchar4); #undef as_raw #define as_raw(x) x.as_int8() INST(dt_qint8, char4); +INST(dt_qint1, char4); #undef as_raw #define as_raw(x) x.as_uint8() INST(dt_quint8, uchar4); @@ -466,6 +467,7 @@ INST_PARAM_VECT_VISITOR; INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_qint8); +INST_DT_IBYTE(dt_qint1); INST_DT_IBYTE(dt_quint8); INST_DT_IBYTE(dt_bool); #undef INST_DT_IBYTE @@ -1299,6 +1301,7 @@ private: INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_qint8); +INST_DT_IBYTE(dt_qint1); INST_DT_IBYTE(dt_quint8); INST_DT_IBYTE(dt_bool); #undef INST_DT_IBYTE @@ -1649,6 +1652,7 @@ public: INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_qint8); +INST_DT_IBYTE(dt_qint1); INST_DT_IBYTE(dt_quint8); INST_DT_IBYTE(dt_bool); #undef INST_DT_IBYTE diff --git a/dnn/src/cuda/type_cvt/kern.cu b/dnn/src/cuda/type_cvt/kern.cu index d786df19793515d0297113eecd09b013fd2aaa20..94ec995ba9cff3c63d75c35b2059d66c6d5b28f5 100644 --- a/dnn/src/cuda/type_cvt/kern.cu +++ b/dnn/src/cuda/type_cvt/kern.cu @@ -88,6 +88,7 @@ struct TypeCvtOpToQuantized< typename std::enable_if< std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam param; @@ -111,6 +112,7 @@ struct TypeCvtOpFromQuantized< ctype_dest, ctype_src, typename std::enable_if< std::is_same::value || + std::is_same::value || std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam param; @@ -134,7 +136,8 @@ struct TypeCvtOpBetweenQuantized< ctype_dest, ctype_src, typename std::enable_if< (std::is_same::value || - std::is_same::value) && + std::is_same::value || + std::is_same::value) && IsNotTypeQ4::value>::type> { ctype_dest* dest; CudaDTypeParam src_param; @@ -306,6 +309,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st cb(dtype_src, dt_quint8) \ cb(dtype_src, dt_qint32) \ cb(dtype_src, dt_qint8) \ + cb(dtype_src, dt_qint1) \ #define INST_SRC_QUANTIZED(dtype_src) \ MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \ @@ -330,7 +334,8 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st cb(dt_qint32) \ cb(dt_qint8) \ cb(dt_qint4) \ - cb(dt_quint4) + cb(dt_quint4) \ + cb(dt_qint1) MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED) MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL) diff --git a/dnn/src/cuda/type_cvt/opr_impl.cpp b/dnn/src/cuda/type_cvt/opr_impl.cpp index 13a19e46d2124e0b936e5fc0778d74d5d76c9301..31013dc55afa90f66b7cd43246e766ec7a1c05b2 100644 --- a/dnn/src/cuda/type_cvt/opr_impl.cpp +++ b/dnn/src/cuda/type_cvt/opr_impl.cpp @@ -50,6 +50,7 @@ void exec_src_quantized( return; \ } MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); + cb(::megdnn::dtype::QuantizedS1); default: megdnn_assert_internal(0); #undef cb @@ -101,6 +102,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre return; \ } MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); + cb(::megdnn::dtype::QuantizedS1); #undef cb default: megdnn_assert_internal(0); @@ -150,9 +152,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { } MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) + cb(::megdnn::dtype::QuantizedS1) #undef cb - default: - megdnn_assert_internal(0); + default : megdnn_assert_internal(0); } } } diff --git a/dnn/src/cuda/utils.cuh b/dnn/src/cuda/utils.cuh index 732f531ad853c611aa9a7674e05efaa2eace71c6..3c6caa06b158d11ef08607e8369e3c91278999bd 100644 --- a/dnn/src/cuda/utils.cuh +++ b/dnn/src/cuda/utils.cuh @@ -241,6 +241,23 @@ struct CudaDTypeParamImpl : DTypeParamImpl { } }; +template <> +struct CudaDTypeParamImpl : DTypeParamImpl { + float inv_scale; + CudaDTypeParamImpl() = default; + CudaDTypeParamImpl(float scale) + : DTypeParamImpl(scale), inv_scale(1.0f / scale) {} + CudaDTypeParamImpl(const DTypeParamImpl& param) + : CudaDTypeParamImpl(param.scale) {} + + __device__ dt_qint1 quantize(float in) const { + float v = in * inv_scale; + v = roundf(v); + v = fmin(fmax(0.f, v), 1.f); + return static_cast(v); + } +}; + #if MEGDNN_CC_CUDA static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) { #if __CUDA_ARCH__ >= 610 diff --git a/dnn/src/fallback/type_cvt/opr_impl.cpp b/dnn/src/fallback/type_cvt/opr_impl.cpp index 090aeac3155dbdeefbda804eca87af92e4779a93..6f002b417247c433e225b08179643fc3026de512 100644 --- a/dnn/src/fallback/type_cvt/opr_impl.cpp +++ b/dnn/src/fallback/type_cvt/opr_impl.cpp @@ -510,7 +510,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { }; if (src.layout.is_contiguous() && dst.layout.is_contiguous() && !is_quantize_lowbit(src.layout.dtype) && - !is_quantize_lowbit(dst.layout.dtype)) { + !is_quantize_lowbit(dst.layout.dtype) && + dst.layout.dtype.enumv() != DTypeEnum::QuantizedS1 && + src.layout.dtype.enumv() != DTypeEnum::QuantizedS1) { MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); } else { naive::TypeCvtImpl::exec(src, dst); diff --git a/dnn/src/naive/type_cvt/opr_impl.cpp b/dnn/src/naive/type_cvt/opr_impl.cpp index 4ab170d047432669ba0e781c2403664f570793a6..ba2f3e42bfd33bfef2f27f6b55037b6eb7b64664 100644 --- a/dnn/src/naive/type_cvt/opr_impl.cpp +++ b/dnn/src/naive/type_cvt/opr_impl.cpp @@ -79,8 +79,9 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) + cb(::megdnn::dtype::QuantizedS1) #undef cb - default : megdnn_throw("bad dtype"); + default : megdnn_throw("bad dtype"); } } @@ -100,8 +101,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) + cb(::megdnn::dtype::QuantizedS1) #undef cb - default : megdnn_throw("bad dtype"); + default : megdnn_throw("bad dtype"); } } diff --git a/dnn/test/common/checker.cpp b/dnn/test/common/checker.cpp index ab519410af3b7294233f4d2f08f01298dae3fd73..66965e9f8707b6f51c467e6a3c1e0207d846ce6b 100644 --- a/dnn/test/common/checker.cpp +++ b/dnn/test/common/checker.cpp @@ -79,7 +79,8 @@ template const char* expr0, const char* expr1, const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { if (!std::is_same::value && - !std::is_same::value) { + !std::is_same::value && + !std::is_same::value) { if (v0.layout.is_physical_contiguous() && v1.layout.is_physical_contiguous()) { return assert_tensor_eq_with_iter( expr0, expr1, v0.ptr(), v1.ptr(), v0.layout, maxerr, @@ -158,7 +159,7 @@ void copy_tensors( //! In order to avoid an unnecessary increase in binary size, we just //! use QuantizedS16 dtype in winograd_filter_preprocess now. cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) - cb(::megdnn::dtype::Uint16) + cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1) #undef cb default : megdnn_trap(); } diff --git a/dnn/test/common/dtype.cpp b/dnn/test/common/dtype.cpp index 9ba34db40465dfedb78a1d8d307f3b2f246a4cb4..df763ccb1ab177b995c7becef8c232beb187bfe4 100644 --- a/dnn/test/common/dtype.cpp +++ b/dnn/test/common/dtype.cpp @@ -71,6 +71,32 @@ TEST(TestDType, TestQuantized8Asymm) { EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::Quantized8Asymm)); } +TEST(TestDType, QuantizedS1) { + using namespace megdnn; + + dtype::QuantizedS1 qint1(0.1f); + EXPECT_EQ(qint1.size(1), 1u); + EXPECT_FLOAT_EQ(qint1.param().scale, 0.1f); + + dtype::QuantizedS1 qint1_copy = qint1; + EXPECT_NO_THROW(qint1_copy.assert_is(qint1)); + EXPECT_FLOAT_EQ(qint1_copy.param().scale, 0.1f); + + dtype::QuantizedS1 qint1_reconstruct_with_same_param(0.1f); + EXPECT_NO_THROW(qint1_reconstruct_with_same_param.assert_is(qint1)); + + dtype::QuantizedS1 qint1_diff(0.2f); + EXPECT_ANY_THROW(qint1_diff.assert_is(qint1)); + + DType parent = qint1; + ASSERT_NO_THROW(dtype::QuantizedS1::downcast_from(parent)); + auto param = dtype::QuantizedS1::downcast_from(parent).param(); + EXPECT_FLOAT_EQ(param.scale, 0.1f); + + EXPECT_ANY_THROW(dtype::QuantizedS1::downcast_from(dtype::IntB1())); + EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::QuantizedS1)); +} + TEST(TestDType, TestQuantizedS4) { using namespace megdnn; diff --git a/dnn/test/common/rng.cpp b/dnn/test/common/rng.cpp index bd4f54deb17c17f62b6c3b0d2d602cb23e1c5331..1c921940ebcde7584d5ec33efc38c5301e597fe1 100644 --- a/dnn/test/common/rng.cpp +++ b/dnn/test/common/rng.cpp @@ -149,7 +149,7 @@ void IIDRNG::gen(const TensorND& tensor) { MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) //! In order to avoid an unnecessary increase in binary size, we just //! use QuantizedS16 dtype in winograd_filter_preprocess now. - cb(::megdnn::dtype::QuantizedS16) + cb(::megdnn::dtype::QuantizedS16) cb(::megdnn::dtype::QuantizedS1) #undef cb if (tensor.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { auto ptr = static_cast(tensor.raw_ptr()); diff --git a/dnn/test/common/utils.h b/dnn/test/common/utils.h index f70418efebfba25620ddd0d7307eb226df3c4a91..2bd376208f6d9679dc553a6cd6cab489cabc4af2 100644 --- a/dnn/test/common/utils.h +++ b/dnn/test/common/utils.h @@ -226,6 +226,10 @@ static inline int diff(dt_qint4 x, dt_qint4 y) { return x.as_int8() - y.as_int8(); } +static inline int diff(dt_qint1 x, dt_qint1 y) { + return x.as_int8() - y.as_int8(); +} + static inline int diff(dt_quint4 x, dt_quint4 y) { return x.as_uint8() - y.as_uint8(); } @@ -339,6 +343,10 @@ static inline bool good_float(dt_qint4) { return true; } +static inline bool good_float(dt_qint1) { + return true; +} + static inline bool good_float(dt_quint4) { return true; } @@ -373,6 +381,11 @@ static inline int operator+(dt_qint4 lhs, int rhs) { megdnn_assert(rhs == 0, "unexpected rhs"); return lhs.as_int8(); } + +static inline int operator+(dt_qint1 lhs, int rhs) { + megdnn_assert(rhs == 0, "unexpected rhs"); + return lhs.as_int8(); +} } // namespace test static inline bool operator==(const TensorLayout& a, const TensorLayout& b) { diff --git a/dnn/test/cuda/type_cvt.cpp b/dnn/test/cuda/type_cvt.cpp index 11c07e232b1bc7473c6b07504091c873b7444324..6cd7e2a97a9c3994a7dc993869e8de4e6f62e900 100644 --- a/dnn/test/cuda/type_cvt.cpp +++ b/dnn/test/cuda/type_cvt.cpp @@ -77,16 +77,19 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { }; run(dtype::Float32(), dtype::QuantizedS8(3.0f)); + run(dtype::Float32(), dtype::QuantizedS1(3.0f)); run(dtype::Float16(), dtype::QuantizedS8(3.0f)); run(dtype::Int32(), dtype::QuantizedS32(5.0f)); run(dtype::Int8(), dtype::QuantizedS32(10.0f)); run(dtype::Float32(), dtype::QuantizedS8(2e-3f)); + run(dtype::Float32(), dtype::QuantizedS1(2e-3f)); run(dtype::Float16(), dtype::QuantizedS8(1e-3f)); run(dtype::Int32(), dtype::QuantizedS32(1e-3f)); run(dtype::Int8(), dtype::QuantizedS32(7e-4f)); run(dtype::QuantizedS8(3.0f), dtype::QuantizedS8(10.0f)); + run(dtype::QuantizedS1(3.0f), dtype::QuantizedS1(10.0f)); run(dtype::QuantizedS32(3.0f), dtype::QuantizedS8(10.0f)); run(dtype::QuantizedS8(3.0f), dtype::QuantizedS32(10.0f)); run(dtype::QuantizedS32(3.0f), dtype::QuantizedS32(10.0f)); @@ -95,6 +98,7 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { run(dtype::QuantizedS32(2e-3f), dtype::QuantizedS8(9e-4f)); run(dtype::QuantizedS8(9e-4f), dtype::QuantizedS32(7e-4f)); run(dtype::QuantizedS32(5e-3f), dtype::QuantizedS32(1e-3f)); + run(dtype::QuantizedS1(1e-3f), dtype::Float32()); run(dtype::Quantized8Asymm(5.0f, (uint8_t)128), dtype::Float32()); run(dtype::Quantized8Asymm(5.0f, (uint8_t)124), dtype::Float16()); diff --git a/imperative/python/megengine/core/tensor/dtype.py b/imperative/python/megengine/core/tensor/dtype.py index 0fccae18b0519f0ec17a715d908e52d68d3c574d..0496d0ab6df11425b54ec1c9b1f49f7d687fb777 100644 --- a/imperative/python/megengine/core/tensor/dtype.py +++ b/imperative/python/megengine/core/tensor/dtype.py @@ -94,6 +94,7 @@ _builtin_quant_dtypes = { "qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127), "quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15), "qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7), + "qint1": QuantDtypeMeta("qint1", "QuantizedS1", "int8", 0, 1), "qint32": QuantDtypeMeta( "qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1, ), @@ -192,6 +193,13 @@ def qint4(scale): return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) +def qint1(scale): + r"""Construct a quantized int1 data type with ``scale`` (float). The real value + represented by a qint1 data type is float_val = scale * int1_val + """ + return create_quantized_dtype(_builtin_quant_dtypes["qint1"], scale, None) + + def _convert_to_quantized_dtype( arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta ): @@ -335,3 +343,22 @@ def convert_from_qint4(arr: np.ndarray): arr: Input ndarray. """ return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) + + +def convert_to_qint1(arr: np.ndarray, q: np.dtype): + r"""Quantize a float NumPy ndarray into a qint1 one with specified params. + + Args: + arr: Input ndarray. + q: Target data type, should be a qint1. + """ + return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint1"]) + + +def convert_from_qint1(arr: np.ndarray): + r"""Dequantize a qint1 NumPy ndarray into a float one. + + Args: + arr: Input ndarray. + """ + return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint1"]) diff --git a/imperative/python/src/helper.cpp b/imperative/python/src/helper.cpp index 94a51f9eb38cc181c64ef9ae6b46cd1f9f0993ae..f9290df560fac92ee645f9b79909a1b9855524bf 100644 --- a/imperative/python/src/helper.cpp +++ b/imperative/python/src/helper.cpp @@ -214,6 +214,14 @@ std::unique_ptr dtype_mgb2np_descr(DType dty if (dtype.has_param()) { PyArray_Descr* type_descr; switch (dtype.enumv()) { + case DTypeEnum::QuantizedS1: { + auto& param = dtype.param(); + type_descr = PyArray_DescrNewFromType(NPY_INT8); + type_descr->metadata = build_mgb_dtype_dict( + DTypeTrait::name, + {{"scale", PyFloat_FromDouble(param.scale)}}); + break; + } case DTypeEnum::Quantized4Asymm: { auto& param = dtype.param(); type_descr = PyArray_DescrNewFromType(NPY_UINT8); @@ -354,7 +362,7 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { static_cast(zero_point)); } if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || - dtype_name == "QuantizedS4") { + dtype_name == "QuantizedS4" || dtype_name == "QuantizedS1") { PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); mgb_assert(scale_py, "Invalid metadata: missing scale"); mgb_assert( @@ -364,8 +372,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { return dtype::QuantizedS32(scale); } else if (dtype_name == "QuantizedS8") { return dtype::QuantizedS8(scale); - } else { + } else if (dtype_name == "QuantizedS4") { return dtype::QuantizedS4(scale); + } else if (dtype_name == "QuantizedS1") { + return dtype::QuantizedS1(scale); } } throw ConversionError( diff --git a/imperative/python/test/unit/core/test_dtype_quant.py b/imperative/python/test/unit/core/test_dtype_quant.py index dc58a1d82c7e1f1474ab0673f61b3f0ce5dc6d38..9cec2da9c8dd80d650e6c985df3830047beba844 100644 --- a/imperative/python/test/unit/core/test_dtype_quant.py +++ b/imperative/python/test/unit/core/test_dtype_quant.py @@ -15,10 +15,12 @@ import megengine.core.tensor.megbrain_graph as G from megengine.core.ops import builtin as ops from megengine.core.tensor.dtype import ( _builtin_quant_dtypes, + convert_from_qint1, convert_from_qint4, convert_from_qint8, convert_from_quint4, convert_from_quint8, + convert_to_qint1, convert_to_qint4, convert_to_qint8, convert_to_quint4, @@ -26,6 +28,7 @@ from megengine.core.tensor.dtype import ( get_scale, get_zero_point, is_quantize, + qint1, qint4, qint8, quint4, @@ -113,9 +116,20 @@ def test_dtype_qint4(): np.testing.assert_allclose(get_scale(dt), 0.01) +def test_dtype_qint1(): + dt = qint1(0.01) + assert isinstance(dt, np.dtype) + assert "mgb_dtype" in dt.metadata + np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01) + + assert is_quantize(dt) + np.testing.assert_allclose(get_scale(dt), 0.01) + + @pytest.mark.parametrize( "dtype, dtype_name", [ + (qint1(0.01), "qint1"), (quint4(0.01, 5), "quint4"), (qint4(0.01), "qint4"), (quint8(0.01, 135), "quint8"), @@ -141,6 +155,7 @@ def test_dtype_qint_mgb_ffi_handle(dtype, dtype_name): @pytest.mark.parametrize( "dtype, dtype_name", [ + (qint1(0.01), "qint1"), (quint4(0.01, 5), "quint4"), (qint4(0.01), "qint4"), (quint8(0.01, 135), "quint8"), @@ -178,6 +193,7 @@ def test_qint_typecvt(dtype, dtype_name): @pytest.mark.parametrize( "dtype, dtype_name", [ + (qint1(0.01), "qint1"), (quint4(0.01, 5), "quint4"), (qint4(0.01), "qint4"), (quint8(0.01, 135), "quint8"), @@ -207,6 +223,7 @@ def test_qint_astype(dtype, dtype_name): @pytest.mark.parametrize( "dtype, dtype_name", [ + (qint1(0.01), "qint1"), (quint4(0.01, 5), "quint4"), (qint4(0.01), "qint4"), (quint8(0.01, 135), "quint8"), diff --git a/src/plugin/impl/opr_io_dump.cpp b/src/plugin/impl/opr_io_dump.cpp index d964dafe344cdf59f10e93bd1a24b6c4c09014c6..e9c5a182ae0559f2d35f1870f0586560b9c6c2c6 100644 --- a/src/plugin/impl/opr_io_dump.cpp +++ b/src/plugin/impl/opr_io_dump.cpp @@ -42,6 +42,10 @@ double as_double(megdnn::dt_qint4& a) { return static_cast(a.as_int8()); } template <> +double as_double(megdnn::dt_qint1& a) { + return static_cast(a.as_int8()); +} +template <> double as_double(megdnn::dt_qint32& a) { return static_cast(a.as_int32()); } @@ -111,7 +115,7 @@ void print_host_val( MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) - cb(dtype::Bool) + cb(dtype::Bool) cb(::megdnn::dtype::QuantizedS1) #undef cb default : mgb_throw( MegBrainError, diff --git a/src/serialization/impl/dtype.fbs b/src/serialization/impl/dtype.fbs index 6e239c6a6e28e7432ea9f03bdb0429ab3fbd9e49..a387d05c0bd270bdb84c0d6ffeec3cd079f59c25 100644 --- a/src/serialization/impl/dtype.fbs +++ b/src/serialization/impl/dtype.fbs @@ -23,6 +23,7 @@ enum DTypeEnum : byte { BFloat16, Bool, Uint16, + QuantizedS1, } table LinearQuantizationParam { diff --git a/src/serialization/impl/flatbuffers_helper.cpp b/src/serialization/impl/flatbuffers_helper.cpp index d1de002b6647977bee290ca8ea88ac84db194021..1bce1d0207cd6081959d7bbc84f4305b53a44e45 100644 --- a/src/serialization/impl/flatbuffers_helper.cpp +++ b/src/serialization/impl/flatbuffers_helper.cpp @@ -55,6 +55,8 @@ megdnn::DType load_dtype(const fbs::DType* dtype) { return dtype::_dt{}; MEGDNN_FOREACH_DTYPE_NAME(cb) #undef cb + case DTypeEnum_QuantizedS1: + return dtype::QuantizedS1{param->scale()}; case DTypeEnum_QuantizedS4: return dtype::QuantizedS4{param->scale()}; case DTypeEnum_QuantizedS8: @@ -113,6 +115,7 @@ flatbuffers::Offset build_dtype( break; CASE_ASYMMETRIC(Quantized4Asymm) CASE_ASYMMETRIC(Quantized8Asymm) + CASE_SYMMETRIC(QuantizedS1) CASE_SYMMETRIC(QuantizedS4) CASE_SYMMETRIC(QuantizedS8) CASE_SYMMETRIC(QuantizedS16)