From 5c2323529d02cc8cf7d1975ddb7518ccac00a8f0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 12 Jun 2020 12:00:24 +0800 Subject: [PATCH] test(mge/quantization): add `quantize_disabled` related test GitOrigin-RevId: f62ba600c537f60e1229b852b97cd67897bf236d --- python_module/megengine/module/module.py | 1 + .../megengine/quantization/quantize.py | 4 +--- .../test/unit/quantization/quantize.py | 18 +++++++++++++++++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 105488254..d91fa94ec 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -318,6 +318,7 @@ class Module(metaclass=ABCMeta): Set ``module``'s ``quantize_diabled`` attribute and return ``module``. Could be used as a decorator. """ + def fn(module: Module) -> None: module.quantize_diabled = value diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 1d2e6b1fb..36d6cc00f 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -80,9 +80,7 @@ def quantize(module: Module, inplace: bool = True): def quantize_qat( - module: Module, - inplace: bool = True, - qconfig: QConfig = ema_fakequant_qconfig, + module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig, ): r""" Recursively convert float :class:`~.Module` to :class:`~.QATModule` diff --git a/python_module/test/unit/quantization/quantize.py b/python_module/test/unit/quantization/quantize.py index 36cb5279e..14e9acb0a 100644 --- a/python_module/test/unit/quantization/quantize.py +++ b/python_module/test/unit/quantization/quantize.py @@ -7,7 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from megengine import module as Float from megengine.module import qat as QAT -from megengine.quantization.quantize import _get_quantable_module_names +from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat def test_get_quantable_module_names(): @@ -36,3 +36,19 @@ def test_get_quantable_module_names(): and issubclass(value, Float.Module) and value != Float.Module ) + + +def test_disable_quantize(): + class Net(Float.Module): + def __init__(self): + super().__init__() + self.conv = Float.ConvBnRelu2d(3, 3, 3) + self.conv.disable_quantize() + + def forward(self, x): + return self.conv(x) + + net = Net() + qat_net = quantize_qat(net, inplace=False) + assert isinstance(qat_net.conv, Float.ConvBnRelu2d) + assert isinstance(qat_net.conv.conv, Float.Conv2d) -- GitLab