diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 0d598e1260c271f86bb5bcc62a0ee45d1a236f81..1ace62a6e03dd09a3d4ae5729dd3ea39bc2c68a1 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -641,6 +641,7 @@ class DeformableConv2d(_ConvNd): bias: bool = True, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + **kwargs ): kernel_size = _pair_nonzero(kernel_size) stride = _pair_nonzero(stride) @@ -657,6 +658,7 @@ class DeformableConv2d(_ConvNd): dilation, groups, bias, + **kwargs, ) def _get_fanin(self): diff --git a/imperative/python/megengine/module/deformable_psroi_pooling.py b/imperative/python/megengine/module/deformable_psroi_pooling.py index 2791eddee66038830878dc625ed6c11f08fb532f..c3b0d52716153c3c8498656354ded2e2a5b08481 100644 --- a/imperative/python/megengine/module/deformable_psroi_pooling.py +++ b/imperative/python/megengine/module/deformable_psroi_pooling.py @@ -21,8 +21,9 @@ class DeformablePSROIPooling(Module): sample_per_part, spatial_scale, trans_std: float = 0.1, + **kwargs ): - super().__init__() + super().__init__(**kwargs) self.no_trans = no_trans self.part_size = part_size self.pooled_h = pooled_h diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 1861ee287062d89e630ff716aefb73ded5b8ea8c..df857626f5477c1fc41230b1732cfef16d22e4ce 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -69,7 +69,17 @@ class Module(metaclass=ABCMeta): Base Module class. """ - def __init__(self, name=""): + def __init__(self, name=None): + """ + :param name: module's name, can be initialized by the ``kwargs`` parameter + of child class. + """ + + if name is not None: + assert ( + isinstance(name, str) and name.strip() + ), "Module's name must be a non-empty string" + self.name = name # runtime attributes @@ -109,7 +119,7 @@ class Module(metaclass=ABCMeta): return HookHandler(self._forward_hooks, hook) def __call__(self, *inputs, **kwargs): - auto_naming.push_scope(self.name if self.name else self._name) + auto_naming.push_scope(self.name if self.name is not None else self._name) for hook in self._forward_pre_hooks.values(): modified_inputs = hook(self, inputs) if modified_inputs is not None: diff --git a/imperative/python/megengine/module/qat/batch_matmul_activation.py b/imperative/python/megengine/module/qat/batch_matmul_activation.py index 3de9e1753dcb369b6f444d0ba28b2474c964905a..e3e2a0b2579d772f9241053b766897e3d31fdfd1 100644 --- a/imperative/python/megengine/module/qat/batch_matmul_activation.py +++ b/imperative/python/megengine/module/qat/batch_matmul_activation.py @@ -28,6 +28,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): float_module.in_features, float_module.out_features, float_module.bias is not None, + name=float_module.name, ) qat_module.weight = float_module.weight qat_module.bias = float_module.bias diff --git a/imperative/python/megengine/module/qat/concat.py b/imperative/python/megengine/module/qat/concat.py index 818462ca3d452d33a9c83829cfea7a334429e3cc..bfcca787ab403c0e5c0b18aa2f0dddd994cac7e6 100644 --- a/imperative/python/megengine/module/qat/concat.py +++ b/imperative/python/megengine/module/qat/concat.py @@ -27,4 +27,4 @@ class Concat(Float.Concat, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls() + return cls(name=float_module.name) diff --git a/imperative/python/megengine/module/qat/conv.py b/imperative/python/megengine/module/qat/conv.py index c5f842dcea904c124d42415a82e487354a73a4ac..f8205c955c2e329ef6b3c1c451e5eb28c00d5ac7 100644 --- a/imperative/python/megengine/module/qat/conv.py +++ b/imperative/python/megengine/module/qat/conv.py @@ -43,6 +43,7 @@ class Conv2d(Float.Conv2d, QATModule): float_module.bias is not None, float_module.conv_mode, float_module.compute_mode, + name=float_module.name, ) qat_module.weight = float_module.weight qat_module.bias = float_module.bias diff --git a/imperative/python/megengine/module/qat/conv_bn.py b/imperative/python/megengine/module/qat/conv_bn.py index 409c49b5d7637a39a008f2501d57cf83b3445598..2cc5be08899a7b89b9dbd88414ea763c883f00b1 100644 --- a/imperative/python/megengine/module/qat/conv_bn.py +++ b/imperative/python/megengine/module/qat/conv_bn.py @@ -155,6 +155,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): float_module.conv.bias is not None, float_module.conv.conv_mode, float_module.conv.compute_mode, + name=float_module.name, ) qat_module.conv.weight = float_module.conv.weight qat_module.conv.bias = float_module.conv.bias diff --git a/imperative/python/megengine/module/qat/elemwise.py b/imperative/python/megengine/module/qat/elemwise.py index 3692bdf18bb7d74652c9bb06ede5be6f9b266436..956bf4fa1611ea78de5df5f26b3b0c80d862bb31 100644 --- a/imperative/python/megengine/module/qat/elemwise.py +++ b/imperative/python/megengine/module/qat/elemwise.py @@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls(float_module.method) + return cls(float_module.method, name=float_module.name) diff --git a/imperative/python/megengine/module/qat/linear.py b/imperative/python/megengine/module/qat/linear.py index 98bf2452d78727c1d807b1402e827fe7bd8e8f10..8647fc9191ede54a45184360bf225d3c18882b10 100644 --- a/imperative/python/megengine/module/qat/linear.py +++ b/imperative/python/megengine/module/qat/linear.py @@ -36,7 +36,9 @@ class Linear(Float.Linear, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - qmod = cls(float_module.in_features, float_module.out_features) + qmod = cls( + float_module.in_features, float_module.out_features, name=float_module.name + ) qmod.weight = float_module.weight qmod.bias = float_module.bias return qmod diff --git a/imperative/python/megengine/module/qat/module.py b/imperative/python/megengine/module/qat/module.py index 85eca44951290af5a7bf3a85ddb30f3533caf84b..82b2911cce169937b8377912b23ca4267db3429d 100644 --- a/imperative/python/megengine/module/qat/module.py +++ b/imperative/python/megengine/module/qat/module.py @@ -26,8 +26,8 @@ class QATModule(Module): with_weight = True with_act = True - def __init__(self): - super().__init__() + def __init__(self, **kwargs): + super().__init__(**kwargs) self.weight_observer = None # type: Observer self.act_observer = None # type: Observer diff --git a/imperative/python/megengine/module/qat/quant_dequant.py b/imperative/python/megengine/module/qat/quant_dequant.py index e75a35b8e6faba5da222bcdd451023d86166762a..580b5f9162a764f5251ccd5f28df12d8ae632f5b 100644 --- a/imperative/python/megengine/module/qat/quant_dequant.py +++ b/imperative/python/megengine/module/qat/quant_dequant.py @@ -26,7 +26,7 @@ class QuantStub(Float.QuantStub, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls() + return cls(name=float_module.name) class DequantStub(Float.DequantStub, QATModule): @@ -47,4 +47,4 @@ class DequantStub(Float.DequantStub, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls() + return cls(name=float_module.name) diff --git a/imperative/python/megengine/module/quantized/batch_matmul_activation.py b/imperative/python/megengine/module/quantized/batch_matmul_activation.py index 0ce763f5903bddfc5d0288ad5af7f272b3905ca9..e115c1463fa3e0f95723d2271f1846037df15905 100644 --- a/imperative/python/megengine/module/quantized/batch_matmul_activation.py +++ b/imperative/python/megengine/module/quantized/batch_matmul_activation.py @@ -61,13 +61,14 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): qat_module.out_features, qat_module.bias is not None, dtype=output_dtype, + name=qat_module.name, ) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) weight = expand_dims(weight, [-1, -2]) - qbmm.weight = Parameter(weight.numpy()) + qbmm.weight = Parameter(weight.numpy(), name=qat_module.weight.name) if qat_module.bias is not None: bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1)) - qbmm.bias = Parameter(bias.numpy()) + qbmm.bias = Parameter(bias.numpy(), name=qat_module.bias.name) else: qbmm.bias = Parameter( np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32) diff --git a/imperative/python/megengine/module/quantized/concat.py b/imperative/python/megengine/module/quantized/concat.py index f8cc6b8a13cdd98d0fb8624ad682710018cae6a9..11af852913f7430b27688663eb65808f39159a54 100644 --- a/imperative/python/megengine/module/quantized/concat.py +++ b/imperative/python/megengine/module/quantized/concat.py @@ -18,8 +18,8 @@ class Concat(QuantizedModule): A :class:`~.QuantizedModule` to do quantized :func:`~.concat`, used for inference only. """ - def __init__(self, dtype=None): - super().__init__() + def __init__(self, dtype=None, **kwargs): + super().__init__(**kwargs) self.output_dtype = dtype def forward(self, inps: Iterable[Tensor], axis: int = 0): @@ -32,4 +32,4 @@ class Concat(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls(qat_module.get_activation_dtype()) + return cls(qat_module.get_activation_dtype(), name=qat_module.name) diff --git a/imperative/python/megengine/module/quantized/conv.py b/imperative/python/megengine/module/quantized/conv.py index 34e51a726e908250b234bf3bf78d452e46d0dc52..0b2ad2fa88ae67b6640ba4e163d4e1b1e94eefc4 100644 --- a/imperative/python/megengine/module/quantized/conv.py +++ b/imperative/python/megengine/module/quantized/conv.py @@ -37,6 +37,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", dtype=None, + **kwargs ): super().__init__( in_channels, @@ -86,11 +87,12 @@ class Conv2d(Float.Conv2d, QuantizedModule): qat_module.dilation, qat_module.groups, dtype=output_dtype, + name=qat_module.name, ) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) - qconv.weight = Parameter(weight.numpy()) + qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name) if qat_module.bias is not None: - qconv.bias = Parameter(qat_module.bias.numpy()) + qconv.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) else: qconv.bias = Parameter( np.zeros(qat_module._infer_bias_shape(), dtype=np.float32) diff --git a/imperative/python/megengine/module/quantized/conv_bn.py b/imperative/python/megengine/module/quantized/conv_bn.py index 55b9466a01a9e066a69ecbce26549771c45c266f..4b3b5772b65988135d10e1056366d63be98d0669 100644 --- a/imperative/python/megengine/module/quantized/conv_bn.py +++ b/imperative/python/megengine/module/quantized/conv_bn.py @@ -33,13 +33,14 @@ class _ConvBnActivation2d(Conv2d): qat_module.conv.dilation, qat_module.conv.groups, dtype=output_dtype, + name=qat_module.name, ) w_fold, b_fold = qat_module.fold_weight_bias( qat_module.bn.running_mean, qat_module.bn.running_var ) weight = w_fold.astype(qat_module.get_weight_dtype()) - qconv.weight = Parameter(weight.numpy()) - qconv.bias = Parameter(b_fold.numpy()) + qconv.weight = Parameter(weight.numpy(), name=qat_module.conv.weight.name) + qconv.bias = Parameter(b_fold.numpy(), name=qat_module.conv.bias.name) return qconv diff --git a/imperative/python/megengine/module/quantized/elemwise.py b/imperative/python/megengine/module/quantized/elemwise.py index 6a76c7b84911fcc21182e832c511370a417fcc9c..46950c8f33a2d9f5ff67b451e7cc6774771b7aab 100644 --- a/imperative/python/megengine/module/quantized/elemwise.py +++ b/imperative/python/megengine/module/quantized/elemwise.py @@ -14,8 +14,8 @@ from .module import QuantizedModule class Elemwise(QuantizedModule): r"""Quantized version of :class:`~.qat.Elemwise`.""" - def __init__(self, method, dtype=None): - super().__init__() + def __init__(self, method, dtype=None, **kwargs): + super().__init__(**kwargs) self.method = "Q" + method self.output_dtype = dtype @@ -30,4 +30,6 @@ class Elemwise(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls(qat_module.method, qat_module.get_activation_dtype()) + return cls( + qat_module.method, qat_module.get_activation_dtype(), name=qat_module.name + ) diff --git a/imperative/python/megengine/module/quantized/linear.py b/imperative/python/megengine/module/quantized/linear.py index 51a32581dfcb753c9a36b9eb0a80ac040950864a..7ecc55a84167acc6ce5633e79d8b8f45382c15e2 100644 --- a/imperative/python/megengine/module/quantized/linear.py +++ b/imperative/python/megengine/module/quantized/linear.py @@ -17,8 +17,8 @@ from .module import QuantizedModule class Linear(QuantizedModule): r"""Quantized version of :class:`~.qat.Linear`.""" - def __init__(self, dtype: np.dtype = None): - super().__init__() + def __init__(self, dtype: np.dtype = None, **kwargs): + super().__init__(**kwargs) self.weight = None self.bias = None self.output_dtype = dtype @@ -44,9 +44,9 @@ class Linear(QuantizedModule): :class:`~.QATModule` instance. """ output_dtype = qat_module.get_activation_dtype() - qmod = cls(dtype=output_dtype) + qmod = cls(dtype=output_dtype, name=qat_module.name) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) - qmod.weight = Parameter(weight.numpy()) + qmod.weight = Parameter(weight.numpy(), name=qat_module.weight.name) if qat_module.bias is not None: - qmod.bias = Parameter(qat_module.bias.numpy()) + qmod.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) return qmod diff --git a/imperative/python/megengine/module/quantized/quant_dequant.py b/imperative/python/megengine/module/quantized/quant_dequant.py index c8eadafeef62f49fb68432f5921f5df154c9e700..d17ca0de361fbf3d5259af26139000bc3dcbb1fd 100644 --- a/imperative/python/megengine/module/quantized/quant_dequant.py +++ b/imperative/python/megengine/module/quantized/quant_dequant.py @@ -15,8 +15,8 @@ class QuantStub(QuantizedModule): will convert input to quantized dtype. """ - def __init__(self, dtype=None): - super().__init__() + def __init__(self, dtype=None, **kwargs): + super().__init__(**kwargs) self.output_dtype = dtype def forward(self, inp): @@ -28,7 +28,7 @@ class QuantStub(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls(qat_module.get_activation_dtype()) + return cls(qat_module.get_activation_dtype(), name=qat_module.name) class DequantStub(QuantizedModule): @@ -46,4 +46,4 @@ class DequantStub(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls() + return cls(name=qat_module.name) diff --git a/imperative/python/test/unit/test_dump_naming.py b/imperative/python/test/unit/test_dump_naming.py index 019d885479e2d1bbd8f8fd2ea3f7363e39948d6e..eabab1b44f44d80c12868614e7591326edc7874d 100644 --- a/imperative/python/test/unit/test_dump_naming.py +++ b/imperative/python/test/unit/test_dump_naming.py @@ -17,6 +17,7 @@ import megengine.utils.comp_graph_tools as cgtools from megengine import Parameter, Tensor from megengine.core.tensor import megbrain_graph as G from megengine.jit.tracing import trace +from megengine.quantization.quantize import quantize, quantize_qat from megengine.utils.naming import auto_naming @@ -29,14 +30,14 @@ def _dump_and_load(func, symbolic, keep_opr_name=True): func.dump( file, optimize_for_inference=False, - arg_names="x", + arg_names=("x",), keep_opr_name=keep_opr_name, keep_var_name=2, ) file.seek(0) *_, outputs = G.load_graph(file) - op = cgtools.get_oprs_seq(outputs)[-1] - return op + ops = cgtools.get_oprs_seq(outputs) + return ops @pytest.mark.parametrize("symbolic", [False, True]) @@ -50,7 +51,7 @@ def test_auto_naming(symbolic): return x + x m = Simple("simple") - op = _dump_and_load(m, symbolic) + op = _dump_and_load(m, symbolic)[-1] assert op.name == "simple.ADD" assert op.outputs[0].name == "simple.ADD" @@ -70,7 +71,7 @@ def test_user_named_tensor(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) + op = _dump_and_load(m, symbolic)[-1] assert op.name == "simple.ADD" assert op.outputs[0].name == "o_x" @@ -88,7 +89,7 @@ def test_user_named_param(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) + op = _dump_and_load(m, symbolic)[-1] assert op.inputs[0].name == "x" assert op.inputs[1].name == "simple.k" @@ -98,7 +99,7 @@ def test_without_module(symbolic): def f(x): return 2 * x - op = _dump_and_load(f, symbolic) + op = _dump_and_load(f, symbolic)[-1] assert op.name == "MUL" @@ -116,10 +117,10 @@ def test_with_submodule(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) - assert op.name == "simple.linear.ADD" - assert op.inputs[0].owner.name == "simple.linear.MatrixMul" - assert op.outputs[0].name == "simple.linear.ADD" + ops = _dump_and_load(m, symbolic) + assert ops[-1].name == "simple.linear.ADD" + assert ops[-2].name == "simple.linear.MatrixMul" + assert ops[-1].outputs[0].name == "simple.linear.ADD" @pytest.mark.parametrize("symbolic", [False, True]) @@ -136,10 +137,10 @@ def test_named_submodule(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) - assert op.name == "simple.x.ADD" - assert op.inputs[0].owner.name == "simple.x.MatrixMul" - assert op.outputs[0].name == "simple.x.ADD" + ops = _dump_and_load(m, symbolic) + assert ops[-1].name == "simple.x.ADD" + assert ops[-2].name == "simple.x.MatrixMul" + assert ops[-1].outputs[0].name == "simple.x.ADD" @pytest.mark.parametrize("symbolic", [False, True]) @@ -156,14 +157,111 @@ def test_with_same_operators(symbolic): m = Simple("simple") - op = _dump_and_load(m, symbolic) - assert op.name == "simple.RELU[1]" - assert op.inputs[0].owner.name == "simple.RELU[0]" + ops = _dump_and_load(m, symbolic) + assert ops[-1].name == "simple.RELU[1]" + assert ops[-2].name == "simple.RELU[0]" def test_not_keep_opr_name(): def f(x): return 2 * x - op = _dump_and_load(f, True, False) + op = _dump_and_load(f, True, False)[-1] assert op.name == "MUL(x,2[2])[4]" + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_quantized_module_auto_naming(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__(name=name) + self.quant = M.QuantStub() + self.linear = M.Linear(3, 3, bias=True) + self.dequant = M.DequantStub() + + def forward(self, x): + out = self.quant(x) + out = self.linear(out) + out = self.dequant(out) + return out + + m = Simple("simple") + quantize_qat(m) + quantize(m) + m.eval() + + ops = _dump_and_load(m, symbolic) + ops_name = ( + "x", + "simple.quant.TypeCvt", + "simple.linear.MatrixMul", + "simple.linear.ADD", + "simple.linear.TypeCvt", + "simple.dequant.TypeCvt", + ) + for op, name in zip(ops, ops_name): + assert op.name == name + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_quantized_module_user_naming(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__(name=name) + self.quant = M.QuantStub() + self.linear = M.Linear(3, 3, bias=True, name="user-linear") + self.dequant = M.DequantStub() + + def forward(self, x): + out = self.quant(x) + out = self.linear(out) + out = self.dequant(out) + return out + + m = Simple("simple") + quantize_qat(m) + quantize(m) + m.eval() + + ops = _dump_and_load(m, symbolic) + ops_name = ( + "x", + "simple.quant.TypeCvt", + "simple.user-linear.MatrixMul", + "simple.user-linear.ADD", + "simple.user-linear.TypeCvt", + "simple.dequant.TypeCvt", + ) + for op, name in zip(ops, ops_name): + assert op.name == name + + +@pytest.mark.parametrize("symbolic", [False, True]) +def test_quantized_module_user_naming_param(symbolic): + class Simple(M.Module): + def __init__(self, name): + super().__init__(name=name) + self.quant = M.QuantStub() + self.linear = M.Linear(3, 3, bias=True) + self.dequant = M.DequantStub() + + self.linear.weight.name = "user-weight" + self.linear.bias.name = "user-bias" + + def forward(self, x): + out = self.quant(x) + out = self.linear(out) + out = self.dequant(out) + return out + + m = Simple("simple") + quantize_qat(m) + quantize(m) + m.eval() + + ops = _dump_and_load(m, symbolic) + + (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"] + for var in matrix_mul_op.inputs: + assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight") + # BUG bias' name does not meet expectations because of astype operator after quantization