From a1ca50c9232441812ffbc3663d8a67bfa46e099d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 25 Feb 2021 13:32:54 +0800 Subject: [PATCH] feat(mge/quantization): add name for quantized module GitOrigin-RevId: edefbec7b70953144105c558bae34e3f792c02ec --- imperative/python/megengine/module/conv.py | 2 + .../module/deformable_psroi_pooling.py | 3 +- imperative/python/megengine/module/module.py | 14 +- .../module/qat/batch_matmul_activation.py | 1 + .../python/megengine/module/qat/concat.py | 2 +- .../python/megengine/module/qat/conv.py | 1 + .../python/megengine/module/qat/conv_bn.py | 1 + .../python/megengine/module/qat/elemwise.py | 2 +- .../python/megengine/module/qat/linear.py | 4 +- .../python/megengine/module/qat/module.py | 4 +- .../megengine/module/qat/quant_dequant.py | 4 +- .../quantized/batch_matmul_activation.py | 5 +- .../megengine/module/quantized/concat.py | 6 +- .../python/megengine/module/quantized/conv.py | 6 +- .../megengine/module/quantized/conv_bn.py | 5 +- .../megengine/module/quantized/elemwise.py | 8 +- .../megengine/module/quantized/linear.py | 10 +- .../module/quantized/quant_dequant.py | 8 +- .../python/test/unit/test_dump_naming.py | 136 +++++++++++++++--- 19 files changed, 172 insertions(+), 50 deletions(-) diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 0d598e126..1ace62a6e 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -641,6 +641,7 @@ class DeformableConv2d(_ConvNd): bias: bool = True, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + **kwargs ): kernel_size = _pair_nonzero(kernel_size) stride = _pair_nonzero(stride) @@ -657,6 +658,7 @@ class DeformableConv2d(_ConvNd): dilation, groups, bias, + **kwargs, ) def _get_fanin(self): diff --git a/imperative/python/megengine/module/deformable_psroi_pooling.py b/imperative/python/megengine/module/deformable_psroi_pooling.py index 2791eddee..c3b0d5271 100644 --- a/imperative/python/megengine/module/deformable_psroi_pooling.py +++ b/imperative/python/megengine/module/deformable_psroi_pooling.py @@ -21,8 +21,9 @@ class DeformablePSROIPooling(Module): sample_per_part, spatial_scale, trans_std: float = 0.1, + **kwargs ): - super().__init__() + super().__init__(**kwargs) self.no_trans = no_trans self.part_size = part_size self.pooled_h = pooled_h diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 1861ee287..df857626f 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -69,7 +69,17 @@ class Module(metaclass=ABCMeta): Base Module class. """ - def __init__(self, name=""): + def __init__(self, name=None): + """ + :param name: module's name, can be initialized by the ``kwargs`` parameter + of child class. + """ + + if name is not None: + assert ( + isinstance(name, str) and name.strip() + ), "Module's name must be a non-empty string" + self.name = name # runtime attributes @@ -109,7 +119,7 @@ class Module(metaclass=ABCMeta): return HookHandler(self._forward_hooks, hook) def __call__(self, *inputs, **kwargs): - auto_naming.push_scope(self.name if self.name else self._name) + auto_naming.push_scope(self.name if self.name is not None else self._name) for hook in self._forward_pre_hooks.values(): modified_inputs = hook(self, inputs) if modified_inputs is not None: diff --git a/imperative/python/megengine/module/qat/batch_matmul_activation.py b/imperative/python/megengine/module/qat/batch_matmul_activation.py index 3de9e1753..e3e2a0b25 100644 --- a/imperative/python/megengine/module/qat/batch_matmul_activation.py +++ b/imperative/python/megengine/module/qat/batch_matmul_activation.py @@ -28,6 +28,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): float_module.in_features, float_module.out_features, float_module.bias is not None, + name=float_module.name, ) qat_module.weight = float_module.weight qat_module.bias = float_module.bias diff --git a/imperative/python/megengine/module/qat/concat.py b/imperative/python/megengine/module/qat/concat.py index 818462ca3..bfcca787a 100644 --- a/imperative/python/megengine/module/qat/concat.py +++ b/imperative/python/megengine/module/qat/concat.py @@ -27,4 +27,4 @@ class Concat(Float.Concat, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls() + return cls(name=float_module.name) diff --git a/imperative/python/megengine/module/qat/conv.py b/imperative/python/megengine/module/qat/conv.py index c5f842dce..f8205c955 100644 --- a/imperative/python/megengine/module/qat/conv.py +++ b/imperative/python/megengine/module/qat/conv.py @@ -43,6 +43,7 @@ class Conv2d(Float.Conv2d, QATModule): float_module.bias is not None, float_module.conv_mode, float_module.compute_mode, + name=float_module.name, ) qat_module.weight = float_module.weight qat_module.bias = float_module.bias diff --git a/imperative/python/megengine/module/qat/conv_bn.py b/imperative/python/megengine/module/qat/conv_bn.py index 409c49b5d..2cc5be088 100644 --- a/imperative/python/megengine/module/qat/conv_bn.py +++ b/imperative/python/megengine/module/qat/conv_bn.py @@ -155,6 +155,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): float_module.conv.bias is not None, float_module.conv.conv_mode, float_module.conv.compute_mode, + name=float_module.name, ) qat_module.conv.weight = float_module.conv.weight qat_module.conv.bias = float_module.conv.bias diff --git a/imperative/python/megengine/module/qat/elemwise.py b/imperative/python/megengine/module/qat/elemwise.py index 3692bdf18..956bf4fa1 100644 --- a/imperative/python/megengine/module/qat/elemwise.py +++ b/imperative/python/megengine/module/qat/elemwise.py @@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls(float_module.method) + return cls(float_module.method, name=float_module.name) diff --git a/imperative/python/megengine/module/qat/linear.py b/imperative/python/megengine/module/qat/linear.py index 98bf2452d..8647fc919 100644 --- a/imperative/python/megengine/module/qat/linear.py +++ b/imperative/python/megengine/module/qat/linear.py @@ -36,7 +36,9 @@ class Linear(Float.Linear, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - qmod = cls(float_module.in_features, float_module.out_features) + qmod = cls( + float_module.in_features, float_module.out_features, name=float_module.name + ) qmod.weight = float_module.weight qmod.bias = float_module.bias return qmod diff --git a/imperative/python/megengine/module/qat/module.py b/imperative/python/megengine/module/qat/module.py index 85eca4495..82b2911cc 100644 --- a/imperative/python/megengine/module/qat/module.py +++ b/imperative/python/megengine/module/qat/module.py @@ -26,8 +26,8 @@ class QATModule(Module): with_weight = True with_act = True - def __init__(self): - super().__init__() + def __init__(self, **kwargs): + super().__init__(**kwargs) self.weight_observer = None # type: Observer self.act_observer = None # type: Observer diff --git a/imperative/python/megengine/module/qat/quant_dequant.py b/imperative/python/megengine/module/qat/quant_dequant.py index e75a35b8e..580b5f916 100644 --- a/imperative/python/megengine/module/qat/quant_dequant.py +++ b/imperative/python/megengine/module/qat/quant_dequant.py @@ -26,7 +26,7 @@ class QuantStub(Float.QuantStub, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls() + return cls(name=float_module.name) class DequantStub(Float.DequantStub, QATModule): @@ -47,4 +47,4 @@ class DequantStub(Float.DequantStub, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls() + return cls(name=float_module.name) diff --git a/imperative/python/megengine/module/quantized/batch_matmul_activation.py b/imperative/python/megengine/module/quantized/batch_matmul_activation.py index 0ce763f59..e115c1463 100644 --- a/imperative/python/megengine/module/quantized/batch_matmul_activation.py +++ b/imperative/python/megengine/module/quantized/batch_matmul_activation.py @@ -61,13 +61,14 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): qat_module.out_features, qat_module.bias is not None, dtype=output_dtype, + name=qat_module.name, ) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) weight = expand_dims(weight, [-1, -2]) - qbmm.weight = Parameter(weight.numpy()) + qbmm.weight = Parameter(weight.numpy(), name=qat_module.weight.name) if qat_module.bias is not None: bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1)) - qbmm.bias = Parameter(bias.numpy()) + qbmm.bias = Parameter(bias.numpy(), name=qat_module.bias.name) else: qbmm.bias = Parameter( np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32) diff --git a/imperative/python/megengine/module/quantized/concat.py b/imperative/python/megengine/module/quantized/concat.py index f8cc6b8a1..11af85291 100644 --- a/imperative/python/megengine/module/quantized/concat.py +++ b/imperative/python/megengine/module/quantized/concat.py @@ -18,8 +18,8 @@ class Concat(QuantizedModule): A :class:`~.QuantizedModule` to do quantized :func:`~.concat`, used for inference only. """ - def __init__(self, dtype=None): - super().__init__() + def __init__(self, dtype=None, **kwargs): + super().__init__(**kwargs) self.output_dtype = dtype def forward(self, inps: Iterable[Tensor], axis: int = 0): @@ -32,4 +32,4 @@ class Concat(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls(qat_module.get_activation_dtype()) + return cls(qat_module.get_activation_dtype(), name=qat_module.name) diff --git a/imperative/python/megengine/module/quantized/conv.py b/imperative/python/megengine/module/quantized/conv.py index 34e51a726..0b2ad2fa8 100644 --- a/imperative/python/megengine/module/quantized/conv.py +++ b/imperative/python/megengine/module/quantized/conv.py @@ -37,6 +37,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", dtype=None, + **kwargs ): super().__init__( in_channels, @@ -86,11 +87,12 @@ class Conv2d(Float.Conv2d, QuantizedModule): qat_module.dilation, qat_module.groups, dtype=output_dtype, + name=qat_module.name, ) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) - qconv.weight = Parameter(weight.numpy()) + qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name) if qat_module.bias is not None: - qconv.bias = Parameter(qat_module.bias.numpy()) + qconv.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) else: qconv.bias = Parameter( np.zeros(qat_module._infer_bias_shape(), dtype=np.float32) diff --git a/imperative/python/megengine/module/quantized/conv_bn.py b/imperative/python/megengine/module/quantized/conv_bn.py index 55b9466a0..4b3b5772b 100644 --- a/imperative/python/megengine/module/quantized/conv_bn.py +++ b/imperative/python/megengine/module/quantized/conv_bn.py @@ -33,13 +33,14 @@ class _ConvBnActivation2d(Conv2d): qat_module.conv.dilation, qat_module.conv.groups, dtype=output_dtype, + name=qat_module.name, ) w_fold, b_fold = qat_module.fold_weight_bias( qat_module.bn.running_mean, qat_module.bn.running_var ) weight = w_fold.astype(qat_module.get_weight_dtype()) - qconv.weight = Parameter(weight.numpy()) - qconv.bias = Parameter(b_fold.numpy()) + qconv.weight = Parameter(weight.numpy(), name=qat_module.conv.weight.name) + qconv.bias = Parameter(b_fold.numpy(), name=qat_module.conv.bias.name) return qconv diff --git a/imperative/python/megengine/module/quantized/elemwise.py b/imperative/python/megengine/module/quantized/elemwise.py index 6a76c7b84..46950c8f3 100644 --- a/imperative/python/megengine/module/quantized/elemwise.py +++ b/imperative/python/megengine/module/quantized/elemwise.py @@ -14,8 +14,8 @@ from .module import QuantizedModule class Elemwise(QuantizedModule): r"""Quantized version of :class:`~.qat.Elemwise`.""" - def __init__(self, method, dtype=None): - super().__init__() + def __init__(self, method, dtype=None, **kwargs): + super().__init__(**kwargs) self.method = "Q" + method self.output_dtype = dtype @@ -30,4 +30,6 @@ class Elemwise(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls(qat_module.method, qat_module.get_activation_dtype()) + return cls( + qat_module.method, qat_module.get_activation_dtype(), name=qat_module.name + ) diff --git a/imperative/python/megengine/module/quantized/linear.py b/imperative/python/megengine/module/quantized/linear.py index 51a32581d..7ecc55a84 100644 --- a/imperative/python/megengine/module/quantized/linear.py +++ b/imperative/python/megengine/module/quantized/linear.py @@ -17,8 +17,8 @@ from .module import QuantizedModule class Linear(QuantizedModule): r"""Quantized version of :class:`~.qat.Linear`.""" - def __init__(self, dtype: np.dtype = None): - super().__init__() + def __init__(self, dtype: np.dtype = None, **kwargs): + super().__init__(**kwargs) self.weight = None self.bias = None self.output_dtype = dtype @@ -44,9 +44,9 @@ class Linear(QuantizedModule): :class:`~.QATModule` instance. """ output_dtype = qat_module.get_activation_dtype() - qmod = cls(dtype=output_dtype) + qmod = cls(dtype=output_dtype, name=qat_module.name) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) - qmod.weight = Parameter(weight.numpy()) + qmod.weight = Parameter(weight.numpy(), name=qat_module.weight.name) if qat_module.bias is not None: - qmod.bias = Parameter(qat_module.bias.numpy()) + qmod.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) return qmod diff --git a/imperative/python/megengine/module/quantized/quant_dequant.py b/imperative/python/megengine/module/quantized/quant_dequant.py index c8eadafee..d17ca0de3 100644 --- a/imperative/python/megengine/module/quantized/quant_dequant.py +++ b/imperative/python/megengine/module/quantized/quant_dequant.py @@ -15,8 +15,8 @@ class QuantStub(QuantizedModule): will convert input to quantized dtype. """ - def __init__(self, dtype=None): - super().__init__() + def __init__(self, dtype=None, **kwargs): + super().__init__(**kwargs) self.output_dtype = dtype def forward(self, inp): @@ -28,7 +28,7 @@ class QuantStub(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls(qat_module.get_activation_dtype()) + return cls(qat_module.get_activation_dtype(), name=qat_module.name) class DequantStub(QuantizedModule): @@ -46,4 +46,4 @@ class DequantStub(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls() + return cls(name=qat_module.name) diff --git a/imperative/python/test/unit/test_dump_naming.py b/imperative/python/test/unit/test_dump_naming.py index 019d88547..eabab1b44 100644 --- a/imperative/python/test/unit/test_dump_naming.py +++ b/imperative/python/test/unit/test_dump_naming.py @@ -17,6 +17,7 @@ import megengine.utils.comp_graph_tools as cgtools from megengine import Parameter, Tensor from megengine.core.tensor import megbrain_graph as G from megengine.jit.tracing import trace +from megengine.quantization.quantize import quantize, quantize_qat from megengine.utils.naming import auto_naming @@ -29,14 +30,14 @@ def _dump_and_load(func, symbolic, keep_opr_name=True): func.dump( file, optimize_for_inference=False, - arg_names="x", + arg_names=("x",), keep_opr_name=keep_opr_name, keep_var_name=2, ) file.seek(0) *_, outputs = G.load_graph(file) - op = cgtools.get_oprs_seq(outputs)[-1] - return op + ops = cgtools.get_oprs_seq(outputs) + return ops @pytest.mark.parametrize("symbolic", [False, True]) @@ -50,7 +51,7 @@ def test_auto_naming(symbolic): return x + x m = Simple("simple") - op = _dump_and_load(m, symbolic) + op = _dump_and_load(m, symbolic)[-1] assert op.name == "simple.ADD" assert op.outputs[0].name == "simple.ADD" @@ -70,7 +71,7 @@ def test_user_named_tensor(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) + op = _dump_and_load(m, symbolic)[-1] assert op.name == "simple.ADD" assert op.outputs[0].name == "o_x" @@ -88,7 +89,7 @@ def test_user_named_param(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) + op = _dump_and_load(m, symbolic)[-1] assert op.inputs[0].name == "x" assert op.inputs[1].name == "simple.k" @@ -98,7 +99,7 @@ def test_without_module(symbolic): def f(x): return 2 * x - op = _dump_and_load(f, symbolic) + op = _dump_and_load(f, symbolic)[-1] assert op.name == "MUL" @@ -116,10 +117,10 @@ def test_with_submodule(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) - assert op.name == "simple.linear.ADD" - assert op.inputs[0].owner.name == "simple.linear.MatrixMul" - assert op.outputs[0].name == "simple.linear.ADD" + ops = _dump_and_load(m, symbolic) + assert ops[-1].name == "simple.linear.ADD" + assert ops[-2].name == "simple.linear.MatrixMul" + assert ops[-1].outputs[0].name == "simple.linear.ADD" @pytest.mark.parametrize("symbolic", [False, True]) @@ -136,10 +137,10 @@ def test_named_submodule(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) - assert op.name == "simple.x.ADD" - assert op.inputs[0].owner.name == "simple.x.MatrixMul" - assert op.outputs[0].name == "simple.x.ADD" + ops = _dump_and_load(m, symbolic) + assert ops[-1].name == "simple.x.ADD" + assert ops[-2].name == "simple.x.MatrixMul" + assert ops[-1].outputs[0].name == "simple.x.ADD" @pytest.mark.parametrize("symbolic", [False, True]) @@ -156,14 +157,111 @@ def test_with_same_operators(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) - assert op.name == "simple.RELU[1]" - assert op.inputs[0].owner.name == "simple.RELU[0]" + ops = _dump_and_load(m, symbolic) + assert ops[-1].name == "simple.RELU[1]" + assert ops[-2].name == "simple.RELU[0]" def test_not_keep_opr_name(): def f(x): return 2 * x - op = _dump_and_load(f, True, False) + op = _dump_and_load(f, True, False)[-1] assert op.name == "MUL(x,2[2])[4]" + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_quantized_module_auto_naming(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__(name=name) + self.quant = M.QuantStub() + self.linear = M.Linear(3, 3, bias=True) + self.dequant = M.DequantStub() + + def forward(self, x): + out = self.quant(x) + out = self.linear(out) + out = self.dequant(out) + return out + + m = Simple("simple") + quantize_qat(m) + quantize(m) + m.eval() + + ops = _dump_and_load(m, symbolic) + ops_name = ( + "x", + "simple.quant.TypeCvt", + "simple.linear.MatrixMul", + "simple.linear.ADD", + "simple.linear.TypeCvt", + "simple.dequant.TypeCvt", + ) + for op, name in zip(ops, ops_name): + assert op.name == name + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_quantized_module_user_naming(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__(name=name) + self.quant = M.QuantStub() + self.linear = M.Linear(3, 3, bias=True, name="user-linear") + self.dequant = M.DequantStub() + + def forward(self, x): + out = self.quant(x) + out = self.linear(out) + out = self.dequant(out) + return out + + m = Simple("simple") + quantize_qat(m) + quantize(m) + m.eval() + + ops = _dump_and_load(m, symbolic) + ops_name = ( + "x", + "simple.quant.TypeCvt", + "simple.user-linear.MatrixMul", + "simple.user-linear.ADD", + "simple.user-linear.TypeCvt", + "simple.dequant.TypeCvt", + ) + for op, name in zip(ops, ops_name): + assert op.name == name + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_quantized_module_user_naming_param(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__(name=name) + self.quant = M.QuantStub() + self.linear = M.Linear(3, 3, bias=True) + self.dequant = M.DequantStub() + + self.linear.weight.name = "user-weight" + self.linear.bias.name = "user-bias" + + def forward(self, x): + out = self.quant(x) + out = self.linear(out) + out = self.dequant(out) + return out + + m = Simple("simple") + quantize_qat(m) + quantize(m) + m.eval() + + ops = _dump_and_load(m, symbolic) + + (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"] + for var in matrix_mul_op.inputs: + assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight") + # BUG bias' name does not meet expectations because of astype operator after quantization -- GitLab