提交 5c232352 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

test(mge/quantization): add `quantize_disabled` related test

GitOrigin-RevId: f62ba600c537f60e1229b852b97cd67897bf236d
上级 ab913025
...@@ -318,6 +318,7 @@ class Module(metaclass=ABCMeta): ...@@ -318,6 +318,7 @@ class Module(metaclass=ABCMeta):
Set ``module``'s ``quantize_diabled`` attribute and return ``module``. Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Could be used as a decorator. Could be used as a decorator.
""" """
def fn(module: Module) -> None: def fn(module: Module) -> None:
module.quantize_diabled = value module.quantize_diabled = value
......
...@@ -80,9 +80,7 @@ def quantize(module: Module, inplace: bool = True): ...@@ -80,9 +80,7 @@ def quantize(module: Module, inplace: bool = True):
def quantize_qat( def quantize_qat(
module: Module, module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
): ):
r""" r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule` Recursively convert float :class:`~.Module` to :class:`~.QATModule`
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import module as Float from megengine import module as Float
from megengine.module import qat as QAT from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat
def test_get_quantable_module_names(): def test_get_quantable_module_names():
...@@ -36,3 +36,19 @@ def test_get_quantable_module_names(): ...@@ -36,3 +36,19 @@ def test_get_quantable_module_names():
and issubclass(value, Float.Module) and issubclass(value, Float.Module)
and value != Float.Module and value != Float.Module
) )
def test_disable_quantize():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.conv = Float.ConvBnRelu2d(3, 3, 3)
self.conv.disable_quantize()
def forward(self, x):
return self.conv(x)
net = Net()
qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册