diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 7e06a355d152de5d7053638fffa0e3f3bd2a8745..852d55185d2315d1576ffee10f108e544382b8b5 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -457,6 +457,7 @@ def module_stats( log_activations = False disable_receptive_field() + recorded_parameters = set() def module_stats_hook(module, inputs, outputs, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] @@ -468,17 +469,27 @@ def module_stats( flops.append(flops_stats) if cal_params: - if hasattr(module, "weight") and module.weight is not None: + if ( + hasattr(module, "weight") + and (module.weight is not None) + and module.weight not in recorded_parameters + ): w = module.weight param_stats = get_param_stats(w) param_stats["name"] = name + "-w" params.append(param_stats) + recorded_parameters.add(w) - if hasattr(module, "bias") and module.bias is not None: + if ( + hasattr(module, "bias") + and module.bias is not None + and module.bias not in recorded_parameters + ): b = module.bias param_stats = get_param_stats(b) param_stats["name"] = name + "-b" params.append(param_stats) + recorded_parameters.add(b) if cal_activations: if not isinstance(outputs, (tuple, list)): @@ -504,7 +515,6 @@ def module_stats( hooks.append( module.register_forward_hook(partial(module_stats_hook, name=name)) ) - with set_module_mode_safe(model, training=False) as model: model(*inputs) diff --git a/imperative/python/test/unit/utils/test_module_stats.py b/imperative/python/test/unit/utils/test_module_stats.py index f7d748a586bc52320849291883e2dddf96a47edc..1abc2eddedfc8075ee7dfe184304c6d77b32171f 100644 --- a/imperative/python/test/unit/utils/test_module_stats.py +++ b/imperative/python/test/unit/utils/test_module_stats.py @@ -42,6 +42,65 @@ def test_other_input_module_state(): net(_nt) +@pytest.mark.skipif( + use_symbolic_shape(), reason="This test do not support symbolic shape.", +) +def test_duplicated_module(): + input_shape = (1, 3, 224, 224) + + net0 = TestNet0() + net0_stats, _ = module_stats(net0, input_shapes=input_shape) + + net1 = TestNet1() + net1_stats, _ = module_stats(net1, input_shapes=input_shape) + + net2 = TestNet2() + net2_stats, _ = module_stats(net2, input_shapes=input_shape) + + assert net0_stats.param_dims == net1_stats.param_dims + assert net0_stats.param_size == net1_stats.param_size + + assert net0_stats.param_dims == net2_stats.param_dims + assert net0_stats.param_size == net2_stats.param_size + + +class TestNet0(M.Module): + def __init__(self): + super().__init__() + self.conv = M.Conv2d(3, 3, 3, padding=(1, 1)) + self.conv.bias = mge.Parameter( + np.random.random(self.conv.bias.shape).astype(np.float32) + ) + + def forward(self, x): + x = self.conv(x) + return x + + +class TestNet1(TestNet0): + def __init__(self): + super().__init__() + self.conv1 = self.conv + + def forward(self, x): + x = self.conv(x) + x = self.conv1(x) + return x + + +class TestNet2(TestNet0): + def __init__(self): + super().__init__() + self.conv1 = M.Conv2d(3, 3, 3, padding=(1, 1)) + self.conv1.weight = self.conv.weight + self.conv1.bias = self.conv.bias + + def forward(self, x): + x = self.conv(x) + x = self.conv1(x) + return x + + class FakeNet(M.Module): def __init__(self): super().__init__()