diff --git a/python_module/megengine/_internal/dtype.py b/python_module/megengine/_internal/dtype.py index 7d0eb61a9a111b19fa0c383ae4bb27fc10940c9e..6bb32f8634a32ae2e964f4f10b9817e5a7b7b42b 100644 --- a/python_module/megengine/_internal/dtype.py +++ b/python_module/megengine/_internal/dtype.py @@ -25,6 +25,9 @@ _metadata_dict = { "qint32": _QuantDtypeMetadata( "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, ), + # NOTE: int2 is not supported for model dump yet + "quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), + "qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), } diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index 82feced1cc3738598f6664f4079e23a6d4854f75..9c8a0e0da5f9f7c8609584653f68b1d3ab584c85 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -13,6 +13,7 @@ from .qconfig import ( QConfig, calibration_qconfig, ema_fakequant_qconfig, + ema_lowbit_fakequant_qconfig, min_max_fakequant_qconfig, tqt_quant_qconfig, ) diff --git a/python_module/megengine/quantization/qconfig.py b/python_module/megengine/quantization/qconfig.py index 4a7b75ecb12dbdb746823609de12f9977219a86f..6606c1a513be2cf3d1a766a7c044f550b6c8480d 100644 --- a/python_module/megengine/quantization/qconfig.py +++ b/python_module/megengine/quantization/qconfig.py @@ -92,6 +92,15 @@ ema_fakequant_qconfig = QConfig( act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), ) +ema_lowbit_fakequant_qconfig = QConfig( + weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), + act_observer=partial( + ExponentialMovingAverageObserver, dtype="qint4", narrow_range=False + ), + weight_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), + act_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), +) + calibration_qconfig = QConfig( weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False),