提交 19af2688 编写于 作者: M Megvii Engine Team

fix(mge/tools): fix module_stats for duplicated module

GitOrigin-RevId: a15f17d6160a7ce80ec3f15b81822b476a7166c1
上级 4cd4a38a
......@@ -457,6 +457,7 @@ def module_stats(
log_activations = False
disable_receptive_field()
recorded_parameters = set()
def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
......@@ -468,17 +469,27 @@ def module_stats(
flops.append(flops_stats)
if cal_params:
if hasattr(module, "weight") and module.weight is not None:
if (
hasattr(module, "weight")
and (module.weight is not None)
and module.weight not in recorded_parameters
):
w = module.weight
param_stats = get_param_stats(w)
param_stats["name"] = name + "-w"
params.append(param_stats)
recorded_parameters.add(w)
if hasattr(module, "bias") and module.bias is not None:
if (
hasattr(module, "bias")
and module.bias is not None
and module.bias not in recorded_parameters
):
b = module.bias
param_stats = get_param_stats(b)
param_stats["name"] = name + "-b"
params.append(param_stats)
recorded_parameters.add(b)
if cal_activations:
if not isinstance(outputs, (tuple, list)):
......@@ -504,7 +515,6 @@ def module_stats(
hooks.append(
module.register_forward_hook(partial(module_stats_hook, name=name))
)
with set_module_mode_safe(model, training=False) as model:
model(*inputs)
......
......@@ -42,6 +42,65 @@ def test_other_input_module_state():
net(_nt)
@pytest.mark.skipif(
use_symbolic_shape(), reason="This test do not support symbolic shape.",
)
def test_duplicated_module():
input_shape = (1, 3, 224, 224)
net0 = TestNet0()
net0_stats, _ = module_stats(net0, input_shapes=input_shape)
net1 = TestNet1()
net1_stats, _ = module_stats(net1, input_shapes=input_shape)
net2 = TestNet2()
net2_stats, _ = module_stats(net2, input_shapes=input_shape)
assert net0_stats.param_dims == net1_stats.param_dims
assert net0_stats.param_size == net1_stats.param_size
assert net0_stats.param_dims == net2_stats.param_dims
assert net0_stats.param_size == net2_stats.param_size
class TestNet0(M.Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(3, 3, 3, padding=(1, 1))
self.conv.bias = mge.Parameter(
np.random.random(self.conv.bias.shape).astype(np.float32)
)
def forward(self, x):
x = self.conv(x)
return x
class TestNet1(TestNet0):
def __init__(self):
super().__init__()
self.conv1 = self.conv
def forward(self, x):
x = self.conv(x)
x = self.conv1(x)
return x
class TestNet2(TestNet0):
def __init__(self):
super().__init__()
self.conv1 = M.Conv2d(3, 3, 3, padding=(1, 1))
self.conv1.weight = self.conv.weight
self.conv1.bias = self.conv.bias
def forward(self, x):
x = self.conv(x)
x = self.conv1(x)
return x
class FakeNet(M.Module):
def __init__(self):
super().__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册