diff --git a/dnn/src/fallback/type_cvt/opr_impl.cpp b/dnn/src/fallback/type_cvt/opr_impl.cpp index 26602c4d524b86ed2be47e2d731b3074d588e194..c63645b512fd64b348b05bb109fe8b62e308005e 100644 --- a/dnn/src/fallback/type_cvt/opr_impl.cpp +++ b/dnn/src/fallback/type_cvt/opr_impl.cpp @@ -451,7 +451,12 @@ namespace fallback { void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { check_exec(src.layout, dst.layout); - if (src.layout.is_contiguous() && dst.layout.is_contiguous()) { + auto is_quantize_lowbit = [](const DType& dt) { + return dt.category() == DTypeCategory::QUANTIZED && dt.is_low_bit(); + }; + if (src.layout.is_contiguous() && dst.layout.is_contiguous() && + !is_quantize_lowbit(src.layout.dtype) && + !is_quantize_lowbit(dst.layout.dtype)) { MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); } else { naive::TypeCvtImpl::exec(src, dst); diff --git a/python_module/megengine/_internal/dtype.py b/python_module/megengine/_internal/dtype.py index b9d700fef459a1151e820dbd939e80b1d101f80d..2bb41a19a855ffc23ce183484a6666cd74c4743c 100644 --- a/python_module/megengine/_internal/dtype.py +++ b/python_module/megengine/_internal/dtype.py @@ -32,7 +32,7 @@ def get_scale(dtype): def get_zero_point(dtype): assert is_quantize(dtype) metadata = dtype.metadata["mgb_dtype"] - assert metadata["name"] == "Quantized8Asymm" + assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") return metadata["zero_point"] @@ -79,6 +79,38 @@ def qint32(scale): ) +def quint4(scale, zero_point): + """ + Consturct a quantized unsigned int4 data type with ``scale`` (float) and + ``zero_point`` (uint8). The real value represented by a quint4 data type is + float_val = scale * (uint4_val - zero_point) + """ + int_zp = int(zero_point) + assert int_zp == zero_point, "zero_point should be an integer" + if int_zp < 0 or int_zp > 15: + raise ValueError("zero_point should be within [0, 15] for quint4") + return np.dtype( + np.uint8, + metadata={ + "mgb_dtype": { + "name": "Quantized4Asymm", + "scale": float(scale), + "zero_point": int(zero_point), + } + }, + ) + + +def qint4(scale): + """ + Construct a quantized int4 data type with ``scale`` (float). The real value + represented by a qint4 data type is float_val = scale * int4_val + """ + return np.dtype( + np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}} + ) + + def convert_to_quint8(arr, q): """ Quantize a float NumPy ndarray into a quint8 one with specified params. @@ -177,3 +209,71 @@ def convert_from_qint32(arr): ), "arr should be a ndarray with qint8 dtype" scale = arr.dtype.metadata["mgb_dtype"]["scale"] return arr.astype(np.float32) * scale + + +def convert_to_quint4(arr, q): + """ + Quantize a float NumPy ndarray into a quint4 one with specified params. + + :param arr: Input ndarray. + :type arr: :class:`np.ndarray` + :param q: Target data type, should be a quint4. + :type q: :class:`np.dtype` + """ + assert isinstance(arr, np.ndarray) + assert ( + "mgb_dtype" in q.metadata + and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" + ), "q should be a quint4 dtype" + scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] + return (np.round(arr / scale) + zp).clip(0, 15).astype(q) + + +def convert_from_quint4(arr): + """ + Dequantize a quint4 NumPy ndarray into a float one. + + :param arr: Input ndarray. + """ + assert isinstance(arr, np.ndarray) + assert ( + "mgb_dtype" in arr.dtype.metadata + and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" + ), "arr should be a ndarray with quint4 dtype" + scale, zp = ( + arr.dtype.metadata["mgb_dtype"]["scale"], + arr.dtype.metadata["mgb_dtype"]["zero_point"], + ) + return (arr.astype(np.float32) - zp) * scale + + +def convert_to_qint4(arr, q): + """ + Quantize a float NumPy ndarray into a qint4 one with specified params. + + :param arr: Input ndarray. + :type arr: :class:`np.ndarray` + :param q: Target data type, should be a qint4. + :type q: :class:`np.dtype` + """ + assert isinstance(arr, np.ndarray) + assert ( + "mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4" + ), "q should be a qint4 dtype" + scale = q.metadata["mgb_dtype"]["scale"] + return (np.round(arr / scale)).clip(-8, 7).astype(q) + + +def convert_from_qint4(arr): + """ + Dequantize a qint4 NumPy ndarray into a float one. + + :param arr: Input ndarray. + """ + assert isinstance(arr, np.ndarray) + assert ( + "mgb_dtype" in arr.dtype.metadata + and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4" + ), "arr should be a ndarray with qint4 dtype" + scale = arr.dtype.metadata["mgb_dtype"]["scale"] + return arr.astype(np.float32) * scale diff --git a/python_module/src/cpp/python_helper.cpp b/python_module/src/cpp/python_helper.cpp index 32469e2106a3c13ac003c44bef7b42385da2a521..7aa48574327949f8abae20216efc799170d0dad2 100644 --- a/python_module/src/cpp/python_helper.cpp +++ b/python_module/src/cpp/python_helper.cpp @@ -452,6 +452,23 @@ std::unique_ptr dtype_mgb2np_descr( {{"scale", PyFloat_FromDouble(param.scale)}}); break; } + case DTypeEnum::Quantized4Asymm: { + auto& param = dtype.param(); + type_descr = PyArray_DescrNewFromType(NPY_UINT8); + type_descr->metadata = build_mgb_dtype_dict( + DTypeTrait::name, + {{"scale", PyFloat_FromDouble(param.scale)}, + {"zero_point", PyLong_FromLong(param.zero_point)}}); + break; + } + case DTypeEnum::QuantizedS4: { + auto& param = dtype.param(); + type_descr = PyArray_DescrNewFromType(NPY_INT8); + type_descr->metadata = build_mgb_dtype_dict( + DTypeTrait::name, + {{"scale", PyFloat_FromDouble(param.scale)}}); + break; + } case DTypeEnum::QuantizedS32: { auto& param = dtype.param(); type_descr = PyArray_DescrNewFromType(NPY_INT32); @@ -529,7 +546,29 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { static_cast(PyFloat_AS_DOUBLE(scale_py)), static_cast(zero_point)); } - if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8") { + if (dtype_name == "Quantized4Asymm") { + PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); + PyObject* zero_point_py = + PyDict_GetItemString(metadata, "zero_point"); + mgb_assert(scale_py && zero_point_py, + "Invalid Quantized4Asymm metadata: missing scale or " + "zero_point."); + mgb_assert( + PyFloat_Check(scale_py), + "Invalid Quantized4Asymm metadata: scale should be float"); + mgb_assert(PyLong_Check(zero_point_py), + "Invalid Quantized4Asymm metadata: zero_point should be " + "integer"); + auto zero_point = PyLong_AS_LONG(zero_point_py); + mgb_assert(zero_point >= 0 && zero_point < 15, + "Invalid Quantized4Asymm metadata: zero_point should be " + "in [0, 15)"); + return dtype::Quantized4Asymm( + static_cast(PyFloat_AS_DOUBLE(scale_py)), + static_cast(zero_point)); + } + if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || + dtype_name == "QuantizedS4") { PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); mgb_assert(scale_py, "Invalid metadata: missing scale"); mgb_assert(PyFloat_Check(scale_py), @@ -537,8 +576,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { float scale = static_cast(PyFloat_AS_DOUBLE(scale_py)); if (dtype_name == "QuantizedS32") { return dtype::QuantizedS32(scale); - } else { + } else if (dtype_name == "QuantizedS8"){ return dtype::QuantizedS8(scale); + } else { + return dtype::QuantizedS4(scale); } } throw ConversionError( diff --git a/src/core/impl/dtype.cpp b/src/core/impl/dtype.cpp index a0030dbb3157ff12fe530ddd0e194b3379581d15..a91a9f9f7f01fe900658c87a170d2380c74d34b3 100644 --- a/src/core/impl/dtype.cpp +++ b/src/core/impl/dtype.cpp @@ -14,6 +14,7 @@ #include "megbrain/exception.h" #include "megbrain/utils/metahelper.h" #include "megbrain/utils/arith_helper.h" +#include "megdnn/dtype.h" #include #include @@ -357,6 +358,52 @@ struct LowbitMemcpy { } } }; + +template +struct QuantizedLowbitTrait; + +template<> +struct QuantizedLowbitTrait { + static constexpr int8_t SHIFT = 0; +}; + +template<> +struct QuantizedLowbitTrait { + static constexpr int8_t SHIFT = 8; +}; + +template ::category == + DTypeCategory::QUANTIZED) && + (8 % DTypeTrait
::low_bit == 0)> +struct QuantizedLowbitMemcpy; + +template +struct QuantizedLowbitMemcpy { + // cast with bits that 8 % bits == 0 + static constexpr uint16_t bits = DTypeTrait
::low_bit; + static constexpr uint8_t MASK = (1 << bits) - 1; + using Trait = QuantizedLowbitTrait
; + + static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { + auto dest = static_cast(dest_raw); + auto src = static_cast(src_raw); + memset(dest, 0, divup(n * bits, 8)); + for (size_t i = 0; i < n; ++i) { + int8_t val = src[i] + Trait::SHIFT; + mgb_assert(val >= 0 && val < (1 << bits)); + dest[i * bits / 8] |= val << (i * bits % 8); + } + } + static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { + auto dest = static_cast(dest_raw); + auto src = static_cast(src_raw); + for (size_t i = 0; i < n; ++i) { + int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); + dest[i] = val - Trait::SHIFT; + } + } +}; + } // anonymous namespace void mgb::lowbit_memcpy_byte2compact( @@ -365,6 +412,11 @@ void mgb::lowbit_memcpy_byte2compact( if (dtype == mgb::dtype::name##bits()) \ return LowbitMemcpy::byte2compact(dest, src, n); MEGDNN_FOREACH_LOWBIT_DTYPE(cb) +#undef cb +#define cb(dt) \ + if (dtype.enumv() == DTypeTrait
::enumv) \ + return QuantizedLowbitMemcpy
::byte2compact(dest, src, n); + MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) #undef cb mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); } @@ -375,6 +427,11 @@ void mgb::lowbit_memcpy_compact2byte( if (dtype == mgb::dtype::name##bits()) \ return LowbitMemcpy::compact2byte(dest, src, n); MEGDNN_FOREACH_LOWBIT_DTYPE(cb) +#undef cb +#define cb(dt) \ + if (dtype.enumv() == DTypeTrait
::enumv) \ + return QuantizedLowbitMemcpy
::compact2byte(dest, src, n); + MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) #undef cb mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); }