提交 fc6aa12e 编写于 作者: M Megvii Engine Team

feat(mge):add qint4/quint4 to python

GitOrigin-RevId: f94609db00fcaaa9ca249eb61639eb0482705f79
上级 a94fb7b1
......@@ -451,7 +451,12 @@ namespace fallback {
void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
check_exec(src.layout, dst.layout);
if (src.layout.is_contiguous() && dst.layout.is_contiguous()) {
auto is_quantize_lowbit = [](const DType& dt) {
return dt.category() == DTypeCategory::QUANTIZED && dt.is_low_bit();
};
if (src.layout.is_contiguous() && dst.layout.is_contiguous() &&
!is_quantize_lowbit(src.layout.dtype) &&
!is_quantize_lowbit(dst.layout.dtype)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst));
} else {
naive::TypeCvtImpl::exec(src, dst);
......
......@@ -32,7 +32,7 @@ def get_scale(dtype):
def get_zero_point(dtype):
assert is_quantize(dtype)
metadata = dtype.metadata["mgb_dtype"]
assert metadata["name"] == "Quantized8Asymm"
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm")
return metadata["zero_point"]
......@@ -79,6 +79,38 @@ def qint32(scale):
)
def quint4(scale, zero_point):
"""
Consturct a quantized unsigned int4 data type with ``scale`` (float) and
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
int_zp = int(zero_point)
assert int_zp == zero_point, "zero_point should be an integer"
if int_zp < 0 or int_zp > 15:
raise ValueError("zero_point should be within [0, 15] for quint4")
return np.dtype(
np.uint8,
metadata={
"mgb_dtype": {
"name": "Quantized4Asymm",
"scale": float(scale),
"zero_point": int(zero_point),
}
},
)
def qint4(scale):
"""
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return np.dtype(
np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}}
)
def convert_to_quint8(arr, q):
"""
Quantize a float NumPy ndarray into a quint8 one with specified params.
......@@ -177,3 +209,71 @@ def convert_from_qint32(arr):
), "arr should be a ndarray with qint8 dtype"
scale = arr.dtype.metadata["mgb_dtype"]["scale"]
return arr.astype(np.float32) * scale
def convert_to_quint4(arr, q):
"""
Quantize a float NumPy ndarray into a quint4 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a quint4.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata
and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm"
), "q should be a quint4 dtype"
scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"]
return (np.round(arr / scale) + zp).clip(0, 15).astype(q)
def convert_from_quint4(arr):
"""
Dequantize a quint4 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm"
), "arr should be a ndarray with quint4 dtype"
scale, zp = (
arr.dtype.metadata["mgb_dtype"]["scale"],
arr.dtype.metadata["mgb_dtype"]["zero_point"],
)
return (arr.astype(np.float32) - zp) * scale
def convert_to_qint4(arr, q):
"""
Quantize a float NumPy ndarray into a qint4 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint4.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4"
), "q should be a qint4 dtype"
scale = q.metadata["mgb_dtype"]["scale"]
return (np.round(arr / scale)).clip(-8, 7).astype(q)
def convert_from_qint4(arr):
"""
Dequantize a qint4 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4"
), "arr should be a ndarray with qint4 dtype"
scale = arr.dtype.metadata["mgb_dtype"]["scale"]
return arr.astype(np.float32) * scale
......@@ -452,6 +452,23 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(
{{"scale", PyFloat_FromDouble(param.scale)}});
break;
}
case DTypeEnum::Quantized4Asymm: {
auto& param = dtype.param<dtype::Quantized4Asymm>();
type_descr = PyArray_DescrNewFromType(NPY_UINT8);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::Quantized4Asymm>::name,
{{"scale", PyFloat_FromDouble(param.scale)},
{"zero_point", PyLong_FromLong(param.zero_point)}});
break;
}
case DTypeEnum::QuantizedS4: {
auto& param = dtype.param<dtype::QuantizedS4>();
type_descr = PyArray_DescrNewFromType(NPY_INT8);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::QuantizedS4>::name,
{{"scale", PyFloat_FromDouble(param.scale)}});
break;
}
case DTypeEnum::QuantizedS32: {
auto& param = dtype.param<dtype::QuantizedS32>();
type_descr = PyArray_DescrNewFromType(NPY_INT32);
......@@ -529,7 +546,29 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
static_cast<uint8_t>(zero_point));
}
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8") {
if (dtype_name == "Quantized4Asymm") {
PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
PyObject* zero_point_py =
PyDict_GetItemString(metadata, "zero_point");
mgb_assert(scale_py && zero_point_py,
"Invalid Quantized4Asymm metadata: missing scale or "
"zero_point.");
mgb_assert(
PyFloat_Check(scale_py),
"Invalid Quantized4Asymm metadata: scale should be float");
mgb_assert(PyLong_Check(zero_point_py),
"Invalid Quantized4Asymm metadata: zero_point should be "
"integer");
auto zero_point = PyLong_AS_LONG(zero_point_py);
mgb_assert(zero_point >= 0 && zero_point < 15,
"Invalid Quantized4Asymm metadata: zero_point should be "
"in [0, 15)");
return dtype::Quantized4Asymm(
static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
static_cast<uint8_t>(zero_point));
}
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" ||
dtype_name == "QuantizedS4") {
PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
mgb_assert(scale_py, "Invalid metadata: missing scale");
mgb_assert(PyFloat_Check(scale_py),
......@@ -537,8 +576,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py));
if (dtype_name == "QuantizedS32") {
return dtype::QuantizedS32(scale);
} else {
} else if (dtype_name == "QuantizedS8"){
return dtype::QuantizedS8(scale);
} else {
return dtype::QuantizedS4(scale);
}
}
throw ConversionError(
......
......@@ -14,6 +14,7 @@
#include "megbrain/exception.h"
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/arith_helper.h"
#include "megdnn/dtype.h"
#include <cmath>
#include <cstring>
......@@ -357,6 +358,52 @@ struct LowbitMemcpy<bits, true> {
}
}
};
template<typename DT>
struct QuantizedLowbitTrait;
template<>
struct QuantizedLowbitTrait<dtype::Quantized4Asymm> {
static constexpr int8_t SHIFT = 0;
};
template<>
struct QuantizedLowbitTrait<dtype::QuantizedS4> {
static constexpr int8_t SHIFT = 8;
};
template <typename DT, bool div_byte = (DTypeTrait<DT>::category ==
DTypeCategory::QUANTIZED) &&
(8 % DTypeTrait<DT>::low_bit == 0)>
struct QuantizedLowbitMemcpy;
template <typename DT>
struct QuantizedLowbitMemcpy<DT, true> {
// cast with bits that 8 % bits == 0
static constexpr uint16_t bits = DTypeTrait<DT>::low_bit;
static constexpr uint8_t MASK = (1 << bits) - 1;
using Trait = QuantizedLowbitTrait<DT>;
static void byte2compact(void* dest_raw, const void* src_raw, size_t n) {
auto dest = static_cast<uint8_t*>(dest_raw);
auto src = static_cast<const int8_t*>(src_raw);
memset(dest, 0, divup<size_t>(n * bits, 8));
for (size_t i = 0; i < n; ++i) {
int8_t val = src[i] + Trait::SHIFT;
mgb_assert(val >= 0 && val < (1 << bits));
dest[i * bits / 8] |= val << (i * bits % 8);
}
}
static void compact2byte(void* dest_raw, const void* src_raw, size_t n) {
auto dest = static_cast<int8_t*>(dest_raw);
auto src = static_cast<const uint8_t*>(src_raw);
for (size_t i = 0; i < n; ++i) {
int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK);
dest[i] = val - Trait::SHIFT;
}
}
};
} // anonymous namespace
void mgb::lowbit_memcpy_byte2compact(
......@@ -365,6 +412,11 @@ void mgb::lowbit_memcpy_byte2compact(
if (dtype == mgb::dtype::name##bits()) \
return LowbitMemcpy<bits>::byte2compact(dest, src, n);
MEGDNN_FOREACH_LOWBIT_DTYPE(cb)
#undef cb
#define cb(dt) \
if (dtype.enumv() == DTypeTrait<dt>::enumv) \
return QuantizedLowbitMemcpy<dt>::byte2compact(dest, src, n);
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name());
}
......@@ -375,6 +427,11 @@ void mgb::lowbit_memcpy_compact2byte(
if (dtype == mgb::dtype::name##bits()) \
return LowbitMemcpy<bits>::compact2byte(dest, src, n);
MEGDNN_FOREACH_LOWBIT_DTYPE(cb)
#undef cb
#define cb(dt) \
if (dtype.enumv() == DTypeTrait<dt>::enumv) \
return QuantizedLowbitMemcpy<dt>::compact2byte(dest, src, n);
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册