提交 5c232352 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

test(mge/quantization): add `quantize_disabled` related test

GitOrigin-RevId: f62ba600c537f60e1229b852b97cd67897bf236d
上级 ab913025
......@@ -318,6 +318,7 @@ class Module(metaclass=ABCMeta):
Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Could be used as a decorator.
"""
def fn(module: Module) -> None:
module.quantize_diabled = value
......
......@@ -80,9 +80,7 @@ def quantize(module: Module, inplace: bool = True):
def quantize_qat(
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig,
):
r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule`
......
......@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import module as Float
from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat
def test_get_quantable_module_names():
......@@ -36,3 +36,19 @@ def test_get_quantable_module_names():
and issubclass(value, Float.Module)
and value != Float.Module
)
def test_disable_quantize():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.conv = Float.ConvBnRelu2d(3, 3, 3)
self.conv.disable_quantize()
def forward(self, x):
return self.conv(x)
net = Net()
qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册