提交 fd6f8e58 编写于 作者: M Megvii Engine Team

feat(mgb/dtype): add dtype qint1

GitOrigin-RevId: abe9fb68b15aa2a73ffa3e2c44175639d5307a57
上级 616352b0
...@@ -62,7 +62,7 @@ namespace megdnn { ...@@ -62,7 +62,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \ #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \
cb(QuantizedS32) cb(QuantizedS8) cb(Quantized4Asymm) cb(QuantizedS4) \ cb(QuantizedS32) cb(QuantizedS8) cb(Quantized4Asymm) cb(QuantizedS4) \
cb(QuantizedS16) cb(QuantizedS16) cb(QuantizedS1)
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \ #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \ MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \
...@@ -112,7 +112,7 @@ namespace megdnn { ...@@ -112,7 +112,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ #define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \
cb(::megdnn::dtype::QuantizedS32) cb(::megdnn::dtype::QuantizedS8) \ cb(::megdnn::dtype::QuantizedS32) cb(::megdnn::dtype::QuantizedS8) \
cb(::megdnn::dtype::QuantizedS4) cb(::megdnn::dtype::QuantizedS4) cb(::megdnn::dtype::QuantizedS1)
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ #define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \
cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::Quantized4Asymm) cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::Quantized4Asymm)
...@@ -292,10 +292,27 @@ public: ...@@ -292,10 +292,27 @@ public:
}; };
using dt_qint4 = dt_qlowbit<4>; using dt_qint4 = dt_qlowbit<4>;
class dt_qint1 {
int8_t _;
public:
MEGDNN_DEVICE int8_t as_int8() const { return _; }
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint1(int8_t val) : _(val) {}
#ifdef MEGDNN_CC_HOST
explicit operator int8_t() { return _; }
#endif
bool operator<(const dt_qint1& b) const { return _ < b._; }
bool operator>(const dt_qint1& b) const { return _ > b._; }
bool operator==(const dt_qint1& b) const { return _ == b._; }
bool operator!=(const dt_qint1& b) const { return _ != b._; }
} MEGDNN_PACKED;
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif #endif
MEGDNN_STATIC_ASSERT(sizeof(dt_byte) == 1, "bad dt_byte size"); MEGDNN_STATIC_ASSERT(sizeof(dt_byte) == 1, "bad dt_byte size");
MEGDNN_STATIC_ASSERT(sizeof(dt_qint1) == 1, "bad dt_qint1 size");
MEGDNN_STATIC_ASSERT(sizeof(dt_quint8) == 1, "bad dt_quint8 size"); MEGDNN_STATIC_ASSERT(sizeof(dt_quint8) == 1, "bad dt_quint8 size");
MEGDNN_STATIC_ASSERT(sizeof(dt_qint16) == 2, "bad dt_qint16 size"); MEGDNN_STATIC_ASSERT(sizeof(dt_qint16) == 2, "bad dt_qint16 size");
MEGDNN_STATIC_ASSERT(sizeof(dt_qint32) == 4, "bad dt_qint32 size"); MEGDNN_STATIC_ASSERT(sizeof(dt_qint32) == 4, "bad dt_qint32 size");
...@@ -677,7 +694,7 @@ MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT) ...@@ -677,7 +694,7 @@ MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT)
return static_cast<_itype>(_maxval); \ return static_cast<_itype>(_maxval); \
} \ } \
}; };
MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS1, dt_qint1, int8_t, QUANTIZED, SIGNED, 0, 1, 0);
MEGDNN_DEF_PARAMETERIZED_DT( MEGDNN_DEF_PARAMETERIZED_DT(
Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, SIGNED, 0, 15, 4); Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, SIGNED, 0, 15, 4);
MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, SIGNED, -8, 7, 4); MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, SIGNED, -8, 7, 4);
...@@ -876,6 +893,26 @@ struct DTypeParamImpl<dt_quint4> { ...@@ -876,6 +893,26 @@ struct DTypeParamImpl<dt_quint4> {
} }
}; };
template <>
struct DTypeParamImpl<dt_qint1> {
float scale;
DTypeParamImpl<dt_qint1>() = default;
MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_qint1>(float scale);
#ifdef MEGDNN_CC_HOST
std::size_t hash() const;
#endif
bool operator==(const DTypeParam<dt_qint1>& rhs) const;
MEGDNN_DEVICE dt_qint1 quantize(float in) const {
float v = in / scale;
v = roundf(v);
v = fmin(fmax(0.f, v), 1.f);
return static_cast<dt_qint1>(v);
}
MEGDNN_DEVICE float dequantize(int8_t in) const { return in * scale; }
MEGDNN_DEVICE float dequantize(dt_qint1 in) const { return in.as_int8() * scale; }
};
template <> template <>
struct DTypeParamImpl<dt_qint4> { struct DTypeParamImpl<dt_qint4> {
float scale; float scale;
......
...@@ -142,6 +142,19 @@ inline bool DTypeParam<dt_qint32>::operator==(const DTypeParam<dt_qint32>& rhs) ...@@ -142,6 +142,19 @@ inline bool DTypeParam<dt_qint32>::operator==(const DTypeParam<dt_qint32>& rhs)
return scale == rhs.scale; return scale == rhs.scale;
} }
DTypeParam<dt_qint1>::DTypeParamImpl(float scale) : scale{scale} {
//! As the nan is not equal to any value
megdnn_assert(!std::isnan(scale), "nan number compare is not support");
}
inline std::size_t DTypeParam<dt_qint1>::hash() const {
return std::hash<float>()(scale);
}
inline bool DTypeParam<dt_qint1>::operator==(const DTypeParam<dt_qint1>& rhs) const {
return scale == rhs.scale;
}
DTypeParam<dt_quint4>::DTypeParamImpl(float scale, uint8_t zero_point) DTypeParam<dt_quint4>::DTypeParamImpl(float scale, uint8_t zero_point)
: scale{scale}, zero_point{zero_point} { : scale{scale}, zero_point{zero_point} {
//! As the nan is not equal to any value //! As the nan is not equal to any value
......
...@@ -241,6 +241,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) { ...@@ -241,6 +241,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) {
return lhs.param<dt>().scale * rhs.param<dt>().scale; return lhs.param<dt>().scale * rhs.param<dt>().scale;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::QuantizedS1)
#undef cb #undef cb
megdnn_assert_internal(0); megdnn_assert_internal(0);
} }
...@@ -253,8 +254,9 @@ float megdnn::get_scale(DType dt) { ...@@ -253,8 +254,9 @@ float megdnn::get_scale(DType dt) {
return dt.param<_dt>().scale; return dt.param<_dt>().scale;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::QuantizedS1)
#undef cb #undef cb
megdnn_assert_internal(0); megdnn_assert_internal(0);
} }
bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { bool megdnn::dtype_almost_equal(DType lhs, DType rhs) {
......
...@@ -160,6 +160,9 @@ INST_FOR_CTYPE ...@@ -160,6 +160,9 @@ INST_FOR_CTYPE
#define ct dt_bool #define ct dt_bool
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
#define ct dt_qint1
INST_FOR_CTYPE
#undef ct
#undef INST_FOR_CTYPE #undef INST_FOR_CTYPE
#undef INST #undef INST
...@@ -210,6 +213,9 @@ INST_FOR_CTYPE ...@@ -210,6 +213,9 @@ INST_FOR_CTYPE
#define ct dt_bool #define ct dt_bool
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
#define ct dt_qint1
INST_FOR_CTYPE
#undef ct
#undef ndim_cb #undef ndim_cb
...@@ -221,6 +227,7 @@ INST(dt_int8); ...@@ -221,6 +227,7 @@ INST(dt_int8);
INST(dt_uint8); INST(dt_uint8);
INST(dt_bool); INST(dt_bool);
INST(dt_qint8); INST(dt_qint8);
INST(dt_qint1);
INST(dt_quint8); INST(dt_quint8);
#undef dt_ibyte #undef dt_ibyte
......
...@@ -96,6 +96,7 @@ INST(dt_bool, uchar4); ...@@ -96,6 +96,7 @@ INST(dt_bool, uchar4);
#undef as_raw #undef as_raw
#define as_raw(x) x.as_int8() #define as_raw(x) x.as_int8()
INST(dt_qint8, char4); INST(dt_qint8, char4);
INST(dt_qint1, char4);
#undef as_raw #undef as_raw
#define as_raw(x) x.as_uint8() #define as_raw(x) x.as_uint8()
INST(dt_quint8, uchar4); INST(dt_quint8, uchar4);
...@@ -466,6 +467,7 @@ INST_PARAM_VECT_VISITOR; ...@@ -466,6 +467,7 @@ INST_PARAM_VECT_VISITOR;
INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_uint8);
INST_DT_IBYTE(dt_qint8); INST_DT_IBYTE(dt_qint8);
INST_DT_IBYTE(dt_qint1);
INST_DT_IBYTE(dt_quint8); INST_DT_IBYTE(dt_quint8);
INST_DT_IBYTE(dt_bool); INST_DT_IBYTE(dt_bool);
#undef INST_DT_IBYTE #undef INST_DT_IBYTE
...@@ -1299,6 +1301,7 @@ private: ...@@ -1299,6 +1301,7 @@ private:
INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_uint8);
INST_DT_IBYTE(dt_qint8); INST_DT_IBYTE(dt_qint8);
INST_DT_IBYTE(dt_qint1);
INST_DT_IBYTE(dt_quint8); INST_DT_IBYTE(dt_quint8);
INST_DT_IBYTE(dt_bool); INST_DT_IBYTE(dt_bool);
#undef INST_DT_IBYTE #undef INST_DT_IBYTE
...@@ -1649,6 +1652,7 @@ public: ...@@ -1649,6 +1652,7 @@ public:
INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE(dt_uint8); INST_DT_IBYTE(dt_uint8);
INST_DT_IBYTE(dt_qint8); INST_DT_IBYTE(dt_qint8);
INST_DT_IBYTE(dt_qint1);
INST_DT_IBYTE(dt_quint8); INST_DT_IBYTE(dt_quint8);
INST_DT_IBYTE(dt_bool); INST_DT_IBYTE(dt_bool);
#undef INST_DT_IBYTE #undef INST_DT_IBYTE
......
...@@ -88,6 +88,7 @@ struct TypeCvtOpToQuantized< ...@@ -88,6 +88,7 @@ struct TypeCvtOpToQuantized<
typename std::enable_if< typename std::enable_if<
std::is_same<ctype_src, dt_int8>::value || std::is_same<ctype_src, dt_int8>::value ||
std::is_same<ctype_src, dt_uint8>::value || std::is_same<ctype_src, dt_uint8>::value ||
std::is_same<ctype_src, dt_qint1>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> { std::is_same<ctype_src, dt_bool>::value>::type> {
ctype_dest* dest; ctype_dest* dest;
CudaDTypeParam<ctype_dest> param; CudaDTypeParam<ctype_dest> param;
...@@ -111,6 +112,7 @@ struct TypeCvtOpFromQuantized< ...@@ -111,6 +112,7 @@ struct TypeCvtOpFromQuantized<
ctype_dest, ctype_src, ctype_dest, ctype_src,
typename std::enable_if< typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value || std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_qint1>::value ||
std::is_same<ctype_src, dt_quint8>::value>::type> { std::is_same<ctype_src, dt_quint8>::value>::type> {
ctype_dest* dest; ctype_dest* dest;
CudaDTypeParam<ctype_src> param; CudaDTypeParam<ctype_src> param;
...@@ -134,7 +136,8 @@ struct TypeCvtOpBetweenQuantized< ...@@ -134,7 +136,8 @@ struct TypeCvtOpBetweenQuantized<
ctype_dest, ctype_src, ctype_dest, ctype_src,
typename std::enable_if< typename std::enable_if<
(std::is_same<ctype_src, dt_qint8>::value || (std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_quint8>::value) && std::is_same<ctype_src, dt_quint8>::value ||
std::is_same<ctype_src, dt_qint1>::value) &&
IsNotTypeQ4<ctype_dest>::value>::type> { IsNotTypeQ4<ctype_dest>::value>::type> {
ctype_dest* dest; ctype_dest* dest;
CudaDTypeParam<ctype_src> src_param; CudaDTypeParam<ctype_src> src_param;
...@@ -306,6 +309,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st ...@@ -306,6 +309,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
cb(dtype_src, dt_quint8) \ cb(dtype_src, dt_quint8) \
cb(dtype_src, dt_qint32) \ cb(dtype_src, dt_qint32) \
cb(dtype_src, dt_qint8) \ cb(dtype_src, dt_qint8) \
cb(dtype_src, dt_qint1) \
#define INST_SRC_QUANTIZED(dtype_src) \ #define INST_SRC_QUANTIZED(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \ MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \
...@@ -330,7 +334,8 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st ...@@ -330,7 +334,8 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
cb(dt_qint32) \ cb(dt_qint32) \
cb(dt_qint8) \ cb(dt_qint8) \
cb(dt_qint4) \ cb(dt_qint4) \
cb(dt_quint4) cb(dt_quint4) \
cb(dt_qint1)
MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED) MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED)
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL) MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL)
......
...@@ -50,6 +50,7 @@ void exec_src_quantized( ...@@ -50,6 +50,7 @@ void exec_src_quantized(
return; \ return; \
} }
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); MEGDNN_FOREACH_QUANTIZED_DTYPE(cb);
cb(::megdnn::dtype::QuantizedS1);
default: default:
megdnn_assert_internal(0); megdnn_assert_internal(0);
#undef cb #undef cb
...@@ -101,6 +102,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre ...@@ -101,6 +102,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre
return; \ return; \
} }
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); MEGDNN_FOREACH_QUANTIZED_DTYPE(cb);
cb(::megdnn::dtype::QuantizedS1);
#undef cb #undef cb
default: default:
megdnn_assert_internal(0); megdnn_assert_internal(0);
...@@ -150,9 +152,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -150,9 +152,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
} }
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::QuantizedS1)
#undef cb #undef cb
default: default : megdnn_assert_internal(0);
megdnn_assert_internal(0);
} }
} }
} }
......
...@@ -241,6 +241,23 @@ struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> { ...@@ -241,6 +241,23 @@ struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> {
} }
}; };
template <>
struct CudaDTypeParamImpl<dt_qint1> : DTypeParamImpl<dt_qint1> {
float inv_scale;
CudaDTypeParamImpl() = default;
CudaDTypeParamImpl(float scale)
: DTypeParamImpl<dt_qint1>(scale), inv_scale(1.0f / scale) {}
CudaDTypeParamImpl(const DTypeParamImpl<dt_qint1>& param)
: CudaDTypeParamImpl(param.scale) {}
__device__ dt_qint1 quantize(float in) const {
float v = in * inv_scale;
v = roundf(v);
v = fmin(fmax(0.f, v), 1.f);
return static_cast<dt_qint1>(v);
}
};
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) { static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) {
#if __CUDA_ARCH__ >= 610 #if __CUDA_ARCH__ >= 610
......
...@@ -510,7 +510,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -510,7 +510,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
}; };
if (src.layout.is_contiguous() && dst.layout.is_contiguous() && if (src.layout.is_contiguous() && dst.layout.is_contiguous() &&
!is_quantize_lowbit(src.layout.dtype) && !is_quantize_lowbit(src.layout.dtype) &&
!is_quantize_lowbit(dst.layout.dtype)) { !is_quantize_lowbit(dst.layout.dtype) &&
dst.layout.dtype.enumv() != DTypeEnum::QuantizedS1 &&
src.layout.dtype.enumv() != DTypeEnum::QuantizedS1) {
MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst));
} else { } else {
naive::TypeCvtImpl::exec(src, dst); naive::TypeCvtImpl::exec(src, dst);
......
...@@ -79,8 +79,9 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src ...@@ -79,8 +79,9 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
cb(::megdnn::dtype::QuantizedS1)
#undef cb #undef cb
default : megdnn_throw("bad dtype"); default : megdnn_throw("bad dtype");
} }
} }
...@@ -100,8 +101,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -100,8 +101,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
cb(::megdnn::dtype::QuantizedS1)
#undef cb #undef cb
default : megdnn_throw("bad dtype"); default : megdnn_throw("bad dtype");
} }
} }
......
...@@ -79,7 +79,8 @@ template <typename ctype> ...@@ -79,7 +79,8 @@ template <typename ctype>
const char* expr0, const char* expr1, const TensorND& v0, const TensorND& v1, const char* expr0, const char* expr1, const TensorND& v0, const TensorND& v1,
float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) {
if (!std::is_same<ctype, dt_qint4>::value && if (!std::is_same<ctype, dt_qint4>::value &&
!std::is_same<ctype, dt_quint4>::value) { !std::is_same<ctype, dt_quint4>::value &&
!std::is_same<ctype, dt_qint1>::value) {
if (v0.layout.is_physical_contiguous() && v1.layout.is_physical_contiguous()) { if (v0.layout.is_physical_contiguous() && v1.layout.is_physical_contiguous()) {
return assert_tensor_eq_with_iter<ctype>( return assert_tensor_eq_with_iter<ctype>(
expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(), v0.layout, maxerr, expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(), v0.layout, maxerr,
...@@ -158,7 +159,7 @@ void copy_tensors( ...@@ -158,7 +159,7 @@ void copy_tensors(
//! In order to avoid an unnecessary increase in binary size, we just //! In order to avoid an unnecessary increase in binary size, we just
//! use QuantizedS16 dtype in winograd_filter_preprocess now. //! use QuantizedS16 dtype in winograd_filter_preprocess now.
cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1)
#undef cb #undef cb
default : megdnn_trap(); default : megdnn_trap();
} }
......
...@@ -71,6 +71,32 @@ TEST(TestDType, TestQuantized8Asymm) { ...@@ -71,6 +71,32 @@ TEST(TestDType, TestQuantized8Asymm) {
EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::Quantized8Asymm)); EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::Quantized8Asymm));
} }
TEST(TestDType, QuantizedS1) {
using namespace megdnn;
dtype::QuantizedS1 qint1(0.1f);
EXPECT_EQ(qint1.size(1), 1u);
EXPECT_FLOAT_EQ(qint1.param().scale, 0.1f);
dtype::QuantizedS1 qint1_copy = qint1;
EXPECT_NO_THROW(qint1_copy.assert_is(qint1));
EXPECT_FLOAT_EQ(qint1_copy.param().scale, 0.1f);
dtype::QuantizedS1 qint1_reconstruct_with_same_param(0.1f);
EXPECT_NO_THROW(qint1_reconstruct_with_same_param.assert_is(qint1));
dtype::QuantizedS1 qint1_diff(0.2f);
EXPECT_ANY_THROW(qint1_diff.assert_is(qint1));
DType parent = qint1;
ASSERT_NO_THROW(dtype::QuantizedS1::downcast_from(parent));
auto param = dtype::QuantizedS1::downcast_from(parent).param();
EXPECT_FLOAT_EQ(param.scale, 0.1f);
EXPECT_ANY_THROW(dtype::QuantizedS1::downcast_from(dtype::IntB1()));
EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::QuantizedS1));
}
TEST(TestDType, TestQuantizedS4) { TEST(TestDType, TestQuantizedS4) {
using namespace megdnn; using namespace megdnn;
......
...@@ -149,7 +149,7 @@ void IIDRNG::gen(const TensorND& tensor) { ...@@ -149,7 +149,7 @@ void IIDRNG::gen(const TensorND& tensor) {
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
//! In order to avoid an unnecessary increase in binary size, we just //! In order to avoid an unnecessary increase in binary size, we just
//! use QuantizedS16 dtype in winograd_filter_preprocess now. //! use QuantizedS16 dtype in winograd_filter_preprocess now.
cb(::megdnn::dtype::QuantizedS16) cb(::megdnn::dtype::QuantizedS16) cb(::megdnn::dtype::QuantizedS1)
#undef cb #undef cb
if (tensor.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { if (tensor.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
auto ptr = static_cast<uint8_t*>(tensor.raw_ptr()); auto ptr = static_cast<uint8_t*>(tensor.raw_ptr());
......
...@@ -226,6 +226,10 @@ static inline int diff(dt_qint4 x, dt_qint4 y) { ...@@ -226,6 +226,10 @@ static inline int diff(dt_qint4 x, dt_qint4 y) {
return x.as_int8() - y.as_int8(); return x.as_int8() - y.as_int8();
} }
static inline int diff(dt_qint1 x, dt_qint1 y) {
return x.as_int8() - y.as_int8();
}
static inline int diff(dt_quint4 x, dt_quint4 y) { static inline int diff(dt_quint4 x, dt_quint4 y) {
return x.as_uint8() - y.as_uint8(); return x.as_uint8() - y.as_uint8();
} }
...@@ -339,6 +343,10 @@ static inline bool good_float(dt_qint4) { ...@@ -339,6 +343,10 @@ static inline bool good_float(dt_qint4) {
return true; return true;
} }
static inline bool good_float(dt_qint1) {
return true;
}
static inline bool good_float(dt_quint4) { static inline bool good_float(dt_quint4) {
return true; return true;
} }
...@@ -373,6 +381,11 @@ static inline int operator+(dt_qint4 lhs, int rhs) { ...@@ -373,6 +381,11 @@ static inline int operator+(dt_qint4 lhs, int rhs) {
megdnn_assert(rhs == 0, "unexpected rhs"); megdnn_assert(rhs == 0, "unexpected rhs");
return lhs.as_int8(); return lhs.as_int8();
} }
static inline int operator+(dt_qint1 lhs, int rhs) {
megdnn_assert(rhs == 0, "unexpected rhs");
return lhs.as_int8();
}
} // namespace test } // namespace test
static inline bool operator==(const TensorLayout& a, const TensorLayout& b) { static inline bool operator==(const TensorLayout& a, const TensorLayout& b) {
......
...@@ -77,16 +77,19 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { ...@@ -77,16 +77,19 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) {
}; };
run(dtype::Float32(), dtype::QuantizedS8(3.0f)); run(dtype::Float32(), dtype::QuantizedS8(3.0f));
run(dtype::Float32(), dtype::QuantizedS1(3.0f));
run(dtype::Float16(), dtype::QuantizedS8(3.0f)); run(dtype::Float16(), dtype::QuantizedS8(3.0f));
run(dtype::Int32(), dtype::QuantizedS32(5.0f)); run(dtype::Int32(), dtype::QuantizedS32(5.0f));
run(dtype::Int8(), dtype::QuantizedS32(10.0f)); run(dtype::Int8(), dtype::QuantizedS32(10.0f));
run(dtype::Float32(), dtype::QuantizedS8(2e-3f)); run(dtype::Float32(), dtype::QuantizedS8(2e-3f));
run(dtype::Float32(), dtype::QuantizedS1(2e-3f));
run(dtype::Float16(), dtype::QuantizedS8(1e-3f)); run(dtype::Float16(), dtype::QuantizedS8(1e-3f));
run(dtype::Int32(), dtype::QuantizedS32(1e-3f)); run(dtype::Int32(), dtype::QuantizedS32(1e-3f));
run(dtype::Int8(), dtype::QuantizedS32(7e-4f)); run(dtype::Int8(), dtype::QuantizedS32(7e-4f));
run(dtype::QuantizedS8(3.0f), dtype::QuantizedS8(10.0f)); run(dtype::QuantizedS8(3.0f), dtype::QuantizedS8(10.0f));
run(dtype::QuantizedS1(3.0f), dtype::QuantizedS1(10.0f));
run(dtype::QuantizedS32(3.0f), dtype::QuantizedS8(10.0f)); run(dtype::QuantizedS32(3.0f), dtype::QuantizedS8(10.0f));
run(dtype::QuantizedS8(3.0f), dtype::QuantizedS32(10.0f)); run(dtype::QuantizedS8(3.0f), dtype::QuantizedS32(10.0f));
run(dtype::QuantizedS32(3.0f), dtype::QuantizedS32(10.0f)); run(dtype::QuantizedS32(3.0f), dtype::QuantizedS32(10.0f));
...@@ -95,6 +98,7 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { ...@@ -95,6 +98,7 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) {
run(dtype::QuantizedS32(2e-3f), dtype::QuantizedS8(9e-4f)); run(dtype::QuantizedS32(2e-3f), dtype::QuantizedS8(9e-4f));
run(dtype::QuantizedS8(9e-4f), dtype::QuantizedS32(7e-4f)); run(dtype::QuantizedS8(9e-4f), dtype::QuantizedS32(7e-4f));
run(dtype::QuantizedS32(5e-3f), dtype::QuantizedS32(1e-3f)); run(dtype::QuantizedS32(5e-3f), dtype::QuantizedS32(1e-3f));
run(dtype::QuantizedS1(1e-3f), dtype::Float32());
run(dtype::Quantized8Asymm(5.0f, (uint8_t)128), dtype::Float32()); run(dtype::Quantized8Asymm(5.0f, (uint8_t)128), dtype::Float32());
run(dtype::Quantized8Asymm(5.0f, (uint8_t)124), dtype::Float16()); run(dtype::Quantized8Asymm(5.0f, (uint8_t)124), dtype::Float16());
......
...@@ -94,6 +94,7 @@ _builtin_quant_dtypes = { ...@@ -94,6 +94,7 @@ _builtin_quant_dtypes = {
"qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127), "qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127),
"quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15), "quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15),
"qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7), "qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7),
"qint1": QuantDtypeMeta("qint1", "QuantizedS1", "int8", 0, 1),
"qint32": QuantDtypeMeta( "qint32": QuantDtypeMeta(
"qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1, "qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1,
), ),
...@@ -192,6 +193,13 @@ def qint4(scale): ...@@ -192,6 +193,13 @@ def qint4(scale):
return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None)
def qint1(scale):
r"""Construct a quantized int1 data type with ``scale`` (float). The real value
represented by a qint1 data type is float_val = scale * int1_val
"""
return create_quantized_dtype(_builtin_quant_dtypes["qint1"], scale, None)
def _convert_to_quantized_dtype( def _convert_to_quantized_dtype(
arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta
): ):
...@@ -335,3 +343,22 @@ def convert_from_qint4(arr: np.ndarray): ...@@ -335,3 +343,22 @@ def convert_from_qint4(arr: np.ndarray):
arr: Input ndarray. arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"])
def convert_to_qint1(arr: np.ndarray, q: np.dtype):
r"""Quantize a float NumPy ndarray into a qint1 one with specified params.
Args:
arr: Input ndarray.
q: Target data type, should be a qint1.
"""
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint1"])
def convert_from_qint1(arr: np.ndarray):
r"""Dequantize a qint1 NumPy ndarray into a float one.
Args:
arr: Input ndarray.
"""
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint1"])
...@@ -214,6 +214,14 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(DType dty ...@@ -214,6 +214,14 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(DType dty
if (dtype.has_param()) { if (dtype.has_param()) {
PyArray_Descr* type_descr; PyArray_Descr* type_descr;
switch (dtype.enumv()) { switch (dtype.enumv()) {
case DTypeEnum::QuantizedS1: {
auto& param = dtype.param<dtype::QuantizedS1>();
type_descr = PyArray_DescrNewFromType(NPY_INT8);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::QuantizedS1>::name,
{{"scale", PyFloat_FromDouble(param.scale)}});
break;
}
case DTypeEnum::Quantized4Asymm: { case DTypeEnum::Quantized4Asymm: {
auto& param = dtype.param<dtype::Quantized4Asymm>(); auto& param = dtype.param<dtype::Quantized4Asymm>();
type_descr = PyArray_DescrNewFromType(NPY_UINT8); type_descr = PyArray_DescrNewFromType(NPY_UINT8);
...@@ -354,7 +362,7 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { ...@@ -354,7 +362,7 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
static_cast<uint8_t>(zero_point)); static_cast<uint8_t>(zero_point));
} }
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" ||
dtype_name == "QuantizedS4") { dtype_name == "QuantizedS4" || dtype_name == "QuantizedS1") {
PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
mgb_assert(scale_py, "Invalid metadata: missing scale"); mgb_assert(scale_py, "Invalid metadata: missing scale");
mgb_assert( mgb_assert(
...@@ -364,8 +372,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { ...@@ -364,8 +372,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
return dtype::QuantizedS32(scale); return dtype::QuantizedS32(scale);
} else if (dtype_name == "QuantizedS8") { } else if (dtype_name == "QuantizedS8") {
return dtype::QuantizedS8(scale); return dtype::QuantizedS8(scale);
} else { } else if (dtype_name == "QuantizedS4") {
return dtype::QuantizedS4(scale); return dtype::QuantizedS4(scale);
} else if (dtype_name == "QuantizedS1") {
return dtype::QuantizedS1(scale);
} }
} }
throw ConversionError( throw ConversionError(
......
...@@ -15,10 +15,12 @@ import megengine.core.tensor.megbrain_graph as G ...@@ -15,10 +15,12 @@ import megengine.core.tensor.megbrain_graph as G
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.tensor.dtype import ( from megengine.core.tensor.dtype import (
_builtin_quant_dtypes, _builtin_quant_dtypes,
convert_from_qint1,
convert_from_qint4, convert_from_qint4,
convert_from_qint8, convert_from_qint8,
convert_from_quint4, convert_from_quint4,
convert_from_quint8, convert_from_quint8,
convert_to_qint1,
convert_to_qint4, convert_to_qint4,
convert_to_qint8, convert_to_qint8,
convert_to_quint4, convert_to_quint4,
...@@ -26,6 +28,7 @@ from megengine.core.tensor.dtype import ( ...@@ -26,6 +28,7 @@ from megengine.core.tensor.dtype import (
get_scale, get_scale,
get_zero_point, get_zero_point,
is_quantize, is_quantize,
qint1,
qint4, qint4,
qint8, qint8,
quint4, quint4,
...@@ -113,9 +116,20 @@ def test_dtype_qint4(): ...@@ -113,9 +116,20 @@ def test_dtype_qint4():
np.testing.assert_allclose(get_scale(dt), 0.01) np.testing.assert_allclose(get_scale(dt), 0.01)
def test_dtype_qint1():
dt = qint1(0.01)
assert isinstance(dt, np.dtype)
assert "mgb_dtype" in dt.metadata
np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01)
assert is_quantize(dt)
np.testing.assert_allclose(get_scale(dt), 0.01)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype, dtype_name", "dtype, dtype_name",
[ [
(qint1(0.01), "qint1"),
(quint4(0.01, 5), "quint4"), (quint4(0.01, 5), "quint4"),
(qint4(0.01), "qint4"), (qint4(0.01), "qint4"),
(quint8(0.01, 135), "quint8"), (quint8(0.01, 135), "quint8"),
...@@ -141,6 +155,7 @@ def test_dtype_qint_mgb_ffi_handle(dtype, dtype_name): ...@@ -141,6 +155,7 @@ def test_dtype_qint_mgb_ffi_handle(dtype, dtype_name):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype, dtype_name", "dtype, dtype_name",
[ [
(qint1(0.01), "qint1"),
(quint4(0.01, 5), "quint4"), (quint4(0.01, 5), "quint4"),
(qint4(0.01), "qint4"), (qint4(0.01), "qint4"),
(quint8(0.01, 135), "quint8"), (quint8(0.01, 135), "quint8"),
...@@ -178,6 +193,7 @@ def test_qint_typecvt(dtype, dtype_name): ...@@ -178,6 +193,7 @@ def test_qint_typecvt(dtype, dtype_name):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype, dtype_name", "dtype, dtype_name",
[ [
(qint1(0.01), "qint1"),
(quint4(0.01, 5), "quint4"), (quint4(0.01, 5), "quint4"),
(qint4(0.01), "qint4"), (qint4(0.01), "qint4"),
(quint8(0.01, 135), "quint8"), (quint8(0.01, 135), "quint8"),
...@@ -207,6 +223,7 @@ def test_qint_astype(dtype, dtype_name): ...@@ -207,6 +223,7 @@ def test_qint_astype(dtype, dtype_name):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype, dtype_name", "dtype, dtype_name",
[ [
(qint1(0.01), "qint1"),
(quint4(0.01, 5), "quint4"), (quint4(0.01, 5), "quint4"),
(qint4(0.01), "qint4"), (qint4(0.01), "qint4"),
(quint8(0.01, 135), "quint8"), (quint8(0.01, 135), "quint8"),
......
...@@ -42,6 +42,10 @@ double as_double(megdnn::dt_qint4& a) { ...@@ -42,6 +42,10 @@ double as_double(megdnn::dt_qint4& a) {
return static_cast<double>(a.as_int8()); return static_cast<double>(a.as_int8());
} }
template <> template <>
double as_double(megdnn::dt_qint1& a) {
return static_cast<double>(a.as_int8());
}
template <>
double as_double(megdnn::dt_qint32& a) { double as_double(megdnn::dt_qint32& a) {
return static_cast<double>(a.as_int32()); return static_cast<double>(a.as_int32());
} }
...@@ -111,7 +115,7 @@ void print_host_val( ...@@ -111,7 +115,7 @@ void print_host_val(
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(dtype::Bool) cb(dtype::Bool) cb(::megdnn::dtype::QuantizedS1)
#undef cb #undef cb
default : mgb_throw( default : mgb_throw(
MegBrainError, MegBrainError,
......
...@@ -23,6 +23,7 @@ enum DTypeEnum : byte { ...@@ -23,6 +23,7 @@ enum DTypeEnum : byte {
BFloat16, BFloat16,
Bool, Bool,
Uint16, Uint16,
QuantizedS1,
} }
table LinearQuantizationParam { table LinearQuantizationParam {
......
...@@ -55,6 +55,8 @@ megdnn::DType load_dtype(const fbs::DType* dtype) { ...@@ -55,6 +55,8 @@ megdnn::DType load_dtype(const fbs::DType* dtype) {
return dtype::_dt{}; return dtype::_dt{};
MEGDNN_FOREACH_DTYPE_NAME(cb) MEGDNN_FOREACH_DTYPE_NAME(cb)
#undef cb #undef cb
case DTypeEnum_QuantizedS1:
return dtype::QuantizedS1{param->scale()};
case DTypeEnum_QuantizedS4: case DTypeEnum_QuantizedS4:
return dtype::QuantizedS4{param->scale()}; return dtype::QuantizedS4{param->scale()};
case DTypeEnum_QuantizedS8: case DTypeEnum_QuantizedS8:
...@@ -113,6 +115,7 @@ flatbuffers::Offset<fbs::DType> build_dtype( ...@@ -113,6 +115,7 @@ flatbuffers::Offset<fbs::DType> build_dtype(
break; break;
CASE_ASYMMETRIC(Quantized4Asymm) CASE_ASYMMETRIC(Quantized4Asymm)
CASE_ASYMMETRIC(Quantized8Asymm) CASE_ASYMMETRIC(Quantized8Asymm)
CASE_SYMMETRIC(QuantizedS1)
CASE_SYMMETRIC(QuantizedS4) CASE_SYMMETRIC(QuantizedS4)
CASE_SYMMETRIC(QuantizedS8) CASE_SYMMETRIC(QuantizedS8)
CASE_SYMMETRIC(QuantizedS16) CASE_SYMMETRIC(QuantizedS16)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册